ShapeOf Sub-Graphs Simplification (#6308)
* ShapeOf Sub-Graphs Simplification * Removed graph visualization * Review comment * comments resolved
This commit is contained in:
parent
86b97e0a74
commit
e9011a9536
@ -14,6 +14,7 @@ namespace ngraph {
|
||||
namespace pass {
|
||||
|
||||
class TRANSFORMATIONS_API EliminateUnsqueezeGather;
|
||||
class TRANSFORMATIONS_API EliminateGatherUnsqueeze;
|
||||
|
||||
} // namespace pass
|
||||
} // namespace ngraph
|
||||
@ -29,3 +30,15 @@ public:
|
||||
NGRAPH_RTTI_DECLARATION;
|
||||
EliminateUnsqueezeGather();
|
||||
};
|
||||
|
||||
/**
|
||||
* @ingroup ie_transformation_common_api
|
||||
* @brief Remove Gather -> Unsqueeze pair, if Gather takes a scalar and
|
||||
* Unsqueeze makes it a 1D tensor
|
||||
*/
|
||||
|
||||
class ngraph::pass::EliminateGatherUnsqueeze : public ngraph::pass::MatcherPass {
|
||||
public:
|
||||
NGRAPH_RTTI_DECLARATION;
|
||||
EliminateGatherUnsqueeze();
|
||||
};
|
||||
|
@ -0,0 +1,60 @@
|
||||
// Copyright (C) 2018-2021 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <vector>
|
||||
#include <memory>
|
||||
|
||||
#include <transformations_visibility.hpp>
|
||||
|
||||
#include <ngraph/opsets/opset1.hpp>
|
||||
#include <ngraph/opsets/opset3.hpp>
|
||||
#include <ngraph/pass/graph_rewrite.hpp>
|
||||
#include <ngraph/util.hpp>
|
||||
|
||||
namespace ngraph {
|
||||
namespace pass {
|
||||
|
||||
class TRANSFORMATIONS_API SimplifyShapeOfSubGraph;
|
||||
class TRANSFORMATIONS_API SharedShapeOf;
|
||||
class TRANSFORMATIONS_API GroupedGatherElimination;
|
||||
|
||||
} // namespace pass
|
||||
} // namespace ngraph
|
||||
|
||||
|
||||
/**
|
||||
* @ingroup ie_transformation_common_api
|
||||
* @brief SharedShapeOf transformation replaces group of ShapeOf
|
||||
* operations with the first ShapeOf in this group. All ShapeOfs in this group
|
||||
* must be equal and consume the same output port.
|
||||
*/
|
||||
class ngraph::pass::SharedShapeOf: public ngraph::pass::FunctionPass {
|
||||
public:
|
||||
NGRAPH_RTTI_DECLARATION;
|
||||
bool run_on_function(std::shared_ptr<ngraph::Function> f) override;
|
||||
};
|
||||
|
||||
/**
|
||||
* @ingroup ie_transformation_common_api
|
||||
* @brief GroupedGatherElimination transformation replaces group of Gather
|
||||
* operations with the first Gather in this group and updated indices input
|
||||
* in case all Gathers in the group are consumed by the same Concat in incremental order.
|
||||
*/
|
||||
class ngraph::pass::GroupedGatherElimination: public ngraph::pass::MatcherPass {
|
||||
public:
|
||||
NGRAPH_RTTI_DECLARATION;
|
||||
GroupedGatherElimination();
|
||||
};
|
||||
|
||||
/**
|
||||
* @ingroup ie_transformation_common_api
|
||||
* @brief SimplifyShapeOfSubGraph transformation runs specific optimizations of shape sub-graphs
|
||||
*/
|
||||
class ngraph::pass::SimplifyShapeOfSubGraph: public ngraph::pass::FunctionPass {
|
||||
public:
|
||||
NGRAPH_RTTI_DECLARATION;
|
||||
bool run_on_function(std::shared_ptr<ngraph::Function> f) override;
|
||||
};
|
@ -76,6 +76,7 @@
|
||||
#include <ngraph/pass/manager.hpp>
|
||||
#include <ngraph/pass/constant_folding.hpp>
|
||||
#include <transformations/common_optimizations/weights_dequantize_to_fake_quantize.hpp>
|
||||
#include <transformations/common_optimizations/simplify_shape_of_sub_graph.hpp>
|
||||
|
||||
NGRAPH_RTTI_DEFINITION(ngraph::pass::CommonOptimizations, "CommonOptimizations", 0);
|
||||
|
||||
@ -85,6 +86,7 @@ bool ngraph::pass::CommonOptimizations::run_on_function(std::shared_ptr<ngraph::
|
||||
|
||||
// This pass must be called first in pipeline
|
||||
manager.register_pass<ngraph::pass::InitNodeInfo>();
|
||||
manager.register_pass<ngraph::pass::SimplifyShapeOfSubGraph>();
|
||||
manager.register_pass<ngraph::pass::ConstantFolding>();
|
||||
manager.register_pass<ngraph::pass::RemoveFilteringBoxesBySize>(); // Resolves dynamism (replaces NonZero), CF needed
|
||||
|
||||
|
@ -7,6 +7,7 @@
|
||||
#include <ngraph/opsets/opset6.hpp>
|
||||
#include <ngraph/pattern/op/wrap_type.hpp>
|
||||
#include <ngraph/rt_info.hpp>
|
||||
#include <transformations/utils/utils.hpp>
|
||||
#include "itt.hpp"
|
||||
|
||||
NGRAPH_RTTI_DEFINITION(ngraph::pass::EliminateUnsqueezeGather, "EliminateUnsqueezeGather", 0);
|
||||
@ -58,3 +59,36 @@ ngraph::pass::EliminateUnsqueezeGather::EliminateUnsqueezeGather() {
|
||||
auto m = std::make_shared<ngraph::pattern::Matcher>(gather, "EliminateUnsqueezeGather");
|
||||
register_matcher(m, callback);
|
||||
}
|
||||
|
||||
NGRAPH_RTTI_DEFINITION(ngraph::pass::EliminateGatherUnsqueeze, "EliminateGatherUnsqueeze", 0);
|
||||
|
||||
ngraph::pass::EliminateGatherUnsqueeze::EliminateGatherUnsqueeze() {
|
||||
MATCHER_SCOPE(EliminateGatherUnsqueeze);
|
||||
|
||||
const auto gather_indices_label = ngraph::pattern::wrap_type<ngraph::op::Constant>(pattern::rank_equals(0));
|
||||
const auto gather_axis_label = ngraph::pattern::wrap_type<ngraph::op::Constant>();
|
||||
const auto gather_label = ngraph::pattern::wrap_type<ngraph::op::util::GatherBase>(
|
||||
{ngraph::pattern::any_input(), gather_indices_label, gather_axis_label}, pattern::rank_equals(0));
|
||||
|
||||
const auto unsqueeze_label = ngraph::pattern::wrap_type<ngraph::opset6::Unsqueeze>(
|
||||
{gather_label, ngraph::pattern::any_input()}, pattern::rank_equals(1));
|
||||
|
||||
ngraph::matcher_pass_callback callback = [=](ngraph::pattern::Matcher& m) {
|
||||
auto pattern_nodes = m.get_pattern_map();
|
||||
|
||||
auto& gather_indices = pattern_nodes.at(gather_indices_label);
|
||||
auto& gather = pattern_nodes.at(gather_label);
|
||||
auto& unsqueeze = pattern_nodes.at(unsqueeze_label);
|
||||
|
||||
auto new_indices = ngraph::op::util::make_try_fold<ngraph::opset6::Reshape>(gather_indices, opset6::Constant::create(element::i32, {1}, {1}), false);
|
||||
auto new_gather = gather->clone_with_new_inputs({gather->input_value(0), new_indices, gather->input_value(2)});
|
||||
|
||||
new_gather->set_friendly_name(gather->get_friendly_name());
|
||||
ngraph::copy_runtime_info({unsqueeze, gather}, {new_gather, new_indices});
|
||||
ngraph::replace_node(unsqueeze, new_gather);
|
||||
return true;
|
||||
};
|
||||
|
||||
auto m = std::make_shared<ngraph::pattern::Matcher>(unsqueeze_label, "EliminateGatherUnsqueeze");
|
||||
register_matcher(m, callback);
|
||||
}
|
||||
|
@ -0,0 +1,101 @@
|
||||
// Copyright (C) 2018-2021 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
|
||||
#include "itt.hpp"
|
||||
#include <ngraph/opsets/opset1.hpp>
|
||||
#include <ngraph/opsets/opset3.hpp>
|
||||
#include <ngraph/opsets/opset7.hpp>
|
||||
#include <ngraph/rt_info.hpp>
|
||||
#include <ngraph/pattern/op/wrap_type.hpp>
|
||||
#include <transformations/common_optimizations/simplify_shape_of_sub_graph.hpp>
|
||||
#include <transformations/common_optimizations/eliminate_unsqueeze_gather.hpp>
|
||||
#include <transformations/utils/utils.hpp>
|
||||
|
||||
NGRAPH_RTTI_DEFINITION(ngraph::pass::SharedShapeOf, "SharedShapeOf", 0);
|
||||
|
||||
bool ngraph::pass::SharedShapeOf::run_on_function(std::shared_ptr<ngraph::Function> f) {
|
||||
RUN_ON_FUNCTION_SCOPE(SharedShapeOf);
|
||||
bool graph_rewritten = false;
|
||||
|
||||
std::map<ngraph::Output<Node>, std::vector<std::shared_ptr<ngraph::Node>>> source_to_shape_of;
|
||||
for (const auto & node : f->get_ordered_ops()) {
|
||||
// Recursively apply transformation for sub-graph based operations
|
||||
if (auto sub_graph_node = std::dynamic_pointer_cast<op::util::SubGraphOp>(node))
|
||||
if (auto sub_graph = sub_graph_node->get_function())
|
||||
graph_rewritten |= run_on_function(sub_graph);
|
||||
|
||||
if (is_type<ngraph::opset1::ShapeOf>(node) || is_type<ngraph::opset3::ShapeOf>(node))
|
||||
source_to_shape_of[node->input_value(0)].push_back(node);
|
||||
}
|
||||
|
||||
for (const auto& pair : source_to_shape_of) {
|
||||
if (pair.second.size() < 2)
|
||||
continue;
|
||||
const auto& root_ss = pair.second[0];
|
||||
for (const auto& child_ss : pair.second)
|
||||
if (root_ss->get_instance_id() != child_ss->get_instance_id() && root_ss->get_output_element_type(0) == root_ss->get_output_element_type(0))
|
||||
graph_rewritten |= replace_output_update_name(child_ss->output(0), root_ss->output(0));
|
||||
}
|
||||
return graph_rewritten;
|
||||
}
|
||||
|
||||
NGRAPH_RTTI_DEFINITION(ngraph::pass::GroupedGatherElimination, "GroupedGatherElimination", 0);
|
||||
|
||||
ngraph::pass::GroupedGatherElimination::GroupedGatherElimination() {
|
||||
MATCHER_SCOPE(GroupedGatherElimination);
|
||||
auto concat_label = ngraph::pattern::wrap_type<ngraph::opset1::Concat>(pattern::rank_equals(1));
|
||||
|
||||
ngraph::matcher_pass_callback callback = [](pattern::Matcher& m) {
|
||||
auto concat = m.get_match_root();
|
||||
OutputVector inputs = concat->input_values();
|
||||
NodeVector new_ops;
|
||||
size_t i = 0, original_inputs_size = inputs.size();
|
||||
while (inputs.size() > i + 1) {
|
||||
auto curr = inputs[i].get_node_shared_ptr(), next = inputs[i + 1].get_node_shared_ptr();
|
||||
if (curr->get_type_info() != next->get_type_info() ||
|
||||
(!is_type<opset1::Gather>(curr) && !is_type<opset7::Gather>(curr)) ||
|
||||
(curr->input_value(0) != next->input_value(0))) {
|
||||
++i;
|
||||
continue;
|
||||
} // curr and next are the same type of gather which takes data from the same source
|
||||
auto joint_indices = ngraph::op::util::make_try_fold<opset1::Concat>(OutputVector{curr->input_value(1), next->input_value(1)}, 0);
|
||||
auto new_gather = curr->clone_with_new_inputs(
|
||||
{curr->input_value(0), joint_indices, ngraph::opset1::Constant::create(element::i64, {}, {0})});
|
||||
new_ops.push_back(joint_indices);
|
||||
new_ops.push_back(new_gather);
|
||||
inputs.erase(inputs.begin() + i);
|
||||
inputs[i] = new_gather->output(0);
|
||||
}
|
||||
if (original_inputs_size > inputs.size()) {
|
||||
auto new_concat = std::make_shared<opset1::Concat>(inputs, 0);
|
||||
new_ops.push_back(new_concat);
|
||||
new_concat->set_friendly_name(concat->get_friendly_name());
|
||||
ngraph::copy_runtime_info(concat, new_ops);
|
||||
ngraph::replace_node(concat, new_concat);
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
};
|
||||
|
||||
auto m = std::make_shared<ngraph::pattern::Matcher>(concat_label, matcher_name);
|
||||
this->register_matcher(m, callback);
|
||||
}
|
||||
|
||||
|
||||
NGRAPH_RTTI_DEFINITION(ngraph::pass::SimplifyShapeOfSubGraph, "SimplifyShapeOfSubGraph", 0);
|
||||
|
||||
bool ngraph::pass::SimplifyShapeOfSubGraph::run_on_function(std::shared_ptr<ngraph::Function> f) {
|
||||
RUN_ON_FUNCTION_SCOPE(GroupedGatherElimination);
|
||||
ngraph::pass::Manager manager;
|
||||
manager.set_per_pass_validation(false);
|
||||
manager.register_pass<ngraph::pass::EliminateGatherUnsqueeze>();
|
||||
manager.register_pass<ngraph::pass::SharedShapeOf>();
|
||||
manager.register_pass<ngraph::pass::GroupedGatherElimination>();
|
||||
manager.register_pass<ngraph::pass::Validate>();
|
||||
manager.run_passes(f);
|
||||
return false;
|
||||
}
|
@ -0,0 +1,81 @@
|
||||
// Copyright (C) 2018-2021 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#include <gtest/gtest.h>
|
||||
|
||||
#include <string>
|
||||
#include <memory>
|
||||
#include <queue>
|
||||
|
||||
#include <ngraph/function.hpp>
|
||||
#include <ngraph/opsets/opset7.hpp>
|
||||
#include <transformations/common_optimizations/simplify_shape_of_sub_graph.hpp>
|
||||
#include <transformations/init_node_info.hpp>
|
||||
#include <ngraph/pass/manager.hpp>
|
||||
|
||||
#include "common_test_utils/ngraph_test_utils.hpp"
|
||||
|
||||
|
||||
using namespace testing;
|
||||
using namespace ngraph;
|
||||
|
||||
auto gather = [](const std::shared_ptr<Node> input, std::vector<int64_t> indices, bool scalar = false) -> Output<Node> {
|
||||
std::shared_ptr<Node> indices_node;
|
||||
if (scalar)
|
||||
indices_node = opset7::Constant::create(element::i64, {}, indices);
|
||||
else
|
||||
indices_node = opset7::Constant::create(element::i64, {indices.size()}, indices);
|
||||
return std::make_shared<ngraph::opset7::Gather>(
|
||||
input, indices_node, opset7::Constant::create(element::i64, {}, {0}));
|
||||
};
|
||||
|
||||
TEST(TransformationTests, ShapeSubGraphTest) {
|
||||
std::shared_ptr<Function> f(nullptr), f_ref(nullptr);
|
||||
|
||||
Shape data_shape{1, 2, 3, 4};
|
||||
{
|
||||
auto data = std::make_shared<opset7::Parameter>(element::f32, data_shape);
|
||||
|
||||
auto shape_op_1 = std::make_shared<opset7::ShapeOf>(data);
|
||||
auto gather_1 = gather(shape_op_1, {1}, true);
|
||||
auto unsqueeze_1 = std::make_shared<opset7::Unsqueeze>(
|
||||
gather_1, opset7::Constant::create(element::i64, {1}, {0}));
|
||||
|
||||
auto shape_op_2 = std::make_shared<opset7::ShapeOf>(data);
|
||||
auto gather_2 = gather(shape_op_2, {2}, true);
|
||||
auto unsqueeze_2 = std::make_shared<opset7::Unsqueeze>(
|
||||
gather_2, opset7::Constant::create(element::i64, {1}, {0}));
|
||||
|
||||
auto const_1 = opset7::Constant::create(element::i64, Shape{1}, {2});
|
||||
auto const_2 = opset7::Constant::create(element::i64, Shape{1}, {2});
|
||||
|
||||
auto concat = std::make_shared<opset7::Concat>(OutputVector{unsqueeze_1, unsqueeze_2, const_1, const_2}, 0);
|
||||
|
||||
auto reshape = std::make_shared<opset7::Reshape>(data, concat, false);
|
||||
f = std::make_shared<Function>(NodeVector{reshape}, ParameterVector{data});
|
||||
pass::Manager m;
|
||||
m.register_pass<pass::InitNodeInfo>();
|
||||
m.register_pass<pass::SimplifyShapeOfSubGraph>();
|
||||
m.run_passes(f);
|
||||
ASSERT_NO_THROW(check_rt_info(f));
|
||||
ASSERT_EQ(reshape->get_output_partial_shape(0), PartialShape({2, 3, 2, 2}));
|
||||
}
|
||||
{
|
||||
auto data = std::make_shared<opset7::Parameter>(element::f32, data_shape);
|
||||
|
||||
auto shape_op_1 = std::make_shared<opset7::ShapeOf>(data);
|
||||
auto gather_1 = gather(shape_op_1, {1, 2});
|
||||
|
||||
auto const_1 = opset7::Constant::create(element::i64, Shape{1}, {2});
|
||||
auto const_2 = opset7::Constant::create(element::i64, Shape{1}, {2});
|
||||
|
||||
auto concat = std::make_shared<opset7::Concat>(OutputVector{gather_1, const_1, const_2}, 0);
|
||||
|
||||
auto reshape = std::make_shared<opset7::Reshape>(data, concat, false);
|
||||
f_ref = std::make_shared<Function>(NodeVector{reshape}, ParameterVector{data});
|
||||
}
|
||||
|
||||
auto res = compare_functions(f, f_ref, true);
|
||||
ASSERT_TRUE(res.first) << res.second;
|
||||
}
|
@ -74,7 +74,6 @@ protected:
|
||||
const auto tensorWithTargetShapeParam = std::make_shared<ngraph::opset3::Parameter>(tensorType, targetShape);
|
||||
|
||||
const auto shapeOfNode = std::make_shared<ngraph::opset3::ShapeOf>(tensorWithTargetShapeParam, shapeType);
|
||||
shapeOfNode->set_is_foldable(false);
|
||||
|
||||
ngraph::ParameterVector params{tensorParam, tensorWithTargetShapeParam};
|
||||
|
||||
@ -197,7 +196,6 @@ protected:
|
||||
const auto tensorWithTargetShapeParam = std::make_shared<ngraph::opset5::Parameter>(shapeType, targetShape);
|
||||
|
||||
const auto shapeOfNode = std::make_shared<ngraph::opset5::ShapeOf>(tensorWithTargetShapeParam, shapeType);
|
||||
shapeOfNode->set_is_foldable(false);
|
||||
|
||||
ngraph::ParameterVector params{tensorParam, tensorWithTargetShapeParam};
|
||||
|
||||
|
@ -33,14 +33,6 @@ namespace ngraph
|
||||
// Overload collision with method on Node
|
||||
using Node::set_output_type;
|
||||
|
||||
// FOR CONSTANT FOLDING INTERNAL USAGE ONLY
|
||||
// Constant folding for cases with static rank but dynamic shape create a subgraph
|
||||
// which contains a Shape of.
|
||||
// In this case we need to prevent constant folding from endless creation of these
|
||||
// subgraphs.
|
||||
// These metods should be removed if better solution will be designed.
|
||||
void set_is_foldable(bool is_foldable) { m_is_foldable = is_foldable; }
|
||||
bool get_is_foldable() const { return m_is_foldable; }
|
||||
bool evaluate(const HostTensorVector& output_values,
|
||||
const HostTensorVector& input_values) const override;
|
||||
bool has_evaluate() const override;
|
||||
@ -50,7 +42,6 @@ namespace ngraph
|
||||
const OutputVector& input_values) override;
|
||||
|
||||
private:
|
||||
bool m_is_foldable = true;
|
||||
element::Type m_output_type;
|
||||
};
|
||||
} // namespace v3
|
||||
@ -72,14 +63,6 @@ namespace ngraph
|
||||
|
||||
void validate_and_infer_types() override;
|
||||
|
||||
// FOR CONSTANT FOLDING INTERNAL USAGE ONLY
|
||||
// Constant folding for cases with static rank but dynamic shape create a subgraph
|
||||
// which contains a Shape of.
|
||||
// In this case we need to prevent constant folding from endless creation of these
|
||||
// subgraphs.
|
||||
// These metods should be removed if better solution will be designed.
|
||||
void set_is_foldable(bool is_foldable) { m_is_foldable = is_foldable; }
|
||||
bool get_is_foldable() const { return m_is_foldable; }
|
||||
bool evaluate(const HostTensorVector& output_values,
|
||||
const HostTensorVector& input_values) const override;
|
||||
bool has_evaluate() const override;
|
||||
@ -87,9 +70,6 @@ namespace ngraph
|
||||
bool evaluate_upper(const HostTensorVector& output_values) const override;
|
||||
bool constant_fold(OutputVector& output_values,
|
||||
const OutputVector& input_values) override;
|
||||
|
||||
private:
|
||||
bool m_is_foldable = true;
|
||||
};
|
||||
} // namespace v0
|
||||
using v0::ShapeOf;
|
||||
|
@ -51,7 +51,6 @@ shared_ptr<Node> op::v3::ShapeOf::clone_with_new_inputs(const OutputVector& new_
|
||||
NGRAPH_OP_SCOPE(v3_ShapeOf_clone_with_new_inputs);
|
||||
check_new_args_count(this, new_args);
|
||||
auto new_shape_of = make_shared<op::v3::ShapeOf>(new_args.at(0), m_output_type);
|
||||
new_shape_of->set_is_foldable(m_is_foldable);
|
||||
return new_shape_of;
|
||||
}
|
||||
|
||||
@ -82,8 +81,7 @@ namespace shape_of
|
||||
|
||||
bool constant_fold_shape_of(Node* shape_of_node,
|
||||
Output<Node>& replacement,
|
||||
const Output<Node>& shape_of_input,
|
||||
bool is_foldable)
|
||||
const Output<Node>& shape_of_input)
|
||||
{
|
||||
auto partial_shape = shape_of_input.get_partial_shape();
|
||||
auto output_type = shape_of_node->get_output_element_type(0);
|
||||
@ -100,46 +98,6 @@ namespace shape_of
|
||||
}
|
||||
return false;
|
||||
}
|
||||
else if (partial_shape.rank().is_static() && is_foldable)
|
||||
{
|
||||
auto shape_of = shape_of_node->copy_with_new_inputs({shape_of_input});
|
||||
// Ugly
|
||||
if (auto ps = as_type_ptr<op::v0::ShapeOf>(shape_of))
|
||||
{
|
||||
ps->set_is_foldable(false);
|
||||
}
|
||||
else if (auto ps = as_type_ptr<op::v3::ShapeOf>(shape_of))
|
||||
{
|
||||
ps->set_is_foldable(false);
|
||||
}
|
||||
auto dimensions = OutputVector{};
|
||||
auto output_dimensions = vector<Dimension>(partial_shape);
|
||||
for (size_t i = 0; i < output_dimensions.size(); ++i)
|
||||
{
|
||||
if (output_dimensions[i].is_static())
|
||||
{
|
||||
auto temp = std::make_shared<op::v0::Constant>(
|
||||
output_type,
|
||||
Shape{1},
|
||||
std::vector<int64_t>{output_dimensions[i].get_length()});
|
||||
temp->set_friendly_name("ConstDim/" + temp->get_name());
|
||||
dimensions.emplace_back(temp);
|
||||
}
|
||||
else
|
||||
{
|
||||
auto index = std::make_shared<op::v0::Constant>(
|
||||
output_type, Shape{1}, std::vector<int64_t>{static_cast<int64_t>(i)});
|
||||
auto axis = std::make_shared<op::v0::Constant>(
|
||||
element::i64, Shape{}, std::vector<int64_t>{0});
|
||||
auto temp = make_shared<op::v1::Gather>(shape_of, index, axis);
|
||||
temp->set_friendly_name("DynDim/" + temp->get_name());
|
||||
dimensions.emplace_back(temp);
|
||||
}
|
||||
}
|
||||
|
||||
replacement = std::make_shared<op::Concat>(dimensions, 0);
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
@ -250,7 +208,7 @@ bool op::v3::ShapeOf::constant_fold(OutputVector& output_values, const OutputVec
|
||||
OV_ITT_SCOPED_TASK(itt::domains::nGraph, "op::v3::ShapeOf::constant_fold");
|
||||
if (get_rt_info().count("DISABLED_CONSTANT_FOLDING"))
|
||||
return false;
|
||||
return shape_of::constant_fold_shape_of(this, output_values[0], input_values[0], m_is_foldable);
|
||||
return shape_of::constant_fold_shape_of(this, output_values[0], input_values[0]);
|
||||
}
|
||||
|
||||
// op::v0::ShapeOf
|
||||
@ -286,7 +244,6 @@ shared_ptr<Node> op::v0::ShapeOf::clone_with_new_inputs(const OutputVector& new_
|
||||
description(),
|
||||
" operation with name ",
|
||||
get_friendly_name());
|
||||
new_shape_of->set_is_foldable(m_is_foldable);
|
||||
return new_shape_of;
|
||||
}
|
||||
|
||||
@ -318,7 +275,7 @@ bool op::v0::ShapeOf::constant_fold(OutputVector& output_values, const OutputVec
|
||||
OV_ITT_SCOPED_TASK(itt::domains::nGraph, "op::v0::ShapeOf::constant_fold");
|
||||
if (get_rt_info().count("DISABLED_CONSTANT_FOLDING"))
|
||||
return false;
|
||||
return shape_of::constant_fold_shape_of(this, output_values[0], input_values[0], m_is_foldable);
|
||||
return shape_of::constant_fold_shape_of(this, output_values[0], input_values[0]);
|
||||
}
|
||||
|
||||
bool op::v0::ShapeOf::evaluate_lower(const HostTensorVector& output_values) const
|
||||
|
@ -567,16 +567,11 @@ TEST(constant_folding, shape_of_dynamic_v0)
|
||||
pass_manager.register_pass<pass::ConstantFolding>();
|
||||
pass_manager.run_passes(f);
|
||||
|
||||
ASSERT_EQ(count_ops_of_type<op::v0::ShapeOf>(f), 1);
|
||||
ASSERT_EQ(count_ops_of_type<op::v1::Gather>(f), 1);
|
||||
ASSERT_EQ(count_ops_of_type<op::Concat>(f), 1);
|
||||
ASSERT_EQ(count_ops_of_type<op::Constant>(f), 8);
|
||||
ASSERT_EQ(f->get_ops().size(), 3);
|
||||
|
||||
auto result_as_concat =
|
||||
as_type_ptr<op::Concat>(f->get_results().at(0)->input_value(0).get_node_shared_ptr());
|
||||
ASSERT_TRUE(result_as_concat);
|
||||
ASSERT_EQ(result_as_concat->get_friendly_name(), "test");
|
||||
ASSERT_EQ(result_as_concat->get_output_shape(0), Shape{7});
|
||||
auto result_shape_of = f->get_results().at(0)->get_input_node_shared_ptr(0);
|
||||
ASSERT_EQ(result_shape_of, shape_of);
|
||||
ASSERT_EQ(result_shape_of->get_friendly_name(), "test");
|
||||
}
|
||||
|
||||
TEST(constant_folding, shape_of_dynamic_v3)
|
||||
@ -592,17 +587,11 @@ TEST(constant_folding, shape_of_dynamic_v3)
|
||||
pass_manager.register_pass<pass::ConstantFolding>();
|
||||
pass_manager.run_passes(f);
|
||||
|
||||
ASSERT_EQ(count_ops_of_type<op::v3::ShapeOf>(f), 1);
|
||||
ASSERT_EQ(count_ops_of_type<op::v1::Gather>(f), 1);
|
||||
ASSERT_EQ(count_ops_of_type<op::Concat>(f), 1);
|
||||
ASSERT_EQ(count_ops_of_type<op::Constant>(f), 8);
|
||||
ASSERT_EQ(f->get_ops().size(), 3);
|
||||
|
||||
auto result_as_concat =
|
||||
as_type_ptr<op::Concat>(f->get_results().at(0)->input_value(0).get_node_shared_ptr());
|
||||
ASSERT_TRUE(result_as_concat);
|
||||
ASSERT_EQ(result_as_concat->get_friendly_name(), "test");
|
||||
ASSERT_EQ(result_as_concat->get_output_shape(0), Shape{7});
|
||||
ASSERT_EQ(result_as_concat->get_output_element_type(0), element::i64);
|
||||
auto result_shape_of = f->get_results().at(0)->get_input_node_shared_ptr(0);
|
||||
ASSERT_EQ(result_shape_of, shape_of);
|
||||
ASSERT_EQ(result_shape_of->get_friendly_name(), "test");
|
||||
}
|
||||
|
||||
TEST(constant_folding, shape_of_dynamic_i32_v3)
|
||||
@ -618,17 +607,11 @@ TEST(constant_folding, shape_of_dynamic_i32_v3)
|
||||
pass_manager.register_pass<pass::ConstantFolding>();
|
||||
pass_manager.run_passes(f);
|
||||
|
||||
ASSERT_EQ(count_ops_of_type<op::v3::ShapeOf>(f), 1);
|
||||
ASSERT_EQ(count_ops_of_type<op::v1::Gather>(f), 1);
|
||||
ASSERT_EQ(count_ops_of_type<op::Concat>(f), 1);
|
||||
ASSERT_EQ(count_ops_of_type<op::Constant>(f), 8);
|
||||
ASSERT_EQ(f->get_ops().size(), 3);
|
||||
|
||||
auto result_as_concat =
|
||||
as_type_ptr<op::Concat>(f->get_results().at(0)->input_value(0).get_node_shared_ptr());
|
||||
ASSERT_TRUE(result_as_concat);
|
||||
ASSERT_EQ(result_as_concat->get_friendly_name(), "test");
|
||||
ASSERT_EQ(result_as_concat->get_output_shape(0), Shape{7});
|
||||
ASSERT_EQ(result_as_concat->get_output_element_type(0), element::i32);
|
||||
auto result_shape_of = f->get_results().at(0)->get_input_node_shared_ptr(0);
|
||||
ASSERT_EQ(result_shape_of, shape_of);
|
||||
ASSERT_EQ(result_shape_of->get_friendly_name(), "test");
|
||||
}
|
||||
|
||||
// We need to be sure that constant folding won't be calculated endlessly.
|
||||
@ -646,16 +629,11 @@ TEST(constant_folding, shape_of_dynamic_double_folding_v0)
|
||||
pass_manager.run_passes(f);
|
||||
pass_manager.run_passes(f);
|
||||
|
||||
ASSERT_EQ(count_ops_of_type<op::v0::ShapeOf>(f), 1);
|
||||
ASSERT_EQ(count_ops_of_type<op::v1::Gather>(f), 1);
|
||||
ASSERT_EQ(count_ops_of_type<op::Concat>(f), 1);
|
||||
ASSERT_EQ(count_ops_of_type<op::Constant>(f), 8);
|
||||
ASSERT_EQ(f->get_ops().size(), 3);
|
||||
|
||||
auto result_as_concat =
|
||||
as_type_ptr<op::Concat>(f->get_results().at(0)->input_value(0).get_node_shared_ptr());
|
||||
ASSERT_TRUE(result_as_concat);
|
||||
ASSERT_EQ(result_as_concat->get_friendly_name(), "test");
|
||||
ASSERT_EQ(result_as_concat->get_output_shape(0), Shape{7});
|
||||
auto result_shape_of = f->get_results().at(0)->get_input_node_shared_ptr(0);
|
||||
ASSERT_EQ(result_shape_of, shape_of);
|
||||
ASSERT_EQ(result_shape_of->get_friendly_name(), "test");
|
||||
}
|
||||
|
||||
TEST(constant_folding, shape_of_dynamic_double_folding_v3)
|
||||
@ -672,16 +650,11 @@ TEST(constant_folding, shape_of_dynamic_double_folding_v3)
|
||||
pass_manager.run_passes(f);
|
||||
pass_manager.run_passes(f);
|
||||
|
||||
ASSERT_EQ(count_ops_of_type<op::v3::ShapeOf>(f), 1);
|
||||
ASSERT_EQ(count_ops_of_type<op::v1::Gather>(f), 1);
|
||||
ASSERT_EQ(count_ops_of_type<op::Concat>(f), 1);
|
||||
ASSERT_EQ(count_ops_of_type<op::Constant>(f), 8);
|
||||
ASSERT_EQ(f->get_ops().size(), 3);
|
||||
|
||||
auto result_as_concat =
|
||||
as_type_ptr<op::Concat>(f->get_results().at(0)->input_value(0).get_node_shared_ptr());
|
||||
ASSERT_TRUE(result_as_concat);
|
||||
ASSERT_EQ(result_as_concat->get_friendly_name(), "test");
|
||||
ASSERT_EQ(result_as_concat->get_output_shape(0), Shape{7});
|
||||
auto result_shape_of = f->get_results().at(0)->get_input_node_shared_ptr(0);
|
||||
ASSERT_EQ(result_shape_of, shape_of);
|
||||
ASSERT_EQ(result_shape_of->get_friendly_name(), "test");
|
||||
}
|
||||
|
||||
// Constant folding will not succeed on ShapeOf if the argument rank is dynamic.
|
||||
|
Loading…
Reference in New Issue
Block a user