ShapeOf Sub-Graphs Simplification (#6308)

* ShapeOf Sub-Graphs Simplification

* Removed graph visualization

* Review comment

* comments resolved
This commit is contained in:
Evgenya Stepyreva 2021-07-07 18:12:14 +03:00 committed by GitHub
parent 86b97e0a74
commit e9011a9536
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
10 changed files with 314 additions and 115 deletions

View File

@ -14,6 +14,7 @@ namespace ngraph {
namespace pass {
class TRANSFORMATIONS_API EliminateUnsqueezeGather;
class TRANSFORMATIONS_API EliminateGatherUnsqueeze;
} // namespace pass
} // namespace ngraph
@ -29,3 +30,15 @@ public:
NGRAPH_RTTI_DECLARATION;
EliminateUnsqueezeGather();
};
/**
* @ingroup ie_transformation_common_api
* @brief Remove Gather -> Unsqueeze pair, if Gather takes a scalar and
* Unsqueeze makes it a 1D tensor
*/
class ngraph::pass::EliminateGatherUnsqueeze : public ngraph::pass::MatcherPass {
public:
NGRAPH_RTTI_DECLARATION;
EliminateGatherUnsqueeze();
};

View File

@ -0,0 +1,60 @@
// Copyright (C) 2018-2021 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#pragma once
#include <vector>
#include <memory>
#include <transformations_visibility.hpp>
#include <ngraph/opsets/opset1.hpp>
#include <ngraph/opsets/opset3.hpp>
#include <ngraph/pass/graph_rewrite.hpp>
#include <ngraph/util.hpp>
namespace ngraph {
namespace pass {
class TRANSFORMATIONS_API SimplifyShapeOfSubGraph;
class TRANSFORMATIONS_API SharedShapeOf;
class TRANSFORMATIONS_API GroupedGatherElimination;
} // namespace pass
} // namespace ngraph
/**
* @ingroup ie_transformation_common_api
* @brief SharedShapeOf transformation replaces group of ShapeOf
* operations with the first ShapeOf in this group. All ShapeOfs in this group
* must be equal and consume the same output port.
*/
class ngraph::pass::SharedShapeOf: public ngraph::pass::FunctionPass {
public:
NGRAPH_RTTI_DECLARATION;
bool run_on_function(std::shared_ptr<ngraph::Function> f) override;
};
/**
* @ingroup ie_transformation_common_api
* @brief GroupedGatherElimination transformation replaces group of Gather
* operations with the first Gather in this group and updated indices input
* in case all Gathers in the group are consumed by the same Concat in incremental order.
*/
class ngraph::pass::GroupedGatherElimination: public ngraph::pass::MatcherPass {
public:
NGRAPH_RTTI_DECLARATION;
GroupedGatherElimination();
};
/**
* @ingroup ie_transformation_common_api
* @brief SimplifyShapeOfSubGraph transformation runs specific optimizations of shape sub-graphs
*/
class ngraph::pass::SimplifyShapeOfSubGraph: public ngraph::pass::FunctionPass {
public:
NGRAPH_RTTI_DECLARATION;
bool run_on_function(std::shared_ptr<ngraph::Function> f) override;
};

View File

@ -76,6 +76,7 @@
#include <ngraph/pass/manager.hpp>
#include <ngraph/pass/constant_folding.hpp>
#include <transformations/common_optimizations/weights_dequantize_to_fake_quantize.hpp>
#include <transformations/common_optimizations/simplify_shape_of_sub_graph.hpp>
NGRAPH_RTTI_DEFINITION(ngraph::pass::CommonOptimizations, "CommonOptimizations", 0);
@ -85,6 +86,7 @@ bool ngraph::pass::CommonOptimizations::run_on_function(std::shared_ptr<ngraph::
// This pass must be called first in pipeline
manager.register_pass<ngraph::pass::InitNodeInfo>();
manager.register_pass<ngraph::pass::SimplifyShapeOfSubGraph>();
manager.register_pass<ngraph::pass::ConstantFolding>();
manager.register_pass<ngraph::pass::RemoveFilteringBoxesBySize>(); // Resolves dynamism (replaces NonZero), CF needed

View File

@ -7,6 +7,7 @@
#include <ngraph/opsets/opset6.hpp>
#include <ngraph/pattern/op/wrap_type.hpp>
#include <ngraph/rt_info.hpp>
#include <transformations/utils/utils.hpp>
#include "itt.hpp"
NGRAPH_RTTI_DEFINITION(ngraph::pass::EliminateUnsqueezeGather, "EliminateUnsqueezeGather", 0);
@ -58,3 +59,36 @@ ngraph::pass::EliminateUnsqueezeGather::EliminateUnsqueezeGather() {
auto m = std::make_shared<ngraph::pattern::Matcher>(gather, "EliminateUnsqueezeGather");
register_matcher(m, callback);
}
NGRAPH_RTTI_DEFINITION(ngraph::pass::EliminateGatherUnsqueeze, "EliminateGatherUnsqueeze", 0);
ngraph::pass::EliminateGatherUnsqueeze::EliminateGatherUnsqueeze() {
MATCHER_SCOPE(EliminateGatherUnsqueeze);
const auto gather_indices_label = ngraph::pattern::wrap_type<ngraph::op::Constant>(pattern::rank_equals(0));
const auto gather_axis_label = ngraph::pattern::wrap_type<ngraph::op::Constant>();
const auto gather_label = ngraph::pattern::wrap_type<ngraph::op::util::GatherBase>(
{ngraph::pattern::any_input(), gather_indices_label, gather_axis_label}, pattern::rank_equals(0));
const auto unsqueeze_label = ngraph::pattern::wrap_type<ngraph::opset6::Unsqueeze>(
{gather_label, ngraph::pattern::any_input()}, pattern::rank_equals(1));
ngraph::matcher_pass_callback callback = [=](ngraph::pattern::Matcher& m) {
auto pattern_nodes = m.get_pattern_map();
auto& gather_indices = pattern_nodes.at(gather_indices_label);
auto& gather = pattern_nodes.at(gather_label);
auto& unsqueeze = pattern_nodes.at(unsqueeze_label);
auto new_indices = ngraph::op::util::make_try_fold<ngraph::opset6::Reshape>(gather_indices, opset6::Constant::create(element::i32, {1}, {1}), false);
auto new_gather = gather->clone_with_new_inputs({gather->input_value(0), new_indices, gather->input_value(2)});
new_gather->set_friendly_name(gather->get_friendly_name());
ngraph::copy_runtime_info({unsqueeze, gather}, {new_gather, new_indices});
ngraph::replace_node(unsqueeze, new_gather);
return true;
};
auto m = std::make_shared<ngraph::pattern::Matcher>(unsqueeze_label, "EliminateGatherUnsqueeze");
register_matcher(m, callback);
}

View File

@ -0,0 +1,101 @@
// Copyright (C) 2018-2021 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include <memory>
#include <vector>
#include "itt.hpp"
#include <ngraph/opsets/opset1.hpp>
#include <ngraph/opsets/opset3.hpp>
#include <ngraph/opsets/opset7.hpp>
#include <ngraph/rt_info.hpp>
#include <ngraph/pattern/op/wrap_type.hpp>
#include <transformations/common_optimizations/simplify_shape_of_sub_graph.hpp>
#include <transformations/common_optimizations/eliminate_unsqueeze_gather.hpp>
#include <transformations/utils/utils.hpp>
NGRAPH_RTTI_DEFINITION(ngraph::pass::SharedShapeOf, "SharedShapeOf", 0);
bool ngraph::pass::SharedShapeOf::run_on_function(std::shared_ptr<ngraph::Function> f) {
RUN_ON_FUNCTION_SCOPE(SharedShapeOf);
bool graph_rewritten = false;
std::map<ngraph::Output<Node>, std::vector<std::shared_ptr<ngraph::Node>>> source_to_shape_of;
for (const auto & node : f->get_ordered_ops()) {
// Recursively apply transformation for sub-graph based operations
if (auto sub_graph_node = std::dynamic_pointer_cast<op::util::SubGraphOp>(node))
if (auto sub_graph = sub_graph_node->get_function())
graph_rewritten |= run_on_function(sub_graph);
if (is_type<ngraph::opset1::ShapeOf>(node) || is_type<ngraph::opset3::ShapeOf>(node))
source_to_shape_of[node->input_value(0)].push_back(node);
}
for (const auto& pair : source_to_shape_of) {
if (pair.second.size() < 2)
continue;
const auto& root_ss = pair.second[0];
for (const auto& child_ss : pair.second)
if (root_ss->get_instance_id() != child_ss->get_instance_id() && root_ss->get_output_element_type(0) == root_ss->get_output_element_type(0))
graph_rewritten |= replace_output_update_name(child_ss->output(0), root_ss->output(0));
}
return graph_rewritten;
}
NGRAPH_RTTI_DEFINITION(ngraph::pass::GroupedGatherElimination, "GroupedGatherElimination", 0);
ngraph::pass::GroupedGatherElimination::GroupedGatherElimination() {
MATCHER_SCOPE(GroupedGatherElimination);
auto concat_label = ngraph::pattern::wrap_type<ngraph::opset1::Concat>(pattern::rank_equals(1));
ngraph::matcher_pass_callback callback = [](pattern::Matcher& m) {
auto concat = m.get_match_root();
OutputVector inputs = concat->input_values();
NodeVector new_ops;
size_t i = 0, original_inputs_size = inputs.size();
while (inputs.size() > i + 1) {
auto curr = inputs[i].get_node_shared_ptr(), next = inputs[i + 1].get_node_shared_ptr();
if (curr->get_type_info() != next->get_type_info() ||
(!is_type<opset1::Gather>(curr) && !is_type<opset7::Gather>(curr)) ||
(curr->input_value(0) != next->input_value(0))) {
++i;
continue;
} // curr and next are the same type of gather which takes data from the same source
auto joint_indices = ngraph::op::util::make_try_fold<opset1::Concat>(OutputVector{curr->input_value(1), next->input_value(1)}, 0);
auto new_gather = curr->clone_with_new_inputs(
{curr->input_value(0), joint_indices, ngraph::opset1::Constant::create(element::i64, {}, {0})});
new_ops.push_back(joint_indices);
new_ops.push_back(new_gather);
inputs.erase(inputs.begin() + i);
inputs[i] = new_gather->output(0);
}
if (original_inputs_size > inputs.size()) {
auto new_concat = std::make_shared<opset1::Concat>(inputs, 0);
new_ops.push_back(new_concat);
new_concat->set_friendly_name(concat->get_friendly_name());
ngraph::copy_runtime_info(concat, new_ops);
ngraph::replace_node(concat, new_concat);
return true;
}
return false;
};
auto m = std::make_shared<ngraph::pattern::Matcher>(concat_label, matcher_name);
this->register_matcher(m, callback);
}
NGRAPH_RTTI_DEFINITION(ngraph::pass::SimplifyShapeOfSubGraph, "SimplifyShapeOfSubGraph", 0);
bool ngraph::pass::SimplifyShapeOfSubGraph::run_on_function(std::shared_ptr<ngraph::Function> f) {
RUN_ON_FUNCTION_SCOPE(GroupedGatherElimination);
ngraph::pass::Manager manager;
manager.set_per_pass_validation(false);
manager.register_pass<ngraph::pass::EliminateGatherUnsqueeze>();
manager.register_pass<ngraph::pass::SharedShapeOf>();
manager.register_pass<ngraph::pass::GroupedGatherElimination>();
manager.register_pass<ngraph::pass::Validate>();
manager.run_passes(f);
return false;
}

View File

@ -0,0 +1,81 @@
// Copyright (C) 2018-2021 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include <gtest/gtest.h>
#include <string>
#include <memory>
#include <queue>
#include <ngraph/function.hpp>
#include <ngraph/opsets/opset7.hpp>
#include <transformations/common_optimizations/simplify_shape_of_sub_graph.hpp>
#include <transformations/init_node_info.hpp>
#include <ngraph/pass/manager.hpp>
#include "common_test_utils/ngraph_test_utils.hpp"
using namespace testing;
using namespace ngraph;
auto gather = [](const std::shared_ptr<Node> input, std::vector<int64_t> indices, bool scalar = false) -> Output<Node> {
std::shared_ptr<Node> indices_node;
if (scalar)
indices_node = opset7::Constant::create(element::i64, {}, indices);
else
indices_node = opset7::Constant::create(element::i64, {indices.size()}, indices);
return std::make_shared<ngraph::opset7::Gather>(
input, indices_node, opset7::Constant::create(element::i64, {}, {0}));
};
TEST(TransformationTests, ShapeSubGraphTest) {
std::shared_ptr<Function> f(nullptr), f_ref(nullptr);
Shape data_shape{1, 2, 3, 4};
{
auto data = std::make_shared<opset7::Parameter>(element::f32, data_shape);
auto shape_op_1 = std::make_shared<opset7::ShapeOf>(data);
auto gather_1 = gather(shape_op_1, {1}, true);
auto unsqueeze_1 = std::make_shared<opset7::Unsqueeze>(
gather_1, opset7::Constant::create(element::i64, {1}, {0}));
auto shape_op_2 = std::make_shared<opset7::ShapeOf>(data);
auto gather_2 = gather(shape_op_2, {2}, true);
auto unsqueeze_2 = std::make_shared<opset7::Unsqueeze>(
gather_2, opset7::Constant::create(element::i64, {1}, {0}));
auto const_1 = opset7::Constant::create(element::i64, Shape{1}, {2});
auto const_2 = opset7::Constant::create(element::i64, Shape{1}, {2});
auto concat = std::make_shared<opset7::Concat>(OutputVector{unsqueeze_1, unsqueeze_2, const_1, const_2}, 0);
auto reshape = std::make_shared<opset7::Reshape>(data, concat, false);
f = std::make_shared<Function>(NodeVector{reshape}, ParameterVector{data});
pass::Manager m;
m.register_pass<pass::InitNodeInfo>();
m.register_pass<pass::SimplifyShapeOfSubGraph>();
m.run_passes(f);
ASSERT_NO_THROW(check_rt_info(f));
ASSERT_EQ(reshape->get_output_partial_shape(0), PartialShape({2, 3, 2, 2}));
}
{
auto data = std::make_shared<opset7::Parameter>(element::f32, data_shape);
auto shape_op_1 = std::make_shared<opset7::ShapeOf>(data);
auto gather_1 = gather(shape_op_1, {1, 2});
auto const_1 = opset7::Constant::create(element::i64, Shape{1}, {2});
auto const_2 = opset7::Constant::create(element::i64, Shape{1}, {2});
auto concat = std::make_shared<opset7::Concat>(OutputVector{gather_1, const_1, const_2}, 0);
auto reshape = std::make_shared<opset7::Reshape>(data, concat, false);
f_ref = std::make_shared<Function>(NodeVector{reshape}, ParameterVector{data});
}
auto res = compare_functions(f, f_ref, true);
ASSERT_TRUE(res.first) << res.second;
}

View File

@ -74,7 +74,6 @@ protected:
const auto tensorWithTargetShapeParam = std::make_shared<ngraph::opset3::Parameter>(tensorType, targetShape);
const auto shapeOfNode = std::make_shared<ngraph::opset3::ShapeOf>(tensorWithTargetShapeParam, shapeType);
shapeOfNode->set_is_foldable(false);
ngraph::ParameterVector params{tensorParam, tensorWithTargetShapeParam};
@ -197,7 +196,6 @@ protected:
const auto tensorWithTargetShapeParam = std::make_shared<ngraph::opset5::Parameter>(shapeType, targetShape);
const auto shapeOfNode = std::make_shared<ngraph::opset5::ShapeOf>(tensorWithTargetShapeParam, shapeType);
shapeOfNode->set_is_foldable(false);
ngraph::ParameterVector params{tensorParam, tensorWithTargetShapeParam};

View File

@ -33,14 +33,6 @@ namespace ngraph
// Overload collision with method on Node
using Node::set_output_type;
// FOR CONSTANT FOLDING INTERNAL USAGE ONLY
// Constant folding for cases with static rank but dynamic shape create a subgraph
// which contains a Shape of.
// In this case we need to prevent constant folding from endless creation of these
// subgraphs.
// These metods should be removed if better solution will be designed.
void set_is_foldable(bool is_foldable) { m_is_foldable = is_foldable; }
bool get_is_foldable() const { return m_is_foldable; }
bool evaluate(const HostTensorVector& output_values,
const HostTensorVector& input_values) const override;
bool has_evaluate() const override;
@ -50,7 +42,6 @@ namespace ngraph
const OutputVector& input_values) override;
private:
bool m_is_foldable = true;
element::Type m_output_type;
};
} // namespace v3
@ -72,14 +63,6 @@ namespace ngraph
void validate_and_infer_types() override;
// FOR CONSTANT FOLDING INTERNAL USAGE ONLY
// Constant folding for cases with static rank but dynamic shape create a subgraph
// which contains a Shape of.
// In this case we need to prevent constant folding from endless creation of these
// subgraphs.
// These metods should be removed if better solution will be designed.
void set_is_foldable(bool is_foldable) { m_is_foldable = is_foldable; }
bool get_is_foldable() const { return m_is_foldable; }
bool evaluate(const HostTensorVector& output_values,
const HostTensorVector& input_values) const override;
bool has_evaluate() const override;
@ -87,9 +70,6 @@ namespace ngraph
bool evaluate_upper(const HostTensorVector& output_values) const override;
bool constant_fold(OutputVector& output_values,
const OutputVector& input_values) override;
private:
bool m_is_foldable = true;
};
} // namespace v0
using v0::ShapeOf;

View File

@ -51,7 +51,6 @@ shared_ptr<Node> op::v3::ShapeOf::clone_with_new_inputs(const OutputVector& new_
NGRAPH_OP_SCOPE(v3_ShapeOf_clone_with_new_inputs);
check_new_args_count(this, new_args);
auto new_shape_of = make_shared<op::v3::ShapeOf>(new_args.at(0), m_output_type);
new_shape_of->set_is_foldable(m_is_foldable);
return new_shape_of;
}
@ -82,8 +81,7 @@ namespace shape_of
bool constant_fold_shape_of(Node* shape_of_node,
Output<Node>& replacement,
const Output<Node>& shape_of_input,
bool is_foldable)
const Output<Node>& shape_of_input)
{
auto partial_shape = shape_of_input.get_partial_shape();
auto output_type = shape_of_node->get_output_element_type(0);
@ -100,46 +98,6 @@ namespace shape_of
}
return false;
}
else if (partial_shape.rank().is_static() && is_foldable)
{
auto shape_of = shape_of_node->copy_with_new_inputs({shape_of_input});
// Ugly
if (auto ps = as_type_ptr<op::v0::ShapeOf>(shape_of))
{
ps->set_is_foldable(false);
}
else if (auto ps = as_type_ptr<op::v3::ShapeOf>(shape_of))
{
ps->set_is_foldable(false);
}
auto dimensions = OutputVector{};
auto output_dimensions = vector<Dimension>(partial_shape);
for (size_t i = 0; i < output_dimensions.size(); ++i)
{
if (output_dimensions[i].is_static())
{
auto temp = std::make_shared<op::v0::Constant>(
output_type,
Shape{1},
std::vector<int64_t>{output_dimensions[i].get_length()});
temp->set_friendly_name("ConstDim/" + temp->get_name());
dimensions.emplace_back(temp);
}
else
{
auto index = std::make_shared<op::v0::Constant>(
output_type, Shape{1}, std::vector<int64_t>{static_cast<int64_t>(i)});
auto axis = std::make_shared<op::v0::Constant>(
element::i64, Shape{}, std::vector<int64_t>{0});
auto temp = make_shared<op::v1::Gather>(shape_of, index, axis);
temp->set_friendly_name("DynDim/" + temp->get_name());
dimensions.emplace_back(temp);
}
}
replacement = std::make_shared<op::Concat>(dimensions, 0);
return true;
}
return false;
}
@ -250,7 +208,7 @@ bool op::v3::ShapeOf::constant_fold(OutputVector& output_values, const OutputVec
OV_ITT_SCOPED_TASK(itt::domains::nGraph, "op::v3::ShapeOf::constant_fold");
if (get_rt_info().count("DISABLED_CONSTANT_FOLDING"))
return false;
return shape_of::constant_fold_shape_of(this, output_values[0], input_values[0], m_is_foldable);
return shape_of::constant_fold_shape_of(this, output_values[0], input_values[0]);
}
// op::v0::ShapeOf
@ -286,7 +244,6 @@ shared_ptr<Node> op::v0::ShapeOf::clone_with_new_inputs(const OutputVector& new_
description(),
" operation with name ",
get_friendly_name());
new_shape_of->set_is_foldable(m_is_foldable);
return new_shape_of;
}
@ -318,7 +275,7 @@ bool op::v0::ShapeOf::constant_fold(OutputVector& output_values, const OutputVec
OV_ITT_SCOPED_TASK(itt::domains::nGraph, "op::v0::ShapeOf::constant_fold");
if (get_rt_info().count("DISABLED_CONSTANT_FOLDING"))
return false;
return shape_of::constant_fold_shape_of(this, output_values[0], input_values[0], m_is_foldable);
return shape_of::constant_fold_shape_of(this, output_values[0], input_values[0]);
}
bool op::v0::ShapeOf::evaluate_lower(const HostTensorVector& output_values) const

View File

@ -567,16 +567,11 @@ TEST(constant_folding, shape_of_dynamic_v0)
pass_manager.register_pass<pass::ConstantFolding>();
pass_manager.run_passes(f);
ASSERT_EQ(count_ops_of_type<op::v0::ShapeOf>(f), 1);
ASSERT_EQ(count_ops_of_type<op::v1::Gather>(f), 1);
ASSERT_EQ(count_ops_of_type<op::Concat>(f), 1);
ASSERT_EQ(count_ops_of_type<op::Constant>(f), 8);
ASSERT_EQ(f->get_ops().size(), 3);
auto result_as_concat =
as_type_ptr<op::Concat>(f->get_results().at(0)->input_value(0).get_node_shared_ptr());
ASSERT_TRUE(result_as_concat);
ASSERT_EQ(result_as_concat->get_friendly_name(), "test");
ASSERT_EQ(result_as_concat->get_output_shape(0), Shape{7});
auto result_shape_of = f->get_results().at(0)->get_input_node_shared_ptr(0);
ASSERT_EQ(result_shape_of, shape_of);
ASSERT_EQ(result_shape_of->get_friendly_name(), "test");
}
TEST(constant_folding, shape_of_dynamic_v3)
@ -592,17 +587,11 @@ TEST(constant_folding, shape_of_dynamic_v3)
pass_manager.register_pass<pass::ConstantFolding>();
pass_manager.run_passes(f);
ASSERT_EQ(count_ops_of_type<op::v3::ShapeOf>(f), 1);
ASSERT_EQ(count_ops_of_type<op::v1::Gather>(f), 1);
ASSERT_EQ(count_ops_of_type<op::Concat>(f), 1);
ASSERT_EQ(count_ops_of_type<op::Constant>(f), 8);
ASSERT_EQ(f->get_ops().size(), 3);
auto result_as_concat =
as_type_ptr<op::Concat>(f->get_results().at(0)->input_value(0).get_node_shared_ptr());
ASSERT_TRUE(result_as_concat);
ASSERT_EQ(result_as_concat->get_friendly_name(), "test");
ASSERT_EQ(result_as_concat->get_output_shape(0), Shape{7});
ASSERT_EQ(result_as_concat->get_output_element_type(0), element::i64);
auto result_shape_of = f->get_results().at(0)->get_input_node_shared_ptr(0);
ASSERT_EQ(result_shape_of, shape_of);
ASSERT_EQ(result_shape_of->get_friendly_name(), "test");
}
TEST(constant_folding, shape_of_dynamic_i32_v3)
@ -618,17 +607,11 @@ TEST(constant_folding, shape_of_dynamic_i32_v3)
pass_manager.register_pass<pass::ConstantFolding>();
pass_manager.run_passes(f);
ASSERT_EQ(count_ops_of_type<op::v3::ShapeOf>(f), 1);
ASSERT_EQ(count_ops_of_type<op::v1::Gather>(f), 1);
ASSERT_EQ(count_ops_of_type<op::Concat>(f), 1);
ASSERT_EQ(count_ops_of_type<op::Constant>(f), 8);
ASSERT_EQ(f->get_ops().size(), 3);
auto result_as_concat =
as_type_ptr<op::Concat>(f->get_results().at(0)->input_value(0).get_node_shared_ptr());
ASSERT_TRUE(result_as_concat);
ASSERT_EQ(result_as_concat->get_friendly_name(), "test");
ASSERT_EQ(result_as_concat->get_output_shape(0), Shape{7});
ASSERT_EQ(result_as_concat->get_output_element_type(0), element::i32);
auto result_shape_of = f->get_results().at(0)->get_input_node_shared_ptr(0);
ASSERT_EQ(result_shape_of, shape_of);
ASSERT_EQ(result_shape_of->get_friendly_name(), "test");
}
// We need to be sure that constant folding won't be calculated endlessly.
@ -646,16 +629,11 @@ TEST(constant_folding, shape_of_dynamic_double_folding_v0)
pass_manager.run_passes(f);
pass_manager.run_passes(f);
ASSERT_EQ(count_ops_of_type<op::v0::ShapeOf>(f), 1);
ASSERT_EQ(count_ops_of_type<op::v1::Gather>(f), 1);
ASSERT_EQ(count_ops_of_type<op::Concat>(f), 1);
ASSERT_EQ(count_ops_of_type<op::Constant>(f), 8);
ASSERT_EQ(f->get_ops().size(), 3);
auto result_as_concat =
as_type_ptr<op::Concat>(f->get_results().at(0)->input_value(0).get_node_shared_ptr());
ASSERT_TRUE(result_as_concat);
ASSERT_EQ(result_as_concat->get_friendly_name(), "test");
ASSERT_EQ(result_as_concat->get_output_shape(0), Shape{7});
auto result_shape_of = f->get_results().at(0)->get_input_node_shared_ptr(0);
ASSERT_EQ(result_shape_of, shape_of);
ASSERT_EQ(result_shape_of->get_friendly_name(), "test");
}
TEST(constant_folding, shape_of_dynamic_double_folding_v3)
@ -672,16 +650,11 @@ TEST(constant_folding, shape_of_dynamic_double_folding_v3)
pass_manager.run_passes(f);
pass_manager.run_passes(f);
ASSERT_EQ(count_ops_of_type<op::v3::ShapeOf>(f), 1);
ASSERT_EQ(count_ops_of_type<op::v1::Gather>(f), 1);
ASSERT_EQ(count_ops_of_type<op::Concat>(f), 1);
ASSERT_EQ(count_ops_of_type<op::Constant>(f), 8);
ASSERT_EQ(f->get_ops().size(), 3);
auto result_as_concat =
as_type_ptr<op::Concat>(f->get_results().at(0)->input_value(0).get_node_shared_ptr());
ASSERT_TRUE(result_as_concat);
ASSERT_EQ(result_as_concat->get_friendly_name(), "test");
ASSERT_EQ(result_as_concat->get_output_shape(0), Shape{7});
auto result_shape_of = f->get_results().at(0)->get_input_node_shared_ptr(0);
ASSERT_EQ(result_shape_of, shape_of);
ASSERT_EQ(result_shape_of->get_friendly_name(), "test");
}
// Constant folding will not succeed on ShapeOf if the argument rank is dynamic.