[IE][VPU][nGraph]: Fixes DTS transformations to properly keep outputs names (#734)
* NonZero, Broadcast * Concat * Gather * [IE][VPU][nGraph]: Fixes DTS transformations to correctly keep outputs names * [IE][VPU][nGraph]: Fixes dynamic to static shape nonzero tests Co-authored-by: Roman Vyunov <roman.vyunov@intel.com>
This commit is contained in:
@@ -39,6 +39,7 @@ void dynamicToStaticShapeBroadcast(std::shared_ptr<ngraph::Node> target) {
|
||||
|
||||
auto dsr = std::make_shared<ngraph::vpu::op::DynamicShapeResolver>(
|
||||
staticShapeBroadcast->output(0), broadcast->input_value(1));
|
||||
dsr->set_friendly_name(broadcast->get_friendly_name());
|
||||
|
||||
ngraph::replace_node(std::move(target), std::move(dsr));
|
||||
}
|
||||
|
||||
@@ -107,8 +107,9 @@ void dynamicToStaticShapeConcat(std::shared_ptr<ngraph::Node> target) {
|
||||
}
|
||||
|
||||
const auto copied = target->clone_with_new_inputs(target->input_values());
|
||||
const auto outDsr = std::make_shared<ngraph::vpu::op::DynamicShapeResolver>(
|
||||
auto outDsr = std::make_shared<ngraph::vpu::op::DynamicShapeResolver>(
|
||||
copied, accumulatedShape);
|
||||
outDsr->set_friendly_name(target->get_friendly_name());
|
||||
|
||||
ngraph::replace_node(std::move(target), outDsr);
|
||||
}
|
||||
|
||||
@@ -42,9 +42,8 @@ void dynamicToStaticShapeGather(std::shared_ptr<ngraph::Node> target) {
|
||||
|
||||
const auto copied = target->clone_with_new_inputs(target->input_values());
|
||||
|
||||
|
||||
const auto & data_rank = data_shape.get_partial_shape();
|
||||
const auto & indices_rank = indices_shape.get_partial_shape();
|
||||
const auto& data_rank = data_shape.get_partial_shape();
|
||||
const auto& indices_rank = indices_shape.get_partial_shape();
|
||||
VPU_THROW_UNLESS(data_rank.is_static() && indices_rank.is_static(),
|
||||
"DynamicToStaticShape transformation for {} doesn't support dynamic rank", gather);
|
||||
|
||||
@@ -72,7 +71,9 @@ void dynamicToStaticShapeGather(std::shared_ptr<ngraph::Node> target) {
|
||||
output_dims.push_back(second_data_shape_part);
|
||||
}
|
||||
const auto output_shape = std::make_shared<ngraph::opset3::Concat>(output_dims, 0);
|
||||
ngraph::replace_node(target, std::make_shared<ngraph::vpu::op::DynamicShapeResolver>(copied, output_shape));
|
||||
auto outDsr = std::make_shared<ngraph::vpu::op::DynamicShapeResolver>(copied, output_shape);
|
||||
outDsr->set_friendly_name(target->get_friendly_name());
|
||||
ngraph::replace_node(target, outDsr);
|
||||
}
|
||||
|
||||
} // namespace vpu
|
||||
|
||||
@@ -21,11 +21,10 @@ void dynamicToStaticShapeNonZero(std::shared_ptr<ngraph::Node> node) {
|
||||
node->get_friendly_name(), node->get_type_info(), ngraph::op::v3::NonZero::type_info);
|
||||
|
||||
auto staticShapeNonZero = std::make_shared<ngraph::vpu::op::StaticShapeNonZero>(nonZero->input(0).get_source_output(), nonZero->get_output_type());
|
||||
staticShapeNonZero->set_friendly_name(nonZero->get_friendly_name() + "/static_shape");
|
||||
|
||||
auto dynamicShapeResolver = std::make_shared<ngraph::vpu::op::DynamicShapeResolver>(
|
||||
staticShapeNonZero->output(0), staticShapeNonZero->output(1));
|
||||
dynamicShapeResolver->set_friendly_name(nonZero->get_friendly_name() + "/resolve_shape");
|
||||
dynamicShapeResolver->set_friendly_name(nonZero->get_friendly_name());
|
||||
|
||||
ngraph::replace_node(std::move(nonZero), std::move(dynamicShapeResolver));
|
||||
}
|
||||
|
||||
@@ -57,10 +57,9 @@ public:
|
||||
const auto input = std::make_shared<ngraph::opset1::Parameter>(inputType, inputShape);
|
||||
|
||||
const auto staticShapeNonZero = std::make_shared<ngraph::vpu::op::StaticShapeNonZero>(input, resultType);
|
||||
staticShapeNonZero->set_friendly_name(std::string(s_FriendlyName) + "/static_shape");
|
||||
const auto dynamicShapeResolver = std::make_shared<ngraph::vpu::op::DynamicShapeResolver>(
|
||||
staticShapeNonZero->output(0), staticShapeNonZero->output(1));
|
||||
dynamicShapeResolver->set_friendly_name(std::string(s_FriendlyName) + "/resolve_shape");
|
||||
dynamicShapeResolver->set_friendly_name(std::string(s_FriendlyName));
|
||||
|
||||
expected = std::make_shared<ngraph::Function>(ngraph::NodeVector{dynamicShapeResolver}, ngraph::ParameterVector{input});
|
||||
}
|
||||
@@ -69,16 +68,13 @@ public:
|
||||
void compareFunctions() {
|
||||
ASSERT_NO_THROW(ngraph::helpers::CompareFunctions(*actual, *expected));
|
||||
|
||||
auto actualResultNode = actual->get_output_op(0);
|
||||
auto actualResolverNode = actualResultNode->input(0).get_source_output().get_node_shared_ptr();
|
||||
auto actualNonZeroNode = actualResolverNode->input(0).get_source_output().get_node_shared_ptr();
|
||||
const auto actualResultNode = actual->get_output_op(0);
|
||||
const auto actualResolverNode = actualResultNode->input(0).get_source_output().get_node_shared_ptr();
|
||||
|
||||
auto expectedResultNode = expected->get_output_op(0);
|
||||
auto expectedResolverNode = expectedResultNode->input(0).get_source_output().get_node_shared_ptr();
|
||||
auto expectedNonZeroNode = expectedResolverNode->input(0).get_source_output().get_node_shared_ptr();
|
||||
const auto expectedResultNode = expected->get_output_op(0);
|
||||
const auto expectedResolverNode = expectedResultNode->input(0).get_source_output().get_node_shared_ptr();
|
||||
|
||||
EXPECT_EQ(actualResolverNode->get_friendly_name(), expectedResolverNode->get_friendly_name());
|
||||
EXPECT_EQ(actualNonZeroNode->get_friendly_name(), expectedNonZeroNode->get_friendly_name());
|
||||
}
|
||||
|
||||
protected:
|
||||
|
||||
Reference in New Issue
Block a user