Do reshape only if input shapes will be changed (#2632)

* Added private reshape

* Removed incorrect check
This commit is contained in:
Ilya Churaev 2020-10-14 09:42:39 +03:00 committed by GitHub
parent 8c48a6044d
commit 72d387c702
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 95 additions and 78 deletions

View File

@ -301,96 +301,29 @@ void CNNNetworkNGraphImpl::reshape() {
// Disable reshape for generic nodes
::ngraph::op::GenericIE::DisableReshape noReshape(_ngraph_function);
StatusCode ret = reshape({}, &desc);
if (ret != OK)
THROW_IE_EXCEPTION << desc.msg;
reshape({});
}
StatusCode
CNNNetworkNGraphImpl::reshape(const std::map<std::string, std::vector<size_t>>& inputShapes,
ResponseDesc* responseDesc) noexcept {
OV_ITT_SCOPED_TASK(itt::domains::IE, "CNNNetworkNGraphImpl::reshape");
ResponseDesc* responseDesc) noexcept {
if (cnnNetwork)
return cnnNetwork->reshape(inputShapes, responseDesc);
try {
auto params = _ngraph_function->get_parameters();
for (size_t i = 0; i < params.size(); i++) {
// Check that we need to do reshape only if input shapes will be changed
bool needReshape = false;
for (size_t i = 0; i < params.size() && !inputShapes.empty(); i++) {
const auto& param = params[i];
if (inputShapes.find(param->get_friendly_name()) == inputShapes.end())
auto it = inputShapes.find(param->get_friendly_name());
if (it == inputShapes.end())
continue;
::ngraph::PartialShape shape(inputShapes.at(param->get_friendly_name()));
auto newParam = std::make_shared<::ngraph::op::Parameter>(param->get_element_type(), shape);
newParam->set_friendly_name(param->get_friendly_name());
_ngraph_function->replace_parameter(i, newParam);
}
_ngraph_function->validate_nodes_and_infer_types();
{
auto specialized_ngraph_function = cloneFunction(true);
// Call this transformation because OneHot IE and nGraph have different output precisions
{
OV_ITT_SCOPED_TASK(itt::domains::IE, "CNNNetworkNGraphImpl::ConvertOneHot");
::ngraph::pass::Manager manager;
manager.register_pass<::ngraph::pass::ConvertOneHotToOneHotIEMatcher>()->detect_output_type(
specialized_ngraph_function);
manager.run_passes(specialized_ngraph_function);
}
specialized_ngraph_function->validate_nodes_and_infer_types();
#if 0
for (const auto &op : specialized_ngraph_function->get_ordered_ops()) {
cout << "[ " << op->description() << " ] " << op->get_friendly_name() << endl;
cout << " Inputs: ";
for (const auto &in : op->inputs()) {
cout << "[" << in.get_element_type().get_type_name() << "]";
if (in.get_partial_shape().is_dynamic()) {
cout << "dyn_shape";
} else {
cout << "{";
bool first = true;
for (auto i : in.get_shape()) {
if (!first) cout << ",";
cout << i;
first = false;
}
cout << "} ";
}
}
cout << endl << " Outputs: ";
for (const auto &in : op->outputs()) {
cout << "[" << in.get_element_type().get_type_name() << "]";
if (in.get_partial_shape().is_dynamic()) {
cout << "dyn_shape";
} else {
cout << "{";
bool first = true;
for (auto i : in.get_shape()) {
if (!first) cout << ",";
cout << i;
first = false;
}
cout << "} ";
}
}
cout << endl;
}
#endif
std::unordered_set<std::string> opName;
for (const auto &result : specialized_ngraph_function->get_results()) {
addOutput(result->input_value(0));
}
for (const auto &parameter : specialized_ngraph_function->get_parameters()) {
const auto &outName = parameter->get_friendly_name();
if (opName.find(outName) != opName.end()) {
THROW_IE_EXCEPTION << "All operations in nGraph function should have unique friendly names!";
}
opName.insert(outName);
createDataForResult(parameter, outName, _data[outName]);
}
if (param->get_partial_shape().is_dynamic() || param->get_shape() != it->second)
needReshape = true;
}
if (needReshape)
reshape(inputShapes);
} catch (std::exception& ex) {
return DescriptionBuffer(GENERAL_ERROR, responseDesc) << ex.what();
}
@ -398,6 +331,89 @@ CNNNetworkNGraphImpl::reshape(const std::map<std::string, std::vector<size_t>>&
return OK;
}
void
CNNNetworkNGraphImpl::reshape(const std::map<std::string, std::vector<size_t>>& inputShapes) {
OV_ITT_SCOPED_TASK(itt::domains::IE, "CNNNetworkNGraphImpl::reshape");
auto params = _ngraph_function->get_parameters();
for (size_t i = 0; i < params.size(); i++) {
const auto& param = params[i];
if (inputShapes.find(param->get_friendly_name()) == inputShapes.end())
continue;
::ngraph::PartialShape shape(inputShapes.at(param->get_friendly_name()));
auto newParam = std::make_shared<::ngraph::op::Parameter>(param->get_element_type(), shape);
newParam->set_friendly_name(param->get_friendly_name());
_ngraph_function->replace_parameter(i, newParam);
}
_ngraph_function->validate_nodes_and_infer_types();
{
auto specialized_ngraph_function = cloneFunction(true);
// Call this transformation because OneHot IE and nGraph have different output precisions
{
OV_ITT_SCOPED_TASK(itt::domains::IE, "CNNNetworkNGraphImpl::ConvertOneHot");
::ngraph::pass::Manager manager;
manager.register_pass<::ngraph::pass::ConvertOneHotToOneHotIEMatcher>()->detect_output_type(
specialized_ngraph_function);
manager.run_passes(specialized_ngraph_function);
}
specialized_ngraph_function->validate_nodes_and_infer_types();
#if 0
for (const auto &op : specialized_ngraph_function->get_ordered_ops()) {
cout << "[ " << op->description() << " ] " << op->get_friendly_name() << endl;
cout << " Inputs: ";
for (const auto &in : op->inputs()) {
cout << "[" << in.get_element_type().get_type_name() << "]";
if (in.get_partial_shape().is_dynamic()) {
cout << "dyn_shape";
} else {
cout << "{";
bool first = true;
for (auto i : in.get_shape()) {
if (!first) cout << ",";
cout << i;
first = false;
}
cout << "} ";
}
}
cout << endl << " Outputs: ";
for (const auto &in : op->outputs()) {
cout << "[" << in.get_element_type().get_type_name() << "]";
if (in.get_partial_shape().is_dynamic()) {
cout << "dyn_shape";
} else {
cout << "{";
bool first = true;
for (auto i : in.get_shape()) {
if (!first) cout << ",";
cout << i;
first = false;
}
cout << "} ";
}
}
cout << endl;
}
#endif
std::unordered_set<std::string> opName;
for (const auto &result : specialized_ngraph_function->get_results()) {
addOutput(result->input_value(0));
}
for (const auto &parameter : specialized_ngraph_function->get_parameters()) {
const auto &outName = parameter->get_friendly_name();
if (opName.find(outName) != opName.end()) {
THROW_IE_EXCEPTION << "All operations in nGraph function should have unique friendly names!";
}
opName.insert(outName);
createDataForResult(parameter, outName, _data[outName]);
}
}
}
StatusCode CNNNetworkNGraphImpl::serialize(const std::string& xmlPath, const std::string& binPath,
ResponseDesc* resp) const noexcept {
auto network = cnnNetwork;

View File

@ -118,6 +118,7 @@ private:
* @brief Reshape on the same shape
*/
void reshape();
void reshape(const std::map<std::string, std::vector<size_t>>& inputShapes);
};
class TINGraphBody : public CNNNetworkNGraphImpl {