[ Transpose sinking ] Transpose->FQ->Reduce (#5026)

* Utils: make_try_fold, clone_try_fold. Template node creation and attempt to fold it

* RTTI for ArithmeticReduction(KeepDims)

* Enriched ngraph::get_default_order overloads with ones for dynamic shape and rank

* [ Transpose sinking ] Transpose->FQ->Reduce to FQ->Reduce->Transpose

* Style: deleted empty line

* RTTI in Reduction operations

* RTTI for LogicalReductionKeepDims

* Transpose: optimizations moved from algebraic simplification to TransposeSinking

* renamed file

* Fix test

* keep_dims is always initialized

* Apply suggestions from code review

Co-authored-by: Gleb Kazantaev <gleb.nnstu@gmail.com>

Co-authored-by: Gleb Kazantaev <gleb.nnstu@gmail.com>
This commit is contained in:
Evgenya Stepyreva 2021-04-05 13:29:21 +03:00 committed by GitHub
parent 71d56ee149
commit 4673cc2d25
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
35 changed files with 626 additions and 130 deletions

View File

@ -0,0 +1,69 @@
// Copyright (C) 2018-2021 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#pragma once
#include <vector>
#include <memory>
#include <transformations_visibility.hpp>
#include <ngraph/ngraph.hpp>
#include <ngraph/pass/graph_rewrite.hpp>
#include "ngraph/pattern/matcher.hpp"
namespace ngraph {
namespace pass {
class TRANSFORMATIONS_API TransposeSinking;
class TRANSFORMATIONS_API TransposeOptimization;
class TRANSFORMATIONS_API TransposeReduction;
class TRANSFORMATIONS_API TransposeFQReduction;
} // namespace pass
} // namespace ngraph
/**
* @ingroup ie_transformation_common_api
* @brief TransposeOptimization transformation replaces suitable Transposes with Reshape operation or optimises them out
*/
class ngraph::pass::TransposeOptimization : public ngraph::pass::MatcherPass {
public:
NGRAPH_RTTI_DECLARATION;
TransposeOptimization();
};
/**
* @ingroup ie_transformation_common_api
* @brief TransposeReduction transformation sinks Transpose through Reduce operations
*/
class ngraph::pass::TransposeReduction : public ngraph::pass::MatcherPass {
public:
NGRAPH_RTTI_DECLARATION;
TransposeReduction();
};
/**
* @ingroup ie_transformation_common_api
* @brief TransposeFQReduction transformation sinks Transpose through FakeQuantize in case it is followed by reduction or squeeze
*/
class ngraph::pass::TransposeFQReduction : public ngraph::pass::MatcherPass {
public:
NGRAPH_RTTI_DECLARATION;
TransposeFQReduction();
};
/**
* @ingroup ie_transformation_common_api
* @brief TransposeSinking transformation sinks Transposes through known operations
*/
class ngraph::pass::TransposeSinking: public ngraph::pass::GraphRewrite {
public:
NGRAPH_RTTI_DECLARATION;
TransposeSinking() {
add_matcher<ngraph::pass::TransposeFQReduction>();
add_matcher<ngraph::pass::TransposeReduction>();
add_matcher<ngraph::pass::TransposeOptimization>();
}
};

View File

@ -106,6 +106,15 @@ TRANSFORMATIONS_API std::shared_ptr<ngraph::Node> activation(const std::string&
TRANSFORMATIONS_API bool is_seq_len_provided(const std::shared_ptr<Node> &seq_len_input, int64_t max_seq_len);
TRANSFORMATIONS_API std::shared_ptr<Node> try_fold_unary_output(const std::shared_ptr<Node>& node);
TRANSFORMATIONS_API std::shared_ptr<Node> clone_try_fold(const std::shared_ptr<Node>& node, const OutputVector& inputs);
template <typename T, typename... Args>
std::shared_ptr<Node> make_try_fold(Args&&... args) {
auto unary_output_node = std::make_shared<T>(std::forward<Args>(args)...);
return try_fold_unary_output(unary_output_node);
}
template <class T>
Output<Node> eltwise_fold(const Output<Node> & input0, const Output<Node> & input1) {

View File

@ -136,93 +136,6 @@ static bool simplify_gather_shapeof(shared_ptr<Node> node) {
return true;
}
static bool replace_transpose_with_reshape(shared_ptr<Node> transpose) {
auto data = transpose->input_value(0);
const auto input_shape = transpose->input(0).get_partial_shape();
if (input_shape.rank().is_dynamic()) {
return false;
}
const size_t input_shape_rank = input_shape.rank().get_length();
auto order = as_type_ptr<opset3::Constant>(transpose->input_value(1).get_node_shared_ptr());
if (!order || !ngraph::shape_size(order->get_shape())) {
return false;
}
const auto order_value = order->cast_vector<int64_t>();
// Check that transpose order without 1 dims has an ascending order
int64_t last_dim(-1);
for (size_t i = 0; i < input_shape_rank; ++i) {
if (input_shape[order_value[i]].is_dynamic() || input_shape[order_value[i]] != 1) {
if (order_value[i] < last_dim) {
return false;
}
last_dim = order_value[i];
}
}
// Transpose operation can be removed if original transpose order is sorted
// or dimension that changes their places equal to 1
using DimensionToPosition = struct {
Dimension dim;
size_t pos;
};
std::vector<DimensionToPosition> dims;
for (size_t i = 0; i < input_shape_rank; ++i) {
if (order_value[i] != static_cast<int64_t>(i)) {
dims.push_back({input_shape[order_value[i]], i});
}
}
// If number of dimensions != 1 to move equal to 0 we can remove this Transpose
if (count_if(dims.begin(), dims.end(), [](const DimensionToPosition& item) {
return !(item.dim.is_static() && item.dim.get_length() == 1);
}) == 0) {
return replace_output_update_name(transpose->output(0), transpose->input_value(0));
}
// Transpose can be replaced with Reshape in two ways:
// 1. Reshape with dims as Constant
// 2. Reshape with dims as input (ShapeOf->Gather)
//
// The first case is possible only if one or less dynamic dimensions changes their position
// For example: input_shape {?, 3, 1, ?} and order {0, 1, 3, 2} can be replaced with Reshape
// with Constant {0, 3, -1, 1} but if input_shape {?, 1, 1, ?} and order {1, 0, 3, 2} transpose
// cannot be replaced int the same way and in this case its only possible to use Gather(ShapeOf,
// order)
Output<Node> reshape_dim;
NodeVector new_ops;
if (count_if(dims.begin(), dims.end(), [](const DimensionToPosition& item) {
return item.dim.is_dynamic();
}) < 2) {
vector<int64_t> reshape_value(input_shape_rank, 0);
for (const auto& item : dims) {
reshape_value[item.pos] = item.dim.is_dynamic() ? -1 : item.dim.get_length();
}
reshape_dim =
opset3::Constant::create(element::i64, Shape{reshape_value.size()}, reshape_value);
} else {
auto shape_of = make_shared<opset3::ShapeOf>(data);
new_ops.push_back(shape_of);
reshape_dim = make_shared<opset3::Gather>(
shape_of, order, opset3::Constant::create(element::i64, Shape{1}, {0}));
new_ops.push_back(reshape_dim.get_node_shared_ptr());
}
auto reshape_op = make_shared<opset3::Reshape>(data, reshape_dim, true);
new_ops.push_back(reshape_op);
reshape_op->set_friendly_name(transpose->get_friendly_name());
copy_runtime_info(transpose, new_ops);
replace_node(transpose, reshape_op);
return true;
}
#define ECHO(NAME) #NAME
#define STR(NAME) ECHO(NAME)
#define SIMPLE_MATCHER_PASS_DEFINITION(NAME, OP, FUNC) \
@ -244,11 +157,9 @@ NGRAPH_RTTI_DEFINITION(NAME, STR(NAME), 0);
SIMPLE_MATCHER_PASS_DEFINITION(EliminateGather, opset3::Gather, simplify_gather);
SIMPLE_MATCHER_PASS_DEFINITION(SimplifyShapeOf2Gather, opset2::ShapeOf, simplify_gather_shapeof);
SIMPLE_MATCHER_PASS_DEFINITION(SimplifyShapeOf3Gather, opset3::ShapeOf, simplify_gather_shapeof);
SIMPLE_MATCHER_PASS_DEFINITION(ConvertTransposeToReshape, opset3::Transpose, replace_transpose_with_reshape);
ngraph::pass::AlgebraicSimplification::AlgebraicSimplification() {
add_matcher<EliminateGather>();
add_matcher<SimplifyShapeOf2Gather>();
add_matcher<SimplifyShapeOf3Gather>();
add_matcher<ConvertTransposeToReshape>();
}

View File

@ -38,6 +38,7 @@
#include "transformations/common_optimizations/space_to_batch_fusion.hpp"
#include "transformations/common_optimizations/batch_to_space_fusion.hpp"
#include "transformations/common_optimizations/dilated_convolution_converter.hpp"
#include "transformations/common_optimizations/transpose_sinking.hpp"
#include "transformations/op_conversions/bidirectional_sequences_decomposition.hpp"
#include "transformations/op_conversions/convert_pad_to_group_conv.hpp"
#include "transformations/op_conversions/convert_divide.hpp"
@ -85,6 +86,7 @@ bool ngraph::pass::CommonOptimizations::run_on_function(std::shared_ptr<ngraph::
manager.register_pass<ngraph::pass::ConstantFolding>();
manager.register_pass<ngraph::pass::StridedSliceOptimization>(); // depends on CF
manager.register_pass<ngraph::pass::BroadcastElementwiseFusion>();
manager.register_pass<ngraph::pass::TransposeSinking>();
auto eliminations = manager.register_pass<ngraph::pass::GraphRewrite>();
eliminations->add_matcher<ngraph::pass::EliminateUnsqueezeGather>();

View File

@ -0,0 +1,273 @@
// Copyright (C) 2018-2021 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include "itt.hpp"
#include "transformations/common_optimizations/transpose_sinking.hpp"
#include "transformations/utils/utils.hpp"
#include <memory>
#include <vector>
#include <ngraph/opsets/opset6.hpp>
#include <ngraph/rt_info.hpp>
#include <ngraph/pattern/op/wrap_type.hpp>
#include <numeric>
NGRAPH_RTTI_DEFINITION(ngraph::pass::TransposeSinking, "TransposeSinking", 0);
NGRAPH_RTTI_DEFINITION(ngraph::pass::TransposeOptimization, "TransposeOptimization", 0);
NGRAPH_RTTI_DEFINITION(ngraph::pass::TransposeReduction, "TransposeReduction", 0);
NGRAPH_RTTI_DEFINITION(ngraph::pass::TransposeFQReduction, "TransposeFQReduction", 0);
using namespace ngraph;
std::shared_ptr<ngraph::opset6::Constant> get_reduced_order_constant(const std::shared_ptr<ngraph::opset6::Constant>& axes_const,
const std::shared_ptr<ngraph::opset6::Constant>& order_const) {
auto order = order_const->cast_vector<int64_t>();
auto axes = axes_const->cast_vector<int64_t>();
std::sort(axes.rbegin(), axes.rend());
for (const auto& i : axes)
order.erase(order.begin() + i);
const auto& updated_order_size = static_cast<int64_t>(order.size());
auto order_sorted = order;
sort(order_sorted.begin(), order_sorted.end());
for (int64_t i = 0; i < updated_order_size; ++i) {
auto lowest_greater_eq_i = std::lower_bound(order_sorted.begin(), order_sorted.end(), i);
std::replace(order.begin(), order.end(), *lowest_greater_eq_i, i);
std::replace(order_sorted.begin(), order_sorted.end(), *lowest_greater_eq_i, i);
}
return std::make_shared<ngraph::opset6::Constant>(
ngraph::element::i64, ngraph::Shape{order.size()}, order);
}
std::shared_ptr<ngraph::opset6::Constant> get_reversed_order_constant(const std::shared_ptr<ngraph::opset6::Constant>& order_const) {
const auto& order = order_const->cast_vector<size_t>();
const auto& rank = order.size();
const auto& default_order = ngraph::get_default_order(rank);
std::vector<size_t> reverse_order(rank);
for (size_t i = 0; i < rank; ++i)
reverse_order[order[i]] = default_order[i];
return std::make_shared<ngraph::opset6::Constant>(
ngraph::element::i64, ngraph::Shape{reverse_order.size()}, reverse_order);
}
bool replace_transpose_with_reshape(const std::shared_ptr<Node>& transpose) {
auto data = transpose->input_value(0);
const auto input_shape = transpose->input(0).get_partial_shape();
const size_t input_shape_rank = input_shape.rank().get_length();
auto order = as_type_ptr<opset6::Constant>(transpose->input_value(1).get_node_shared_ptr());
if (!order || !ngraph::shape_size(order->get_shape())) {
return false;
}
const auto order_value = order->cast_vector<int64_t>();
// Check that transpose order without 1 dims has an ascending order
int64_t last_dim(-1);
for (size_t i = 0; i < input_shape_rank; ++i) {
if (input_shape[order_value[i]].is_dynamic() || input_shape[order_value[i]] != 1) {
if (order_value[i] < last_dim) {
return false;
}
last_dim = order_value[i];
}
}
// Transpose operation can be removed if original transpose order is sorted
// or dimension that changes their places equal to 1
using DimensionToPosition = struct {
Dimension dim;
size_t pos;
};
std::vector<DimensionToPosition> dims;
for (size_t i = 0; i < input_shape_rank; ++i) {
if (order_value[i] != static_cast<int64_t>(i)) {
dims.push_back({input_shape[order_value[i]], i});
}
}
// If number of dimensions != 1 to move equal to 0 we can remove this Transpose
if (count_if(dims.begin(), dims.end(), [](const DimensionToPosition& item) {
return !(item.dim.is_static() && item.dim.get_length() == 1);
}) == 0) {
return replace_output_update_name(transpose->output(0), transpose->input_value(0));
}
// Transpose can be replaced with Reshape in two ways:
// 1. Reshape with dims as Constant
// 2. Reshape with dims as input (ShapeOf->Gather)
//
// The first case is possible only if one or less dynamic dimensions changes their position
// For example: input_shape {?, 3, 1, ?} and order {0, 1, 3, 2} can be replaced with Reshape
// with Constant {0, 3, -1, 1} but if input_shape {?, 1, 1, ?} and order {1, 0, 3, 2} transpose
// cannot be replaced int the same way and in this case its only possible to use Gather(ShapeOf,
// order)
Output<Node> reshape_dim;
NodeVector new_ops;
if (count_if(dims.begin(), dims.end(), [](const DimensionToPosition& item) {
return item.dim.is_dynamic();
}) < 2) {
std::vector<int64_t> reshape_value(input_shape_rank, 0);
for (const auto& item : dims) {
reshape_value[item.pos] = item.dim.is_dynamic() ? -1 : item.dim.get_length();
}
reshape_dim =
opset3::Constant::create(element::i64, Shape{reshape_value.size()}, reshape_value);
} else {
auto shape_of = std::make_shared<opset3::ShapeOf>(data);
new_ops.push_back(shape_of);
reshape_dim = std::make_shared<opset3::Gather>(
shape_of, order, opset3::Constant::create(element::i64, Shape{1}, {0}));
new_ops.push_back(reshape_dim.get_node_shared_ptr());
}
auto reshape_op = std::make_shared<opset3::Reshape>(data, reshape_dim, true);
new_ops.push_back(reshape_op);
reshape_op->set_friendly_name(transpose->get_friendly_name());
copy_runtime_info(transpose, new_ops);
replace_node(transpose, reshape_op);
return true;
}
ngraph::pass::TransposeOptimization::TransposeOptimization() {
MATCHER_SCOPE(TransposeOptimization);
auto transpose_label = pattern::wrap_type<opset6::Transpose>(
{pattern::any_input(pattern::has_static_rank()), pattern::wrap_type<opset6::Constant>()});
ngraph::matcher_pass_callback matcher_pass_callback = [=](ngraph::pattern::Matcher &m) {
return replace_transpose_with_reshape(m.get_match_root());
};
auto m = std::make_shared<ngraph::pattern::Matcher>(transpose_label, matcher_name);
register_matcher(m, matcher_pass_callback);
}
ngraph::pass::TransposeReduction::TransposeReduction() {
MATCHER_SCOPE(TransposeReduction);
auto transpose_label = pattern::wrap_type<opset6::Transpose>({pattern::any_input(), pattern::wrap_type<opset6::Constant>()});
auto reduce_or_squeeze_label = pattern::wrap_type<op::util::ArithmeticReductionKeepDims, op::util::LogicalReductionKeepDims, opset6::Squeeze>(
{transpose_label, pattern::wrap_type<opset6::Constant>()});
ngraph::matcher_pass_callback matcher_pass_callback = [=](ngraph::pattern::Matcher &m) {
const auto &pattern_to_output = m.get_pattern_value_map();
auto transpose = pattern_to_output.at(transpose_label).get_node_shared_ptr();
auto reduction = pattern_to_output.at(reduce_or_squeeze_label).get_node_shared_ptr();
auto arithmetic_reduce = std::dynamic_pointer_cast<op::util::ArithmeticReductionKeepDims>(reduction);
auto logical_reduce = std::dynamic_pointer_cast<op::util::LogicalReductionKeepDims>(reduction);
auto squeeze = std::dynamic_pointer_cast<opset6::Squeeze>(reduction);
if (!transpose || !(arithmetic_reduce || logical_reduce || squeeze))
return false;
bool keep_dims = false; // squeeze always reduces number of output dimensions
if (logical_reduce)
keep_dims = logical_reduce->get_keep_dims();
else if (arithmetic_reduce)
keep_dims = arithmetic_reduce->get_keep_dims();
auto transpose_order = std::dynamic_pointer_cast<ngraph::opset6::Constant>(transpose->get_input_node_shared_ptr(1));
auto reduction_axes = std::dynamic_pointer_cast<ngraph::opset6::Constant>(reduction->get_input_node_shared_ptr(1));
if (!transpose_order || !reduction_axes)
return false;
const auto& non_negative_axes = ngraph::normalize_axes(
reduction->get_friendly_name(), reduction_axes->cast_vector<int64_t>(), reduction->get_input_partial_shape(0).rank());
reduction_axes = ngraph::opset6::Constant::create(ngraph::element::i64, {non_negative_axes.size()}, non_negative_axes);
ngraph::NodeVector new_ops;
auto new_axes = ngraph::op::util::make_try_fold<ngraph::opset6::Gather>(
transpose_order, reduction_axes, ngraph::opset6::Constant::create(ngraph::element::i64, {}, {0}));
new_ops.push_back(new_axes);
auto new_reduce = reduction->copy_with_new_inputs({transpose->input_value(0), new_axes});
new_ops.push_back(new_reduce);
auto updated_order = transpose_order;
if (!keep_dims) {
updated_order = get_reduced_order_constant(reduction_axes, transpose_order);
new_ops.push_back(updated_order);
}
auto new_transpose = register_new_node<opset6::Transpose>(new_reduce, updated_order);
new_ops.push_back(new_transpose);
new_transpose->set_friendly_name(reduction->get_friendly_name());
ngraph::copy_runtime_info({reduction, transpose}, new_ops);
ngraph::replace_node(reduction, new_transpose);
return true;
};
auto m = std::make_shared<ngraph::pattern::Matcher>(reduce_or_squeeze_label, matcher_name);
register_matcher(m, matcher_pass_callback);
}
ngraph::pass::TransposeFQReduction::TransposeFQReduction() {
MATCHER_SCOPE(TransposeFQReduction);
auto transpose_label = pattern::wrap_type<opset6::Transpose>({pattern::any_input(), pattern::wrap_type<opset6::Constant>()});
auto fq_label = pattern::wrap_type<opset6::FakeQuantize>(
{transpose_label, pattern::any_input(pattern::has_static_rank()), pattern::any_input(pattern::has_static_rank()),
pattern::any_input(pattern::has_static_rank()), pattern::any_input(pattern::has_static_rank())});
auto reduce_or_squeeze_label = pattern::wrap_type<op::util::ArithmeticReductionKeepDims, op::util::LogicalReductionKeepDims, opset6::Squeeze>(
{fq_label, pattern::wrap_type<opset6::Constant>()});
ngraph::matcher_pass_callback matcher_pass_callback = [=](ngraph::pattern::Matcher &m) {
auto &pattern_to_output = m.get_pattern_value_map();
auto transpose = pattern_to_output.at(transpose_label).get_node_shared_ptr();
auto transpose_order = std::dynamic_pointer_cast<opset6::Constant>(transpose->get_input_node_shared_ptr(1));
auto fq = pattern_to_output.at(fq_label).get_node_shared_ptr();
if (!transpose || !transpose_order || !fq)
return false;
ngraph::NodeVector new_ops;
const auto& reverse_order_constant = get_reversed_order_constant(transpose_order);
new_ops.push_back(reverse_order_constant);
const auto& input_rank = fq->get_input_partial_shape(0).rank().get_length();
ngraph::OutputVector fq_inputs = {transpose->input_value(0)};
for (size_t i = 1; i < fq->inputs().size(); ++i) {
auto input = fq->input_value(i);
const auto& ranks_diff = input_rank - input.get_partial_shape().rank().get_length();
NGRAPH_CHECK(ranks_diff >= 0);
if (ranks_diff > 0) {
std::vector<int64_t> axes(ranks_diff);
std::iota(axes.begin(), axes.end(), 0);
const auto& axes_const = opset6::Constant::create(element::i64, Shape{axes.size()}, axes);
new_ops.push_back(axes_const);
const auto& unsqueezed_input = op::util::make_try_fold<opset6::Unsqueeze>(input, axes_const);
new_ops.push_back(unsqueezed_input);
input = unsqueezed_input->output(0);
}
const auto& transposed_input = op::util::make_try_fold<opset6::Transpose>(input, reverse_order_constant);
new_ops.push_back(transposed_input);
fq_inputs.push_back(transposed_input);
}
auto new_fq = fq->copy_with_new_inputs(fq_inputs);
new_ops.push_back(new_fq);
auto new_transpose = std::make_shared<ngraph::opset6::Transpose>(new_fq, transpose_order);
new_ops.push_back(new_transpose);
new_transpose->set_friendly_name(fq->get_friendly_name());
ngraph::copy_runtime_info({fq, transpose}, new_ops);
ngraph::replace_node(fq, new_transpose);
// The root node (reduction) left unchanged during current matcher pass.
// We return false here for further MatcherPasses to be applicable for this node as a root node
return false;
};
auto m = std::make_shared<ngraph::pattern::Matcher>(reduce_or_squeeze_label, matcher_name);
register_matcher(m, matcher_pass_callback);
}

View File

@ -130,6 +130,18 @@ bool is_seq_len_provided(const std::shared_ptr<Node> &seq_len_input, int64_t max
return true;
}
std::shared_ptr<Node> try_fold_unary_output(const std::shared_ptr<Node>& node) {
const auto& num_outputs = node->get_output_size();
NGRAPH_CHECK(num_outputs == 1, "Unary has unexpected number of outputs:" + std::to_string(num_outputs));
OutputVector output(num_outputs);
return node->constant_fold(output, node->input_values()) ? output[0].get_node_shared_ptr() : node;
}
std::shared_ptr<Node> clone_try_fold(const std::shared_ptr<Node>& node, const OutputVector& inputs) {
auto unary_output_node = node->clone_with_new_inputs(inputs);
return try_fold_unary_output(unary_output_node);
}
} // namespace util
} // namespace op
} // namespace ngraph

View File

@ -19,6 +19,7 @@
#include <transformations/common_optimizations/algebraic_simplification.hpp>
#include <transformations/utils/utils.hpp>
#include <transformations/init_node_info.hpp>
#include <transformations/common_optimizations/transpose_sinking.hpp>
#include "common_test_utils/ngraph_test_utils.hpp"
@ -311,8 +312,7 @@ TEST(algebraic_simplification, replace_transpose_with_reshape) {
pass::Manager pass_manager;
pass_manager.register_pass<pass::Validate>();
pass_manager.register_pass<pass::AlgebraicSimplification>();
pass_manager.register_pass<pass::ConstantFolding>();
pass_manager.register_pass<pass::TransposeSinking>();
pass_manager.run_passes(optimized_f);
auto ps = baseline_f->get_results()[0]->get_output_partial_shape(0);

View File

@ -0,0 +1,203 @@
// Copyright (C) 2018-2021 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include <gtest/gtest.h>
#include <string>
#include <memory>
#include <ngraph/function.hpp>
#include <ngraph/opsets/opset6.hpp>
#include <ngraph/pass/manager.hpp>
#include <transformations/common_optimizations/transpose_sinking.hpp>
#include <transformations/init_node_info.hpp>
#include <ngraph_functions/utils/ngraph_helpers.hpp>
#include "common_test_utils/ngraph_test_utils.hpp"
using namespace testing;
using namespace ngraph;
struct TransposeFQReduceParams {
// given params
PartialShape transpose_input_shape;
std::vector<int32_t> transpose_order;
Shape il, ih, ol, oh;
std::vector<int32_t> reduce_axes;
bool reduce_keep_dims;
// expected params
Shape ex_il, ex_ih, ex_ol, ex_oh;
std::vector<int32_t> ex_reduce_axes;
std::vector<int32_t> ex_transpose_order;
};
class TransposeSinkingFQ : public CommonTestUtils::TestsCommon,
public testing::WithParamInterface<std::tuple<TransposeFQReduceParams>> {
public:
std::shared_ptr<Function> f, f_ref;
void SetUp() override {
const auto& test_case = std::get<0>(GetParam());
{
auto input = std::make_shared<opset6::Parameter>(element::f32, test_case.transpose_input_shape);
auto order = std::make_shared<opset6::Constant>(element::i64, Shape{test_case.transpose_order.size()}, test_case.transpose_order);
auto transpose = std::make_shared<ngraph::opset6::Transpose>(input, order);
auto i_low = std::make_shared<ngraph::opset6::Constant>(element::i64, test_case.il, std::vector<int32_t>{0});
auto i_high = std::make_shared<ngraph::opset6::Constant>(element::i64, test_case.ih, std::vector<int32_t>{0});
auto o_low = std::make_shared<ngraph::opset6::Constant>(element::i64, test_case.ol, std::vector<int32_t>{0});
auto o_high = std::make_shared<ngraph::opset6::Constant>(element::i64, test_case.oh, std::vector<int32_t>{0});
auto fq = std::make_shared<ngraph::opset6::FakeQuantize>(transpose, i_low, i_high, o_low, o_high, 256);
auto axes = std::make_shared<ngraph::opset6::Constant>(
element::i64, Shape{test_case.reduce_axes.size()}, test_case.reduce_axes);
auto reduce = std::make_shared<ngraph::opset6::ReduceMean>(fq, axes, test_case.reduce_keep_dims);
f = std::make_shared<ngraph::Function>(ngraph::NodeVector{reduce}, ngraph::ParameterVector{input});
}
{
auto input = std::make_shared<opset6::Parameter>(element::f32, test_case.transpose_input_shape);
auto i_low = std::make_shared<ngraph::opset6::Constant>(element::i64, test_case.ex_il, std::vector<int32_t>{0});
auto i_high = std::make_shared<ngraph::opset6::Constant>(element::i64, test_case.ex_ih, std::vector<int32_t>{0});
auto o_low = std::make_shared<ngraph::opset6::Constant>(element::i64, test_case.ex_ol, std::vector<int32_t>{0});
auto o_high = std::make_shared<ngraph::opset6::Constant>(element::i64, test_case.ex_oh, std::vector<int32_t>{0});
auto fq = std::make_shared<ngraph::opset6::FakeQuantize>(input, i_low, i_high, o_low, o_high, 256);
auto axes = std::make_shared<ngraph::opset6::Constant>(
element::i64, Shape{test_case.ex_reduce_axes.size()}, test_case.ex_reduce_axes);
auto reduce = std::make_shared<ngraph::opset6::ReduceMean>(fq, axes, test_case.reduce_keep_dims);
auto order = std::make_shared<opset6::Constant>(element::i64, Shape{test_case.ex_transpose_order.size()}, test_case.ex_transpose_order);
auto transpose = std::make_shared<ngraph::opset6::Transpose>(reduce, order);
f_ref = std::make_shared<ngraph::Function>(ngraph::NodeVector{transpose}, ngraph::ParameterVector{input});
}
}
};
TEST_P(TransposeSinkingFQ, TransposeFQReduce) {
ngraph::pass::Manager manager;
manager.register_pass<ngraph::pass::InitNodeInfo>();
manager.register_pass<ngraph::pass::TransposeFQReduction>();
manager.register_pass<ngraph::pass::TransposeReduction>();
manager.run_passes(f);
ASSERT_NO_THROW(check_rt_info(f));
auto res = compare_functions(f, f_ref, true);
ASSERT_TRUE(res.first) << res.second;
}
INSTANTIATE_TEST_CASE_P(TransformationTest, TransposeSinkingFQ, testing::Values(
TransposeFQReduceParams{{1, 3, 240, 140}, {0, 2, 3, 1}, {1}, {3}, {1, 1, 1, 1}, {1, 1, 1, 3}, {1, 2}, true,
{1, 1, 1, 1}, {1, 3, 1, 1}, {1, 1, 1, 1}, {1, 3, 1, 1}, {2, 3}, {0, 2, 3, 1}},
TransposeFQReduceParams{{1, 3, 240, 140}, {0, 2, 3, 1}, {1}, {3}, {1, 1, 1, 1}, {1, 1, 1, 3}, {1, 2}, false,
{1, 1, 1, 1}, {1, 3, 1, 1}, {1, 1, 1, 1}, {1, 3, 1, 1}, {2, 3}, {0, 1}}));
struct TransposeReduceParams {
// given params
PartialShape transpose_input_shape;
std::vector<int32_t> transpose_order;
std::vector<int32_t> reduce_axes;
bool reduction_keep_dims;
// expected params
std::vector<int32_t> ex_reduce_axes;
std::vector<int32_t> ex_transpose_order;
};
class TransposeSinking : public CommonTestUtils::TestsCommon,
public testing::WithParamInterface<std::tuple<TransposeReduceParams, ngraph::NodeTypeInfo>> {
public:
std::shared_ptr<Function> f, f_ref;
void SetUp() override {
const auto& test_case = std::get<0>(GetParam());
const auto& reduction_type_info = std::get<1>(GetParam());
{
auto input = std::make_shared<opset6::Parameter>(element::dynamic, test_case.transpose_input_shape);
auto order = std::make_shared<opset6::Constant>(element::i64, Shape{test_case.transpose_order.size()}, test_case.transpose_order);
auto transpose = std::make_shared<ngraph::opset6::Transpose>(input, order);
auto axes = std::make_shared<ngraph::opset6::Constant>(
element::i64, Shape{test_case.reduce_axes.size()}, test_case.reduce_axes);
auto reduction = get_reduction(reduction_type_info, {transpose, axes}, test_case.reduction_keep_dims);
f = std::make_shared<ngraph::Function>(ngraph::NodeVector{reduction}, ngraph::ParameterVector{input});
}
{
auto input = std::make_shared<opset6::Parameter>(element::dynamic, test_case.transpose_input_shape);
auto axes = std::make_shared<ngraph::opset6::Constant>(
element::i64, Shape{test_case.ex_reduce_axes.size()}, test_case.ex_reduce_axes);
auto reduction = get_reduction(reduction_type_info, {input, axes}, test_case.reduction_keep_dims);
auto order = std::make_shared<opset6::Constant>(element::i64, Shape{test_case.ex_transpose_order.size()}, test_case.ex_transpose_order);
auto transpose = std::make_shared<ngraph::opset6::Transpose>(reduction, order);
f_ref = std::make_shared<ngraph::Function>(ngraph::NodeVector{transpose}, ngraph::ParameterVector{input});
}
}
private:
std::shared_ptr<Node> get_reduction(ngraph::NodeTypeInfo reduction_type_info, const OutputVector& inputs, bool keep_dims) {
auto reduction = ngraph::helpers::getNodeSharedPtr(reduction_type_info, inputs);
if (auto arithmetic_reduce = std::dynamic_pointer_cast<op::util::ArithmeticReductionKeepDims>(reduction))
arithmetic_reduce->set_keep_dims(keep_dims);
else if (auto logical_reduce = std::dynamic_pointer_cast<op::util::LogicalReductionKeepDims>(reduction))
logical_reduce->set_keep_dims(keep_dims);
reduction->validate_and_infer_types();
return reduction;
}
};
TEST_P(TransposeSinking, TransposeReduction) {
ngraph::pass::Manager manager;
manager.register_pass<ngraph::pass::InitNodeInfo>();
manager.register_pass<ngraph::pass::TransposeReduction>();
manager.run_passes(f);
ASSERT_NO_THROW(check_rt_info(f));
auto res = compare_functions(f, f_ref, true);
ASSERT_TRUE(res.first) << res.second;
}
INSTANTIATE_TEST_CASE_P(TransposeSinkingReduces, TransposeSinking, testing::Combine(
testing::Values(
TransposeReduceParams{{1, 3, 240, 140}, {0, 2, 3, 1}, {1, 2}, true, {2, 3}, {0, 2, 3, 1}},
TransposeReduceParams{{10, 20, 30, 40, 50, 60, 70}, {0, 6, 1, 5, 2, 4, 3}, {1, 3, 6}, true, {6, 5, 3}, {0, 6, 1, 5, 2, 4, 3}},
TransposeReduceParams{{1, 3, 240, 140}, {0, 2, 3, 1}, {1, 2}, false, {2, 3}, {0, 1}},
TransposeReduceParams{{10, 20, 30, 40, 50, 60, 70}, {0, 6, 1, 5, 2, 4, 3}, {1, 3, 6}, false, {6, 5, 3}, {0, 1, 2, 3}},
TransposeReduceParams{{10, 20, 30, 40, 50, 60, 70}, {0, 6, 1, 5, 2, 4, 3}, {1, -4, 6}, false, {6, 5, 3}, {0, 1, 2, 3}},
TransposeReduceParams{{1, 3, 240, 140}, {0, 1, 2, 3}, {0, 1, 2, -1}, false, {0, 1, 2, 3}, {}}),
testing::Values(
ngraph::opset6::ReduceMax::type_info,
ngraph::opset6::ReduceMean::type_info,
ngraph::opset6::ReduceMin::type_info,
ngraph::opset6::ReduceProd::type_info,
ngraph::opset6::ReduceSum::type_info,
ngraph::opset6::ReduceL1::type_info,
ngraph::opset6::ReduceL2::type_info,
ngraph::opset6::ReduceLogicalAnd::type_info,
ngraph::opset6::ReduceLogicalOr::type_info)));
INSTANTIATE_TEST_CASE_P(TransposeSinkingSqueeze, TransposeSinking, testing::Combine(
testing::Values(
TransposeReduceParams{{2, 3, 1, 1}, {0, 2, 3, 1}, {1, 2}, false, {2, 3}, {0, 1}},
TransposeReduceParams{{10, 20, 30, 1, 50, 1, 1}, {0, 6, 1, 5, 2, 4, 3}, {1, 3, 6}, false, {6, 5, 3}, {0, 1, 2, 3}}),
testing::Values(
ngraph::opset6::Squeeze::type_info)));

View File

@ -19,6 +19,7 @@
#include <transformations/init_node_info.hpp>
#include <transformations/common_optimizations/algebraic_simplification.hpp>
#include <ngraph/pass/visualize_tree.hpp>
#include <transformations/common_optimizations/transpose_sinking.hpp>
#include "common_test_utils/ngraph_test_utils.hpp"
@ -97,7 +98,7 @@ private:
TEST_P(TransposeToReshapeTests, CompareFunctions) {
ngraph::pass::InitNodeInfo().run_on_function(f);
ngraph::pass::AlgebraicSimplification().run_on_function(f);
ngraph::pass::TransposeSinking().run_on_function(f);
f->validate_nodes_and_infer_types();
ASSERT_NO_THROW(check_rt_info(f));
auto res = compare_functions(f, f_ref);

View File

@ -16,8 +16,7 @@ namespace ngraph
class NGRAPH_API ReduceMax : public util::ArithmeticReductionKeepDims
{
public:
static constexpr NodeTypeInfo type_info{"ReduceMax", 1};
const NodeTypeInfo& get_type_info() const override { return type_info; }
NGRAPH_RTTI_DECLARATION;
/// \brief Constructs a summation operation.
ReduceMax() = default;
/// \brief Constructs a summation operation.

View File

@ -16,8 +16,7 @@ namespace ngraph
class NGRAPH_API ReduceMin : public util::ArithmeticReductionKeepDims
{
public:
static constexpr NodeTypeInfo type_info{"ReduceMin", 1};
const NodeTypeInfo& get_type_info() const override { return type_info; }
NGRAPH_RTTI_DECLARATION;
/// \brief Constructs a summation operation.
ReduceMin() = default;
/// \brief Constructs a summation operation.

View File

@ -19,8 +19,7 @@ namespace ngraph
class NGRAPH_API ReduceL1 : public util::ArithmeticReductionKeepDims
{
public:
static constexpr NodeTypeInfo type_info{"ReduceL1", 4};
const NodeTypeInfo& get_type_info() const override { return type_info; }
NGRAPH_RTTI_DECLARATION;
/// \brief Constructs a reducet L1-norm operation.
ReduceL1() = default;
/// \brief Constructs a reduce L1-norm operation.

View File

@ -18,8 +18,7 @@ namespace ngraph
class NGRAPH_API ReduceL2 : public util::ArithmeticReductionKeepDims
{
public:
static constexpr NodeTypeInfo type_info{"ReduceL2", 4};
const NodeTypeInfo& get_type_info() const override { return type_info; }
NGRAPH_RTTI_DECLARATION;
/// \brief Constructs a reducet L2-norm operation.
ReduceL2() = default;
/// \brief Constructs a reduce L2-norm operation.

View File

@ -16,8 +16,7 @@ namespace ngraph
class NGRAPH_API ReduceMean : public util::ArithmeticReductionKeepDims
{
public:
static constexpr NodeTypeInfo type_info{"ReduceMean", 1};
const NodeTypeInfo& get_type_info() const override { return type_info; }
NGRAPH_RTTI_DECLARATION;
ReduceMean() = default;
/// \param arg The tensor to be summed.

View File

@ -18,8 +18,7 @@ namespace ngraph
class NGRAPH_API ReduceProd : public util::ArithmeticReductionKeepDims
{
public:
static constexpr NodeTypeInfo type_info{"ReduceProd", 1};
const NodeTypeInfo& get_type_info() const override { return type_info; }
NGRAPH_RTTI_DECLARATION;
/// \brief Constructs a product reduction operation.
ReduceProd() = default;
/// \brief Constructs a product reduction operation.

View File

@ -65,8 +65,7 @@ namespace ngraph
class NGRAPH_API ReduceSum : public util::ArithmeticReductionKeepDims
{
public:
static constexpr NodeTypeInfo type_info{"ReduceSum", 1};
const NodeTypeInfo& get_type_info() const override { return type_info; }
NGRAPH_RTTI_DECLARATION;
/// \brief Constructs a summation operation.
ReduceSum() = default;
/// \brief Constructs a summation operation.

View File

@ -21,11 +21,6 @@ namespace ngraph
/// \brief Constructs an arithmetic reduction operation.
ArithmeticReduction();
/// \brief Constructs an arithmetic reduction operation.
///
/// \param arg Output that produces the first input tensor.
/// \param reduction_axes The axis positions (0-based) to be eliminated.
ArithmeticReduction(const Output<Node>& arg, const AxisSet& reduction_axes);
/// \brief Constructs an arithmetic reduction operation.
///
/// \param arg Output that produces the first input tensor.
@ -33,6 +28,7 @@ namespace ngraph
ArithmeticReduction(const Output<Node>& arg, const Output<Node>& reduction_axes);
public:
NGRAPH_RTTI_DECLARATION;
void validate_and_infer_types() override;
/// \return true if reduction axes are constant else false.

View File

@ -28,6 +28,7 @@ namespace ngraph
bool visit_attributes(AttributeVisitor& visitor) override;
public:
NGRAPH_RTTI_DECLARATION;
void validate_and_infer_types() override;
/// \return If set to 1 it holds axes that are used for reduction.

View File

@ -32,6 +32,7 @@ namespace ngraph
LogicalReduction(const Output<Node>& arg, const Output<Node>& reduction_axes);
public:
NGRAPH_RTTI_DECLARATION;
void validate_and_infer_types() override;
/// \return true if reduction axes are constant else false.

View File

@ -28,6 +28,7 @@ namespace ngraph
bool visit_attributes(AttributeVisitor& visitor) override;
public:
NGRAPH_RTTI_DECLARATION;
void validate_and_infer_types() override;
/// \return If set to 1 it holds axes that are used for reduction.

View File

@ -215,9 +215,15 @@ namespace ngraph
NGRAPH_API
AxisVector get_default_order(size_t rank);
NGRAPH_API
AxisVector get_default_order(const Rank& rank);
NGRAPH_API
AxisVector get_default_order(const Shape& shape);
NGRAPH_API
AxisVector get_default_order(const PartialShape& shape);
//
// EnumMask is intended to work with a scoped enum type. It's used to store
// a combination of enum values and provides easy access and manipulation

View File

@ -46,7 +46,7 @@ namespace maxop
}
}
constexpr NodeTypeInfo op::v1::ReduceMax::type_info;
NGRAPH_RTTI_DEFINITION(op::v1::ReduceMax, "ReduceMax", 1, util::ArithmeticReductionKeepDims);
op::v1::ReduceMax::ReduceMax(const Output<Node>& arg,
const Output<Node>& reduction_axes,

View File

@ -46,7 +46,7 @@ namespace minop
}
} // namespace minop
constexpr NodeTypeInfo op::v1::ReduceMin::type_info;
NGRAPH_RTTI_DEFINITION(op::v1::ReduceMin, "ReduceMin", 1, util::ArithmeticReductionKeepDims);
op::v1::ReduceMin::ReduceMin(const Output<Node>& arg,
const Output<Node>& reduction_axes,

View File

@ -12,7 +12,7 @@
using namespace std;
using namespace ngraph;
constexpr NodeTypeInfo op::v4::ReduceL1::type_info;
NGRAPH_RTTI_DEFINITION(op::v4::ReduceL1, "ReduceL1", 4, util::ArithmeticReductionKeepDims);
op::v4::ReduceL1::ReduceL1(const Output<Node>& arg,
const Output<Node>& reduction_axes,

View File

@ -12,7 +12,7 @@
using namespace std;
using namespace ngraph;
constexpr NodeTypeInfo op::v4::ReduceL2::type_info;
NGRAPH_RTTI_DEFINITION(op::v4::ReduceL2, "ReduceL2", 4, util::ArithmeticReductionKeepDims);
op::v4::ReduceL2::ReduceL2(const Output<Node>& arg,
const Output<Node>& reduction_axes,

View File

@ -12,7 +12,10 @@
using namespace ngraph;
using namespace std;
NGRAPH_RTTI_DEFINITION(op::v1::ReduceLogicalAnd, "ReduceLogicalAnd", 1);
NGRAPH_RTTI_DEFINITION(op::v1::ReduceLogicalAnd,
"ReduceLogicalAnd",
1,
util::LogicalReductionKeepDims);
op::v1::ReduceLogicalAnd::ReduceLogicalAnd(const Output<Node>& data,
const Output<Node>& reduction_axes,

View File

@ -12,7 +12,10 @@
using namespace ngraph;
using namespace std;
NGRAPH_RTTI_DEFINITION(op::v1::ReduceLogicalOr, "ReduceLogicalOr", 1);
NGRAPH_RTTI_DEFINITION(op::v1::ReduceLogicalOr,
"ReduceLogicalOr",
1,
util::LogicalReductionKeepDims);
op::v1::ReduceLogicalOr::ReduceLogicalOr(const Output<Node>& data,
const Output<Node>& reduction_axes,

View File

@ -13,7 +13,7 @@
using namespace std;
using namespace ngraph;
constexpr NodeTypeInfo op::v1::ReduceMean::type_info;
NGRAPH_RTTI_DEFINITION(op::v1::ReduceMean, "ReduceMean", 1, util::ArithmeticReductionKeepDims);
op::v1::ReduceMean::ReduceMean(const Output<Node>& arg,
const Output<Node>& reduction_axes,

View File

@ -13,7 +13,7 @@
using namespace std;
using namespace ngraph;
constexpr NodeTypeInfo op::v1::ReduceProd::type_info;
NGRAPH_RTTI_DEFINITION(op::v1::ReduceProd, "ReduceProd", 1, util::ArithmeticReductionKeepDims);
op::v1::ReduceProd::ReduceProd(const Output<Node>& arg,
const Output<Node>& reduction_axes,

View File

@ -13,7 +13,7 @@
using namespace std;
using namespace ngraph;
constexpr NodeTypeInfo op::v1::ReduceSum::type_info;
NGRAPH_RTTI_DEFINITION(op::v1::ReduceSum, "ReduceSum", 1, util::ArithmeticReductionKeepDims);
op::v1::ReduceSum::ReduceSum(const Output<Node>& arg,
const Output<Node>& reduction_axes,

View File

@ -10,17 +10,9 @@
using namespace std;
using namespace ngraph;
op::util::ArithmeticReduction::ArithmeticReduction() {}
NGRAPH_RTTI_DEFINITION(op::util::ArithmeticReduction, "ArithmeticReduction", 0);
op::util::ArithmeticReduction::ArithmeticReduction(const Output<Node>& arg,
const AxisSet& reduction_axes)
: Op({arg,
op::Constant::create(
element::i64, Shape{reduction_axes.size()}, reduction_axes.to_vector())
->output(0)})
{
add_provenance_group_member(input_value(1).get_node_shared_ptr());
}
op::util::ArithmeticReduction::ArithmeticReduction() {}
op::util::ArithmeticReduction::ArithmeticReduction(const Output<Node>& arg,
const Output<Node>& reduction_axes)

View File

@ -11,6 +11,8 @@
using namespace std;
using namespace ngraph;
NGRAPH_RTTI_DEFINITION(op::util::ArithmeticReductionKeepDims, "ArithmeticReductionKeepDims", 0);
op::util::ArithmeticReductionKeepDims::ArithmeticReductionKeepDims(
const ngraph::Output<ngraph::Node>& arg,
const ngraph::Output<ngraph::Node>& reduction_axes,

View File

@ -10,6 +10,8 @@
using namespace std;
using namespace ngraph;
NGRAPH_RTTI_DEFINITION(op::util::LogicalReduction, "LogicalReduction", 1);
op::util::LogicalReduction::LogicalReduction() {}
op::util::LogicalReduction::LogicalReduction(const Output<Node>& arg, const AxisSet& reduction_axes)

View File

@ -11,6 +11,8 @@
using namespace std;
using namespace ngraph;
NGRAPH_RTTI_DEFINITION(op::util::LogicalReductionKeepDims, "LogicalReductionKeepDims", 1);
op::util::LogicalReductionKeepDims::LogicalReductionKeepDims(
const ngraph::Output<ngraph::Node>& arg,
const ngraph::Output<ngraph::Node>& reduction_axes,

View File

@ -401,6 +401,11 @@ AxisVector ngraph::get_default_order(const Shape& shape)
return get_default_order(shape.size());
}
AxisVector ngraph::get_default_order(const PartialShape& shape)
{
return get_default_order(shape.rank());
}
AxisVector ngraph::get_default_order(size_t rank)
{
AxisVector default_order(rank);
@ -408,6 +413,15 @@ AxisVector ngraph::get_default_order(size_t rank)
return default_order;
}
AxisVector ngraph::get_default_order(const Rank& rank)
{
NGRAPH_CHECK(rank.is_static(), "Can not calculate default order for dynamic rank");
AxisVector default_order(rank.get_length());
std::iota(begin(default_order), end(default_order), 0);
return default_order;
}
void ngraph::parse_version_string(
std::string version, size_t& major, size_t& minor, size_t& patch, string& extra)
{