diff --git a/inference-engine/src/transformations/include/transformations/common_optimizations/transpose_sinking.hpp b/inference-engine/src/transformations/include/transformations/common_optimizations/transpose_sinking.hpp new file mode 100644 index 00000000000..497d9f24230 --- /dev/null +++ b/inference-engine/src/transformations/include/transformations/common_optimizations/transpose_sinking.hpp @@ -0,0 +1,69 @@ +// Copyright (C) 2018-2021 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#include +#include + +#include + +#include +#include +#include "ngraph/pattern/matcher.hpp" + +namespace ngraph { +namespace pass { + +class TRANSFORMATIONS_API TransposeSinking; +class TRANSFORMATIONS_API TransposeOptimization; +class TRANSFORMATIONS_API TransposeReduction; +class TRANSFORMATIONS_API TransposeFQReduction; + +} // namespace pass +} // namespace ngraph + +/** + * @ingroup ie_transformation_common_api + * @brief TransposeOptimization transformation replaces suitable Transposes with Reshape operation or optimises them out + */ +class ngraph::pass::TransposeOptimization : public ngraph::pass::MatcherPass { +public: + NGRAPH_RTTI_DECLARATION; + TransposeOptimization(); +}; + +/** + * @ingroup ie_transformation_common_api + * @brief TransposeReduction transformation sinks Transpose through Reduce operations + */ +class ngraph::pass::TransposeReduction : public ngraph::pass::MatcherPass { +public: + NGRAPH_RTTI_DECLARATION; + TransposeReduction(); +}; + +/** + * @ingroup ie_transformation_common_api + * @brief TransposeFQReduction transformation sinks Transpose through FakeQuantize in case it is followed by reduction or squeeze + */ +class ngraph::pass::TransposeFQReduction : public ngraph::pass::MatcherPass { +public: + NGRAPH_RTTI_DECLARATION; + TransposeFQReduction(); +}; + +/** + * @ingroup ie_transformation_common_api + * @brief TransposeSinking transformation sinks Transposes through known operations + */ +class ngraph::pass::TransposeSinking: public ngraph::pass::GraphRewrite { +public: + NGRAPH_RTTI_DECLARATION; + TransposeSinking() { + add_matcher(); + add_matcher(); + add_matcher(); + } +}; \ No newline at end of file diff --git a/inference-engine/src/transformations/include/transformations/utils/utils.hpp b/inference-engine/src/transformations/include/transformations/utils/utils.hpp index 9a4016dfbb7..f932aca04a9 100644 --- a/inference-engine/src/transformations/include/transformations/utils/utils.hpp +++ b/inference-engine/src/transformations/include/transformations/utils/utils.hpp @@ -106,6 +106,15 @@ TRANSFORMATIONS_API std::shared_ptr activation(const std::string& TRANSFORMATIONS_API bool is_seq_len_provided(const std::shared_ptr &seq_len_input, int64_t max_seq_len); +TRANSFORMATIONS_API std::shared_ptr try_fold_unary_output(const std::shared_ptr& node); + +TRANSFORMATIONS_API std::shared_ptr clone_try_fold(const std::shared_ptr& node, const OutputVector& inputs); + +template +std::shared_ptr make_try_fold(Args&&... args) { + auto unary_output_node = std::make_shared(std::forward(args)...); + return try_fold_unary_output(unary_output_node); +} template Output eltwise_fold(const Output & input0, const Output & input1) { diff --git a/inference-engine/src/transformations/src/transformations/common_optimizations/algebraic_simplification.cpp b/inference-engine/src/transformations/src/transformations/common_optimizations/algebraic_simplification.cpp index ce73883ac15..9e36bb8c8f0 100644 --- a/inference-engine/src/transformations/src/transformations/common_optimizations/algebraic_simplification.cpp +++ b/inference-engine/src/transformations/src/transformations/common_optimizations/algebraic_simplification.cpp @@ -136,93 +136,6 @@ static bool simplify_gather_shapeof(shared_ptr node) { return true; } -static bool replace_transpose_with_reshape(shared_ptr transpose) { - auto data = transpose->input_value(0); - const auto input_shape = transpose->input(0).get_partial_shape(); - if (input_shape.rank().is_dynamic()) { - return false; - } - - const size_t input_shape_rank = input_shape.rank().get_length(); - - auto order = as_type_ptr(transpose->input_value(1).get_node_shared_ptr()); - if (!order || !ngraph::shape_size(order->get_shape())) { - return false; - } - - const auto order_value = order->cast_vector(); - - // Check that transpose order without 1 dims has an ascending order - int64_t last_dim(-1); - for (size_t i = 0; i < input_shape_rank; ++i) { - if (input_shape[order_value[i]].is_dynamic() || input_shape[order_value[i]] != 1) { - if (order_value[i] < last_dim) { - return false; - } - last_dim = order_value[i]; - } - } - - // Transpose operation can be removed if original transpose order is sorted - // or dimension that changes their places equal to 1 - using DimensionToPosition = struct { - Dimension dim; - size_t pos; - }; - std::vector dims; - for (size_t i = 0; i < input_shape_rank; ++i) { - if (order_value[i] != static_cast(i)) { - dims.push_back({input_shape[order_value[i]], i}); - } - } - - // If number of dimensions != 1 to move equal to 0 we can remove this Transpose - if (count_if(dims.begin(), dims.end(), [](const DimensionToPosition& item) { - return !(item.dim.is_static() && item.dim.get_length() == 1); - }) == 0) { - return replace_output_update_name(transpose->output(0), transpose->input_value(0)); - } - - // Transpose can be replaced with Reshape in two ways: - // 1. Reshape with dims as Constant - // 2. Reshape with dims as input (ShapeOf->Gather) - // - // The first case is possible only if one or less dynamic dimensions changes their position - // For example: input_shape {?, 3, 1, ?} and order {0, 1, 3, 2} can be replaced with Reshape - // with Constant {0, 3, -1, 1} but if input_shape {?, 1, 1, ?} and order {1, 0, 3, 2} transpose - // cannot be replaced int the same way and in this case its only possible to use Gather(ShapeOf, - // order) - - Output reshape_dim; - NodeVector new_ops; - - if (count_if(dims.begin(), dims.end(), [](const DimensionToPosition& item) { - return item.dim.is_dynamic(); - }) < 2) { - vector reshape_value(input_shape_rank, 0); - for (const auto& item : dims) { - reshape_value[item.pos] = item.dim.is_dynamic() ? -1 : item.dim.get_length(); - } - reshape_dim = - opset3::Constant::create(element::i64, Shape{reshape_value.size()}, reshape_value); - } else { - auto shape_of = make_shared(data); - new_ops.push_back(shape_of); - reshape_dim = make_shared( - shape_of, order, opset3::Constant::create(element::i64, Shape{1}, {0})); - new_ops.push_back(reshape_dim.get_node_shared_ptr()); - } - - auto reshape_op = make_shared(data, reshape_dim, true); - new_ops.push_back(reshape_op); - - reshape_op->set_friendly_name(transpose->get_friendly_name()); - copy_runtime_info(transpose, new_ops); - replace_node(transpose, reshape_op); - - return true; -} - #define ECHO(NAME) #NAME #define STR(NAME) ECHO(NAME) #define SIMPLE_MATCHER_PASS_DEFINITION(NAME, OP, FUNC) \ @@ -244,11 +157,9 @@ NGRAPH_RTTI_DEFINITION(NAME, STR(NAME), 0); SIMPLE_MATCHER_PASS_DEFINITION(EliminateGather, opset3::Gather, simplify_gather); SIMPLE_MATCHER_PASS_DEFINITION(SimplifyShapeOf2Gather, opset2::ShapeOf, simplify_gather_shapeof); SIMPLE_MATCHER_PASS_DEFINITION(SimplifyShapeOf3Gather, opset3::ShapeOf, simplify_gather_shapeof); -SIMPLE_MATCHER_PASS_DEFINITION(ConvertTransposeToReshape, opset3::Transpose, replace_transpose_with_reshape); ngraph::pass::AlgebraicSimplification::AlgebraicSimplification() { add_matcher(); add_matcher(); add_matcher(); - add_matcher(); } diff --git a/inference-engine/src/transformations/src/transformations/common_optimizations/common_optimizations.cpp b/inference-engine/src/transformations/src/transformations/common_optimizations/common_optimizations.cpp index 9b08b9b93f6..b1d63ff6fe0 100644 --- a/inference-engine/src/transformations/src/transformations/common_optimizations/common_optimizations.cpp +++ b/inference-engine/src/transformations/src/transformations/common_optimizations/common_optimizations.cpp @@ -38,6 +38,7 @@ #include "transformations/common_optimizations/space_to_batch_fusion.hpp" #include "transformations/common_optimizations/batch_to_space_fusion.hpp" #include "transformations/common_optimizations/dilated_convolution_converter.hpp" +#include "transformations/common_optimizations/transpose_sinking.hpp" #include "transformations/op_conversions/bidirectional_sequences_decomposition.hpp" #include "transformations/op_conversions/convert_pad_to_group_conv.hpp" #include "transformations/op_conversions/convert_divide.hpp" @@ -85,6 +86,7 @@ bool ngraph::pass::CommonOptimizations::run_on_function(std::shared_ptr(); manager.register_pass(); // depends on CF manager.register_pass(); + manager.register_pass(); auto eliminations = manager.register_pass(); eliminations->add_matcher(); diff --git a/inference-engine/src/transformations/src/transformations/common_optimizations/transpose_sinking.cpp b/inference-engine/src/transformations/src/transformations/common_optimizations/transpose_sinking.cpp new file mode 100644 index 00000000000..a83a9a945d7 --- /dev/null +++ b/inference-engine/src/transformations/src/transformations/common_optimizations/transpose_sinking.cpp @@ -0,0 +1,273 @@ +// Copyright (C) 2018-2021 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include "itt.hpp" +#include "transformations/common_optimizations/transpose_sinking.hpp" +#include "transformations/utils/utils.hpp" + +#include +#include + +#include +#include +#include +#include + +NGRAPH_RTTI_DEFINITION(ngraph::pass::TransposeSinking, "TransposeSinking", 0); +NGRAPH_RTTI_DEFINITION(ngraph::pass::TransposeOptimization, "TransposeOptimization", 0); +NGRAPH_RTTI_DEFINITION(ngraph::pass::TransposeReduction, "TransposeReduction", 0); +NGRAPH_RTTI_DEFINITION(ngraph::pass::TransposeFQReduction, "TransposeFQReduction", 0); + +using namespace ngraph; + +std::shared_ptr get_reduced_order_constant(const std::shared_ptr& axes_const, + const std::shared_ptr& order_const) { + auto order = order_const->cast_vector(); + + auto axes = axes_const->cast_vector(); + std::sort(axes.rbegin(), axes.rend()); + for (const auto& i : axes) + order.erase(order.begin() + i); + + const auto& updated_order_size = static_cast(order.size()); + + auto order_sorted = order; + sort(order_sorted.begin(), order_sorted.end()); + for (int64_t i = 0; i < updated_order_size; ++i) { + auto lowest_greater_eq_i = std::lower_bound(order_sorted.begin(), order_sorted.end(), i); + std::replace(order.begin(), order.end(), *lowest_greater_eq_i, i); + std::replace(order_sorted.begin(), order_sorted.end(), *lowest_greater_eq_i, i); + } + return std::make_shared( + ngraph::element::i64, ngraph::Shape{order.size()}, order); +} + +std::shared_ptr get_reversed_order_constant(const std::shared_ptr& order_const) { + const auto& order = order_const->cast_vector(); + const auto& rank = order.size(); + const auto& default_order = ngraph::get_default_order(rank); + std::vector reverse_order(rank); + for (size_t i = 0; i < rank; ++i) + reverse_order[order[i]] = default_order[i]; + + return std::make_shared( + ngraph::element::i64, ngraph::Shape{reverse_order.size()}, reverse_order); +} + + +bool replace_transpose_with_reshape(const std::shared_ptr& transpose) { + auto data = transpose->input_value(0); + const auto input_shape = transpose->input(0).get_partial_shape(); + + const size_t input_shape_rank = input_shape.rank().get_length(); + + auto order = as_type_ptr(transpose->input_value(1).get_node_shared_ptr()); + if (!order || !ngraph::shape_size(order->get_shape())) { + return false; + } + + const auto order_value = order->cast_vector(); + + // Check that transpose order without 1 dims has an ascending order + int64_t last_dim(-1); + for (size_t i = 0; i < input_shape_rank; ++i) { + if (input_shape[order_value[i]].is_dynamic() || input_shape[order_value[i]] != 1) { + if (order_value[i] < last_dim) { + return false; + } + last_dim = order_value[i]; + } + } + + // Transpose operation can be removed if original transpose order is sorted + // or dimension that changes their places equal to 1 + using DimensionToPosition = struct { + Dimension dim; + size_t pos; + }; + std::vector dims; + for (size_t i = 0; i < input_shape_rank; ++i) { + if (order_value[i] != static_cast(i)) { + dims.push_back({input_shape[order_value[i]], i}); + } + } + + // If number of dimensions != 1 to move equal to 0 we can remove this Transpose + if (count_if(dims.begin(), dims.end(), [](const DimensionToPosition& item) { + return !(item.dim.is_static() && item.dim.get_length() == 1); + }) == 0) { + return replace_output_update_name(transpose->output(0), transpose->input_value(0)); + } + + // Transpose can be replaced with Reshape in two ways: + // 1. Reshape with dims as Constant + // 2. Reshape with dims as input (ShapeOf->Gather) + // + // The first case is possible only if one or less dynamic dimensions changes their position + // For example: input_shape {?, 3, 1, ?} and order {0, 1, 3, 2} can be replaced with Reshape + // with Constant {0, 3, -1, 1} but if input_shape {?, 1, 1, ?} and order {1, 0, 3, 2} transpose + // cannot be replaced int the same way and in this case its only possible to use Gather(ShapeOf, + // order) + + Output reshape_dim; + NodeVector new_ops; + + if (count_if(dims.begin(), dims.end(), [](const DimensionToPosition& item) { + return item.dim.is_dynamic(); + }) < 2) { + std::vector reshape_value(input_shape_rank, 0); + for (const auto& item : dims) { + reshape_value[item.pos] = item.dim.is_dynamic() ? -1 : item.dim.get_length(); + } + reshape_dim = + opset3::Constant::create(element::i64, Shape{reshape_value.size()}, reshape_value); + } else { + auto shape_of = std::make_shared(data); + new_ops.push_back(shape_of); + reshape_dim = std::make_shared( + shape_of, order, opset3::Constant::create(element::i64, Shape{1}, {0})); + new_ops.push_back(reshape_dim.get_node_shared_ptr()); + } + + auto reshape_op = std::make_shared(data, reshape_dim, true); + new_ops.push_back(reshape_op); + + reshape_op->set_friendly_name(transpose->get_friendly_name()); + copy_runtime_info(transpose, new_ops); + replace_node(transpose, reshape_op); + return true; +} + +ngraph::pass::TransposeOptimization::TransposeOptimization() { + MATCHER_SCOPE(TransposeOptimization); + + auto transpose_label = pattern::wrap_type( + {pattern::any_input(pattern::has_static_rank()), pattern::wrap_type()}); + ngraph::matcher_pass_callback matcher_pass_callback = [=](ngraph::pattern::Matcher &m) { + return replace_transpose_with_reshape(m.get_match_root()); + }; + + auto m = std::make_shared(transpose_label, matcher_name); + register_matcher(m, matcher_pass_callback); +} + +ngraph::pass::TransposeReduction::TransposeReduction() { + MATCHER_SCOPE(TransposeReduction); + + auto transpose_label = pattern::wrap_type({pattern::any_input(), pattern::wrap_type()}); + auto reduce_or_squeeze_label = pattern::wrap_type( + {transpose_label, pattern::wrap_type()}); + + ngraph::matcher_pass_callback matcher_pass_callback = [=](ngraph::pattern::Matcher &m) { + const auto &pattern_to_output = m.get_pattern_value_map(); + + auto transpose = pattern_to_output.at(transpose_label).get_node_shared_ptr(); + auto reduction = pattern_to_output.at(reduce_or_squeeze_label).get_node_shared_ptr(); + auto arithmetic_reduce = std::dynamic_pointer_cast(reduction); + auto logical_reduce = std::dynamic_pointer_cast(reduction); + auto squeeze = std::dynamic_pointer_cast(reduction); + if (!transpose || !(arithmetic_reduce || logical_reduce || squeeze)) + return false; + + bool keep_dims = false; // squeeze always reduces number of output dimensions + if (logical_reduce) + keep_dims = logical_reduce->get_keep_dims(); + else if (arithmetic_reduce) + keep_dims = arithmetic_reduce->get_keep_dims(); + + auto transpose_order = std::dynamic_pointer_cast(transpose->get_input_node_shared_ptr(1)); + auto reduction_axes = std::dynamic_pointer_cast(reduction->get_input_node_shared_ptr(1)); + if (!transpose_order || !reduction_axes) + return false; + + const auto& non_negative_axes = ngraph::normalize_axes( + reduction->get_friendly_name(), reduction_axes->cast_vector(), reduction->get_input_partial_shape(0).rank()); + reduction_axes = ngraph::opset6::Constant::create(ngraph::element::i64, {non_negative_axes.size()}, non_negative_axes); + + ngraph::NodeVector new_ops; + auto new_axes = ngraph::op::util::make_try_fold( + transpose_order, reduction_axes, ngraph::opset6::Constant::create(ngraph::element::i64, {}, {0})); + new_ops.push_back(new_axes); + auto new_reduce = reduction->copy_with_new_inputs({transpose->input_value(0), new_axes}); + new_ops.push_back(new_reduce); + + auto updated_order = transpose_order; + if (!keep_dims) { + updated_order = get_reduced_order_constant(reduction_axes, transpose_order); + new_ops.push_back(updated_order); + } + auto new_transpose = register_new_node(new_reduce, updated_order); + new_ops.push_back(new_transpose); + new_transpose->set_friendly_name(reduction->get_friendly_name()); + + ngraph::copy_runtime_info({reduction, transpose}, new_ops); + ngraph::replace_node(reduction, new_transpose); + + return true; + }; + + auto m = std::make_shared(reduce_or_squeeze_label, matcher_name); + register_matcher(m, matcher_pass_callback); +} + +ngraph::pass::TransposeFQReduction::TransposeFQReduction() { + MATCHER_SCOPE(TransposeFQReduction); + + auto transpose_label = pattern::wrap_type({pattern::any_input(), pattern::wrap_type()}); + auto fq_label = pattern::wrap_type( + {transpose_label, pattern::any_input(pattern::has_static_rank()), pattern::any_input(pattern::has_static_rank()), + pattern::any_input(pattern::has_static_rank()), pattern::any_input(pattern::has_static_rank())}); + auto reduce_or_squeeze_label = pattern::wrap_type( + {fq_label, pattern::wrap_type()}); + + ngraph::matcher_pass_callback matcher_pass_callback = [=](ngraph::pattern::Matcher &m) { + auto &pattern_to_output = m.get_pattern_value_map(); + + auto transpose = pattern_to_output.at(transpose_label).get_node_shared_ptr(); + auto transpose_order = std::dynamic_pointer_cast(transpose->get_input_node_shared_ptr(1)); + auto fq = pattern_to_output.at(fq_label).get_node_shared_ptr(); + if (!transpose || !transpose_order || !fq) + return false; + + ngraph::NodeVector new_ops; + + const auto& reverse_order_constant = get_reversed_order_constant(transpose_order); + new_ops.push_back(reverse_order_constant); + + const auto& input_rank = fq->get_input_partial_shape(0).rank().get_length(); + ngraph::OutputVector fq_inputs = {transpose->input_value(0)}; + for (size_t i = 1; i < fq->inputs().size(); ++i) { + auto input = fq->input_value(i); + const auto& ranks_diff = input_rank - input.get_partial_shape().rank().get_length(); + NGRAPH_CHECK(ranks_diff >= 0); + if (ranks_diff > 0) { + std::vector axes(ranks_diff); + std::iota(axes.begin(), axes.end(), 0); + const auto& axes_const = opset6::Constant::create(element::i64, Shape{axes.size()}, axes); + new_ops.push_back(axes_const); + const auto& unsqueezed_input = op::util::make_try_fold(input, axes_const); + new_ops.push_back(unsqueezed_input); + input = unsqueezed_input->output(0); + } + const auto& transposed_input = op::util::make_try_fold(input, reverse_order_constant); + new_ops.push_back(transposed_input); + fq_inputs.push_back(transposed_input); + } + auto new_fq = fq->copy_with_new_inputs(fq_inputs); + new_ops.push_back(new_fq); + + auto new_transpose = std::make_shared(new_fq, transpose_order); + new_ops.push_back(new_transpose); + new_transpose->set_friendly_name(fq->get_friendly_name()); + + ngraph::copy_runtime_info({fq, transpose}, new_ops); + ngraph::replace_node(fq, new_transpose); + // The root node (reduction) left unchanged during current matcher pass. + // We return false here for further MatcherPasses to be applicable for this node as a root node + return false; + }; + + auto m = std::make_shared(reduce_or_squeeze_label, matcher_name); + register_matcher(m, matcher_pass_callback); +} diff --git a/inference-engine/src/transformations/src/transformations/utils/utils.cpp b/inference-engine/src/transformations/src/transformations/utils/utils.cpp index c5df435a7e3..f8179fbb643 100644 --- a/inference-engine/src/transformations/src/transformations/utils/utils.cpp +++ b/inference-engine/src/transformations/src/transformations/utils/utils.cpp @@ -130,6 +130,18 @@ bool is_seq_len_provided(const std::shared_ptr &seq_len_input, int64_t max return true; } +std::shared_ptr try_fold_unary_output(const std::shared_ptr& node) { + const auto& num_outputs = node->get_output_size(); + NGRAPH_CHECK(num_outputs == 1, "Unary has unexpected number of outputs:" + std::to_string(num_outputs)); + OutputVector output(num_outputs); + return node->constant_fold(output, node->input_values()) ? output[0].get_node_shared_ptr() : node; +} + +std::shared_ptr clone_try_fold(const std::shared_ptr& node, const OutputVector& inputs) { + auto unary_output_node = node->clone_with_new_inputs(inputs); + return try_fold_unary_output(unary_output_node); +} + } // namespace util } // namespace op } // namespace ngraph diff --git a/inference-engine/tests/functional/inference_engine/transformations/algebraic_simplification.cpp b/inference-engine/tests/functional/inference_engine/transformations/algebraic_simplification.cpp index ea0b588881b..4f94db681a5 100644 --- a/inference-engine/tests/functional/inference_engine/transformations/algebraic_simplification.cpp +++ b/inference-engine/tests/functional/inference_engine/transformations/algebraic_simplification.cpp @@ -19,6 +19,7 @@ #include #include #include +#include #include "common_test_utils/ngraph_test_utils.hpp" @@ -311,8 +312,7 @@ TEST(algebraic_simplification, replace_transpose_with_reshape) { pass::Manager pass_manager; pass_manager.register_pass(); - pass_manager.register_pass(); - pass_manager.register_pass(); + pass_manager.register_pass(); pass_manager.run_passes(optimized_f); auto ps = baseline_f->get_results()[0]->get_output_partial_shape(0); diff --git a/inference-engine/tests/functional/inference_engine/transformations/transpose_sinking_test.cpp b/inference-engine/tests/functional/inference_engine/transformations/transpose_sinking_test.cpp new file mode 100644 index 00000000000..13cdd901493 --- /dev/null +++ b/inference-engine/tests/functional/inference_engine/transformations/transpose_sinking_test.cpp @@ -0,0 +1,203 @@ +// Copyright (C) 2018-2021 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include + +#include +#include + +#include +#include +#include +#include +#include +#include + +#include "common_test_utils/ngraph_test_utils.hpp" + +using namespace testing; +using namespace ngraph; + +struct TransposeFQReduceParams { + // given params + PartialShape transpose_input_shape; + std::vector transpose_order; + Shape il, ih, ol, oh; + std::vector reduce_axes; + bool reduce_keep_dims; + + // expected params + Shape ex_il, ex_ih, ex_ol, ex_oh; + std::vector ex_reduce_axes; + std::vector ex_transpose_order; +}; + +class TransposeSinkingFQ : public CommonTestUtils::TestsCommon, + public testing::WithParamInterface> { +public: + std::shared_ptr f, f_ref; + + void SetUp() override { + const auto& test_case = std::get<0>(GetParam()); + + { + auto input = std::make_shared(element::f32, test_case.transpose_input_shape); + + auto order = std::make_shared(element::i64, Shape{test_case.transpose_order.size()}, test_case.transpose_order); + auto transpose = std::make_shared(input, order); + + auto i_low = std::make_shared(element::i64, test_case.il, std::vector{0}); + auto i_high = std::make_shared(element::i64, test_case.ih, std::vector{0}); + auto o_low = std::make_shared(element::i64, test_case.ol, std::vector{0}); + auto o_high = std::make_shared(element::i64, test_case.oh, std::vector{0}); + auto fq = std::make_shared(transpose, i_low, i_high, o_low, o_high, 256); + + auto axes = std::make_shared( + element::i64, Shape{test_case.reduce_axes.size()}, test_case.reduce_axes); + auto reduce = std::make_shared(fq, axes, test_case.reduce_keep_dims); + + f = std::make_shared(ngraph::NodeVector{reduce}, ngraph::ParameterVector{input}); + } + + { + auto input = std::make_shared(element::f32, test_case.transpose_input_shape); + + auto i_low = std::make_shared(element::i64, test_case.ex_il, std::vector{0}); + auto i_high = std::make_shared(element::i64, test_case.ex_ih, std::vector{0}); + auto o_low = std::make_shared(element::i64, test_case.ex_ol, std::vector{0}); + auto o_high = std::make_shared(element::i64, test_case.ex_oh, std::vector{0}); + auto fq = std::make_shared(input, i_low, i_high, o_low, o_high, 256); + + auto axes = std::make_shared( + element::i64, Shape{test_case.ex_reduce_axes.size()}, test_case.ex_reduce_axes); + auto reduce = std::make_shared(fq, axes, test_case.reduce_keep_dims); + + auto order = std::make_shared(element::i64, Shape{test_case.ex_transpose_order.size()}, test_case.ex_transpose_order); + auto transpose = std::make_shared(reduce, order); + + f_ref = std::make_shared(ngraph::NodeVector{transpose}, ngraph::ParameterVector{input}); + } + } +}; + +TEST_P(TransposeSinkingFQ, TransposeFQReduce) { + ngraph::pass::Manager manager; + manager.register_pass(); + manager.register_pass(); + manager.register_pass(); + manager.run_passes(f); + ASSERT_NO_THROW(check_rt_info(f)); + + auto res = compare_functions(f, f_ref, true); + ASSERT_TRUE(res.first) << res.second; +} + + +INSTANTIATE_TEST_CASE_P(TransformationTest, TransposeSinkingFQ, testing::Values( + TransposeFQReduceParams{{1, 3, 240, 140}, {0, 2, 3, 1}, {1}, {3}, {1, 1, 1, 1}, {1, 1, 1, 3}, {1, 2}, true, + {1, 1, 1, 1}, {1, 3, 1, 1}, {1, 1, 1, 1}, {1, 3, 1, 1}, {2, 3}, {0, 2, 3, 1}}, + TransposeFQReduceParams{{1, 3, 240, 140}, {0, 2, 3, 1}, {1}, {3}, {1, 1, 1, 1}, {1, 1, 1, 3}, {1, 2}, false, + {1, 1, 1, 1}, {1, 3, 1, 1}, {1, 1, 1, 1}, {1, 3, 1, 1}, {2, 3}, {0, 1}})); + + + +struct TransposeReduceParams { + // given params + PartialShape transpose_input_shape; + std::vector transpose_order; + std::vector reduce_axes; + bool reduction_keep_dims; + + // expected params + std::vector ex_reduce_axes; + std::vector ex_transpose_order; +}; + +class TransposeSinking : public CommonTestUtils::TestsCommon, + public testing::WithParamInterface> { +public: + std::shared_ptr f, f_ref; + + void SetUp() override { + const auto& test_case = std::get<0>(GetParam()); + const auto& reduction_type_info = std::get<1>(GetParam()); + + { + auto input = std::make_shared(element::dynamic, test_case.transpose_input_shape); + + auto order = std::make_shared(element::i64, Shape{test_case.transpose_order.size()}, test_case.transpose_order); + auto transpose = std::make_shared(input, order); + + auto axes = std::make_shared( + element::i64, Shape{test_case.reduce_axes.size()}, test_case.reduce_axes); + + auto reduction = get_reduction(reduction_type_info, {transpose, axes}, test_case.reduction_keep_dims); + + f = std::make_shared(ngraph::NodeVector{reduction}, ngraph::ParameterVector{input}); + } + + { + auto input = std::make_shared(element::dynamic, test_case.transpose_input_shape); + + auto axes = std::make_shared( + element::i64, Shape{test_case.ex_reduce_axes.size()}, test_case.ex_reduce_axes); + auto reduction = get_reduction(reduction_type_info, {input, axes}, test_case.reduction_keep_dims); + + auto order = std::make_shared(element::i64, Shape{test_case.ex_transpose_order.size()}, test_case.ex_transpose_order); + auto transpose = std::make_shared(reduction, order); + + f_ref = std::make_shared(ngraph::NodeVector{transpose}, ngraph::ParameterVector{input}); + } + } +private: + std::shared_ptr get_reduction(ngraph::NodeTypeInfo reduction_type_info, const OutputVector& inputs, bool keep_dims) { + auto reduction = ngraph::helpers::getNodeSharedPtr(reduction_type_info, inputs); + if (auto arithmetic_reduce = std::dynamic_pointer_cast(reduction)) + arithmetic_reduce->set_keep_dims(keep_dims); + else if (auto logical_reduce = std::dynamic_pointer_cast(reduction)) + logical_reduce->set_keep_dims(keep_dims); + reduction->validate_and_infer_types(); + return reduction; + } +}; + +TEST_P(TransposeSinking, TransposeReduction) { + ngraph::pass::Manager manager; + manager.register_pass(); + manager.register_pass(); + manager.run_passes(f); + ASSERT_NO_THROW(check_rt_info(f)); + + auto res = compare_functions(f, f_ref, true); + +ASSERT_TRUE(res.first) << res.second; +} + + +INSTANTIATE_TEST_CASE_P(TransposeSinkingReduces, TransposeSinking, testing::Combine( + testing::Values( + TransposeReduceParams{{1, 3, 240, 140}, {0, 2, 3, 1}, {1, 2}, true, {2, 3}, {0, 2, 3, 1}}, + TransposeReduceParams{{10, 20, 30, 40, 50, 60, 70}, {0, 6, 1, 5, 2, 4, 3}, {1, 3, 6}, true, {6, 5, 3}, {0, 6, 1, 5, 2, 4, 3}}, + TransposeReduceParams{{1, 3, 240, 140}, {0, 2, 3, 1}, {1, 2}, false, {2, 3}, {0, 1}}, + TransposeReduceParams{{10, 20, 30, 40, 50, 60, 70}, {0, 6, 1, 5, 2, 4, 3}, {1, 3, 6}, false, {6, 5, 3}, {0, 1, 2, 3}}, + TransposeReduceParams{{10, 20, 30, 40, 50, 60, 70}, {0, 6, 1, 5, 2, 4, 3}, {1, -4, 6}, false, {6, 5, 3}, {0, 1, 2, 3}}, + TransposeReduceParams{{1, 3, 240, 140}, {0, 1, 2, 3}, {0, 1, 2, -1}, false, {0, 1, 2, 3}, {}}), + testing::Values( + ngraph::opset6::ReduceMax::type_info, + ngraph::opset6::ReduceMean::type_info, + ngraph::opset6::ReduceMin::type_info, + ngraph::opset6::ReduceProd::type_info, + ngraph::opset6::ReduceSum::type_info, + ngraph::opset6::ReduceL1::type_info, + ngraph::opset6::ReduceL2::type_info, + ngraph::opset6::ReduceLogicalAnd::type_info, + ngraph::opset6::ReduceLogicalOr::type_info))); + +INSTANTIATE_TEST_CASE_P(TransposeSinkingSqueeze, TransposeSinking, testing::Combine( + testing::Values( + TransposeReduceParams{{2, 3, 1, 1}, {0, 2, 3, 1}, {1, 2}, false, {2, 3}, {0, 1}}, + TransposeReduceParams{{10, 20, 30, 1, 50, 1, 1}, {0, 6, 1, 5, 2, 4, 3}, {1, 3, 6}, false, {6, 5, 3}, {0, 1, 2, 3}}), + testing::Values( + ngraph::opset6::Squeeze::type_info))); + diff --git a/inference-engine/tests/functional/inference_engine/transformations/transpose_to_reshape_test.cpp b/inference-engine/tests/functional/inference_engine/transformations/transpose_to_reshape_test.cpp index 7b6d2384387..61ab9d3964c 100644 --- a/inference-engine/tests/functional/inference_engine/transformations/transpose_to_reshape_test.cpp +++ b/inference-engine/tests/functional/inference_engine/transformations/transpose_to_reshape_test.cpp @@ -19,6 +19,7 @@ #include #include #include +#include #include "common_test_utils/ngraph_test_utils.hpp" @@ -97,7 +98,7 @@ private: TEST_P(TransposeToReshapeTests, CompareFunctions) { ngraph::pass::InitNodeInfo().run_on_function(f); - ngraph::pass::AlgebraicSimplification().run_on_function(f); + ngraph::pass::TransposeSinking().run_on_function(f); f->validate_nodes_and_infer_types(); ASSERT_NO_THROW(check_rt_info(f)); auto res = compare_functions(f, f_ref); diff --git a/ngraph/core/include/ngraph/op/max.hpp b/ngraph/core/include/ngraph/op/max.hpp index 5d83e74f31c..94032212f4c 100644 --- a/ngraph/core/include/ngraph/op/max.hpp +++ b/ngraph/core/include/ngraph/op/max.hpp @@ -16,8 +16,7 @@ namespace ngraph class NGRAPH_API ReduceMax : public util::ArithmeticReductionKeepDims { public: - static constexpr NodeTypeInfo type_info{"ReduceMax", 1}; - const NodeTypeInfo& get_type_info() const override { return type_info; } + NGRAPH_RTTI_DECLARATION; /// \brief Constructs a summation operation. ReduceMax() = default; /// \brief Constructs a summation operation. diff --git a/ngraph/core/include/ngraph/op/min.hpp b/ngraph/core/include/ngraph/op/min.hpp index 78cd5edcbbb..d78d30725d1 100644 --- a/ngraph/core/include/ngraph/op/min.hpp +++ b/ngraph/core/include/ngraph/op/min.hpp @@ -16,8 +16,7 @@ namespace ngraph class NGRAPH_API ReduceMin : public util::ArithmeticReductionKeepDims { public: - static constexpr NodeTypeInfo type_info{"ReduceMin", 1}; - const NodeTypeInfo& get_type_info() const override { return type_info; } + NGRAPH_RTTI_DECLARATION; /// \brief Constructs a summation operation. ReduceMin() = default; /// \brief Constructs a summation operation. diff --git a/ngraph/core/include/ngraph/op/reduce_l1.hpp b/ngraph/core/include/ngraph/op/reduce_l1.hpp index b09c9398be7..1329b5d9e22 100644 --- a/ngraph/core/include/ngraph/op/reduce_l1.hpp +++ b/ngraph/core/include/ngraph/op/reduce_l1.hpp @@ -19,8 +19,7 @@ namespace ngraph class NGRAPH_API ReduceL1 : public util::ArithmeticReductionKeepDims { public: - static constexpr NodeTypeInfo type_info{"ReduceL1", 4}; - const NodeTypeInfo& get_type_info() const override { return type_info; } + NGRAPH_RTTI_DECLARATION; /// \brief Constructs a reducet L1-norm operation. ReduceL1() = default; /// \brief Constructs a reduce L1-norm operation. diff --git a/ngraph/core/include/ngraph/op/reduce_l2.hpp b/ngraph/core/include/ngraph/op/reduce_l2.hpp index 3841d359141..1daa8697acd 100644 --- a/ngraph/core/include/ngraph/op/reduce_l2.hpp +++ b/ngraph/core/include/ngraph/op/reduce_l2.hpp @@ -18,8 +18,7 @@ namespace ngraph class NGRAPH_API ReduceL2 : public util::ArithmeticReductionKeepDims { public: - static constexpr NodeTypeInfo type_info{"ReduceL2", 4}; - const NodeTypeInfo& get_type_info() const override { return type_info; } + NGRAPH_RTTI_DECLARATION; /// \brief Constructs a reducet L2-norm operation. ReduceL2() = default; /// \brief Constructs a reduce L2-norm operation. diff --git a/ngraph/core/include/ngraph/op/reduce_mean.hpp b/ngraph/core/include/ngraph/op/reduce_mean.hpp index 6eca8555be3..9f0f3bdb262 100644 --- a/ngraph/core/include/ngraph/op/reduce_mean.hpp +++ b/ngraph/core/include/ngraph/op/reduce_mean.hpp @@ -16,8 +16,7 @@ namespace ngraph class NGRAPH_API ReduceMean : public util::ArithmeticReductionKeepDims { public: - static constexpr NodeTypeInfo type_info{"ReduceMean", 1}; - const NodeTypeInfo& get_type_info() const override { return type_info; } + NGRAPH_RTTI_DECLARATION; ReduceMean() = default; /// \param arg The tensor to be summed. diff --git a/ngraph/core/include/ngraph/op/reduce_prod.hpp b/ngraph/core/include/ngraph/op/reduce_prod.hpp index c54b87a64b9..b3904a76da9 100644 --- a/ngraph/core/include/ngraph/op/reduce_prod.hpp +++ b/ngraph/core/include/ngraph/op/reduce_prod.hpp @@ -18,8 +18,7 @@ namespace ngraph class NGRAPH_API ReduceProd : public util::ArithmeticReductionKeepDims { public: - static constexpr NodeTypeInfo type_info{"ReduceProd", 1}; - const NodeTypeInfo& get_type_info() const override { return type_info; } + NGRAPH_RTTI_DECLARATION; /// \brief Constructs a product reduction operation. ReduceProd() = default; /// \brief Constructs a product reduction operation. diff --git a/ngraph/core/include/ngraph/op/reduce_sum.hpp b/ngraph/core/include/ngraph/op/reduce_sum.hpp index 8becb286f63..2de81ee71ff 100644 --- a/ngraph/core/include/ngraph/op/reduce_sum.hpp +++ b/ngraph/core/include/ngraph/op/reduce_sum.hpp @@ -65,8 +65,7 @@ namespace ngraph class NGRAPH_API ReduceSum : public util::ArithmeticReductionKeepDims { public: - static constexpr NodeTypeInfo type_info{"ReduceSum", 1}; - const NodeTypeInfo& get_type_info() const override { return type_info; } + NGRAPH_RTTI_DECLARATION; /// \brief Constructs a summation operation. ReduceSum() = default; /// \brief Constructs a summation operation. diff --git a/ngraph/core/include/ngraph/op/util/arithmetic_reduction.hpp b/ngraph/core/include/ngraph/op/util/arithmetic_reduction.hpp index acc5e22b9f6..893c54664c0 100644 --- a/ngraph/core/include/ngraph/op/util/arithmetic_reduction.hpp +++ b/ngraph/core/include/ngraph/op/util/arithmetic_reduction.hpp @@ -21,11 +21,6 @@ namespace ngraph /// \brief Constructs an arithmetic reduction operation. ArithmeticReduction(); - /// \brief Constructs an arithmetic reduction operation. - /// - /// \param arg Output that produces the first input tensor. - /// \param reduction_axes The axis positions (0-based) to be eliminated. - ArithmeticReduction(const Output& arg, const AxisSet& reduction_axes); /// \brief Constructs an arithmetic reduction operation. /// /// \param arg Output that produces the first input tensor. @@ -33,6 +28,7 @@ namespace ngraph ArithmeticReduction(const Output& arg, const Output& reduction_axes); public: + NGRAPH_RTTI_DECLARATION; void validate_and_infer_types() override; /// \return true if reduction axes are constant else false. diff --git a/ngraph/core/include/ngraph/op/util/arithmetic_reductions_keep_dims.hpp b/ngraph/core/include/ngraph/op/util/arithmetic_reductions_keep_dims.hpp index 5398bb53394..f92d282ce42 100644 --- a/ngraph/core/include/ngraph/op/util/arithmetic_reductions_keep_dims.hpp +++ b/ngraph/core/include/ngraph/op/util/arithmetic_reductions_keep_dims.hpp @@ -28,6 +28,7 @@ namespace ngraph bool visit_attributes(AttributeVisitor& visitor) override; public: + NGRAPH_RTTI_DECLARATION; void validate_and_infer_types() override; /// \return If set to 1 it holds axes that are used for reduction. diff --git a/ngraph/core/include/ngraph/op/util/logical_reduction.hpp b/ngraph/core/include/ngraph/op/util/logical_reduction.hpp index 9508887e4a9..e5d0d95ba38 100644 --- a/ngraph/core/include/ngraph/op/util/logical_reduction.hpp +++ b/ngraph/core/include/ngraph/op/util/logical_reduction.hpp @@ -32,6 +32,7 @@ namespace ngraph LogicalReduction(const Output& arg, const Output& reduction_axes); public: + NGRAPH_RTTI_DECLARATION; void validate_and_infer_types() override; /// \return true if reduction axes are constant else false. diff --git a/ngraph/core/include/ngraph/op/util/logical_reduction_keep_dims.hpp b/ngraph/core/include/ngraph/op/util/logical_reduction_keep_dims.hpp index e7a5d8ca448..340f377f67f 100644 --- a/ngraph/core/include/ngraph/op/util/logical_reduction_keep_dims.hpp +++ b/ngraph/core/include/ngraph/op/util/logical_reduction_keep_dims.hpp @@ -28,6 +28,7 @@ namespace ngraph bool visit_attributes(AttributeVisitor& visitor) override; public: + NGRAPH_RTTI_DECLARATION; void validate_and_infer_types() override; /// \return If set to 1 it holds axes that are used for reduction. diff --git a/ngraph/core/include/ngraph/util.hpp b/ngraph/core/include/ngraph/util.hpp index 3295eae9ba3..49605f00e16 100644 --- a/ngraph/core/include/ngraph/util.hpp +++ b/ngraph/core/include/ngraph/util.hpp @@ -215,9 +215,15 @@ namespace ngraph NGRAPH_API AxisVector get_default_order(size_t rank); + NGRAPH_API + AxisVector get_default_order(const Rank& rank); + NGRAPH_API AxisVector get_default_order(const Shape& shape); + NGRAPH_API + AxisVector get_default_order(const PartialShape& shape); + // // EnumMask is intended to work with a scoped enum type. It's used to store // a combination of enum values and provides easy access and manipulation diff --git a/ngraph/core/src/op/max.cpp b/ngraph/core/src/op/max.cpp index 063d55c804f..493810edbac 100644 --- a/ngraph/core/src/op/max.cpp +++ b/ngraph/core/src/op/max.cpp @@ -46,7 +46,7 @@ namespace maxop } } -constexpr NodeTypeInfo op::v1::ReduceMax::type_info; +NGRAPH_RTTI_DEFINITION(op::v1::ReduceMax, "ReduceMax", 1, util::ArithmeticReductionKeepDims); op::v1::ReduceMax::ReduceMax(const Output& arg, const Output& reduction_axes, diff --git a/ngraph/core/src/op/min.cpp b/ngraph/core/src/op/min.cpp index 25c41d1766c..300bd8add3c 100644 --- a/ngraph/core/src/op/min.cpp +++ b/ngraph/core/src/op/min.cpp @@ -46,7 +46,7 @@ namespace minop } } // namespace minop -constexpr NodeTypeInfo op::v1::ReduceMin::type_info; +NGRAPH_RTTI_DEFINITION(op::v1::ReduceMin, "ReduceMin", 1, util::ArithmeticReductionKeepDims); op::v1::ReduceMin::ReduceMin(const Output& arg, const Output& reduction_axes, diff --git a/ngraph/core/src/op/reduce_l1.cpp b/ngraph/core/src/op/reduce_l1.cpp index 39f3b0f48af..f4c02d6f133 100644 --- a/ngraph/core/src/op/reduce_l1.cpp +++ b/ngraph/core/src/op/reduce_l1.cpp @@ -12,7 +12,7 @@ using namespace std; using namespace ngraph; -constexpr NodeTypeInfo op::v4::ReduceL1::type_info; +NGRAPH_RTTI_DEFINITION(op::v4::ReduceL1, "ReduceL1", 4, util::ArithmeticReductionKeepDims); op::v4::ReduceL1::ReduceL1(const Output& arg, const Output& reduction_axes, diff --git a/ngraph/core/src/op/reduce_l2.cpp b/ngraph/core/src/op/reduce_l2.cpp index 567581f4316..8c2498f0c3d 100644 --- a/ngraph/core/src/op/reduce_l2.cpp +++ b/ngraph/core/src/op/reduce_l2.cpp @@ -12,7 +12,7 @@ using namespace std; using namespace ngraph; -constexpr NodeTypeInfo op::v4::ReduceL2::type_info; +NGRAPH_RTTI_DEFINITION(op::v4::ReduceL2, "ReduceL2", 4, util::ArithmeticReductionKeepDims); op::v4::ReduceL2::ReduceL2(const Output& arg, const Output& reduction_axes, diff --git a/ngraph/core/src/op/reduce_logical_and.cpp b/ngraph/core/src/op/reduce_logical_and.cpp index 90814b94603..a522131a7d3 100644 --- a/ngraph/core/src/op/reduce_logical_and.cpp +++ b/ngraph/core/src/op/reduce_logical_and.cpp @@ -12,7 +12,10 @@ using namespace ngraph; using namespace std; -NGRAPH_RTTI_DEFINITION(op::v1::ReduceLogicalAnd, "ReduceLogicalAnd", 1); +NGRAPH_RTTI_DEFINITION(op::v1::ReduceLogicalAnd, + "ReduceLogicalAnd", + 1, + util::LogicalReductionKeepDims); op::v1::ReduceLogicalAnd::ReduceLogicalAnd(const Output& data, const Output& reduction_axes, diff --git a/ngraph/core/src/op/reduce_logical_or.cpp b/ngraph/core/src/op/reduce_logical_or.cpp index 4008863580f..cc09e5c42bf 100644 --- a/ngraph/core/src/op/reduce_logical_or.cpp +++ b/ngraph/core/src/op/reduce_logical_or.cpp @@ -12,7 +12,10 @@ using namespace ngraph; using namespace std; -NGRAPH_RTTI_DEFINITION(op::v1::ReduceLogicalOr, "ReduceLogicalOr", 1); +NGRAPH_RTTI_DEFINITION(op::v1::ReduceLogicalOr, + "ReduceLogicalOr", + 1, + util::LogicalReductionKeepDims); op::v1::ReduceLogicalOr::ReduceLogicalOr(const Output& data, const Output& reduction_axes, diff --git a/ngraph/core/src/op/reduce_mean.cpp b/ngraph/core/src/op/reduce_mean.cpp index 9036766527e..28331a8e905 100644 --- a/ngraph/core/src/op/reduce_mean.cpp +++ b/ngraph/core/src/op/reduce_mean.cpp @@ -13,7 +13,7 @@ using namespace std; using namespace ngraph; -constexpr NodeTypeInfo op::v1::ReduceMean::type_info; +NGRAPH_RTTI_DEFINITION(op::v1::ReduceMean, "ReduceMean", 1, util::ArithmeticReductionKeepDims); op::v1::ReduceMean::ReduceMean(const Output& arg, const Output& reduction_axes, diff --git a/ngraph/core/src/op/reduce_prod.cpp b/ngraph/core/src/op/reduce_prod.cpp index 7696d9a7299..d24ac763f65 100644 --- a/ngraph/core/src/op/reduce_prod.cpp +++ b/ngraph/core/src/op/reduce_prod.cpp @@ -13,7 +13,7 @@ using namespace std; using namespace ngraph; -constexpr NodeTypeInfo op::v1::ReduceProd::type_info; +NGRAPH_RTTI_DEFINITION(op::v1::ReduceProd, "ReduceProd", 1, util::ArithmeticReductionKeepDims); op::v1::ReduceProd::ReduceProd(const Output& arg, const Output& reduction_axes, diff --git a/ngraph/core/src/op/reduce_sum.cpp b/ngraph/core/src/op/reduce_sum.cpp index cbb64dde3b2..935942fe7cf 100644 --- a/ngraph/core/src/op/reduce_sum.cpp +++ b/ngraph/core/src/op/reduce_sum.cpp @@ -13,7 +13,7 @@ using namespace std; using namespace ngraph; -constexpr NodeTypeInfo op::v1::ReduceSum::type_info; +NGRAPH_RTTI_DEFINITION(op::v1::ReduceSum, "ReduceSum", 1, util::ArithmeticReductionKeepDims); op::v1::ReduceSum::ReduceSum(const Output& arg, const Output& reduction_axes, diff --git a/ngraph/core/src/op/util/arithmetic_reduction.cpp b/ngraph/core/src/op/util/arithmetic_reduction.cpp index 565f78f970d..2861ef5f287 100644 --- a/ngraph/core/src/op/util/arithmetic_reduction.cpp +++ b/ngraph/core/src/op/util/arithmetic_reduction.cpp @@ -10,17 +10,9 @@ using namespace std; using namespace ngraph; -op::util::ArithmeticReduction::ArithmeticReduction() {} +NGRAPH_RTTI_DEFINITION(op::util::ArithmeticReduction, "ArithmeticReduction", 0); -op::util::ArithmeticReduction::ArithmeticReduction(const Output& arg, - const AxisSet& reduction_axes) - : Op({arg, - op::Constant::create( - element::i64, Shape{reduction_axes.size()}, reduction_axes.to_vector()) - ->output(0)}) -{ - add_provenance_group_member(input_value(1).get_node_shared_ptr()); -} +op::util::ArithmeticReduction::ArithmeticReduction() {} op::util::ArithmeticReduction::ArithmeticReduction(const Output& arg, const Output& reduction_axes) diff --git a/ngraph/core/src/op/util/arithmetic_reductions_keep_dims.cpp b/ngraph/core/src/op/util/arithmetic_reductions_keep_dims.cpp index 97cdc05bd81..67670c55f15 100644 --- a/ngraph/core/src/op/util/arithmetic_reductions_keep_dims.cpp +++ b/ngraph/core/src/op/util/arithmetic_reductions_keep_dims.cpp @@ -11,6 +11,8 @@ using namespace std; using namespace ngraph; +NGRAPH_RTTI_DEFINITION(op::util::ArithmeticReductionKeepDims, "ArithmeticReductionKeepDims", 0); + op::util::ArithmeticReductionKeepDims::ArithmeticReductionKeepDims( const ngraph::Output& arg, const ngraph::Output& reduction_axes, diff --git a/ngraph/core/src/op/util/logical_reduction.cpp b/ngraph/core/src/op/util/logical_reduction.cpp index 627692eea4b..698dbc32c50 100644 --- a/ngraph/core/src/op/util/logical_reduction.cpp +++ b/ngraph/core/src/op/util/logical_reduction.cpp @@ -10,6 +10,8 @@ using namespace std; using namespace ngraph; +NGRAPH_RTTI_DEFINITION(op::util::LogicalReduction, "LogicalReduction", 1); + op::util::LogicalReduction::LogicalReduction() {} op::util::LogicalReduction::LogicalReduction(const Output& arg, const AxisSet& reduction_axes) diff --git a/ngraph/core/src/op/util/logical_reduction_keep_dims.cpp b/ngraph/core/src/op/util/logical_reduction_keep_dims.cpp index f19d87e1872..9c4ae46c055 100644 --- a/ngraph/core/src/op/util/logical_reduction_keep_dims.cpp +++ b/ngraph/core/src/op/util/logical_reduction_keep_dims.cpp @@ -11,6 +11,8 @@ using namespace std; using namespace ngraph; +NGRAPH_RTTI_DEFINITION(op::util::LogicalReductionKeepDims, "LogicalReductionKeepDims", 1); + op::util::LogicalReductionKeepDims::LogicalReductionKeepDims( const ngraph::Output& arg, const ngraph::Output& reduction_axes, diff --git a/ngraph/core/src/util.cpp b/ngraph/core/src/util.cpp index 93d815e1ed3..c6e87cb6cad 100644 --- a/ngraph/core/src/util.cpp +++ b/ngraph/core/src/util.cpp @@ -401,6 +401,11 @@ AxisVector ngraph::get_default_order(const Shape& shape) return get_default_order(shape.size()); } +AxisVector ngraph::get_default_order(const PartialShape& shape) +{ + return get_default_order(shape.rank()); +} + AxisVector ngraph::get_default_order(size_t rank) { AxisVector default_order(rank); @@ -408,6 +413,15 @@ AxisVector ngraph::get_default_order(size_t rank) return default_order; } +AxisVector ngraph::get_default_order(const Rank& rank) +{ + NGRAPH_CHECK(rank.is_static(), "Can not calculate default order for dynamic rank"); + + AxisVector default_order(rank.get_length()); + std::iota(begin(default_order), end(default_order), 0); + return default_order; +} + void ngraph::parse_version_string( std::string version, size_t& major, size_t& minor, size_t& patch, string& extra) {