Do reshape only if input shapes will be changed (#2632)
* Added private reshape * Removed incorrect check
This commit is contained in:
parent
8c48a6044d
commit
72d387c702
@ -301,96 +301,29 @@ void CNNNetworkNGraphImpl::reshape() {
|
|||||||
|
|
||||||
// Disable reshape for generic nodes
|
// Disable reshape for generic nodes
|
||||||
::ngraph::op::GenericIE::DisableReshape noReshape(_ngraph_function);
|
::ngraph::op::GenericIE::DisableReshape noReshape(_ngraph_function);
|
||||||
StatusCode ret = reshape({}, &desc);
|
reshape({});
|
||||||
if (ret != OK)
|
|
||||||
THROW_IE_EXCEPTION << desc.msg;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
StatusCode
|
StatusCode
|
||||||
CNNNetworkNGraphImpl::reshape(const std::map<std::string, std::vector<size_t>>& inputShapes,
|
CNNNetworkNGraphImpl::reshape(const std::map<std::string, std::vector<size_t>>& inputShapes,
|
||||||
ResponseDesc* responseDesc) noexcept {
|
ResponseDesc* responseDesc) noexcept {
|
||||||
OV_ITT_SCOPED_TASK(itt::domains::IE, "CNNNetworkNGraphImpl::reshape");
|
|
||||||
|
|
||||||
if (cnnNetwork)
|
if (cnnNetwork)
|
||||||
return cnnNetwork->reshape(inputShapes, responseDesc);
|
return cnnNetwork->reshape(inputShapes, responseDesc);
|
||||||
try {
|
try {
|
||||||
auto params = _ngraph_function->get_parameters();
|
auto params = _ngraph_function->get_parameters();
|
||||||
|
|
||||||
for (size_t i = 0; i < params.size(); i++) {
|
// Check that we need to do reshape only if input shapes will be changed
|
||||||
|
bool needReshape = false;
|
||||||
|
for (size_t i = 0; i < params.size() && !inputShapes.empty(); i++) {
|
||||||
const auto& param = params[i];
|
const auto& param = params[i];
|
||||||
if (inputShapes.find(param->get_friendly_name()) == inputShapes.end())
|
auto it = inputShapes.find(param->get_friendly_name());
|
||||||
|
if (it == inputShapes.end())
|
||||||
continue;
|
continue;
|
||||||
::ngraph::PartialShape shape(inputShapes.at(param->get_friendly_name()));
|
if (param->get_partial_shape().is_dynamic() || param->get_shape() != it->second)
|
||||||
auto newParam = std::make_shared<::ngraph::op::Parameter>(param->get_element_type(), shape);
|
needReshape = true;
|
||||||
newParam->set_friendly_name(param->get_friendly_name());
|
|
||||||
_ngraph_function->replace_parameter(i, newParam);
|
|
||||||
}
|
|
||||||
_ngraph_function->validate_nodes_and_infer_types();
|
|
||||||
|
|
||||||
{
|
|
||||||
auto specialized_ngraph_function = cloneFunction(true);
|
|
||||||
// Call this transformation because OneHot IE and nGraph have different output precisions
|
|
||||||
{
|
|
||||||
OV_ITT_SCOPED_TASK(itt::domains::IE, "CNNNetworkNGraphImpl::ConvertOneHot");
|
|
||||||
::ngraph::pass::Manager manager;
|
|
||||||
manager.register_pass<::ngraph::pass::ConvertOneHotToOneHotIEMatcher>()->detect_output_type(
|
|
||||||
specialized_ngraph_function);
|
|
||||||
manager.run_passes(specialized_ngraph_function);
|
|
||||||
}
|
|
||||||
specialized_ngraph_function->validate_nodes_and_infer_types();
|
|
||||||
|
|
||||||
#if 0
|
|
||||||
for (const auto &op : specialized_ngraph_function->get_ordered_ops()) {
|
|
||||||
cout << "[ " << op->description() << " ] " << op->get_friendly_name() << endl;
|
|
||||||
cout << " Inputs: ";
|
|
||||||
for (const auto &in : op->inputs()) {
|
|
||||||
cout << "[" << in.get_element_type().get_type_name() << "]";
|
|
||||||
if (in.get_partial_shape().is_dynamic()) {
|
|
||||||
cout << "dyn_shape";
|
|
||||||
} else {
|
|
||||||
cout << "{";
|
|
||||||
bool first = true;
|
|
||||||
for (auto i : in.get_shape()) {
|
|
||||||
if (!first) cout << ",";
|
|
||||||
cout << i;
|
|
||||||
first = false;
|
|
||||||
}
|
|
||||||
cout << "} ";
|
|
||||||
}
|
|
||||||
}
|
|
||||||
cout << endl << " Outputs: ";
|
|
||||||
for (const auto &in : op->outputs()) {
|
|
||||||
cout << "[" << in.get_element_type().get_type_name() << "]";
|
|
||||||
if (in.get_partial_shape().is_dynamic()) {
|
|
||||||
cout << "dyn_shape";
|
|
||||||
} else {
|
|
||||||
cout << "{";
|
|
||||||
bool first = true;
|
|
||||||
for (auto i : in.get_shape()) {
|
|
||||||
if (!first) cout << ",";
|
|
||||||
cout << i;
|
|
||||||
first = false;
|
|
||||||
}
|
|
||||||
cout << "} ";
|
|
||||||
}
|
|
||||||
}
|
|
||||||
cout << endl;
|
|
||||||
}
|
|
||||||
#endif
|
|
||||||
std::unordered_set<std::string> opName;
|
|
||||||
for (const auto &result : specialized_ngraph_function->get_results()) {
|
|
||||||
addOutput(result->input_value(0));
|
|
||||||
}
|
|
||||||
|
|
||||||
for (const auto ¶meter : specialized_ngraph_function->get_parameters()) {
|
|
||||||
const auto &outName = parameter->get_friendly_name();
|
|
||||||
if (opName.find(outName) != opName.end()) {
|
|
||||||
THROW_IE_EXCEPTION << "All operations in nGraph function should have unique friendly names!";
|
|
||||||
}
|
|
||||||
opName.insert(outName);
|
|
||||||
createDataForResult(parameter, outName, _data[outName]);
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
if (needReshape)
|
||||||
|
reshape(inputShapes);
|
||||||
} catch (std::exception& ex) {
|
} catch (std::exception& ex) {
|
||||||
return DescriptionBuffer(GENERAL_ERROR, responseDesc) << ex.what();
|
return DescriptionBuffer(GENERAL_ERROR, responseDesc) << ex.what();
|
||||||
}
|
}
|
||||||
@ -398,6 +331,89 @@ CNNNetworkNGraphImpl::reshape(const std::map<std::string, std::vector<size_t>>&
|
|||||||
return OK;
|
return OK;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void
|
||||||
|
CNNNetworkNGraphImpl::reshape(const std::map<std::string, std::vector<size_t>>& inputShapes) {
|
||||||
|
OV_ITT_SCOPED_TASK(itt::domains::IE, "CNNNetworkNGraphImpl::reshape");
|
||||||
|
|
||||||
|
auto params = _ngraph_function->get_parameters();
|
||||||
|
|
||||||
|
for (size_t i = 0; i < params.size(); i++) {
|
||||||
|
const auto& param = params[i];
|
||||||
|
if (inputShapes.find(param->get_friendly_name()) == inputShapes.end())
|
||||||
|
continue;
|
||||||
|
::ngraph::PartialShape shape(inputShapes.at(param->get_friendly_name()));
|
||||||
|
auto newParam = std::make_shared<::ngraph::op::Parameter>(param->get_element_type(), shape);
|
||||||
|
newParam->set_friendly_name(param->get_friendly_name());
|
||||||
|
_ngraph_function->replace_parameter(i, newParam);
|
||||||
|
}
|
||||||
|
_ngraph_function->validate_nodes_and_infer_types();
|
||||||
|
|
||||||
|
{
|
||||||
|
auto specialized_ngraph_function = cloneFunction(true);
|
||||||
|
// Call this transformation because OneHot IE and nGraph have different output precisions
|
||||||
|
{
|
||||||
|
OV_ITT_SCOPED_TASK(itt::domains::IE, "CNNNetworkNGraphImpl::ConvertOneHot");
|
||||||
|
::ngraph::pass::Manager manager;
|
||||||
|
manager.register_pass<::ngraph::pass::ConvertOneHotToOneHotIEMatcher>()->detect_output_type(
|
||||||
|
specialized_ngraph_function);
|
||||||
|
manager.run_passes(specialized_ngraph_function);
|
||||||
|
}
|
||||||
|
specialized_ngraph_function->validate_nodes_and_infer_types();
|
||||||
|
|
||||||
|
#if 0
|
||||||
|
for (const auto &op : specialized_ngraph_function->get_ordered_ops()) {
|
||||||
|
cout << "[ " << op->description() << " ] " << op->get_friendly_name() << endl;
|
||||||
|
cout << " Inputs: ";
|
||||||
|
for (const auto &in : op->inputs()) {
|
||||||
|
cout << "[" << in.get_element_type().get_type_name() << "]";
|
||||||
|
if (in.get_partial_shape().is_dynamic()) {
|
||||||
|
cout << "dyn_shape";
|
||||||
|
} else {
|
||||||
|
cout << "{";
|
||||||
|
bool first = true;
|
||||||
|
for (auto i : in.get_shape()) {
|
||||||
|
if (!first) cout << ",";
|
||||||
|
cout << i;
|
||||||
|
first = false;
|
||||||
|
}
|
||||||
|
cout << "} ";
|
||||||
|
}
|
||||||
|
}
|
||||||
|
cout << endl << " Outputs: ";
|
||||||
|
for (const auto &in : op->outputs()) {
|
||||||
|
cout << "[" << in.get_element_type().get_type_name() << "]";
|
||||||
|
if (in.get_partial_shape().is_dynamic()) {
|
||||||
|
cout << "dyn_shape";
|
||||||
|
} else {
|
||||||
|
cout << "{";
|
||||||
|
bool first = true;
|
||||||
|
for (auto i : in.get_shape()) {
|
||||||
|
if (!first) cout << ",";
|
||||||
|
cout << i;
|
||||||
|
first = false;
|
||||||
|
}
|
||||||
|
cout << "} ";
|
||||||
|
}
|
||||||
|
}
|
||||||
|
cout << endl;
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
std::unordered_set<std::string> opName;
|
||||||
|
for (const auto &result : specialized_ngraph_function->get_results()) {
|
||||||
|
addOutput(result->input_value(0));
|
||||||
|
}
|
||||||
|
|
||||||
|
for (const auto ¶meter : specialized_ngraph_function->get_parameters()) {
|
||||||
|
const auto &outName = parameter->get_friendly_name();
|
||||||
|
if (opName.find(outName) != opName.end()) {
|
||||||
|
THROW_IE_EXCEPTION << "All operations in nGraph function should have unique friendly names!";
|
||||||
|
}
|
||||||
|
opName.insert(outName);
|
||||||
|
createDataForResult(parameter, outName, _data[outName]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
StatusCode CNNNetworkNGraphImpl::serialize(const std::string& xmlPath, const std::string& binPath,
|
StatusCode CNNNetworkNGraphImpl::serialize(const std::string& xmlPath, const std::string& binPath,
|
||||||
ResponseDesc* resp) const noexcept {
|
ResponseDesc* resp) const noexcept {
|
||||||
auto network = cnnNetwork;
|
auto network = cnnNetwork;
|
||||||
|
@ -118,6 +118,7 @@ private:
|
|||||||
* @brief Reshape on the same shape
|
* @brief Reshape on the same shape
|
||||||
*/
|
*/
|
||||||
void reshape();
|
void reshape();
|
||||||
|
void reshape(const std::map<std::string, std::vector<size_t>>& inputShapes);
|
||||||
};
|
};
|
||||||
|
|
||||||
class TINGraphBody : public CNNNetworkNGraphImpl {
|
class TINGraphBody : public CNNNetworkNGraphImpl {
|
||||||
|
Loading…
Reference in New Issue
Block a user