From e9011a9536533d0a1d3c6bca92e4cf7d559816fc Mon Sep 17 00:00:00 2001 From: Evgenya Stepyreva Date: Wed, 7 Jul 2021 18:12:14 +0300 Subject: [PATCH] ShapeOf Sub-Graphs Simplification (#6308) * ShapeOf Sub-Graphs Simplification * Removed graph visualization * Review comment * comments resolved --- .../eliminate_unsqueeze_gather.hpp | 13 +++ .../simplify_shape_of_sub_graph.hpp | 60 +++++++++++ .../common_optimizations.cpp | 2 + .../eliminate_unsqueeze_gather.cpp | 34 ++++++ .../simplify_shape_of_sub_graph.cpp | 101 ++++++++++++++++++ .../simplify_shape_of_sub_graph.cpp | 81 ++++++++++++++ .../dynamic_to_static_shape_broadcast.cpp | 2 - ngraph/core/include/ngraph/op/shape_of.hpp | 20 ---- ngraph/core/src/op/shape_of.cpp | 49 +-------- ngraph/test/constant_folding.cpp | 67 ++++-------- 10 files changed, 314 insertions(+), 115 deletions(-) create mode 100644 inference-engine/src/transformations/include/transformations/common_optimizations/simplify_shape_of_sub_graph.hpp create mode 100644 inference-engine/src/transformations/src/transformations/common_optimizations/simplify_shape_of_sub_graph.cpp create mode 100644 inference-engine/tests/functional/inference_engine/transformations/simplify_shape_of_sub_graph.cpp diff --git a/inference-engine/src/transformations/include/transformations/common_optimizations/eliminate_unsqueeze_gather.hpp b/inference-engine/src/transformations/include/transformations/common_optimizations/eliminate_unsqueeze_gather.hpp index 141ef7d774d..012d78a4774 100644 --- a/inference-engine/src/transformations/include/transformations/common_optimizations/eliminate_unsqueeze_gather.hpp +++ b/inference-engine/src/transformations/include/transformations/common_optimizations/eliminate_unsqueeze_gather.hpp @@ -14,6 +14,7 @@ namespace ngraph { namespace pass { class TRANSFORMATIONS_API EliminateUnsqueezeGather; +class TRANSFORMATIONS_API EliminateGatherUnsqueeze; } // namespace pass } // namespace ngraph @@ -29,3 +30,15 @@ public: NGRAPH_RTTI_DECLARATION; EliminateUnsqueezeGather(); }; + +/** + * @ingroup ie_transformation_common_api + * @brief Remove Gather -> Unsqueeze pair, if Gather takes a scalar and + * Unsqueeze makes it a 1D tensor + */ + +class ngraph::pass::EliminateGatherUnsqueeze : public ngraph::pass::MatcherPass { +public: + NGRAPH_RTTI_DECLARATION; + EliminateGatherUnsqueeze(); +}; diff --git a/inference-engine/src/transformations/include/transformations/common_optimizations/simplify_shape_of_sub_graph.hpp b/inference-engine/src/transformations/include/transformations/common_optimizations/simplify_shape_of_sub_graph.hpp new file mode 100644 index 00000000000..85d8aa662da --- /dev/null +++ b/inference-engine/src/transformations/include/transformations/common_optimizations/simplify_shape_of_sub_graph.hpp @@ -0,0 +1,60 @@ +// Copyright (C) 2018-2021 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#include +#include + +#include + +#include +#include +#include +#include + +namespace ngraph { +namespace pass { + +class TRANSFORMATIONS_API SimplifyShapeOfSubGraph; +class TRANSFORMATIONS_API SharedShapeOf; +class TRANSFORMATIONS_API GroupedGatherElimination; + +} // namespace pass +} // namespace ngraph + + +/** + * @ingroup ie_transformation_common_api + * @brief SharedShapeOf transformation replaces group of ShapeOf + * operations with the first ShapeOf in this group. All ShapeOfs in this group + * must be equal and consume the same output port. + */ +class ngraph::pass::SharedShapeOf: public ngraph::pass::FunctionPass { +public: + NGRAPH_RTTI_DECLARATION; + bool run_on_function(std::shared_ptr f) override; +}; + +/** + * @ingroup ie_transformation_common_api + * @brief GroupedGatherElimination transformation replaces group of Gather + * operations with the first Gather in this group and updated indices input + * in case all Gathers in the group are consumed by the same Concat in incremental order. + */ +class ngraph::pass::GroupedGatherElimination: public ngraph::pass::MatcherPass { +public: + NGRAPH_RTTI_DECLARATION; + GroupedGatherElimination(); +}; + +/** + * @ingroup ie_transformation_common_api + * @brief SimplifyShapeOfSubGraph transformation runs specific optimizations of shape sub-graphs + */ +class ngraph::pass::SimplifyShapeOfSubGraph: public ngraph::pass::FunctionPass { +public: + NGRAPH_RTTI_DECLARATION; + bool run_on_function(std::shared_ptr f) override; +}; diff --git a/inference-engine/src/transformations/src/transformations/common_optimizations/common_optimizations.cpp b/inference-engine/src/transformations/src/transformations/common_optimizations/common_optimizations.cpp index ffe80ec9639..956904c9a34 100644 --- a/inference-engine/src/transformations/src/transformations/common_optimizations/common_optimizations.cpp +++ b/inference-engine/src/transformations/src/transformations/common_optimizations/common_optimizations.cpp @@ -76,6 +76,7 @@ #include #include #include +#include NGRAPH_RTTI_DEFINITION(ngraph::pass::CommonOptimizations, "CommonOptimizations", 0); @@ -85,6 +86,7 @@ bool ngraph::pass::CommonOptimizations::run_on_function(std::shared_ptr(); + manager.register_pass(); manager.register_pass(); manager.register_pass(); // Resolves dynamism (replaces NonZero), CF needed diff --git a/inference-engine/src/transformations/src/transformations/common_optimizations/eliminate_unsqueeze_gather.cpp b/inference-engine/src/transformations/src/transformations/common_optimizations/eliminate_unsqueeze_gather.cpp index fae3b71ac1e..ec3fafdea39 100644 --- a/inference-engine/src/transformations/src/transformations/common_optimizations/eliminate_unsqueeze_gather.cpp +++ b/inference-engine/src/transformations/src/transformations/common_optimizations/eliminate_unsqueeze_gather.cpp @@ -7,6 +7,7 @@ #include #include #include +#include #include "itt.hpp" NGRAPH_RTTI_DEFINITION(ngraph::pass::EliminateUnsqueezeGather, "EliminateUnsqueezeGather", 0); @@ -58,3 +59,36 @@ ngraph::pass::EliminateUnsqueezeGather::EliminateUnsqueezeGather() { auto m = std::make_shared(gather, "EliminateUnsqueezeGather"); register_matcher(m, callback); } + +NGRAPH_RTTI_DEFINITION(ngraph::pass::EliminateGatherUnsqueeze, "EliminateGatherUnsqueeze", 0); + +ngraph::pass::EliminateGatherUnsqueeze::EliminateGatherUnsqueeze() { + MATCHER_SCOPE(EliminateGatherUnsqueeze); + + const auto gather_indices_label = ngraph::pattern::wrap_type(pattern::rank_equals(0)); + const auto gather_axis_label = ngraph::pattern::wrap_type(); + const auto gather_label = ngraph::pattern::wrap_type( + {ngraph::pattern::any_input(), gather_indices_label, gather_axis_label}, pattern::rank_equals(0)); + + const auto unsqueeze_label = ngraph::pattern::wrap_type( + {gather_label, ngraph::pattern::any_input()}, pattern::rank_equals(1)); + + ngraph::matcher_pass_callback callback = [=](ngraph::pattern::Matcher& m) { + auto pattern_nodes = m.get_pattern_map(); + + auto& gather_indices = pattern_nodes.at(gather_indices_label); + auto& gather = pattern_nodes.at(gather_label); + auto& unsqueeze = pattern_nodes.at(unsqueeze_label); + + auto new_indices = ngraph::op::util::make_try_fold(gather_indices, opset6::Constant::create(element::i32, {1}, {1}), false); + auto new_gather = gather->clone_with_new_inputs({gather->input_value(0), new_indices, gather->input_value(2)}); + + new_gather->set_friendly_name(gather->get_friendly_name()); + ngraph::copy_runtime_info({unsqueeze, gather}, {new_gather, new_indices}); + ngraph::replace_node(unsqueeze, new_gather); + return true; + }; + + auto m = std::make_shared(unsqueeze_label, "EliminateGatherUnsqueeze"); + register_matcher(m, callback); +} diff --git a/inference-engine/src/transformations/src/transformations/common_optimizations/simplify_shape_of_sub_graph.cpp b/inference-engine/src/transformations/src/transformations/common_optimizations/simplify_shape_of_sub_graph.cpp new file mode 100644 index 00000000000..4aeae1d8f14 --- /dev/null +++ b/inference-engine/src/transformations/src/transformations/common_optimizations/simplify_shape_of_sub_graph.cpp @@ -0,0 +1,101 @@ +// Copyright (C) 2018-2021 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include +#include + +#include "itt.hpp" +#include +#include +#include +#include +#include +#include +#include +#include + +NGRAPH_RTTI_DEFINITION(ngraph::pass::SharedShapeOf, "SharedShapeOf", 0); + +bool ngraph::pass::SharedShapeOf::run_on_function(std::shared_ptr f) { + RUN_ON_FUNCTION_SCOPE(SharedShapeOf); + bool graph_rewritten = false; + + std::map, std::vector>> source_to_shape_of; + for (const auto & node : f->get_ordered_ops()) { + // Recursively apply transformation for sub-graph based operations + if (auto sub_graph_node = std::dynamic_pointer_cast(node)) + if (auto sub_graph = sub_graph_node->get_function()) + graph_rewritten |= run_on_function(sub_graph); + + if (is_type(node) || is_type(node)) + source_to_shape_of[node->input_value(0)].push_back(node); + } + + for (const auto& pair : source_to_shape_of) { + if (pair.second.size() < 2) + continue; + const auto& root_ss = pair.second[0]; + for (const auto& child_ss : pair.second) + if (root_ss->get_instance_id() != child_ss->get_instance_id() && root_ss->get_output_element_type(0) == root_ss->get_output_element_type(0)) + graph_rewritten |= replace_output_update_name(child_ss->output(0), root_ss->output(0)); + } + return graph_rewritten; +} + +NGRAPH_RTTI_DEFINITION(ngraph::pass::GroupedGatherElimination, "GroupedGatherElimination", 0); + +ngraph::pass::GroupedGatherElimination::GroupedGatherElimination() { + MATCHER_SCOPE(GroupedGatherElimination); + auto concat_label = ngraph::pattern::wrap_type(pattern::rank_equals(1)); + + ngraph::matcher_pass_callback callback = [](pattern::Matcher& m) { + auto concat = m.get_match_root(); + OutputVector inputs = concat->input_values(); + NodeVector new_ops; + size_t i = 0, original_inputs_size = inputs.size(); + while (inputs.size() > i + 1) { + auto curr = inputs[i].get_node_shared_ptr(), next = inputs[i + 1].get_node_shared_ptr(); + if (curr->get_type_info() != next->get_type_info() || + (!is_type(curr) && !is_type(curr)) || + (curr->input_value(0) != next->input_value(0))) { + ++i; + continue; + } // curr and next are the same type of gather which takes data from the same source + auto joint_indices = ngraph::op::util::make_try_fold(OutputVector{curr->input_value(1), next->input_value(1)}, 0); + auto new_gather = curr->clone_with_new_inputs( + {curr->input_value(0), joint_indices, ngraph::opset1::Constant::create(element::i64, {}, {0})}); + new_ops.push_back(joint_indices); + new_ops.push_back(new_gather); + inputs.erase(inputs.begin() + i); + inputs[i] = new_gather->output(0); + } + if (original_inputs_size > inputs.size()) { + auto new_concat = std::make_shared(inputs, 0); + new_ops.push_back(new_concat); + new_concat->set_friendly_name(concat->get_friendly_name()); + ngraph::copy_runtime_info(concat, new_ops); + ngraph::replace_node(concat, new_concat); + return true; + } + return false; + }; + + auto m = std::make_shared(concat_label, matcher_name); + this->register_matcher(m, callback); +} + + +NGRAPH_RTTI_DEFINITION(ngraph::pass::SimplifyShapeOfSubGraph, "SimplifyShapeOfSubGraph", 0); + +bool ngraph::pass::SimplifyShapeOfSubGraph::run_on_function(std::shared_ptr f) { + RUN_ON_FUNCTION_SCOPE(GroupedGatherElimination); + ngraph::pass::Manager manager; + manager.set_per_pass_validation(false); + manager.register_pass(); + manager.register_pass(); + manager.register_pass(); + manager.register_pass(); + manager.run_passes(f); + return false; +} diff --git a/inference-engine/tests/functional/inference_engine/transformations/simplify_shape_of_sub_graph.cpp b/inference-engine/tests/functional/inference_engine/transformations/simplify_shape_of_sub_graph.cpp new file mode 100644 index 00000000000..d4b0e166573 --- /dev/null +++ b/inference-engine/tests/functional/inference_engine/transformations/simplify_shape_of_sub_graph.cpp @@ -0,0 +1,81 @@ +// Copyright (C) 2018-2021 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include + +#include +#include +#include + +#include +#include +#include +#include +#include + +#include "common_test_utils/ngraph_test_utils.hpp" + + +using namespace testing; +using namespace ngraph; + +auto gather = [](const std::shared_ptr input, std::vector indices, bool scalar = false) -> Output { + std::shared_ptr indices_node; + if (scalar) + indices_node = opset7::Constant::create(element::i64, {}, indices); + else + indices_node = opset7::Constant::create(element::i64, {indices.size()}, indices); + return std::make_shared( + input, indices_node, opset7::Constant::create(element::i64, {}, {0})); +}; + +TEST(TransformationTests, ShapeSubGraphTest) { + std::shared_ptr f(nullptr), f_ref(nullptr); + + Shape data_shape{1, 2, 3, 4}; + { + auto data = std::make_shared(element::f32, data_shape); + + auto shape_op_1 = std::make_shared(data); + auto gather_1 = gather(shape_op_1, {1}, true); + auto unsqueeze_1 = std::make_shared( + gather_1, opset7::Constant::create(element::i64, {1}, {0})); + + auto shape_op_2 = std::make_shared(data); + auto gather_2 = gather(shape_op_2, {2}, true); + auto unsqueeze_2 = std::make_shared( + gather_2, opset7::Constant::create(element::i64, {1}, {0})); + + auto const_1 = opset7::Constant::create(element::i64, Shape{1}, {2}); + auto const_2 = opset7::Constant::create(element::i64, Shape{1}, {2}); + + auto concat = std::make_shared(OutputVector{unsqueeze_1, unsqueeze_2, const_1, const_2}, 0); + + auto reshape = std::make_shared(data, concat, false); + f = std::make_shared(NodeVector{reshape}, ParameterVector{data}); + pass::Manager m; + m.register_pass(); + m.register_pass(); + m.run_passes(f); + ASSERT_NO_THROW(check_rt_info(f)); + ASSERT_EQ(reshape->get_output_partial_shape(0), PartialShape({2, 3, 2, 2})); + } + { + auto data = std::make_shared(element::f32, data_shape); + + auto shape_op_1 = std::make_shared(data); + auto gather_1 = gather(shape_op_1, {1, 2}); + + auto const_1 = opset7::Constant::create(element::i64, Shape{1}, {2}); + auto const_2 = opset7::Constant::create(element::i64, Shape{1}, {2}); + + auto concat = std::make_shared(OutputVector{gather_1, const_1, const_2}, 0); + + auto reshape = std::make_shared(data, concat, false); + f_ref = std::make_shared(NodeVector{reshape}, ParameterVector{data}); + } + + auto res = compare_functions(f, f_ref, true); + ASSERT_TRUE(res.first) << res.second; +} diff --git a/inference-engine/tests/functional/plugin/myriad/ngraph/transformations/dynamic_to_static_shape_broadcast.cpp b/inference-engine/tests/functional/plugin/myriad/ngraph/transformations/dynamic_to_static_shape_broadcast.cpp index 78de9260c2d..c58f283c6f7 100644 --- a/inference-engine/tests/functional/plugin/myriad/ngraph/transformations/dynamic_to_static_shape_broadcast.cpp +++ b/inference-engine/tests/functional/plugin/myriad/ngraph/transformations/dynamic_to_static_shape_broadcast.cpp @@ -74,7 +74,6 @@ protected: const auto tensorWithTargetShapeParam = std::make_shared(tensorType, targetShape); const auto shapeOfNode = std::make_shared(tensorWithTargetShapeParam, shapeType); - shapeOfNode->set_is_foldable(false); ngraph::ParameterVector params{tensorParam, tensorWithTargetShapeParam}; @@ -197,7 +196,6 @@ protected: const auto tensorWithTargetShapeParam = std::make_shared(shapeType, targetShape); const auto shapeOfNode = std::make_shared(tensorWithTargetShapeParam, shapeType); - shapeOfNode->set_is_foldable(false); ngraph::ParameterVector params{tensorParam, tensorWithTargetShapeParam}; diff --git a/ngraph/core/include/ngraph/op/shape_of.hpp b/ngraph/core/include/ngraph/op/shape_of.hpp index 1fb26548008..932ea4c56e0 100644 --- a/ngraph/core/include/ngraph/op/shape_of.hpp +++ b/ngraph/core/include/ngraph/op/shape_of.hpp @@ -33,14 +33,6 @@ namespace ngraph // Overload collision with method on Node using Node::set_output_type; - // FOR CONSTANT FOLDING INTERNAL USAGE ONLY - // Constant folding for cases with static rank but dynamic shape create a subgraph - // which contains a Shape of. - // In this case we need to prevent constant folding from endless creation of these - // subgraphs. - // These metods should be removed if better solution will be designed. - void set_is_foldable(bool is_foldable) { m_is_foldable = is_foldable; } - bool get_is_foldable() const { return m_is_foldable; } bool evaluate(const HostTensorVector& output_values, const HostTensorVector& input_values) const override; bool has_evaluate() const override; @@ -50,7 +42,6 @@ namespace ngraph const OutputVector& input_values) override; private: - bool m_is_foldable = true; element::Type m_output_type; }; } // namespace v3 @@ -72,14 +63,6 @@ namespace ngraph void validate_and_infer_types() override; - // FOR CONSTANT FOLDING INTERNAL USAGE ONLY - // Constant folding for cases with static rank but dynamic shape create a subgraph - // which contains a Shape of. - // In this case we need to prevent constant folding from endless creation of these - // subgraphs. - // These metods should be removed if better solution will be designed. - void set_is_foldable(bool is_foldable) { m_is_foldable = is_foldable; } - bool get_is_foldable() const { return m_is_foldable; } bool evaluate(const HostTensorVector& output_values, const HostTensorVector& input_values) const override; bool has_evaluate() const override; @@ -87,9 +70,6 @@ namespace ngraph bool evaluate_upper(const HostTensorVector& output_values) const override; bool constant_fold(OutputVector& output_values, const OutputVector& input_values) override; - - private: - bool m_is_foldable = true; }; } // namespace v0 using v0::ShapeOf; diff --git a/ngraph/core/src/op/shape_of.cpp b/ngraph/core/src/op/shape_of.cpp index d929eb0c1c0..792bad19bae 100644 --- a/ngraph/core/src/op/shape_of.cpp +++ b/ngraph/core/src/op/shape_of.cpp @@ -51,7 +51,6 @@ shared_ptr op::v3::ShapeOf::clone_with_new_inputs(const OutputVector& new_ NGRAPH_OP_SCOPE(v3_ShapeOf_clone_with_new_inputs); check_new_args_count(this, new_args); auto new_shape_of = make_shared(new_args.at(0), m_output_type); - new_shape_of->set_is_foldable(m_is_foldable); return new_shape_of; } @@ -82,8 +81,7 @@ namespace shape_of bool constant_fold_shape_of(Node* shape_of_node, Output& replacement, - const Output& shape_of_input, - bool is_foldable) + const Output& shape_of_input) { auto partial_shape = shape_of_input.get_partial_shape(); auto output_type = shape_of_node->get_output_element_type(0); @@ -100,46 +98,6 @@ namespace shape_of } return false; } - else if (partial_shape.rank().is_static() && is_foldable) - { - auto shape_of = shape_of_node->copy_with_new_inputs({shape_of_input}); - // Ugly - if (auto ps = as_type_ptr(shape_of)) - { - ps->set_is_foldable(false); - } - else if (auto ps = as_type_ptr(shape_of)) - { - ps->set_is_foldable(false); - } - auto dimensions = OutputVector{}; - auto output_dimensions = vector(partial_shape); - for (size_t i = 0; i < output_dimensions.size(); ++i) - { - if (output_dimensions[i].is_static()) - { - auto temp = std::make_shared( - output_type, - Shape{1}, - std::vector{output_dimensions[i].get_length()}); - temp->set_friendly_name("ConstDim/" + temp->get_name()); - dimensions.emplace_back(temp); - } - else - { - auto index = std::make_shared( - output_type, Shape{1}, std::vector{static_cast(i)}); - auto axis = std::make_shared( - element::i64, Shape{}, std::vector{0}); - auto temp = make_shared(shape_of, index, axis); - temp->set_friendly_name("DynDim/" + temp->get_name()); - dimensions.emplace_back(temp); - } - } - - replacement = std::make_shared(dimensions, 0); - return true; - } return false; } @@ -250,7 +208,7 @@ bool op::v3::ShapeOf::constant_fold(OutputVector& output_values, const OutputVec OV_ITT_SCOPED_TASK(itt::domains::nGraph, "op::v3::ShapeOf::constant_fold"); if (get_rt_info().count("DISABLED_CONSTANT_FOLDING")) return false; - return shape_of::constant_fold_shape_of(this, output_values[0], input_values[0], m_is_foldable); + return shape_of::constant_fold_shape_of(this, output_values[0], input_values[0]); } // op::v0::ShapeOf @@ -286,7 +244,6 @@ shared_ptr op::v0::ShapeOf::clone_with_new_inputs(const OutputVector& new_ description(), " operation with name ", get_friendly_name()); - new_shape_of->set_is_foldable(m_is_foldable); return new_shape_of; } @@ -318,7 +275,7 @@ bool op::v0::ShapeOf::constant_fold(OutputVector& output_values, const OutputVec OV_ITT_SCOPED_TASK(itt::domains::nGraph, "op::v0::ShapeOf::constant_fold"); if (get_rt_info().count("DISABLED_CONSTANT_FOLDING")) return false; - return shape_of::constant_fold_shape_of(this, output_values[0], input_values[0], m_is_foldable); + return shape_of::constant_fold_shape_of(this, output_values[0], input_values[0]); } bool op::v0::ShapeOf::evaluate_lower(const HostTensorVector& output_values) const diff --git a/ngraph/test/constant_folding.cpp b/ngraph/test/constant_folding.cpp index 0f5ce320531..d7efe623708 100644 --- a/ngraph/test/constant_folding.cpp +++ b/ngraph/test/constant_folding.cpp @@ -567,16 +567,11 @@ TEST(constant_folding, shape_of_dynamic_v0) pass_manager.register_pass(); pass_manager.run_passes(f); - ASSERT_EQ(count_ops_of_type(f), 1); - ASSERT_EQ(count_ops_of_type(f), 1); - ASSERT_EQ(count_ops_of_type(f), 1); - ASSERT_EQ(count_ops_of_type(f), 8); + ASSERT_EQ(f->get_ops().size(), 3); - auto result_as_concat = - as_type_ptr(f->get_results().at(0)->input_value(0).get_node_shared_ptr()); - ASSERT_TRUE(result_as_concat); - ASSERT_EQ(result_as_concat->get_friendly_name(), "test"); - ASSERT_EQ(result_as_concat->get_output_shape(0), Shape{7}); + auto result_shape_of = f->get_results().at(0)->get_input_node_shared_ptr(0); + ASSERT_EQ(result_shape_of, shape_of); + ASSERT_EQ(result_shape_of->get_friendly_name(), "test"); } TEST(constant_folding, shape_of_dynamic_v3) @@ -592,17 +587,11 @@ TEST(constant_folding, shape_of_dynamic_v3) pass_manager.register_pass(); pass_manager.run_passes(f); - ASSERT_EQ(count_ops_of_type(f), 1); - ASSERT_EQ(count_ops_of_type(f), 1); - ASSERT_EQ(count_ops_of_type(f), 1); - ASSERT_EQ(count_ops_of_type(f), 8); + ASSERT_EQ(f->get_ops().size(), 3); - auto result_as_concat = - as_type_ptr(f->get_results().at(0)->input_value(0).get_node_shared_ptr()); - ASSERT_TRUE(result_as_concat); - ASSERT_EQ(result_as_concat->get_friendly_name(), "test"); - ASSERT_EQ(result_as_concat->get_output_shape(0), Shape{7}); - ASSERT_EQ(result_as_concat->get_output_element_type(0), element::i64); + auto result_shape_of = f->get_results().at(0)->get_input_node_shared_ptr(0); + ASSERT_EQ(result_shape_of, shape_of); + ASSERT_EQ(result_shape_of->get_friendly_name(), "test"); } TEST(constant_folding, shape_of_dynamic_i32_v3) @@ -618,17 +607,11 @@ TEST(constant_folding, shape_of_dynamic_i32_v3) pass_manager.register_pass(); pass_manager.run_passes(f); - ASSERT_EQ(count_ops_of_type(f), 1); - ASSERT_EQ(count_ops_of_type(f), 1); - ASSERT_EQ(count_ops_of_type(f), 1); - ASSERT_EQ(count_ops_of_type(f), 8); + ASSERT_EQ(f->get_ops().size(), 3); - auto result_as_concat = - as_type_ptr(f->get_results().at(0)->input_value(0).get_node_shared_ptr()); - ASSERT_TRUE(result_as_concat); - ASSERT_EQ(result_as_concat->get_friendly_name(), "test"); - ASSERT_EQ(result_as_concat->get_output_shape(0), Shape{7}); - ASSERT_EQ(result_as_concat->get_output_element_type(0), element::i32); + auto result_shape_of = f->get_results().at(0)->get_input_node_shared_ptr(0); + ASSERT_EQ(result_shape_of, shape_of); + ASSERT_EQ(result_shape_of->get_friendly_name(), "test"); } // We need to be sure that constant folding won't be calculated endlessly. @@ -646,16 +629,11 @@ TEST(constant_folding, shape_of_dynamic_double_folding_v0) pass_manager.run_passes(f); pass_manager.run_passes(f); - ASSERT_EQ(count_ops_of_type(f), 1); - ASSERT_EQ(count_ops_of_type(f), 1); - ASSERT_EQ(count_ops_of_type(f), 1); - ASSERT_EQ(count_ops_of_type(f), 8); + ASSERT_EQ(f->get_ops().size(), 3); - auto result_as_concat = - as_type_ptr(f->get_results().at(0)->input_value(0).get_node_shared_ptr()); - ASSERT_TRUE(result_as_concat); - ASSERT_EQ(result_as_concat->get_friendly_name(), "test"); - ASSERT_EQ(result_as_concat->get_output_shape(0), Shape{7}); + auto result_shape_of = f->get_results().at(0)->get_input_node_shared_ptr(0); + ASSERT_EQ(result_shape_of, shape_of); + ASSERT_EQ(result_shape_of->get_friendly_name(), "test"); } TEST(constant_folding, shape_of_dynamic_double_folding_v3) @@ -672,16 +650,11 @@ TEST(constant_folding, shape_of_dynamic_double_folding_v3) pass_manager.run_passes(f); pass_manager.run_passes(f); - ASSERT_EQ(count_ops_of_type(f), 1); - ASSERT_EQ(count_ops_of_type(f), 1); - ASSERT_EQ(count_ops_of_type(f), 1); - ASSERT_EQ(count_ops_of_type(f), 8); + ASSERT_EQ(f->get_ops().size(), 3); - auto result_as_concat = - as_type_ptr(f->get_results().at(0)->input_value(0).get_node_shared_ptr()); - ASSERT_TRUE(result_as_concat); - ASSERT_EQ(result_as_concat->get_friendly_name(), "test"); - ASSERT_EQ(result_as_concat->get_output_shape(0), Shape{7}); + auto result_shape_of = f->get_results().at(0)->get_input_node_shared_ptr(0); + ASSERT_EQ(result_shape_of, shape_of); + ASSERT_EQ(result_shape_of->get_friendly_name(), "test"); } // Constant folding will not succeed on ShapeOf if the argument rank is dynamic.