[ Transpose sinking ] Transpose->FQ->Reduce (#5026)
* Utils: make_try_fold, clone_try_fold. Template node creation and attempt to fold it * RTTI for ArithmeticReduction(KeepDims) * Enriched ngraph::get_default_order overloads with ones for dynamic shape and rank * [ Transpose sinking ] Transpose->FQ->Reduce to FQ->Reduce->Transpose * Style: deleted empty line * RTTI in Reduction operations * RTTI for LogicalReductionKeepDims * Transpose: optimizations moved from algebraic simplification to TransposeSinking * renamed file * Fix test * keep_dims is always initialized * Apply suggestions from code review Co-authored-by: Gleb Kazantaev <gleb.nnstu@gmail.com> Co-authored-by: Gleb Kazantaev <gleb.nnstu@gmail.com>
This commit is contained in:
parent
71d56ee149
commit
4673cc2d25
@ -0,0 +1,69 @@
|
||||
// Copyright (C) 2018-2021 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <vector>
|
||||
#include <memory>
|
||||
|
||||
#include <transformations_visibility.hpp>
|
||||
|
||||
#include <ngraph/ngraph.hpp>
|
||||
#include <ngraph/pass/graph_rewrite.hpp>
|
||||
#include "ngraph/pattern/matcher.hpp"
|
||||
|
||||
namespace ngraph {
|
||||
namespace pass {
|
||||
|
||||
class TRANSFORMATIONS_API TransposeSinking;
|
||||
class TRANSFORMATIONS_API TransposeOptimization;
|
||||
class TRANSFORMATIONS_API TransposeReduction;
|
||||
class TRANSFORMATIONS_API TransposeFQReduction;
|
||||
|
||||
} // namespace pass
|
||||
} // namespace ngraph
|
||||
|
||||
/**
|
||||
* @ingroup ie_transformation_common_api
|
||||
* @brief TransposeOptimization transformation replaces suitable Transposes with Reshape operation or optimises them out
|
||||
*/
|
||||
class ngraph::pass::TransposeOptimization : public ngraph::pass::MatcherPass {
|
||||
public:
|
||||
NGRAPH_RTTI_DECLARATION;
|
||||
TransposeOptimization();
|
||||
};
|
||||
|
||||
/**
|
||||
* @ingroup ie_transformation_common_api
|
||||
* @brief TransposeReduction transformation sinks Transpose through Reduce operations
|
||||
*/
|
||||
class ngraph::pass::TransposeReduction : public ngraph::pass::MatcherPass {
|
||||
public:
|
||||
NGRAPH_RTTI_DECLARATION;
|
||||
TransposeReduction();
|
||||
};
|
||||
|
||||
/**
|
||||
* @ingroup ie_transformation_common_api
|
||||
* @brief TransposeFQReduction transformation sinks Transpose through FakeQuantize in case it is followed by reduction or squeeze
|
||||
*/
|
||||
class ngraph::pass::TransposeFQReduction : public ngraph::pass::MatcherPass {
|
||||
public:
|
||||
NGRAPH_RTTI_DECLARATION;
|
||||
TransposeFQReduction();
|
||||
};
|
||||
|
||||
/**
|
||||
* @ingroup ie_transformation_common_api
|
||||
* @brief TransposeSinking transformation sinks Transposes through known operations
|
||||
*/
|
||||
class ngraph::pass::TransposeSinking: public ngraph::pass::GraphRewrite {
|
||||
public:
|
||||
NGRAPH_RTTI_DECLARATION;
|
||||
TransposeSinking() {
|
||||
add_matcher<ngraph::pass::TransposeFQReduction>();
|
||||
add_matcher<ngraph::pass::TransposeReduction>();
|
||||
add_matcher<ngraph::pass::TransposeOptimization>();
|
||||
}
|
||||
};
|
@ -106,6 +106,15 @@ TRANSFORMATIONS_API std::shared_ptr<ngraph::Node> activation(const std::string&
|
||||
|
||||
TRANSFORMATIONS_API bool is_seq_len_provided(const std::shared_ptr<Node> &seq_len_input, int64_t max_seq_len);
|
||||
|
||||
TRANSFORMATIONS_API std::shared_ptr<Node> try_fold_unary_output(const std::shared_ptr<Node>& node);
|
||||
|
||||
TRANSFORMATIONS_API std::shared_ptr<Node> clone_try_fold(const std::shared_ptr<Node>& node, const OutputVector& inputs);
|
||||
|
||||
template <typename T, typename... Args>
|
||||
std::shared_ptr<Node> make_try_fold(Args&&... args) {
|
||||
auto unary_output_node = std::make_shared<T>(std::forward<Args>(args)...);
|
||||
return try_fold_unary_output(unary_output_node);
|
||||
}
|
||||
|
||||
template <class T>
|
||||
Output<Node> eltwise_fold(const Output<Node> & input0, const Output<Node> & input1) {
|
||||
|
@ -136,93 +136,6 @@ static bool simplify_gather_shapeof(shared_ptr<Node> node) {
|
||||
return true;
|
||||
}
|
||||
|
||||
static bool replace_transpose_with_reshape(shared_ptr<Node> transpose) {
|
||||
auto data = transpose->input_value(0);
|
||||
const auto input_shape = transpose->input(0).get_partial_shape();
|
||||
if (input_shape.rank().is_dynamic()) {
|
||||
return false;
|
||||
}
|
||||
|
||||
const size_t input_shape_rank = input_shape.rank().get_length();
|
||||
|
||||
auto order = as_type_ptr<opset3::Constant>(transpose->input_value(1).get_node_shared_ptr());
|
||||
if (!order || !ngraph::shape_size(order->get_shape())) {
|
||||
return false;
|
||||
}
|
||||
|
||||
const auto order_value = order->cast_vector<int64_t>();
|
||||
|
||||
// Check that transpose order without 1 dims has an ascending order
|
||||
int64_t last_dim(-1);
|
||||
for (size_t i = 0; i < input_shape_rank; ++i) {
|
||||
if (input_shape[order_value[i]].is_dynamic() || input_shape[order_value[i]] != 1) {
|
||||
if (order_value[i] < last_dim) {
|
||||
return false;
|
||||
}
|
||||
last_dim = order_value[i];
|
||||
}
|
||||
}
|
||||
|
||||
// Transpose operation can be removed if original transpose order is sorted
|
||||
// or dimension that changes their places equal to 1
|
||||
using DimensionToPosition = struct {
|
||||
Dimension dim;
|
||||
size_t pos;
|
||||
};
|
||||
std::vector<DimensionToPosition> dims;
|
||||
for (size_t i = 0; i < input_shape_rank; ++i) {
|
||||
if (order_value[i] != static_cast<int64_t>(i)) {
|
||||
dims.push_back({input_shape[order_value[i]], i});
|
||||
}
|
||||
}
|
||||
|
||||
// If number of dimensions != 1 to move equal to 0 we can remove this Transpose
|
||||
if (count_if(dims.begin(), dims.end(), [](const DimensionToPosition& item) {
|
||||
return !(item.dim.is_static() && item.dim.get_length() == 1);
|
||||
}) == 0) {
|
||||
return replace_output_update_name(transpose->output(0), transpose->input_value(0));
|
||||
}
|
||||
|
||||
// Transpose can be replaced with Reshape in two ways:
|
||||
// 1. Reshape with dims as Constant
|
||||
// 2. Reshape with dims as input (ShapeOf->Gather)
|
||||
//
|
||||
// The first case is possible only if one or less dynamic dimensions changes their position
|
||||
// For example: input_shape {?, 3, 1, ?} and order {0, 1, 3, 2} can be replaced with Reshape
|
||||
// with Constant {0, 3, -1, 1} but if input_shape {?, 1, 1, ?} and order {1, 0, 3, 2} transpose
|
||||
// cannot be replaced int the same way and in this case its only possible to use Gather(ShapeOf,
|
||||
// order)
|
||||
|
||||
Output<Node> reshape_dim;
|
||||
NodeVector new_ops;
|
||||
|
||||
if (count_if(dims.begin(), dims.end(), [](const DimensionToPosition& item) {
|
||||
return item.dim.is_dynamic();
|
||||
}) < 2) {
|
||||
vector<int64_t> reshape_value(input_shape_rank, 0);
|
||||
for (const auto& item : dims) {
|
||||
reshape_value[item.pos] = item.dim.is_dynamic() ? -1 : item.dim.get_length();
|
||||
}
|
||||
reshape_dim =
|
||||
opset3::Constant::create(element::i64, Shape{reshape_value.size()}, reshape_value);
|
||||
} else {
|
||||
auto shape_of = make_shared<opset3::ShapeOf>(data);
|
||||
new_ops.push_back(shape_of);
|
||||
reshape_dim = make_shared<opset3::Gather>(
|
||||
shape_of, order, opset3::Constant::create(element::i64, Shape{1}, {0}));
|
||||
new_ops.push_back(reshape_dim.get_node_shared_ptr());
|
||||
}
|
||||
|
||||
auto reshape_op = make_shared<opset3::Reshape>(data, reshape_dim, true);
|
||||
new_ops.push_back(reshape_op);
|
||||
|
||||
reshape_op->set_friendly_name(transpose->get_friendly_name());
|
||||
copy_runtime_info(transpose, new_ops);
|
||||
replace_node(transpose, reshape_op);
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
#define ECHO(NAME) #NAME
|
||||
#define STR(NAME) ECHO(NAME)
|
||||
#define SIMPLE_MATCHER_PASS_DEFINITION(NAME, OP, FUNC) \
|
||||
@ -244,11 +157,9 @@ NGRAPH_RTTI_DEFINITION(NAME, STR(NAME), 0);
|
||||
SIMPLE_MATCHER_PASS_DEFINITION(EliminateGather, opset3::Gather, simplify_gather);
|
||||
SIMPLE_MATCHER_PASS_DEFINITION(SimplifyShapeOf2Gather, opset2::ShapeOf, simplify_gather_shapeof);
|
||||
SIMPLE_MATCHER_PASS_DEFINITION(SimplifyShapeOf3Gather, opset3::ShapeOf, simplify_gather_shapeof);
|
||||
SIMPLE_MATCHER_PASS_DEFINITION(ConvertTransposeToReshape, opset3::Transpose, replace_transpose_with_reshape);
|
||||
|
||||
ngraph::pass::AlgebraicSimplification::AlgebraicSimplification() {
|
||||
add_matcher<EliminateGather>();
|
||||
add_matcher<SimplifyShapeOf2Gather>();
|
||||
add_matcher<SimplifyShapeOf3Gather>();
|
||||
add_matcher<ConvertTransposeToReshape>();
|
||||
}
|
||||
|
@ -38,6 +38,7 @@
|
||||
#include "transformations/common_optimizations/space_to_batch_fusion.hpp"
|
||||
#include "transformations/common_optimizations/batch_to_space_fusion.hpp"
|
||||
#include "transformations/common_optimizations/dilated_convolution_converter.hpp"
|
||||
#include "transformations/common_optimizations/transpose_sinking.hpp"
|
||||
#include "transformations/op_conversions/bidirectional_sequences_decomposition.hpp"
|
||||
#include "transformations/op_conversions/convert_pad_to_group_conv.hpp"
|
||||
#include "transformations/op_conversions/convert_divide.hpp"
|
||||
@ -85,6 +86,7 @@ bool ngraph::pass::CommonOptimizations::run_on_function(std::shared_ptr<ngraph::
|
||||
manager.register_pass<ngraph::pass::ConstantFolding>();
|
||||
manager.register_pass<ngraph::pass::StridedSliceOptimization>(); // depends on CF
|
||||
manager.register_pass<ngraph::pass::BroadcastElementwiseFusion>();
|
||||
manager.register_pass<ngraph::pass::TransposeSinking>();
|
||||
|
||||
auto eliminations = manager.register_pass<ngraph::pass::GraphRewrite>();
|
||||
eliminations->add_matcher<ngraph::pass::EliminateUnsqueezeGather>();
|
||||
|
@ -0,0 +1,273 @@
|
||||
// Copyright (C) 2018-2021 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#include "itt.hpp"
|
||||
#include "transformations/common_optimizations/transpose_sinking.hpp"
|
||||
#include "transformations/utils/utils.hpp"
|
||||
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
|
||||
#include <ngraph/opsets/opset6.hpp>
|
||||
#include <ngraph/rt_info.hpp>
|
||||
#include <ngraph/pattern/op/wrap_type.hpp>
|
||||
#include <numeric>
|
||||
|
||||
NGRAPH_RTTI_DEFINITION(ngraph::pass::TransposeSinking, "TransposeSinking", 0);
|
||||
NGRAPH_RTTI_DEFINITION(ngraph::pass::TransposeOptimization, "TransposeOptimization", 0);
|
||||
NGRAPH_RTTI_DEFINITION(ngraph::pass::TransposeReduction, "TransposeReduction", 0);
|
||||
NGRAPH_RTTI_DEFINITION(ngraph::pass::TransposeFQReduction, "TransposeFQReduction", 0);
|
||||
|
||||
using namespace ngraph;
|
||||
|
||||
std::shared_ptr<ngraph::opset6::Constant> get_reduced_order_constant(const std::shared_ptr<ngraph::opset6::Constant>& axes_const,
|
||||
const std::shared_ptr<ngraph::opset6::Constant>& order_const) {
|
||||
auto order = order_const->cast_vector<int64_t>();
|
||||
|
||||
auto axes = axes_const->cast_vector<int64_t>();
|
||||
std::sort(axes.rbegin(), axes.rend());
|
||||
for (const auto& i : axes)
|
||||
order.erase(order.begin() + i);
|
||||
|
||||
const auto& updated_order_size = static_cast<int64_t>(order.size());
|
||||
|
||||
auto order_sorted = order;
|
||||
sort(order_sorted.begin(), order_sorted.end());
|
||||
for (int64_t i = 0; i < updated_order_size; ++i) {
|
||||
auto lowest_greater_eq_i = std::lower_bound(order_sorted.begin(), order_sorted.end(), i);
|
||||
std::replace(order.begin(), order.end(), *lowest_greater_eq_i, i);
|
||||
std::replace(order_sorted.begin(), order_sorted.end(), *lowest_greater_eq_i, i);
|
||||
}
|
||||
return std::make_shared<ngraph::opset6::Constant>(
|
||||
ngraph::element::i64, ngraph::Shape{order.size()}, order);
|
||||
}
|
||||
|
||||
std::shared_ptr<ngraph::opset6::Constant> get_reversed_order_constant(const std::shared_ptr<ngraph::opset6::Constant>& order_const) {
|
||||
const auto& order = order_const->cast_vector<size_t>();
|
||||
const auto& rank = order.size();
|
||||
const auto& default_order = ngraph::get_default_order(rank);
|
||||
std::vector<size_t> reverse_order(rank);
|
||||
for (size_t i = 0; i < rank; ++i)
|
||||
reverse_order[order[i]] = default_order[i];
|
||||
|
||||
return std::make_shared<ngraph::opset6::Constant>(
|
||||
ngraph::element::i64, ngraph::Shape{reverse_order.size()}, reverse_order);
|
||||
}
|
||||
|
||||
|
||||
bool replace_transpose_with_reshape(const std::shared_ptr<Node>& transpose) {
|
||||
auto data = transpose->input_value(0);
|
||||
const auto input_shape = transpose->input(0).get_partial_shape();
|
||||
|
||||
const size_t input_shape_rank = input_shape.rank().get_length();
|
||||
|
||||
auto order = as_type_ptr<opset6::Constant>(transpose->input_value(1).get_node_shared_ptr());
|
||||
if (!order || !ngraph::shape_size(order->get_shape())) {
|
||||
return false;
|
||||
}
|
||||
|
||||
const auto order_value = order->cast_vector<int64_t>();
|
||||
|
||||
// Check that transpose order without 1 dims has an ascending order
|
||||
int64_t last_dim(-1);
|
||||
for (size_t i = 0; i < input_shape_rank; ++i) {
|
||||
if (input_shape[order_value[i]].is_dynamic() || input_shape[order_value[i]] != 1) {
|
||||
if (order_value[i] < last_dim) {
|
||||
return false;
|
||||
}
|
||||
last_dim = order_value[i];
|
||||
}
|
||||
}
|
||||
|
||||
// Transpose operation can be removed if original transpose order is sorted
|
||||
// or dimension that changes their places equal to 1
|
||||
using DimensionToPosition = struct {
|
||||
Dimension dim;
|
||||
size_t pos;
|
||||
};
|
||||
std::vector<DimensionToPosition> dims;
|
||||
for (size_t i = 0; i < input_shape_rank; ++i) {
|
||||
if (order_value[i] != static_cast<int64_t>(i)) {
|
||||
dims.push_back({input_shape[order_value[i]], i});
|
||||
}
|
||||
}
|
||||
|
||||
// If number of dimensions != 1 to move equal to 0 we can remove this Transpose
|
||||
if (count_if(dims.begin(), dims.end(), [](const DimensionToPosition& item) {
|
||||
return !(item.dim.is_static() && item.dim.get_length() == 1);
|
||||
}) == 0) {
|
||||
return replace_output_update_name(transpose->output(0), transpose->input_value(0));
|
||||
}
|
||||
|
||||
// Transpose can be replaced with Reshape in two ways:
|
||||
// 1. Reshape with dims as Constant
|
||||
// 2. Reshape with dims as input (ShapeOf->Gather)
|
||||
//
|
||||
// The first case is possible only if one or less dynamic dimensions changes their position
|
||||
// For example: input_shape {?, 3, 1, ?} and order {0, 1, 3, 2} can be replaced with Reshape
|
||||
// with Constant {0, 3, -1, 1} but if input_shape {?, 1, 1, ?} and order {1, 0, 3, 2} transpose
|
||||
// cannot be replaced int the same way and in this case its only possible to use Gather(ShapeOf,
|
||||
// order)
|
||||
|
||||
Output<Node> reshape_dim;
|
||||
NodeVector new_ops;
|
||||
|
||||
if (count_if(dims.begin(), dims.end(), [](const DimensionToPosition& item) {
|
||||
return item.dim.is_dynamic();
|
||||
}) < 2) {
|
||||
std::vector<int64_t> reshape_value(input_shape_rank, 0);
|
||||
for (const auto& item : dims) {
|
||||
reshape_value[item.pos] = item.dim.is_dynamic() ? -1 : item.dim.get_length();
|
||||
}
|
||||
reshape_dim =
|
||||
opset3::Constant::create(element::i64, Shape{reshape_value.size()}, reshape_value);
|
||||
} else {
|
||||
auto shape_of = std::make_shared<opset3::ShapeOf>(data);
|
||||
new_ops.push_back(shape_of);
|
||||
reshape_dim = std::make_shared<opset3::Gather>(
|
||||
shape_of, order, opset3::Constant::create(element::i64, Shape{1}, {0}));
|
||||
new_ops.push_back(reshape_dim.get_node_shared_ptr());
|
||||
}
|
||||
|
||||
auto reshape_op = std::make_shared<opset3::Reshape>(data, reshape_dim, true);
|
||||
new_ops.push_back(reshape_op);
|
||||
|
||||
reshape_op->set_friendly_name(transpose->get_friendly_name());
|
||||
copy_runtime_info(transpose, new_ops);
|
||||
replace_node(transpose, reshape_op);
|
||||
return true;
|
||||
}
|
||||
|
||||
ngraph::pass::TransposeOptimization::TransposeOptimization() {
|
||||
MATCHER_SCOPE(TransposeOptimization);
|
||||
|
||||
auto transpose_label = pattern::wrap_type<opset6::Transpose>(
|
||||
{pattern::any_input(pattern::has_static_rank()), pattern::wrap_type<opset6::Constant>()});
|
||||
ngraph::matcher_pass_callback matcher_pass_callback = [=](ngraph::pattern::Matcher &m) {
|
||||
return replace_transpose_with_reshape(m.get_match_root());
|
||||
};
|
||||
|
||||
auto m = std::make_shared<ngraph::pattern::Matcher>(transpose_label, matcher_name);
|
||||
register_matcher(m, matcher_pass_callback);
|
||||
}
|
||||
|
||||
ngraph::pass::TransposeReduction::TransposeReduction() {
|
||||
MATCHER_SCOPE(TransposeReduction);
|
||||
|
||||
auto transpose_label = pattern::wrap_type<opset6::Transpose>({pattern::any_input(), pattern::wrap_type<opset6::Constant>()});
|
||||
auto reduce_or_squeeze_label = pattern::wrap_type<op::util::ArithmeticReductionKeepDims, op::util::LogicalReductionKeepDims, opset6::Squeeze>(
|
||||
{transpose_label, pattern::wrap_type<opset6::Constant>()});
|
||||
|
||||
ngraph::matcher_pass_callback matcher_pass_callback = [=](ngraph::pattern::Matcher &m) {
|
||||
const auto &pattern_to_output = m.get_pattern_value_map();
|
||||
|
||||
auto transpose = pattern_to_output.at(transpose_label).get_node_shared_ptr();
|
||||
auto reduction = pattern_to_output.at(reduce_or_squeeze_label).get_node_shared_ptr();
|
||||
auto arithmetic_reduce = std::dynamic_pointer_cast<op::util::ArithmeticReductionKeepDims>(reduction);
|
||||
auto logical_reduce = std::dynamic_pointer_cast<op::util::LogicalReductionKeepDims>(reduction);
|
||||
auto squeeze = std::dynamic_pointer_cast<opset6::Squeeze>(reduction);
|
||||
if (!transpose || !(arithmetic_reduce || logical_reduce || squeeze))
|
||||
return false;
|
||||
|
||||
bool keep_dims = false; // squeeze always reduces number of output dimensions
|
||||
if (logical_reduce)
|
||||
keep_dims = logical_reduce->get_keep_dims();
|
||||
else if (arithmetic_reduce)
|
||||
keep_dims = arithmetic_reduce->get_keep_dims();
|
||||
|
||||
auto transpose_order = std::dynamic_pointer_cast<ngraph::opset6::Constant>(transpose->get_input_node_shared_ptr(1));
|
||||
auto reduction_axes = std::dynamic_pointer_cast<ngraph::opset6::Constant>(reduction->get_input_node_shared_ptr(1));
|
||||
if (!transpose_order || !reduction_axes)
|
||||
return false;
|
||||
|
||||
const auto& non_negative_axes = ngraph::normalize_axes(
|
||||
reduction->get_friendly_name(), reduction_axes->cast_vector<int64_t>(), reduction->get_input_partial_shape(0).rank());
|
||||
reduction_axes = ngraph::opset6::Constant::create(ngraph::element::i64, {non_negative_axes.size()}, non_negative_axes);
|
||||
|
||||
ngraph::NodeVector new_ops;
|
||||
auto new_axes = ngraph::op::util::make_try_fold<ngraph::opset6::Gather>(
|
||||
transpose_order, reduction_axes, ngraph::opset6::Constant::create(ngraph::element::i64, {}, {0}));
|
||||
new_ops.push_back(new_axes);
|
||||
auto new_reduce = reduction->copy_with_new_inputs({transpose->input_value(0), new_axes});
|
||||
new_ops.push_back(new_reduce);
|
||||
|
||||
auto updated_order = transpose_order;
|
||||
if (!keep_dims) {
|
||||
updated_order = get_reduced_order_constant(reduction_axes, transpose_order);
|
||||
new_ops.push_back(updated_order);
|
||||
}
|
||||
auto new_transpose = register_new_node<opset6::Transpose>(new_reduce, updated_order);
|
||||
new_ops.push_back(new_transpose);
|
||||
new_transpose->set_friendly_name(reduction->get_friendly_name());
|
||||
|
||||
ngraph::copy_runtime_info({reduction, transpose}, new_ops);
|
||||
ngraph::replace_node(reduction, new_transpose);
|
||||
|
||||
return true;
|
||||
};
|
||||
|
||||
auto m = std::make_shared<ngraph::pattern::Matcher>(reduce_or_squeeze_label, matcher_name);
|
||||
register_matcher(m, matcher_pass_callback);
|
||||
}
|
||||
|
||||
ngraph::pass::TransposeFQReduction::TransposeFQReduction() {
|
||||
MATCHER_SCOPE(TransposeFQReduction);
|
||||
|
||||
auto transpose_label = pattern::wrap_type<opset6::Transpose>({pattern::any_input(), pattern::wrap_type<opset6::Constant>()});
|
||||
auto fq_label = pattern::wrap_type<opset6::FakeQuantize>(
|
||||
{transpose_label, pattern::any_input(pattern::has_static_rank()), pattern::any_input(pattern::has_static_rank()),
|
||||
pattern::any_input(pattern::has_static_rank()), pattern::any_input(pattern::has_static_rank())});
|
||||
auto reduce_or_squeeze_label = pattern::wrap_type<op::util::ArithmeticReductionKeepDims, op::util::LogicalReductionKeepDims, opset6::Squeeze>(
|
||||
{fq_label, pattern::wrap_type<opset6::Constant>()});
|
||||
|
||||
ngraph::matcher_pass_callback matcher_pass_callback = [=](ngraph::pattern::Matcher &m) {
|
||||
auto &pattern_to_output = m.get_pattern_value_map();
|
||||
|
||||
auto transpose = pattern_to_output.at(transpose_label).get_node_shared_ptr();
|
||||
auto transpose_order = std::dynamic_pointer_cast<opset6::Constant>(transpose->get_input_node_shared_ptr(1));
|
||||
auto fq = pattern_to_output.at(fq_label).get_node_shared_ptr();
|
||||
if (!transpose || !transpose_order || !fq)
|
||||
return false;
|
||||
|
||||
ngraph::NodeVector new_ops;
|
||||
|
||||
const auto& reverse_order_constant = get_reversed_order_constant(transpose_order);
|
||||
new_ops.push_back(reverse_order_constant);
|
||||
|
||||
const auto& input_rank = fq->get_input_partial_shape(0).rank().get_length();
|
||||
ngraph::OutputVector fq_inputs = {transpose->input_value(0)};
|
||||
for (size_t i = 1; i < fq->inputs().size(); ++i) {
|
||||
auto input = fq->input_value(i);
|
||||
const auto& ranks_diff = input_rank - input.get_partial_shape().rank().get_length();
|
||||
NGRAPH_CHECK(ranks_diff >= 0);
|
||||
if (ranks_diff > 0) {
|
||||
std::vector<int64_t> axes(ranks_diff);
|
||||
std::iota(axes.begin(), axes.end(), 0);
|
||||
const auto& axes_const = opset6::Constant::create(element::i64, Shape{axes.size()}, axes);
|
||||
new_ops.push_back(axes_const);
|
||||
const auto& unsqueezed_input = op::util::make_try_fold<opset6::Unsqueeze>(input, axes_const);
|
||||
new_ops.push_back(unsqueezed_input);
|
||||
input = unsqueezed_input->output(0);
|
||||
}
|
||||
const auto& transposed_input = op::util::make_try_fold<opset6::Transpose>(input, reverse_order_constant);
|
||||
new_ops.push_back(transposed_input);
|
||||
fq_inputs.push_back(transposed_input);
|
||||
}
|
||||
auto new_fq = fq->copy_with_new_inputs(fq_inputs);
|
||||
new_ops.push_back(new_fq);
|
||||
|
||||
auto new_transpose = std::make_shared<ngraph::opset6::Transpose>(new_fq, transpose_order);
|
||||
new_ops.push_back(new_transpose);
|
||||
new_transpose->set_friendly_name(fq->get_friendly_name());
|
||||
|
||||
ngraph::copy_runtime_info({fq, transpose}, new_ops);
|
||||
ngraph::replace_node(fq, new_transpose);
|
||||
// The root node (reduction) left unchanged during current matcher pass.
|
||||
// We return false here for further MatcherPasses to be applicable for this node as a root node
|
||||
return false;
|
||||
};
|
||||
|
||||
auto m = std::make_shared<ngraph::pattern::Matcher>(reduce_or_squeeze_label, matcher_name);
|
||||
register_matcher(m, matcher_pass_callback);
|
||||
}
|
@ -130,6 +130,18 @@ bool is_seq_len_provided(const std::shared_ptr<Node> &seq_len_input, int64_t max
|
||||
return true;
|
||||
}
|
||||
|
||||
std::shared_ptr<Node> try_fold_unary_output(const std::shared_ptr<Node>& node) {
|
||||
const auto& num_outputs = node->get_output_size();
|
||||
NGRAPH_CHECK(num_outputs == 1, "Unary has unexpected number of outputs:" + std::to_string(num_outputs));
|
||||
OutputVector output(num_outputs);
|
||||
return node->constant_fold(output, node->input_values()) ? output[0].get_node_shared_ptr() : node;
|
||||
}
|
||||
|
||||
std::shared_ptr<Node> clone_try_fold(const std::shared_ptr<Node>& node, const OutputVector& inputs) {
|
||||
auto unary_output_node = node->clone_with_new_inputs(inputs);
|
||||
return try_fold_unary_output(unary_output_node);
|
||||
}
|
||||
|
||||
} // namespace util
|
||||
} // namespace op
|
||||
} // namespace ngraph
|
||||
|
@ -19,6 +19,7 @@
|
||||
#include <transformations/common_optimizations/algebraic_simplification.hpp>
|
||||
#include <transformations/utils/utils.hpp>
|
||||
#include <transformations/init_node_info.hpp>
|
||||
#include <transformations/common_optimizations/transpose_sinking.hpp>
|
||||
|
||||
#include "common_test_utils/ngraph_test_utils.hpp"
|
||||
|
||||
@ -311,8 +312,7 @@ TEST(algebraic_simplification, replace_transpose_with_reshape) {
|
||||
|
||||
pass::Manager pass_manager;
|
||||
pass_manager.register_pass<pass::Validate>();
|
||||
pass_manager.register_pass<pass::AlgebraicSimplification>();
|
||||
pass_manager.register_pass<pass::ConstantFolding>();
|
||||
pass_manager.register_pass<pass::TransposeSinking>();
|
||||
pass_manager.run_passes(optimized_f);
|
||||
|
||||
auto ps = baseline_f->get_results()[0]->get_output_partial_shape(0);
|
||||
|
@ -0,0 +1,203 @@
|
||||
// Copyright (C) 2018-2021 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#include <gtest/gtest.h>
|
||||
|
||||
#include <string>
|
||||
#include <memory>
|
||||
|
||||
#include <ngraph/function.hpp>
|
||||
#include <ngraph/opsets/opset6.hpp>
|
||||
#include <ngraph/pass/manager.hpp>
|
||||
#include <transformations/common_optimizations/transpose_sinking.hpp>
|
||||
#include <transformations/init_node_info.hpp>
|
||||
#include <ngraph_functions/utils/ngraph_helpers.hpp>
|
||||
|
||||
#include "common_test_utils/ngraph_test_utils.hpp"
|
||||
|
||||
using namespace testing;
|
||||
using namespace ngraph;
|
||||
|
||||
struct TransposeFQReduceParams {
|
||||
// given params
|
||||
PartialShape transpose_input_shape;
|
||||
std::vector<int32_t> transpose_order;
|
||||
Shape il, ih, ol, oh;
|
||||
std::vector<int32_t> reduce_axes;
|
||||
bool reduce_keep_dims;
|
||||
|
||||
// expected params
|
||||
Shape ex_il, ex_ih, ex_ol, ex_oh;
|
||||
std::vector<int32_t> ex_reduce_axes;
|
||||
std::vector<int32_t> ex_transpose_order;
|
||||
};
|
||||
|
||||
class TransposeSinkingFQ : public CommonTestUtils::TestsCommon,
|
||||
public testing::WithParamInterface<std::tuple<TransposeFQReduceParams>> {
|
||||
public:
|
||||
std::shared_ptr<Function> f, f_ref;
|
||||
|
||||
void SetUp() override {
|
||||
const auto& test_case = std::get<0>(GetParam());
|
||||
|
||||
{
|
||||
auto input = std::make_shared<opset6::Parameter>(element::f32, test_case.transpose_input_shape);
|
||||
|
||||
auto order = std::make_shared<opset6::Constant>(element::i64, Shape{test_case.transpose_order.size()}, test_case.transpose_order);
|
||||
auto transpose = std::make_shared<ngraph::opset6::Transpose>(input, order);
|
||||
|
||||
auto i_low = std::make_shared<ngraph::opset6::Constant>(element::i64, test_case.il, std::vector<int32_t>{0});
|
||||
auto i_high = std::make_shared<ngraph::opset6::Constant>(element::i64, test_case.ih, std::vector<int32_t>{0});
|
||||
auto o_low = std::make_shared<ngraph::opset6::Constant>(element::i64, test_case.ol, std::vector<int32_t>{0});
|
||||
auto o_high = std::make_shared<ngraph::opset6::Constant>(element::i64, test_case.oh, std::vector<int32_t>{0});
|
||||
auto fq = std::make_shared<ngraph::opset6::FakeQuantize>(transpose, i_low, i_high, o_low, o_high, 256);
|
||||
|
||||
auto axes = std::make_shared<ngraph::opset6::Constant>(
|
||||
element::i64, Shape{test_case.reduce_axes.size()}, test_case.reduce_axes);
|
||||
auto reduce = std::make_shared<ngraph::opset6::ReduceMean>(fq, axes, test_case.reduce_keep_dims);
|
||||
|
||||
f = std::make_shared<ngraph::Function>(ngraph::NodeVector{reduce}, ngraph::ParameterVector{input});
|
||||
}
|
||||
|
||||
{
|
||||
auto input = std::make_shared<opset6::Parameter>(element::f32, test_case.transpose_input_shape);
|
||||
|
||||
auto i_low = std::make_shared<ngraph::opset6::Constant>(element::i64, test_case.ex_il, std::vector<int32_t>{0});
|
||||
auto i_high = std::make_shared<ngraph::opset6::Constant>(element::i64, test_case.ex_ih, std::vector<int32_t>{0});
|
||||
auto o_low = std::make_shared<ngraph::opset6::Constant>(element::i64, test_case.ex_ol, std::vector<int32_t>{0});
|
||||
auto o_high = std::make_shared<ngraph::opset6::Constant>(element::i64, test_case.ex_oh, std::vector<int32_t>{0});
|
||||
auto fq = std::make_shared<ngraph::opset6::FakeQuantize>(input, i_low, i_high, o_low, o_high, 256);
|
||||
|
||||
auto axes = std::make_shared<ngraph::opset6::Constant>(
|
||||
element::i64, Shape{test_case.ex_reduce_axes.size()}, test_case.ex_reduce_axes);
|
||||
auto reduce = std::make_shared<ngraph::opset6::ReduceMean>(fq, axes, test_case.reduce_keep_dims);
|
||||
|
||||
auto order = std::make_shared<opset6::Constant>(element::i64, Shape{test_case.ex_transpose_order.size()}, test_case.ex_transpose_order);
|
||||
auto transpose = std::make_shared<ngraph::opset6::Transpose>(reduce, order);
|
||||
|
||||
f_ref = std::make_shared<ngraph::Function>(ngraph::NodeVector{transpose}, ngraph::ParameterVector{input});
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
TEST_P(TransposeSinkingFQ, TransposeFQReduce) {
|
||||
ngraph::pass::Manager manager;
|
||||
manager.register_pass<ngraph::pass::InitNodeInfo>();
|
||||
manager.register_pass<ngraph::pass::TransposeFQReduction>();
|
||||
manager.register_pass<ngraph::pass::TransposeReduction>();
|
||||
manager.run_passes(f);
|
||||
ASSERT_NO_THROW(check_rt_info(f));
|
||||
|
||||
auto res = compare_functions(f, f_ref, true);
|
||||
ASSERT_TRUE(res.first) << res.second;
|
||||
}
|
||||
|
||||
|
||||
INSTANTIATE_TEST_CASE_P(TransformationTest, TransposeSinkingFQ, testing::Values(
|
||||
TransposeFQReduceParams{{1, 3, 240, 140}, {0, 2, 3, 1}, {1}, {3}, {1, 1, 1, 1}, {1, 1, 1, 3}, {1, 2}, true,
|
||||
{1, 1, 1, 1}, {1, 3, 1, 1}, {1, 1, 1, 1}, {1, 3, 1, 1}, {2, 3}, {0, 2, 3, 1}},
|
||||
TransposeFQReduceParams{{1, 3, 240, 140}, {0, 2, 3, 1}, {1}, {3}, {1, 1, 1, 1}, {1, 1, 1, 3}, {1, 2}, false,
|
||||
{1, 1, 1, 1}, {1, 3, 1, 1}, {1, 1, 1, 1}, {1, 3, 1, 1}, {2, 3}, {0, 1}}));
|
||||
|
||||
|
||||
|
||||
struct TransposeReduceParams {
|
||||
// given params
|
||||
PartialShape transpose_input_shape;
|
||||
std::vector<int32_t> transpose_order;
|
||||
std::vector<int32_t> reduce_axes;
|
||||
bool reduction_keep_dims;
|
||||
|
||||
// expected params
|
||||
std::vector<int32_t> ex_reduce_axes;
|
||||
std::vector<int32_t> ex_transpose_order;
|
||||
};
|
||||
|
||||
class TransposeSinking : public CommonTestUtils::TestsCommon,
|
||||
public testing::WithParamInterface<std::tuple<TransposeReduceParams, ngraph::NodeTypeInfo>> {
|
||||
public:
|
||||
std::shared_ptr<Function> f, f_ref;
|
||||
|
||||
void SetUp() override {
|
||||
const auto& test_case = std::get<0>(GetParam());
|
||||
const auto& reduction_type_info = std::get<1>(GetParam());
|
||||
|
||||
{
|
||||
auto input = std::make_shared<opset6::Parameter>(element::dynamic, test_case.transpose_input_shape);
|
||||
|
||||
auto order = std::make_shared<opset6::Constant>(element::i64, Shape{test_case.transpose_order.size()}, test_case.transpose_order);
|
||||
auto transpose = std::make_shared<ngraph::opset6::Transpose>(input, order);
|
||||
|
||||
auto axes = std::make_shared<ngraph::opset6::Constant>(
|
||||
element::i64, Shape{test_case.reduce_axes.size()}, test_case.reduce_axes);
|
||||
|
||||
auto reduction = get_reduction(reduction_type_info, {transpose, axes}, test_case.reduction_keep_dims);
|
||||
|
||||
f = std::make_shared<ngraph::Function>(ngraph::NodeVector{reduction}, ngraph::ParameterVector{input});
|
||||
}
|
||||
|
||||
{
|
||||
auto input = std::make_shared<opset6::Parameter>(element::dynamic, test_case.transpose_input_shape);
|
||||
|
||||
auto axes = std::make_shared<ngraph::opset6::Constant>(
|
||||
element::i64, Shape{test_case.ex_reduce_axes.size()}, test_case.ex_reduce_axes);
|
||||
auto reduction = get_reduction(reduction_type_info, {input, axes}, test_case.reduction_keep_dims);
|
||||
|
||||
auto order = std::make_shared<opset6::Constant>(element::i64, Shape{test_case.ex_transpose_order.size()}, test_case.ex_transpose_order);
|
||||
auto transpose = std::make_shared<ngraph::opset6::Transpose>(reduction, order);
|
||||
|
||||
f_ref = std::make_shared<ngraph::Function>(ngraph::NodeVector{transpose}, ngraph::ParameterVector{input});
|
||||
}
|
||||
}
|
||||
private:
|
||||
std::shared_ptr<Node> get_reduction(ngraph::NodeTypeInfo reduction_type_info, const OutputVector& inputs, bool keep_dims) {
|
||||
auto reduction = ngraph::helpers::getNodeSharedPtr(reduction_type_info, inputs);
|
||||
if (auto arithmetic_reduce = std::dynamic_pointer_cast<op::util::ArithmeticReductionKeepDims>(reduction))
|
||||
arithmetic_reduce->set_keep_dims(keep_dims);
|
||||
else if (auto logical_reduce = std::dynamic_pointer_cast<op::util::LogicalReductionKeepDims>(reduction))
|
||||
logical_reduce->set_keep_dims(keep_dims);
|
||||
reduction->validate_and_infer_types();
|
||||
return reduction;
|
||||
}
|
||||
};
|
||||
|
||||
TEST_P(TransposeSinking, TransposeReduction) {
|
||||
ngraph::pass::Manager manager;
|
||||
manager.register_pass<ngraph::pass::InitNodeInfo>();
|
||||
manager.register_pass<ngraph::pass::TransposeReduction>();
|
||||
manager.run_passes(f);
|
||||
ASSERT_NO_THROW(check_rt_info(f));
|
||||
|
||||
auto res = compare_functions(f, f_ref, true);
|
||||
|
||||
ASSERT_TRUE(res.first) << res.second;
|
||||
}
|
||||
|
||||
|
||||
INSTANTIATE_TEST_CASE_P(TransposeSinkingReduces, TransposeSinking, testing::Combine(
|
||||
testing::Values(
|
||||
TransposeReduceParams{{1, 3, 240, 140}, {0, 2, 3, 1}, {1, 2}, true, {2, 3}, {0, 2, 3, 1}},
|
||||
TransposeReduceParams{{10, 20, 30, 40, 50, 60, 70}, {0, 6, 1, 5, 2, 4, 3}, {1, 3, 6}, true, {6, 5, 3}, {0, 6, 1, 5, 2, 4, 3}},
|
||||
TransposeReduceParams{{1, 3, 240, 140}, {0, 2, 3, 1}, {1, 2}, false, {2, 3}, {0, 1}},
|
||||
TransposeReduceParams{{10, 20, 30, 40, 50, 60, 70}, {0, 6, 1, 5, 2, 4, 3}, {1, 3, 6}, false, {6, 5, 3}, {0, 1, 2, 3}},
|
||||
TransposeReduceParams{{10, 20, 30, 40, 50, 60, 70}, {0, 6, 1, 5, 2, 4, 3}, {1, -4, 6}, false, {6, 5, 3}, {0, 1, 2, 3}},
|
||||
TransposeReduceParams{{1, 3, 240, 140}, {0, 1, 2, 3}, {0, 1, 2, -1}, false, {0, 1, 2, 3}, {}}),
|
||||
testing::Values(
|
||||
ngraph::opset6::ReduceMax::type_info,
|
||||
ngraph::opset6::ReduceMean::type_info,
|
||||
ngraph::opset6::ReduceMin::type_info,
|
||||
ngraph::opset6::ReduceProd::type_info,
|
||||
ngraph::opset6::ReduceSum::type_info,
|
||||
ngraph::opset6::ReduceL1::type_info,
|
||||
ngraph::opset6::ReduceL2::type_info,
|
||||
ngraph::opset6::ReduceLogicalAnd::type_info,
|
||||
ngraph::opset6::ReduceLogicalOr::type_info)));
|
||||
|
||||
INSTANTIATE_TEST_CASE_P(TransposeSinkingSqueeze, TransposeSinking, testing::Combine(
|
||||
testing::Values(
|
||||
TransposeReduceParams{{2, 3, 1, 1}, {0, 2, 3, 1}, {1, 2}, false, {2, 3}, {0, 1}},
|
||||
TransposeReduceParams{{10, 20, 30, 1, 50, 1, 1}, {0, 6, 1, 5, 2, 4, 3}, {1, 3, 6}, false, {6, 5, 3}, {0, 1, 2, 3}}),
|
||||
testing::Values(
|
||||
ngraph::opset6::Squeeze::type_info)));
|
||||
|
@ -19,6 +19,7 @@
|
||||
#include <transformations/init_node_info.hpp>
|
||||
#include <transformations/common_optimizations/algebraic_simplification.hpp>
|
||||
#include <ngraph/pass/visualize_tree.hpp>
|
||||
#include <transformations/common_optimizations/transpose_sinking.hpp>
|
||||
|
||||
#include "common_test_utils/ngraph_test_utils.hpp"
|
||||
|
||||
@ -97,7 +98,7 @@ private:
|
||||
|
||||
TEST_P(TransposeToReshapeTests, CompareFunctions) {
|
||||
ngraph::pass::InitNodeInfo().run_on_function(f);
|
||||
ngraph::pass::AlgebraicSimplification().run_on_function(f);
|
||||
ngraph::pass::TransposeSinking().run_on_function(f);
|
||||
f->validate_nodes_and_infer_types();
|
||||
ASSERT_NO_THROW(check_rt_info(f));
|
||||
auto res = compare_functions(f, f_ref);
|
||||
|
@ -16,8 +16,7 @@ namespace ngraph
|
||||
class NGRAPH_API ReduceMax : public util::ArithmeticReductionKeepDims
|
||||
{
|
||||
public:
|
||||
static constexpr NodeTypeInfo type_info{"ReduceMax", 1};
|
||||
const NodeTypeInfo& get_type_info() const override { return type_info; }
|
||||
NGRAPH_RTTI_DECLARATION;
|
||||
/// \brief Constructs a summation operation.
|
||||
ReduceMax() = default;
|
||||
/// \brief Constructs a summation operation.
|
||||
|
@ -16,8 +16,7 @@ namespace ngraph
|
||||
class NGRAPH_API ReduceMin : public util::ArithmeticReductionKeepDims
|
||||
{
|
||||
public:
|
||||
static constexpr NodeTypeInfo type_info{"ReduceMin", 1};
|
||||
const NodeTypeInfo& get_type_info() const override { return type_info; }
|
||||
NGRAPH_RTTI_DECLARATION;
|
||||
/// \brief Constructs a summation operation.
|
||||
ReduceMin() = default;
|
||||
/// \brief Constructs a summation operation.
|
||||
|
@ -19,8 +19,7 @@ namespace ngraph
|
||||
class NGRAPH_API ReduceL1 : public util::ArithmeticReductionKeepDims
|
||||
{
|
||||
public:
|
||||
static constexpr NodeTypeInfo type_info{"ReduceL1", 4};
|
||||
const NodeTypeInfo& get_type_info() const override { return type_info; }
|
||||
NGRAPH_RTTI_DECLARATION;
|
||||
/// \brief Constructs a reducet L1-norm operation.
|
||||
ReduceL1() = default;
|
||||
/// \brief Constructs a reduce L1-norm operation.
|
||||
|
@ -18,8 +18,7 @@ namespace ngraph
|
||||
class NGRAPH_API ReduceL2 : public util::ArithmeticReductionKeepDims
|
||||
{
|
||||
public:
|
||||
static constexpr NodeTypeInfo type_info{"ReduceL2", 4};
|
||||
const NodeTypeInfo& get_type_info() const override { return type_info; }
|
||||
NGRAPH_RTTI_DECLARATION;
|
||||
/// \brief Constructs a reducet L2-norm operation.
|
||||
ReduceL2() = default;
|
||||
/// \brief Constructs a reduce L2-norm operation.
|
||||
|
@ -16,8 +16,7 @@ namespace ngraph
|
||||
class NGRAPH_API ReduceMean : public util::ArithmeticReductionKeepDims
|
||||
{
|
||||
public:
|
||||
static constexpr NodeTypeInfo type_info{"ReduceMean", 1};
|
||||
const NodeTypeInfo& get_type_info() const override { return type_info; }
|
||||
NGRAPH_RTTI_DECLARATION;
|
||||
ReduceMean() = default;
|
||||
|
||||
/// \param arg The tensor to be summed.
|
||||
|
@ -18,8 +18,7 @@ namespace ngraph
|
||||
class NGRAPH_API ReduceProd : public util::ArithmeticReductionKeepDims
|
||||
{
|
||||
public:
|
||||
static constexpr NodeTypeInfo type_info{"ReduceProd", 1};
|
||||
const NodeTypeInfo& get_type_info() const override { return type_info; }
|
||||
NGRAPH_RTTI_DECLARATION;
|
||||
/// \brief Constructs a product reduction operation.
|
||||
ReduceProd() = default;
|
||||
/// \brief Constructs a product reduction operation.
|
||||
|
@ -65,8 +65,7 @@ namespace ngraph
|
||||
class NGRAPH_API ReduceSum : public util::ArithmeticReductionKeepDims
|
||||
{
|
||||
public:
|
||||
static constexpr NodeTypeInfo type_info{"ReduceSum", 1};
|
||||
const NodeTypeInfo& get_type_info() const override { return type_info; }
|
||||
NGRAPH_RTTI_DECLARATION;
|
||||
/// \brief Constructs a summation operation.
|
||||
ReduceSum() = default;
|
||||
/// \brief Constructs a summation operation.
|
||||
|
@ -21,11 +21,6 @@ namespace ngraph
|
||||
/// \brief Constructs an arithmetic reduction operation.
|
||||
ArithmeticReduction();
|
||||
|
||||
/// \brief Constructs an arithmetic reduction operation.
|
||||
///
|
||||
/// \param arg Output that produces the first input tensor.
|
||||
/// \param reduction_axes The axis positions (0-based) to be eliminated.
|
||||
ArithmeticReduction(const Output<Node>& arg, const AxisSet& reduction_axes);
|
||||
/// \brief Constructs an arithmetic reduction operation.
|
||||
///
|
||||
/// \param arg Output that produces the first input tensor.
|
||||
@ -33,6 +28,7 @@ namespace ngraph
|
||||
ArithmeticReduction(const Output<Node>& arg, const Output<Node>& reduction_axes);
|
||||
|
||||
public:
|
||||
NGRAPH_RTTI_DECLARATION;
|
||||
void validate_and_infer_types() override;
|
||||
|
||||
/// \return true if reduction axes are constant else false.
|
||||
|
@ -28,6 +28,7 @@ namespace ngraph
|
||||
bool visit_attributes(AttributeVisitor& visitor) override;
|
||||
|
||||
public:
|
||||
NGRAPH_RTTI_DECLARATION;
|
||||
void validate_and_infer_types() override;
|
||||
|
||||
/// \return If set to 1 it holds axes that are used for reduction.
|
||||
|
@ -32,6 +32,7 @@ namespace ngraph
|
||||
LogicalReduction(const Output<Node>& arg, const Output<Node>& reduction_axes);
|
||||
|
||||
public:
|
||||
NGRAPH_RTTI_DECLARATION;
|
||||
void validate_and_infer_types() override;
|
||||
|
||||
/// \return true if reduction axes are constant else false.
|
||||
|
@ -28,6 +28,7 @@ namespace ngraph
|
||||
bool visit_attributes(AttributeVisitor& visitor) override;
|
||||
|
||||
public:
|
||||
NGRAPH_RTTI_DECLARATION;
|
||||
void validate_and_infer_types() override;
|
||||
|
||||
/// \return If set to 1 it holds axes that are used for reduction.
|
||||
|
@ -215,9 +215,15 @@ namespace ngraph
|
||||
NGRAPH_API
|
||||
AxisVector get_default_order(size_t rank);
|
||||
|
||||
NGRAPH_API
|
||||
AxisVector get_default_order(const Rank& rank);
|
||||
|
||||
NGRAPH_API
|
||||
AxisVector get_default_order(const Shape& shape);
|
||||
|
||||
NGRAPH_API
|
||||
AxisVector get_default_order(const PartialShape& shape);
|
||||
|
||||
//
|
||||
// EnumMask is intended to work with a scoped enum type. It's used to store
|
||||
// a combination of enum values and provides easy access and manipulation
|
||||
|
@ -46,7 +46,7 @@ namespace maxop
|
||||
}
|
||||
}
|
||||
|
||||
constexpr NodeTypeInfo op::v1::ReduceMax::type_info;
|
||||
NGRAPH_RTTI_DEFINITION(op::v1::ReduceMax, "ReduceMax", 1, util::ArithmeticReductionKeepDims);
|
||||
|
||||
op::v1::ReduceMax::ReduceMax(const Output<Node>& arg,
|
||||
const Output<Node>& reduction_axes,
|
||||
|
@ -46,7 +46,7 @@ namespace minop
|
||||
}
|
||||
} // namespace minop
|
||||
|
||||
constexpr NodeTypeInfo op::v1::ReduceMin::type_info;
|
||||
NGRAPH_RTTI_DEFINITION(op::v1::ReduceMin, "ReduceMin", 1, util::ArithmeticReductionKeepDims);
|
||||
|
||||
op::v1::ReduceMin::ReduceMin(const Output<Node>& arg,
|
||||
const Output<Node>& reduction_axes,
|
||||
|
@ -12,7 +12,7 @@
|
||||
using namespace std;
|
||||
using namespace ngraph;
|
||||
|
||||
constexpr NodeTypeInfo op::v4::ReduceL1::type_info;
|
||||
NGRAPH_RTTI_DEFINITION(op::v4::ReduceL1, "ReduceL1", 4, util::ArithmeticReductionKeepDims);
|
||||
|
||||
op::v4::ReduceL1::ReduceL1(const Output<Node>& arg,
|
||||
const Output<Node>& reduction_axes,
|
||||
|
@ -12,7 +12,7 @@
|
||||
using namespace std;
|
||||
using namespace ngraph;
|
||||
|
||||
constexpr NodeTypeInfo op::v4::ReduceL2::type_info;
|
||||
NGRAPH_RTTI_DEFINITION(op::v4::ReduceL2, "ReduceL2", 4, util::ArithmeticReductionKeepDims);
|
||||
|
||||
op::v4::ReduceL2::ReduceL2(const Output<Node>& arg,
|
||||
const Output<Node>& reduction_axes,
|
||||
|
@ -12,7 +12,10 @@
|
||||
using namespace ngraph;
|
||||
using namespace std;
|
||||
|
||||
NGRAPH_RTTI_DEFINITION(op::v1::ReduceLogicalAnd, "ReduceLogicalAnd", 1);
|
||||
NGRAPH_RTTI_DEFINITION(op::v1::ReduceLogicalAnd,
|
||||
"ReduceLogicalAnd",
|
||||
1,
|
||||
util::LogicalReductionKeepDims);
|
||||
|
||||
op::v1::ReduceLogicalAnd::ReduceLogicalAnd(const Output<Node>& data,
|
||||
const Output<Node>& reduction_axes,
|
||||
|
@ -12,7 +12,10 @@
|
||||
using namespace ngraph;
|
||||
using namespace std;
|
||||
|
||||
NGRAPH_RTTI_DEFINITION(op::v1::ReduceLogicalOr, "ReduceLogicalOr", 1);
|
||||
NGRAPH_RTTI_DEFINITION(op::v1::ReduceLogicalOr,
|
||||
"ReduceLogicalOr",
|
||||
1,
|
||||
util::LogicalReductionKeepDims);
|
||||
|
||||
op::v1::ReduceLogicalOr::ReduceLogicalOr(const Output<Node>& data,
|
||||
const Output<Node>& reduction_axes,
|
||||
|
@ -13,7 +13,7 @@
|
||||
using namespace std;
|
||||
using namespace ngraph;
|
||||
|
||||
constexpr NodeTypeInfo op::v1::ReduceMean::type_info;
|
||||
NGRAPH_RTTI_DEFINITION(op::v1::ReduceMean, "ReduceMean", 1, util::ArithmeticReductionKeepDims);
|
||||
|
||||
op::v1::ReduceMean::ReduceMean(const Output<Node>& arg,
|
||||
const Output<Node>& reduction_axes,
|
||||
|
@ -13,7 +13,7 @@
|
||||
using namespace std;
|
||||
using namespace ngraph;
|
||||
|
||||
constexpr NodeTypeInfo op::v1::ReduceProd::type_info;
|
||||
NGRAPH_RTTI_DEFINITION(op::v1::ReduceProd, "ReduceProd", 1, util::ArithmeticReductionKeepDims);
|
||||
|
||||
op::v1::ReduceProd::ReduceProd(const Output<Node>& arg,
|
||||
const Output<Node>& reduction_axes,
|
||||
|
@ -13,7 +13,7 @@
|
||||
using namespace std;
|
||||
using namespace ngraph;
|
||||
|
||||
constexpr NodeTypeInfo op::v1::ReduceSum::type_info;
|
||||
NGRAPH_RTTI_DEFINITION(op::v1::ReduceSum, "ReduceSum", 1, util::ArithmeticReductionKeepDims);
|
||||
|
||||
op::v1::ReduceSum::ReduceSum(const Output<Node>& arg,
|
||||
const Output<Node>& reduction_axes,
|
||||
|
@ -10,17 +10,9 @@
|
||||
using namespace std;
|
||||
using namespace ngraph;
|
||||
|
||||
op::util::ArithmeticReduction::ArithmeticReduction() {}
|
||||
NGRAPH_RTTI_DEFINITION(op::util::ArithmeticReduction, "ArithmeticReduction", 0);
|
||||
|
||||
op::util::ArithmeticReduction::ArithmeticReduction(const Output<Node>& arg,
|
||||
const AxisSet& reduction_axes)
|
||||
: Op({arg,
|
||||
op::Constant::create(
|
||||
element::i64, Shape{reduction_axes.size()}, reduction_axes.to_vector())
|
||||
->output(0)})
|
||||
{
|
||||
add_provenance_group_member(input_value(1).get_node_shared_ptr());
|
||||
}
|
||||
op::util::ArithmeticReduction::ArithmeticReduction() {}
|
||||
|
||||
op::util::ArithmeticReduction::ArithmeticReduction(const Output<Node>& arg,
|
||||
const Output<Node>& reduction_axes)
|
||||
|
@ -11,6 +11,8 @@
|
||||
using namespace std;
|
||||
using namespace ngraph;
|
||||
|
||||
NGRAPH_RTTI_DEFINITION(op::util::ArithmeticReductionKeepDims, "ArithmeticReductionKeepDims", 0);
|
||||
|
||||
op::util::ArithmeticReductionKeepDims::ArithmeticReductionKeepDims(
|
||||
const ngraph::Output<ngraph::Node>& arg,
|
||||
const ngraph::Output<ngraph::Node>& reduction_axes,
|
||||
|
@ -10,6 +10,8 @@
|
||||
using namespace std;
|
||||
using namespace ngraph;
|
||||
|
||||
NGRAPH_RTTI_DEFINITION(op::util::LogicalReduction, "LogicalReduction", 1);
|
||||
|
||||
op::util::LogicalReduction::LogicalReduction() {}
|
||||
|
||||
op::util::LogicalReduction::LogicalReduction(const Output<Node>& arg, const AxisSet& reduction_axes)
|
||||
|
@ -11,6 +11,8 @@
|
||||
using namespace std;
|
||||
using namespace ngraph;
|
||||
|
||||
NGRAPH_RTTI_DEFINITION(op::util::LogicalReductionKeepDims, "LogicalReductionKeepDims", 1);
|
||||
|
||||
op::util::LogicalReductionKeepDims::LogicalReductionKeepDims(
|
||||
const ngraph::Output<ngraph::Node>& arg,
|
||||
const ngraph::Output<ngraph::Node>& reduction_axes,
|
||||
|
@ -401,6 +401,11 @@ AxisVector ngraph::get_default_order(const Shape& shape)
|
||||
return get_default_order(shape.size());
|
||||
}
|
||||
|
||||
AxisVector ngraph::get_default_order(const PartialShape& shape)
|
||||
{
|
||||
return get_default_order(shape.rank());
|
||||
}
|
||||
|
||||
AxisVector ngraph::get_default_order(size_t rank)
|
||||
{
|
||||
AxisVector default_order(rank);
|
||||
@ -408,6 +413,15 @@ AxisVector ngraph::get_default_order(size_t rank)
|
||||
return default_order;
|
||||
}
|
||||
|
||||
AxisVector ngraph::get_default_order(const Rank& rank)
|
||||
{
|
||||
NGRAPH_CHECK(rank.is_static(), "Can not calculate default order for dynamic rank");
|
||||
|
||||
AxisVector default_order(rank.get_length());
|
||||
std::iota(begin(default_order), end(default_order), 0);
|
||||
return default_order;
|
||||
}
|
||||
|
||||
void ngraph::parse_version_string(
|
||||
std::string version, size_t& major, size_t& minor, size_t& patch, string& extra)
|
||||
{
|
||||
|
Loading…
Reference in New Issue
Block a user