Do reshape only if input shapes will be changed (#2632)
* Added private reshape * Removed incorrect check
This commit is contained in:
parent
8c48a6044d
commit
72d387c702
@ -301,96 +301,29 @@ void CNNNetworkNGraphImpl::reshape() {
|
||||
|
||||
// Disable reshape for generic nodes
|
||||
::ngraph::op::GenericIE::DisableReshape noReshape(_ngraph_function);
|
||||
StatusCode ret = reshape({}, &desc);
|
||||
if (ret != OK)
|
||||
THROW_IE_EXCEPTION << desc.msg;
|
||||
reshape({});
|
||||
}
|
||||
|
||||
StatusCode
|
||||
CNNNetworkNGraphImpl::reshape(const std::map<std::string, std::vector<size_t>>& inputShapes,
|
||||
ResponseDesc* responseDesc) noexcept {
|
||||
OV_ITT_SCOPED_TASK(itt::domains::IE, "CNNNetworkNGraphImpl::reshape");
|
||||
|
||||
ResponseDesc* responseDesc) noexcept {
|
||||
if (cnnNetwork)
|
||||
return cnnNetwork->reshape(inputShapes, responseDesc);
|
||||
try {
|
||||
auto params = _ngraph_function->get_parameters();
|
||||
|
||||
for (size_t i = 0; i < params.size(); i++) {
|
||||
// Check that we need to do reshape only if input shapes will be changed
|
||||
bool needReshape = false;
|
||||
for (size_t i = 0; i < params.size() && !inputShapes.empty(); i++) {
|
||||
const auto& param = params[i];
|
||||
if (inputShapes.find(param->get_friendly_name()) == inputShapes.end())
|
||||
auto it = inputShapes.find(param->get_friendly_name());
|
||||
if (it == inputShapes.end())
|
||||
continue;
|
||||
::ngraph::PartialShape shape(inputShapes.at(param->get_friendly_name()));
|
||||
auto newParam = std::make_shared<::ngraph::op::Parameter>(param->get_element_type(), shape);
|
||||
newParam->set_friendly_name(param->get_friendly_name());
|
||||
_ngraph_function->replace_parameter(i, newParam);
|
||||
}
|
||||
_ngraph_function->validate_nodes_and_infer_types();
|
||||
|
||||
{
|
||||
auto specialized_ngraph_function = cloneFunction(true);
|
||||
// Call this transformation because OneHot IE and nGraph have different output precisions
|
||||
{
|
||||
OV_ITT_SCOPED_TASK(itt::domains::IE, "CNNNetworkNGraphImpl::ConvertOneHot");
|
||||
::ngraph::pass::Manager manager;
|
||||
manager.register_pass<::ngraph::pass::ConvertOneHotToOneHotIEMatcher>()->detect_output_type(
|
||||
specialized_ngraph_function);
|
||||
manager.run_passes(specialized_ngraph_function);
|
||||
}
|
||||
specialized_ngraph_function->validate_nodes_and_infer_types();
|
||||
|
||||
#if 0
|
||||
for (const auto &op : specialized_ngraph_function->get_ordered_ops()) {
|
||||
cout << "[ " << op->description() << " ] " << op->get_friendly_name() << endl;
|
||||
cout << " Inputs: ";
|
||||
for (const auto &in : op->inputs()) {
|
||||
cout << "[" << in.get_element_type().get_type_name() << "]";
|
||||
if (in.get_partial_shape().is_dynamic()) {
|
||||
cout << "dyn_shape";
|
||||
} else {
|
||||
cout << "{";
|
||||
bool first = true;
|
||||
for (auto i : in.get_shape()) {
|
||||
if (!first) cout << ",";
|
||||
cout << i;
|
||||
first = false;
|
||||
}
|
||||
cout << "} ";
|
||||
}
|
||||
}
|
||||
cout << endl << " Outputs: ";
|
||||
for (const auto &in : op->outputs()) {
|
||||
cout << "[" << in.get_element_type().get_type_name() << "]";
|
||||
if (in.get_partial_shape().is_dynamic()) {
|
||||
cout << "dyn_shape";
|
||||
} else {
|
||||
cout << "{";
|
||||
bool first = true;
|
||||
for (auto i : in.get_shape()) {
|
||||
if (!first) cout << ",";
|
||||
cout << i;
|
||||
first = false;
|
||||
}
|
||||
cout << "} ";
|
||||
}
|
||||
}
|
||||
cout << endl;
|
||||
}
|
||||
#endif
|
||||
std::unordered_set<std::string> opName;
|
||||
for (const auto &result : specialized_ngraph_function->get_results()) {
|
||||
addOutput(result->input_value(0));
|
||||
}
|
||||
|
||||
for (const auto ¶meter : specialized_ngraph_function->get_parameters()) {
|
||||
const auto &outName = parameter->get_friendly_name();
|
||||
if (opName.find(outName) != opName.end()) {
|
||||
THROW_IE_EXCEPTION << "All operations in nGraph function should have unique friendly names!";
|
||||
}
|
||||
opName.insert(outName);
|
||||
createDataForResult(parameter, outName, _data[outName]);
|
||||
}
|
||||
if (param->get_partial_shape().is_dynamic() || param->get_shape() != it->second)
|
||||
needReshape = true;
|
||||
}
|
||||
if (needReshape)
|
||||
reshape(inputShapes);
|
||||
} catch (std::exception& ex) {
|
||||
return DescriptionBuffer(GENERAL_ERROR, responseDesc) << ex.what();
|
||||
}
|
||||
@ -398,6 +331,89 @@ CNNNetworkNGraphImpl::reshape(const std::map<std::string, std::vector<size_t>>&
|
||||
return OK;
|
||||
}
|
||||
|
||||
void
|
||||
CNNNetworkNGraphImpl::reshape(const std::map<std::string, std::vector<size_t>>& inputShapes) {
|
||||
OV_ITT_SCOPED_TASK(itt::domains::IE, "CNNNetworkNGraphImpl::reshape");
|
||||
|
||||
auto params = _ngraph_function->get_parameters();
|
||||
|
||||
for (size_t i = 0; i < params.size(); i++) {
|
||||
const auto& param = params[i];
|
||||
if (inputShapes.find(param->get_friendly_name()) == inputShapes.end())
|
||||
continue;
|
||||
::ngraph::PartialShape shape(inputShapes.at(param->get_friendly_name()));
|
||||
auto newParam = std::make_shared<::ngraph::op::Parameter>(param->get_element_type(), shape);
|
||||
newParam->set_friendly_name(param->get_friendly_name());
|
||||
_ngraph_function->replace_parameter(i, newParam);
|
||||
}
|
||||
_ngraph_function->validate_nodes_and_infer_types();
|
||||
|
||||
{
|
||||
auto specialized_ngraph_function = cloneFunction(true);
|
||||
// Call this transformation because OneHot IE and nGraph have different output precisions
|
||||
{
|
||||
OV_ITT_SCOPED_TASK(itt::domains::IE, "CNNNetworkNGraphImpl::ConvertOneHot");
|
||||
::ngraph::pass::Manager manager;
|
||||
manager.register_pass<::ngraph::pass::ConvertOneHotToOneHotIEMatcher>()->detect_output_type(
|
||||
specialized_ngraph_function);
|
||||
manager.run_passes(specialized_ngraph_function);
|
||||
}
|
||||
specialized_ngraph_function->validate_nodes_and_infer_types();
|
||||
|
||||
#if 0
|
||||
for (const auto &op : specialized_ngraph_function->get_ordered_ops()) {
|
||||
cout << "[ " << op->description() << " ] " << op->get_friendly_name() << endl;
|
||||
cout << " Inputs: ";
|
||||
for (const auto &in : op->inputs()) {
|
||||
cout << "[" << in.get_element_type().get_type_name() << "]";
|
||||
if (in.get_partial_shape().is_dynamic()) {
|
||||
cout << "dyn_shape";
|
||||
} else {
|
||||
cout << "{";
|
||||
bool first = true;
|
||||
for (auto i : in.get_shape()) {
|
||||
if (!first) cout << ",";
|
||||
cout << i;
|
||||
first = false;
|
||||
}
|
||||
cout << "} ";
|
||||
}
|
||||
}
|
||||
cout << endl << " Outputs: ";
|
||||
for (const auto &in : op->outputs()) {
|
||||
cout << "[" << in.get_element_type().get_type_name() << "]";
|
||||
if (in.get_partial_shape().is_dynamic()) {
|
||||
cout << "dyn_shape";
|
||||
} else {
|
||||
cout << "{";
|
||||
bool first = true;
|
||||
for (auto i : in.get_shape()) {
|
||||
if (!first) cout << ",";
|
||||
cout << i;
|
||||
first = false;
|
||||
}
|
||||
cout << "} ";
|
||||
}
|
||||
}
|
||||
cout << endl;
|
||||
}
|
||||
#endif
|
||||
std::unordered_set<std::string> opName;
|
||||
for (const auto &result : specialized_ngraph_function->get_results()) {
|
||||
addOutput(result->input_value(0));
|
||||
}
|
||||
|
||||
for (const auto ¶meter : specialized_ngraph_function->get_parameters()) {
|
||||
const auto &outName = parameter->get_friendly_name();
|
||||
if (opName.find(outName) != opName.end()) {
|
||||
THROW_IE_EXCEPTION << "All operations in nGraph function should have unique friendly names!";
|
||||
}
|
||||
opName.insert(outName);
|
||||
createDataForResult(parameter, outName, _data[outName]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
StatusCode CNNNetworkNGraphImpl::serialize(const std::string& xmlPath, const std::string& binPath,
|
||||
ResponseDesc* resp) const noexcept {
|
||||
auto network = cnnNetwork;
|
||||
|
@ -118,6 +118,7 @@ private:
|
||||
* @brief Reshape on the same shape
|
||||
*/
|
||||
void reshape();
|
||||
void reshape(const std::map<std::string, std::vector<size_t>>& inputShapes);
|
||||
};
|
||||
|
||||
class TINGraphBody : public CNNNetworkNGraphImpl {
|
||||
|
Loading…
Reference in New Issue
Block a user