[IE][VPU][nGraph]: Fixes DTS transformations to properly keep outputs names (#734)

* NonZero, Broadcast

* Concat

* Gather

* [IE][VPU][nGraph]: Fixes DTS transformations to correctly keep outputs names

* [IE][VPU][nGraph]: Fixes dynamic to static shape nonzero tests

Co-authored-by: Roman Vyunov <roman.vyunov@intel.com>
This commit is contained in:
Gladilov, Gleb
2020-06-05 11:16:52 +03:00
committed by GitHub
parent f9ac555857
commit f80bd537bf
5 changed files with 14 additions and 16 deletions

View File

@@ -39,6 +39,7 @@ void dynamicToStaticShapeBroadcast(std::shared_ptr<ngraph::Node> target) {
auto dsr = std::make_shared<ngraph::vpu::op::DynamicShapeResolver>(
staticShapeBroadcast->output(0), broadcast->input_value(1));
dsr->set_friendly_name(broadcast->get_friendly_name());
ngraph::replace_node(std::move(target), std::move(dsr));
}

View File

@@ -107,8 +107,9 @@ void dynamicToStaticShapeConcat(std::shared_ptr<ngraph::Node> target) {
}
const auto copied = target->clone_with_new_inputs(target->input_values());
const auto outDsr = std::make_shared<ngraph::vpu::op::DynamicShapeResolver>(
auto outDsr = std::make_shared<ngraph::vpu::op::DynamicShapeResolver>(
copied, accumulatedShape);
outDsr->set_friendly_name(target->get_friendly_name());
ngraph::replace_node(std::move(target), outDsr);
}

View File

@@ -42,9 +42,8 @@ void dynamicToStaticShapeGather(std::shared_ptr<ngraph::Node> target) {
const auto copied = target->clone_with_new_inputs(target->input_values());
const auto & data_rank = data_shape.get_partial_shape();
const auto & indices_rank = indices_shape.get_partial_shape();
const auto& data_rank = data_shape.get_partial_shape();
const auto& indices_rank = indices_shape.get_partial_shape();
VPU_THROW_UNLESS(data_rank.is_static() && indices_rank.is_static(),
"DynamicToStaticShape transformation for {} doesn't support dynamic rank", gather);
@@ -72,7 +71,9 @@ void dynamicToStaticShapeGather(std::shared_ptr<ngraph::Node> target) {
output_dims.push_back(second_data_shape_part);
}
const auto output_shape = std::make_shared<ngraph::opset3::Concat>(output_dims, 0);
ngraph::replace_node(target, std::make_shared<ngraph::vpu::op::DynamicShapeResolver>(copied, output_shape));
auto outDsr = std::make_shared<ngraph::vpu::op::DynamicShapeResolver>(copied, output_shape);
outDsr->set_friendly_name(target->get_friendly_name());
ngraph::replace_node(target, outDsr);
}
} // namespace vpu

View File

@@ -21,11 +21,10 @@ void dynamicToStaticShapeNonZero(std::shared_ptr<ngraph::Node> node) {
node->get_friendly_name(), node->get_type_info(), ngraph::op::v3::NonZero::type_info);
auto staticShapeNonZero = std::make_shared<ngraph::vpu::op::StaticShapeNonZero>(nonZero->input(0).get_source_output(), nonZero->get_output_type());
staticShapeNonZero->set_friendly_name(nonZero->get_friendly_name() + "/static_shape");
auto dynamicShapeResolver = std::make_shared<ngraph::vpu::op::DynamicShapeResolver>(
staticShapeNonZero->output(0), staticShapeNonZero->output(1));
dynamicShapeResolver->set_friendly_name(nonZero->get_friendly_name() + "/resolve_shape");
dynamicShapeResolver->set_friendly_name(nonZero->get_friendly_name());
ngraph::replace_node(std::move(nonZero), std::move(dynamicShapeResolver));
}

View File

@@ -57,10 +57,9 @@ public:
const auto input = std::make_shared<ngraph::opset1::Parameter>(inputType, inputShape);
const auto staticShapeNonZero = std::make_shared<ngraph::vpu::op::StaticShapeNonZero>(input, resultType);
staticShapeNonZero->set_friendly_name(std::string(s_FriendlyName) + "/static_shape");
const auto dynamicShapeResolver = std::make_shared<ngraph::vpu::op::DynamicShapeResolver>(
staticShapeNonZero->output(0), staticShapeNonZero->output(1));
dynamicShapeResolver->set_friendly_name(std::string(s_FriendlyName) + "/resolve_shape");
dynamicShapeResolver->set_friendly_name(std::string(s_FriendlyName));
expected = std::make_shared<ngraph::Function>(ngraph::NodeVector{dynamicShapeResolver}, ngraph::ParameterVector{input});
}
@@ -69,16 +68,13 @@ public:
void compareFunctions() {
ASSERT_NO_THROW(ngraph::helpers::CompareFunctions(*actual, *expected));
auto actualResultNode = actual->get_output_op(0);
auto actualResolverNode = actualResultNode->input(0).get_source_output().get_node_shared_ptr();
auto actualNonZeroNode = actualResolverNode->input(0).get_source_output().get_node_shared_ptr();
const auto actualResultNode = actual->get_output_op(0);
const auto actualResolverNode = actualResultNode->input(0).get_source_output().get_node_shared_ptr();
auto expectedResultNode = expected->get_output_op(0);
auto expectedResolverNode = expectedResultNode->input(0).get_source_output().get_node_shared_ptr();
auto expectedNonZeroNode = expectedResolverNode->input(0).get_source_output().get_node_shared_ptr();
const auto expectedResultNode = expected->get_output_op(0);
const auto expectedResolverNode = expectedResultNode->input(0).get_source_output().get_node_shared_ptr();
EXPECT_EQ(actualResolverNode->get_friendly_name(), expectedResolverNode->get_friendly_name());
EXPECT_EQ(actualNonZeroNode->get_friendly_name(), expectedNonZeroNode->get_friendly_name());
}
protected: