TransposeSinking: add support for Slice and Reshape ops (#16208)
* Resolve the performance issues in TransposeSinking transformation * codestyle * fix warning as error, fix tests failures * fix ts for Concat and Reduce * Fix TransposeReduceBackward * fix the issue in TransposeFuse transformation * fix TransposeReduce transformations * Fix TransposeReduction, fix TransposeSinkingSplit, add unsqueeze support * delete debug print * Add additional validations * fix node validation * Fix validate for split, revert changes for concat, add BatchToSpace/SpaceToBatch * Add SpaceToBatch/BatchToSpace * fix TS for Interpolate + codestyle * fix gna build * Support TS for Interpolate, VariadicSplit, IsInf, IsNan, IsFinite + refactoring * add the missed line * add include * TransposeSinking tests refactoring: part1 * TransposeSinking tests refactoring: part2 * Add limited support for StridedSlice op * codestye * TransposeReduction: skip the case when 2nd input for Squeeze is not provided * Transpose sinking tests refactoring: part 3. + Revert changes in MOC. * fix build * codestyle * Add tests for TS backward transformations, update TransposeSinkingFuse transformation, delete StridedSlice transformation prototype + tests refactoring * fix unary tests * Fix warning as error on Windows * Add new tests for Unsqueeze/Squeeze; refactoring; remove debug code * TransposeSinking: add support for Slice op * Add descriptions to the transformations, add additional checks * fix a warning * TransposeSinking Rafactoring part2: move the transformations to a separate folder, align namespaces * TransposeSinking refactoring: class names, namespaces * codestyle * resolve merge conflicts * codestyle * TSReduction refactoring, move Unsqueeze/Squeeze transformations to separate files, added limited support for Reshape op + tests * fix minor mistakes * fix warnings * Added TSSlice transformation to TSGeneral, created TransposeSinkingGeneral alias in ov::pass namespace * refactoring * codestyle * fix TSSqueeze/TSUnsqueeze transformations * delete debug serialize * remove TransposeSinking from MOC * fix TSSqueeze/TSUnsqueeze transformations in case of Reshape op * delete debug code * codestyle * fix unit tests, revert changes for TSSlice transformation * fix TSSqueeze transformation * resolve review comments * codestyle
This commit is contained in:
parent
c5b348dd4f
commit
5a8a195dad
@ -16,6 +16,9 @@ class TRANSFORMATIONS_API TSGeneralBackward;
|
||||
class TRANSFORMATIONS_API TSGeneral;
|
||||
|
||||
} // namespace transpose_sinking
|
||||
|
||||
using TransposeSinkingGeneral = ov::pass::transpose_sinking::TSGeneral;
|
||||
|
||||
} // namespace pass
|
||||
} // namespace ov
|
||||
|
||||
|
@ -21,7 +21,7 @@ class TRANSFORMATIONS_API TSReductionBackward;
|
||||
|
||||
/**
|
||||
* @ingroup ie_transformation_common_api
|
||||
* @brief TransposeReductionForward transformation sinks Transpose through Reduce, Squeeze, Unsqueeze operations
|
||||
* @brief TransposeReductionForward transformation sinks Transpose through Reduce operations
|
||||
* in the forward direction.
|
||||
*/
|
||||
class ov::pass::transpose_sinking::TSReductionForward : public ov::pass::MatcherPass {
|
||||
@ -32,7 +32,7 @@ public:
|
||||
|
||||
/**
|
||||
* @ingroup ie_transformation_common_api
|
||||
* @brief TransposeReductionBackward transformation sinks Transpose through Reduce, Squeeze, Unsqueeze operations
|
||||
* @brief TransposeReductionBackward transformation sinks Transpose through Reduce operations
|
||||
* in the backward direction.
|
||||
*/
|
||||
class ov::pass::transpose_sinking::TSReductionBackward : public ov::pass::MatcherPass {
|
||||
|
@ -0,0 +1,32 @@
|
||||
// Copyright (C) 2022-2023 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "openvino/pass/graph_rewrite.hpp"
|
||||
#include "openvino/pass/pass.hpp"
|
||||
#include "transformations_visibility.hpp"
|
||||
|
||||
namespace ov {
|
||||
namespace pass {
|
||||
namespace transpose_sinking {
|
||||
|
||||
class TRANSFORMATIONS_API TSSliceForward;
|
||||
class TRANSFORMATIONS_API TSSliceBackward;
|
||||
|
||||
} // namespace transpose_sinking
|
||||
} // namespace pass
|
||||
} // namespace ov
|
||||
|
||||
class ov::pass::transpose_sinking::TSSliceForward : public ov::pass::MatcherPass {
|
||||
public:
|
||||
OPENVINO_RTTI("ov::pass::TSSliceForward", "0");
|
||||
TSSliceForward();
|
||||
};
|
||||
|
||||
class ov::pass::transpose_sinking::TSSliceBackward : public ov::pass::MatcherPass {
|
||||
public:
|
||||
OPENVINO_RTTI("ov::pass::TSSliceBackward", "0");
|
||||
TSSliceBackward();
|
||||
};
|
@ -0,0 +1,42 @@
|
||||
// Copyright (C) 2022-2023 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "openvino/pass/graph_rewrite.hpp"
|
||||
#include "openvino/pass/pass.hpp"
|
||||
#include "transformations_visibility.hpp"
|
||||
|
||||
namespace ov {
|
||||
namespace pass {
|
||||
namespace transpose_sinking {
|
||||
|
||||
class TRANSFORMATIONS_API TSSqueezeForward;
|
||||
class TRANSFORMATIONS_API TSSqueezeBackward;
|
||||
|
||||
} // namespace transpose_sinking
|
||||
} // namespace pass
|
||||
} // namespace ov
|
||||
|
||||
/**
|
||||
* @ingroup ie_transformation_common_api
|
||||
* @brief TSSqueezeForward transformation sinks Transpose through Reshape, Squeeze operations
|
||||
* in the forward direction.
|
||||
*/
|
||||
class ov::pass::transpose_sinking::TSSqueezeForward : public ov::pass::MatcherPass {
|
||||
public:
|
||||
OPENVINO_RTTI("ov::pass::TSSqueezeForward", "0");
|
||||
TSSqueezeForward();
|
||||
};
|
||||
|
||||
/**
|
||||
* @ingroup ie_transformation_common_api
|
||||
* @brief TSSqueezeBackward transformation sinks Transpose through Reshape, Squeeze operations
|
||||
* in the backward direction.
|
||||
*/
|
||||
class ov::pass::transpose_sinking::TSSqueezeBackward : public ov::pass::MatcherPass {
|
||||
public:
|
||||
OPENVINO_RTTI("ov::pass::TSSqueezeBackward", "0");
|
||||
TSSqueezeBackward();
|
||||
};
|
@ -0,0 +1,42 @@
|
||||
// Copyright (C) 2022-2023 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "openvino/pass/graph_rewrite.hpp"
|
||||
#include "openvino/pass/pass.hpp"
|
||||
#include "transformations_visibility.hpp"
|
||||
|
||||
namespace ov {
|
||||
namespace pass {
|
||||
namespace transpose_sinking {
|
||||
|
||||
class TRANSFORMATIONS_API TSUnsqueezeForward;
|
||||
class TRANSFORMATIONS_API TSUnsqueezeBackward;
|
||||
|
||||
} // namespace transpose_sinking
|
||||
} // namespace pass
|
||||
} // namespace ov
|
||||
|
||||
/**
|
||||
* @ingroup ie_transformation_common_api
|
||||
* @brief TSUnsqueezeForward transformation sinks Transpose through Unsqueeze, Reshape operations
|
||||
* in the forward direction.
|
||||
*/
|
||||
class ov::pass::transpose_sinking::TSUnsqueezeForward : public ov::pass::MatcherPass {
|
||||
public:
|
||||
OPENVINO_RTTI("ov::pass::TSUnsqueezeForward", "0");
|
||||
TSUnsqueezeForward();
|
||||
};
|
||||
|
||||
/**
|
||||
* @ingroup ie_transformation_common_api
|
||||
* @brief TSUnsqueezeBackward transformation sinks Transpose through Unsqueeze, Reshape operations
|
||||
* in the backward direction.
|
||||
*/
|
||||
class ov::pass::transpose_sinking::TSUnsqueezeBackward : public ov::pass::MatcherPass {
|
||||
public:
|
||||
OPENVINO_RTTI("ov::pass::TSUnsqueezeBackward", "0");
|
||||
TSUnsqueezeBackward();
|
||||
};
|
@ -98,17 +98,43 @@ void UpdateForwardSinkingAbility(const std::shared_ptr<ov::Node>&);
|
||||
bool HasSameOutputTransposeNodes(const ov::Output<ov::Node>&);
|
||||
|
||||
/**
|
||||
* Removes all direct node consumers that have one output
|
||||
* @brief Removes all direct node consumers that have one output
|
||||
*/
|
||||
void RemoveSingleOutputConsumers(const std::shared_ptr<ov::Node>&);
|
||||
|
||||
/**
|
||||
* Changes the order of values in @arg input according to @arg transpose_axis_order along @arg axis
|
||||
* @brief Changes the order of values in @arg input according to @arg transpose_axis_order along @arg axis
|
||||
*/
|
||||
ov::Output<ov::Node> ChangeValuesOrder(const ov::Output<ov::Node>& input,
|
||||
const ov::AxisVector& transpose_axis_order,
|
||||
const std::shared_ptr<ov::opset10::Constant>& axis);
|
||||
|
||||
/**
|
||||
* @brief Returns the updated axes order for case when the initial axes order has more elements
|
||||
* than after TransposeSinking, e.g.:
|
||||
*
|
||||
* before: Transpose(the initial axes order) -> ReduceMax
|
||||
* after : ReduceMax -> Transpose (the updated axes order)
|
||||
*
|
||||
* before: Unsqueeze -> Transpose (the initial axes order)
|
||||
* after : Transpose (the updated axes order) -> Unsqueeze
|
||||
*/
|
||||
std::vector<size_t> GetOrderAfterReduction(const std::vector<size_t>& axes_values,
|
||||
const std::vector<size_t>& order_values);
|
||||
|
||||
/**
|
||||
* @brief Returns the updated axes order for case when the initial axes order has less elements
|
||||
* than after TransposeSinking, e.g.:
|
||||
*
|
||||
* before : ReduceMax -> Transpose (the updated axes order)
|
||||
* after: Transpose(the initial axes order) -> ReduceMax
|
||||
*
|
||||
* before: Transpose (the updated axes order) -> Unsqueeze
|
||||
* after : Unsqueeze -> Transpose (the initial axes order)
|
||||
*/
|
||||
std::vector<size_t> GetOrderBeforeReduction(const std::vector<size_t>& axes_values,
|
||||
const std::vector<size_t>& order_values);
|
||||
|
||||
} // namespace utils
|
||||
} // namespace transpose_sinking
|
||||
} // namespace pass
|
||||
|
@ -236,7 +236,6 @@ bool ov::pass::MOCTransformations::run_on_model(const std::shared_ptr<ngraph::Fu
|
||||
REGISTER_PASS(manager, ReverseInputChannelsFusion)
|
||||
REGISTER_PASS(manager, AlignEltwiseInputRanks)
|
||||
REGISTER_PASS(manager, ConstantFolding)
|
||||
|
||||
manager.run_passes(f);
|
||||
|
||||
if (!m_use_shapes) {
|
||||
|
@ -15,8 +15,11 @@
|
||||
#include "transformations/transpose_sinking/ts_fuse.hpp"
|
||||
#include "transformations/transpose_sinking/ts_interpolate.hpp"
|
||||
#include "transformations/transpose_sinking/ts_reduction.hpp"
|
||||
#include "transformations/transpose_sinking/ts_slice.hpp"
|
||||
#include "transformations/transpose_sinking/ts_split.hpp"
|
||||
#include "transformations/transpose_sinking/ts_squeeze.hpp"
|
||||
#include "transformations/transpose_sinking/ts_unary.hpp"
|
||||
#include "transformations/transpose_sinking/ts_unsqueeze.hpp"
|
||||
#include "transformations/utils/utils.hpp"
|
||||
|
||||
using namespace ov::pass::transpose_sinking;
|
||||
@ -29,7 +32,10 @@ TSGeneralForward::TSGeneralForward() {
|
||||
add_matcher<TSSplitForward>();
|
||||
add_matcher<TSDataMovementForward>();
|
||||
add_matcher<TSReductionForward>();
|
||||
add_matcher<TSSqueezeForward>();
|
||||
add_matcher<TSUnsqueezeForward>();
|
||||
add_matcher<TSInterpolateForward>();
|
||||
add_matcher<TSSliceForward>();
|
||||
add_matcher<TSFuse>();
|
||||
}
|
||||
|
||||
@ -41,7 +47,10 @@ TSGeneralBackward::TSGeneralBackward() {
|
||||
add_matcher<TSSplitBackward>();
|
||||
add_matcher<TSDataMovementBackward>();
|
||||
add_matcher<TSReductionBackward>();
|
||||
add_matcher<TSSqueezeBackward>();
|
||||
add_matcher<TSUnsqueezeBackward>();
|
||||
add_matcher<TSInterpolateBackward>();
|
||||
add_matcher<TSSliceBackward>();
|
||||
add_matcher<TSFuse>();
|
||||
}
|
||||
|
||||
|
@ -18,60 +18,15 @@
|
||||
|
||||
using namespace ov;
|
||||
using namespace opset10;
|
||||
using namespace ov::pass::pattern;
|
||||
using namespace ov::pass::transpose_sinking;
|
||||
using namespace ov::pass::transpose_sinking::utils;
|
||||
|
||||
namespace {
|
||||
std::vector<size_t> get_updated_order_forward(const std::vector<size_t>& axes_values,
|
||||
const std::vector<size_t>& order_values) {
|
||||
size_t buffer_size = order_values.size() - axes_values.size();
|
||||
std::vector<size_t> aligned_order(buffer_size, 0);
|
||||
std::vector<size_t> values_to_reduce(axes_values);
|
||||
for (size_t i = 0; i < values_to_reduce.size(); ++i) {
|
||||
values_to_reduce[i] = order_values[axes_values[i]];
|
||||
}
|
||||
std::sort(values_to_reduce.begin(), values_to_reduce.end());
|
||||
for (size_t i = 0, j = 0; i < order_values.size(); ++i) {
|
||||
if (std::find(axes_values.begin(), axes_values.end(), i) != axes_values.end()) {
|
||||
continue;
|
||||
}
|
||||
|
||||
auto lb = std::lower_bound(values_to_reduce.begin(), values_to_reduce.end(), order_values[i]);
|
||||
aligned_order[j] = order_values[i] - (lb - values_to_reduce.begin());
|
||||
++j;
|
||||
}
|
||||
return aligned_order;
|
||||
}
|
||||
|
||||
std::vector<size_t> get_updated_order_backward(const std::vector<size_t>& axes_values,
|
||||
const std::vector<size_t>& order_values) {
|
||||
size_t buffer_size = order_values.size() + axes_values.size();
|
||||
std::vector<size_t> aligned_order(buffer_size);
|
||||
|
||||
std::vector<int64_t> cnt_deleted(buffer_size);
|
||||
int64_t cnt = 0;
|
||||
for (int64_t i = 0; i < static_cast<int64_t>(cnt_deleted.size()); ++i) {
|
||||
if (std::find(axes_values.begin(), axes_values.end(), i) != axes_values.end()) {
|
||||
cnt++;
|
||||
}
|
||||
cnt_deleted[i] = i - cnt;
|
||||
}
|
||||
|
||||
for (size_t i = 0, j = 0; i < aligned_order.size(); ++i) {
|
||||
if (std::find(axes_values.begin(), axes_values.end(), i) != axes_values.end()) {
|
||||
aligned_order[i] = i;
|
||||
continue;
|
||||
}
|
||||
|
||||
aligned_order[i] = std::find(cnt_deleted.begin(), cnt_deleted.end(), order_values[j]) - cnt_deleted.begin();
|
||||
++j;
|
||||
}
|
||||
return aligned_order;
|
||||
}
|
||||
|
||||
bool get_keep_dims(const std::shared_ptr<Node>& reduction) {
|
||||
auto arithmetic_reduce = std::dynamic_pointer_cast<ov::op::util::ArithmeticReductionKeepDims>(reduction);
|
||||
auto logical_reduce = std::dynamic_pointer_cast<ov::op::util::LogicalReductionKeepDims>(reduction);
|
||||
auto arithmetic_reduce = as_type_ptr<ov::op::util::ArithmeticReductionKeepDims>(reduction);
|
||||
auto logical_reduce = as_type_ptr<ov::op::util::LogicalReductionKeepDims>(reduction);
|
||||
|
||||
bool keep_dims = false; // squeeze/unsqueeze always reduces number of output dimensions
|
||||
if (logical_reduce)
|
||||
@ -80,32 +35,29 @@ bool get_keep_dims(const std::shared_ptr<Node>& reduction) {
|
||||
keep_dims = arithmetic_reduce->get_keep_dims();
|
||||
return keep_dims;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
TSReductionForward::TSReductionForward() {
|
||||
MATCHER_SCOPE(TSReductionForward);
|
||||
|
||||
auto transpose_label = pattern::wrap_type<Transpose>({pattern::any_input(), pattern::wrap_type<Constant>()},
|
||||
pattern::consumers_count(1));
|
||||
auto reduce_or_squeeze_label = pattern::
|
||||
wrap_type<op::util::ArithmeticReductionKeepDims, op::util::LogicalReductionKeepDims, Squeeze, Unsqueeze>(
|
||||
{transpose_label, pattern::wrap_type<Constant>()});
|
||||
auto transpose_label = wrap_type<Transpose>({any_input(), wrap_type<Constant>()});
|
||||
auto reduce_label = wrap_type<op::util::ArithmeticReductionKeepDims, op::util::LogicalReductionKeepDims>(
|
||||
{transpose_label, wrap_type<Constant>()});
|
||||
|
||||
ov::matcher_pass_callback matcher_pass_callback = [=](pattern::Matcher& m) {
|
||||
const auto& pattern_to_output = m.get_pattern_value_map();
|
||||
const auto& pattern_to_output = m.get_pattern_map();
|
||||
|
||||
auto transpose = pattern_to_output.at(transpose_label).get_node_shared_ptr();
|
||||
auto reduction = pattern_to_output.at(reduce_or_squeeze_label).get_node_shared_ptr();
|
||||
auto transpose = pattern_to_output.at(transpose_label);
|
||||
auto reduction = pattern_to_output.at(reduce_label);
|
||||
auto keep_dims = get_keep_dims(reduction);
|
||||
|
||||
auto transpose_order = std::dynamic_pointer_cast<Constant>(transpose->get_input_node_shared_ptr(1));
|
||||
auto reduction_axes = std::dynamic_pointer_cast<Constant>(reduction->get_input_node_shared_ptr(1));
|
||||
auto transpose_order = as_type_ptr<Constant>(transpose->get_input_node_shared_ptr(1));
|
||||
auto reduction_axes = as_type_ptr<Constant>(reduction->get_input_node_shared_ptr(1));
|
||||
if (!transpose_order || !reduction_axes)
|
||||
return false;
|
||||
|
||||
auto unsqueeze = std::dynamic_pointer_cast<Unsqueeze>(reduction);
|
||||
auto rank =
|
||||
unsqueeze ? reduction->get_output_partial_shape(0).rank() : reduction->get_input_partial_shape(0).rank();
|
||||
auto rank = reduction->get_input_partial_shape(0).rank();
|
||||
auto non_negative_axes =
|
||||
normalize_axes(reduction->get_friendly_name(), reduction_axes->cast_vector<int64_t>(), rank);
|
||||
|
||||
@ -117,38 +69,17 @@ TSReductionForward::TSReductionForward() {
|
||||
}
|
||||
|
||||
if (!keep_dims) {
|
||||
if (non_negative_axes.empty()) {
|
||||
auto input_pshape = transpose->input_value(0).get_partial_shape();
|
||||
|
||||
if (input_pshape.is_static()) {
|
||||
for (size_t i = 0; i < input_pshape.size(); ++i) {
|
||||
if (input_pshape[i] == 1) {
|
||||
non_negative_axes.push_back(i);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
if (unsqueeze) {
|
||||
transpose_order_values = get_updated_order_backward(non_negative_axes, transpose_order_values);
|
||||
} else {
|
||||
transpose_order_values = get_updated_order_forward(non_negative_axes, transpose_order_values);
|
||||
}
|
||||
transpose_order_values = GetOrderAfterReduction(non_negative_axes, transpose_order_values);
|
||||
}
|
||||
auto new_transpose_order = std::make_shared<Constant>(transpose_order->get_element_type(),
|
||||
Shape{transpose_order_values.size()},
|
||||
transpose_order_values);
|
||||
|
||||
std::shared_ptr<Node> new_reduction;
|
||||
if (!unsqueeze) {
|
||||
auto new_const =
|
||||
std::make_shared<Constant>(reduction_axes->get_element_type(), reduction_axes->get_shape(), new_values);
|
||||
new_reduction = reduction->clone_with_new_inputs({transpose->input_value(0), new_const});
|
||||
} else {
|
||||
new_reduction = reduction->clone_with_new_inputs({transpose->input_value(0), reduction->input_value(1)});
|
||||
}
|
||||
auto new_transpose_order = Constant::create(transpose_order->get_element_type(),
|
||||
{transpose_order_values.size()},
|
||||
transpose_order_values);
|
||||
|
||||
auto new_const = Constant::create(reduction_axes->get_element_type(), reduction_axes->get_shape(), new_values);
|
||||
auto new_reduction = reduction->clone_with_new_inputs({transpose->input_value(0), new_const});
|
||||
auto new_transpose = transpose->clone_with_new_inputs({new_reduction, new_transpose_order});
|
||||
|
||||
replace_node(reduction, new_transpose);
|
||||
new_reduction->set_friendly_name(transpose->get_friendly_name());
|
||||
new_transpose->set_friendly_name(reduction->get_friendly_name());
|
||||
@ -158,98 +89,55 @@ TSReductionForward::TSReductionForward() {
|
||||
return true;
|
||||
};
|
||||
|
||||
auto m = std::make_shared<pattern::Matcher>(reduce_or_squeeze_label, matcher_name);
|
||||
auto m = std::make_shared<pattern::Matcher>(reduce_label, matcher_name);
|
||||
register_matcher(m, matcher_pass_callback);
|
||||
}
|
||||
|
||||
TSReductionBackward::TSReductionBackward() {
|
||||
MATCHER_SCOPE(TSReductionBackward);
|
||||
|
||||
auto reduce_or_squeeze_label = pattern::
|
||||
wrap_type<op::util::ArithmeticReductionKeepDims, op::util::LogicalReductionKeepDims, Squeeze, Unsqueeze>(
|
||||
{pattern::any_input(), pattern::wrap_type<Constant>()},
|
||||
HasSameOutputTransposeNodes);
|
||||
auto transpose_label = pattern::wrap_type<Transpose>({reduce_or_squeeze_label, pattern::wrap_type<Constant>()});
|
||||
ov::matcher_pass_callback matcher_pass_callback = [=](pattern::Matcher& m) {
|
||||
const auto& pattern_to_output = m.get_pattern_value_map();
|
||||
auto reduce_label = wrap_type<op::util::ArithmeticReductionKeepDims, op::util::LogicalReductionKeepDims>(
|
||||
{any_input(), wrap_type<Constant>()},
|
||||
HasSameOutputTransposeNodes);
|
||||
auto transpose_label = wrap_type<Transpose>({reduce_label, wrap_type<Constant>()});
|
||||
|
||||
auto transpose = pattern_to_output.at(transpose_label).get_node_shared_ptr();
|
||||
auto reduction = pattern_to_output.at(reduce_or_squeeze_label).get_node_shared_ptr();
|
||||
ov::matcher_pass_callback matcher_pass_callback = [=](pattern::Matcher& m) {
|
||||
const auto& pattern_to_output = m.get_pattern_map();
|
||||
auto transpose = pattern_to_output.at(transpose_label);
|
||||
auto reduction = pattern_to_output.at(reduce_label);
|
||||
auto keep_dims = get_keep_dims(reduction);
|
||||
auto transpose_order = std::dynamic_pointer_cast<Constant>(transpose->get_input_node_shared_ptr(1));
|
||||
auto reduction_axes = std::dynamic_pointer_cast<Constant>(reduction->get_input_node_shared_ptr(1));
|
||||
|
||||
auto transpose_order = as_type_ptr<Constant>(transpose->get_input_node_shared_ptr(1));
|
||||
auto reduction_axes = as_type_ptr<Constant>(reduction->get_input_node_shared_ptr(1));
|
||||
if (!transpose_order || !reduction_axes)
|
||||
return false;
|
||||
|
||||
auto unsqueeze = std::dynamic_pointer_cast<Unsqueeze>(reduction);
|
||||
auto rank =
|
||||
unsqueeze ? reduction->get_output_partial_shape(0).rank() : reduction->get_input_partial_shape(0).rank();
|
||||
auto rank = reduction->get_input_partial_shape(0).rank();
|
||||
auto non_negative_axes =
|
||||
normalize_axes(reduction->get_friendly_name(), reduction_axes->cast_vector<int64_t>(), rank);
|
||||
|
||||
auto transpose_order_values = transpose_order->cast_vector<size_t>();
|
||||
auto old_transpose_order_values = transpose_order_values;
|
||||
std::vector<size_t> new_values;
|
||||
if (unsqueeze) {
|
||||
if (non_negative_axes.size() == transpose_order_values.size()) {
|
||||
// input is a scalar, we unsqueeze all dims
|
||||
// it's enough to eliminate such Transpose
|
||||
transpose->output(0).replace(reduction);
|
||||
return true;
|
||||
}
|
||||
for (const auto& axis : non_negative_axes) {
|
||||
auto it = std::find(old_transpose_order_values.begin(), old_transpose_order_values.end(), axis);
|
||||
if (it != old_transpose_order_values.end()) {
|
||||
new_values.push_back(it - old_transpose_order_values.begin());
|
||||
}
|
||||
}
|
||||
}
|
||||
bool squeeze_all_dims = false;
|
||||
if (!keep_dims) {
|
||||
if (non_negative_axes.empty()) {
|
||||
auto input_pshape = reduction->input_value(0).get_partial_shape();
|
||||
if (input_pshape.is_static()) {
|
||||
for (size_t i = 0; i < input_pshape.size(); ++i) {
|
||||
if (input_pshape[i] == 1) {
|
||||
non_negative_axes.push_back(i);
|
||||
}
|
||||
}
|
||||
squeeze_all_dims = true;
|
||||
} else {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
if (unsqueeze) {
|
||||
transpose_order_values = get_updated_order_forward(new_values, transpose_order_values);
|
||||
} else {
|
||||
transpose_order_values = get_updated_order_backward(non_negative_axes, transpose_order_values);
|
||||
}
|
||||
transpose_order_values = GetOrderBeforeReduction(non_negative_axes, transpose_order_values);
|
||||
}
|
||||
auto reversed_order_values = ReverseTransposeOrder(transpose_order_values);
|
||||
auto new_transpose_order = Constant::create(transpose_order->get_element_type(),
|
||||
{transpose_order_values.size()},
|
||||
transpose_order_values);
|
||||
|
||||
std::vector<size_t> new_values;
|
||||
for (const auto& axis : non_negative_axes) {
|
||||
new_values.push_back(reversed_order_values[axis]);
|
||||
}
|
||||
|
||||
if (!unsqueeze) {
|
||||
auto reversed_order_values = ReverseTransposeOrder(transpose_order_values);
|
||||
for (const auto& axis : non_negative_axes) {
|
||||
new_values.push_back(reversed_order_values[axis]);
|
||||
}
|
||||
}
|
||||
auto new_const = Constant::create(reduction_axes->get_element_type(), reduction_axes->get_shape(), new_values);
|
||||
auto new_transpose = transpose->clone_with_new_inputs({reduction->input_value(0), new_transpose_order});
|
||||
auto new_reduction = reduction->clone_with_new_inputs({new_transpose, new_const});
|
||||
|
||||
auto new_transpose_order = std::make_shared<Constant>(transpose_order->get_element_type(),
|
||||
Shape{transpose_order_values.size()},
|
||||
transpose_order_values);
|
||||
std::shared_ptr<Node> new_transpose, new_reduction;
|
||||
if (squeeze_all_dims) {
|
||||
new_transpose = transpose->clone_with_new_inputs({reduction->input_value(0), new_transpose_order});
|
||||
new_reduction = reduction->clone_with_new_inputs({new_transpose, reduction->input_value(1)});
|
||||
} else {
|
||||
auto new_const =
|
||||
std::make_shared<Constant>(reduction_axes->get_element_type(), reduction_axes->get_shape(), new_values);
|
||||
new_transpose = transpose->clone_with_new_inputs({reduction->input_value(0), new_transpose_order});
|
||||
new_reduction = reduction->clone_with_new_inputs({new_transpose, new_const});
|
||||
}
|
||||
replace_node(transpose, new_reduction);
|
||||
copy_runtime_info({transpose, reduction}, {new_transpose, new_reduction});
|
||||
UpdateForwardSinkingAbility(new_transpose);
|
||||
new_reduction->set_friendly_name(transpose->get_friendly_name());
|
||||
new_transpose->set_friendly_name(reduction->get_friendly_name());
|
||||
register_new_node(new_transpose);
|
||||
return true;
|
||||
};
|
||||
|
@ -0,0 +1,116 @@
|
||||
// Copyright (C) 2018-2023 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#include "transformations/transpose_sinking/ts_slice.hpp"
|
||||
|
||||
#include "itt.hpp"
|
||||
#include "openvino/op/util/op_types.hpp"
|
||||
#include "openvino/opsets/opset10.hpp"
|
||||
#include "openvino/pass/pattern/op/or.hpp"
|
||||
#include "openvino/pass/pattern/op/wrap_type.hpp"
|
||||
#include "openvino/util/common_util.hpp"
|
||||
#include "transformations/rt_info/transpose_sinking_attr.hpp"
|
||||
#include "transformations/transpose_sinking/ts_utils.hpp"
|
||||
|
||||
using namespace ov;
|
||||
using namespace ov::opset10;
|
||||
using namespace ov::pass::pattern;
|
||||
using namespace ov::pass::transpose_sinking;
|
||||
using namespace ov::pass::transpose_sinking::utils;
|
||||
|
||||
TSSliceForward::TSSliceForward() {
|
||||
MATCHER_SCOPE(TSSliceForward);
|
||||
auto const_label = wrap_type<Constant>();
|
||||
auto transpose_label = wrap_type<Transpose>({any_input(), const_label});
|
||||
auto main_node_label = wrap_type<Slice>({transpose_label, any_input(), any_input(), any_input(), any_input()});
|
||||
|
||||
matcher_pass_callback matcher_pass_callback = [=](Matcher& m) {
|
||||
const auto& pattern_to_node = m.get_pattern_map();
|
||||
|
||||
auto& main_node = pattern_to_node.at(main_node_label);
|
||||
auto transpose = std::dynamic_pointer_cast<Transpose>(pattern_to_node.at(transpose_label));
|
||||
if (!transpose || main_node->get_input_size() < 5) {
|
||||
return false;
|
||||
}
|
||||
|
||||
auto transpose_const = as_type_ptr<Constant>(pattern_to_node.at(const_label));
|
||||
if (!transpose_const) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// remove Transpose on 1st input:
|
||||
auto transpose_parent = transpose->input_value(0);
|
||||
main_node->input(0).replace_source_output(transpose_parent);
|
||||
|
||||
const auto transpose_axis_order = transpose_const->get_axis_vector_val();
|
||||
auto axis = std::make_shared<Constant>(element::i32, Shape{}, std::vector<int32_t>{0});
|
||||
|
||||
auto data = std::make_shared<Constant>(element::i32, Shape{transpose_axis_order.size()}, transpose_axis_order);
|
||||
const auto& indices = main_node->input_value(4);
|
||||
auto new_axis = std::make_shared<Gather>(data, indices, axis);
|
||||
|
||||
main_node->input(4).replace_source_output(new_axis);
|
||||
|
||||
main_node->validate_and_infer_types();
|
||||
TransposeInputsInfo transpose_input_info = {transpose, transpose_const, 0};
|
||||
for (auto& new_node : sink_forward::InsertOutputTransposes(main_node, transpose_input_info)) {
|
||||
register_new_node(new_node);
|
||||
UpdateForwardSinkingAbility(new_node);
|
||||
}
|
||||
return true;
|
||||
};
|
||||
|
||||
auto m = std::make_shared<Matcher>(main_node_label, matcher_name);
|
||||
register_matcher(m, matcher_pass_callback);
|
||||
}
|
||||
|
||||
TSSliceBackward::TSSliceBackward() {
|
||||
MATCHER_SCOPE(TSSliceBackward);
|
||||
|
||||
auto main_node_label = wrap_type<Slice>([](const Output<Node>& output) -> bool {
|
||||
return has_static_rank()(output) && HasSameOutputTransposeNodes(output);
|
||||
});
|
||||
|
||||
auto transpose_const_label = wrap_type<Constant>();
|
||||
|
||||
auto transpose_label =
|
||||
wrap_type<Transpose>({main_node_label, transpose_const_label}, [](const Output<Node>& output) -> bool {
|
||||
return has_static_rank()(output) && is_sinking_node(output);
|
||||
});
|
||||
|
||||
matcher_pass_callback matcher_pass_callback = [=](Matcher& m) {
|
||||
const auto& pattern_to_output = m.get_pattern_value_map();
|
||||
auto transpose_const = as_type_ptr<Constant>(pattern_to_output.at(transpose_const_label).get_node_shared_ptr());
|
||||
auto transpose = pattern_to_output.at(transpose_label).get_node_shared_ptr();
|
||||
auto main_node = pattern_to_output.at(main_node_label).get_node_shared_ptr();
|
||||
|
||||
if (main_node->get_input_size() < 5) {
|
||||
return false;
|
||||
}
|
||||
|
||||
for (auto& new_node : sink_backward::InsertTransposeBeforeNode(main_node,
|
||||
transpose_const,
|
||||
/* input_indexes= */ {0})) {
|
||||
register_new_node(new_node);
|
||||
}
|
||||
|
||||
// remove output transposes
|
||||
RemoveSingleOutputConsumers(main_node);
|
||||
SwapNames(main_node, transpose);
|
||||
const auto transpose_axis_order = transpose_const->get_axis_vector_val();
|
||||
const auto reversed_transpose_order = ReverseTransposeOrder(transpose_axis_order);
|
||||
auto axis = std::make_shared<Constant>(element::i32, Shape{}, std::vector<int32_t>{0});
|
||||
auto data =
|
||||
std::make_shared<Constant>(element::i32, Shape{reversed_transpose_order.size()}, reversed_transpose_order);
|
||||
const auto& indices = main_node->input_value(4);
|
||||
auto new_axis = std::make_shared<Gather>(data, indices, axis);
|
||||
main_node->input(4).replace_source_output(new_axis);
|
||||
|
||||
main_node->validate_and_infer_types();
|
||||
return true;
|
||||
};
|
||||
|
||||
auto m = std::make_shared<Matcher>(transpose_label, matcher_name);
|
||||
register_matcher(m, matcher_pass_callback);
|
||||
}
|
@ -0,0 +1,271 @@
|
||||
// Copyright (C) 2018-2023 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#include "transformations/transpose_sinking/ts_squeeze.hpp"
|
||||
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
|
||||
#include "itt.hpp"
|
||||
#include "openvino/core/validation_util.hpp"
|
||||
#include "openvino/opsets/opset10.hpp"
|
||||
#include "openvino/pass/pattern/op/wrap_type.hpp"
|
||||
#include "transformations/rt_info/transpose_sinking_attr.hpp"
|
||||
#include "transformations/transpose_sinking/ts_utils.hpp"
|
||||
#include "transformations/utils/utils.hpp"
|
||||
|
||||
using namespace ov;
|
||||
using namespace opset10;
|
||||
using namespace ov::pass::pattern;
|
||||
using namespace ov::pass::transpose_sinking;
|
||||
using namespace ov::pass::transpose_sinking::utils;
|
||||
|
||||
namespace {
|
||||
|
||||
/**
|
||||
* @brief Checks that Reshape operation is equal to Squeeze:
|
||||
* Only 1 dims are deleted, all other dims must be the same.
|
||||
* Converts these 1 dims to axes format.
|
||||
* @arg reshape Reshape operation.
|
||||
* @arg reshape_to_shape 2nd input to Reshape op as a constant.
|
||||
* @arg result_axes Contains axes which will be squeezed.
|
||||
*/
|
||||
bool shape_to_squeeze_axes(const std::shared_ptr<Node>& reshape,
|
||||
const std::shared_ptr<Constant>& reshape_to_shape,
|
||||
std::vector<size_t>& result_axes) {
|
||||
result_axes.clear();
|
||||
auto reduction_axes_values = reshape_to_shape->cast_vector<int64_t>();
|
||||
// supported the case if Reshape is equal to Squeeze
|
||||
const auto& new_shape = reduction_axes_values;
|
||||
const auto& input_pshape = reshape->get_input_partial_shape(0);
|
||||
// todo: support dynamic case
|
||||
if (input_pshape.is_dynamic()) {
|
||||
return false;
|
||||
}
|
||||
|
||||
const auto input_shape = input_pshape.to_shape();
|
||||
if (new_shape.size() < input_shape.size()) {
|
||||
size_t j = 0;
|
||||
for (size_t i = 0; i < input_shape.size(); i++) {
|
||||
const auto input_dim = static_cast<int64_t>(input_shape[i]);
|
||||
if (j < new_shape.size() && new_shape[j] == input_dim) {
|
||||
j++;
|
||||
} else if (input_dim != 1) {
|
||||
return false;
|
||||
} else {
|
||||
result_axes.push_back(i);
|
||||
}
|
||||
}
|
||||
if (j != new_shape.size()) {
|
||||
// not all new_shape values are in input_shape
|
||||
return false;
|
||||
}
|
||||
} else {
|
||||
// another reshape type, not Squeeze
|
||||
// todo: move this checks in the pattern
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Converts squeezed_axes to actual shape (2nd input) for Reshape operation
|
||||
* using the shape of the 1st input to Reshape.
|
||||
* @arg input_node 1st input to Reshape op.
|
||||
* @arg squeeze_axes In case of Reshape op is equal to squeeze, these axes indicate the places where 1 dims have
|
||||
* to be deleted.
|
||||
*/
|
||||
bool squeeze_axes_to_shape(const Output<Node>& input_node,
|
||||
std::vector<size_t> squeeze_axes,
|
||||
std::vector<size_t>& to_shape) {
|
||||
to_shape.clear();
|
||||
std::sort(squeeze_axes.begin(), squeeze_axes.end());
|
||||
const auto& input_pshape = input_node.get_partial_shape();
|
||||
if (input_pshape.is_dynamic()) {
|
||||
return false;
|
||||
}
|
||||
const auto& input_shape = input_pshape.get_shape();
|
||||
for (size_t i = 0, j = 0; i < input_shape.size(); ++i) {
|
||||
if (j < squeeze_axes.size() && i == squeeze_axes[j]) {
|
||||
++j;
|
||||
continue;
|
||||
}
|
||||
to_shape.push_back(input_shape[i]);
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
TSSqueezeForward::TSSqueezeForward() {
|
||||
MATCHER_SCOPE(TSSqueezeForward);
|
||||
|
||||
auto transpose_label = wrap_type<Transpose>({any_input(), wrap_type<Constant>()});
|
||||
auto squeeze_label = wrap_type<Squeeze, Reshape>({transpose_label, wrap_type<Constant>()});
|
||||
|
||||
ov::matcher_pass_callback matcher_pass_callback = [=](Matcher& m) {
|
||||
const auto& pattern_to_output = m.get_pattern_map();
|
||||
|
||||
auto transpose = pattern_to_output.at(transpose_label);
|
||||
auto squeeze = pattern_to_output.at(squeeze_label);
|
||||
|
||||
auto transpose_order = as_type_ptr<Constant>(transpose->get_input_node_shared_ptr(1));
|
||||
auto squeeze_axes = as_type_ptr<Constant>(squeeze->get_input_node_shared_ptr(1));
|
||||
if (!transpose_order || !squeeze_axes) {
|
||||
return false;
|
||||
}
|
||||
|
||||
std::vector<size_t> non_negative_axes;
|
||||
if (as_type_ptr<Reshape>(squeeze)) {
|
||||
auto success = shape_to_squeeze_axes(squeeze, squeeze_axes, non_negative_axes);
|
||||
if (!success) {
|
||||
return false;
|
||||
}
|
||||
} else {
|
||||
auto rank = squeeze->get_input_partial_shape(0).rank();
|
||||
non_negative_axes =
|
||||
normalize_axes(squeeze->get_friendly_name(), squeeze_axes->cast_vector<int64_t>(), rank);
|
||||
}
|
||||
|
||||
// if 2nd input to squeeze is empty then all '1' dims will be deleted.
|
||||
if (non_negative_axes.empty()) {
|
||||
auto input_pshape = transpose->output(0).get_partial_shape();
|
||||
if (input_pshape.is_dynamic()) {
|
||||
return false;
|
||||
}
|
||||
for (size_t i = 0; i < input_pshape.size(); ++i) {
|
||||
if (input_pshape[i].get_length() == 1) {
|
||||
non_negative_axes.push_back(i);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
auto transpose_order_values = transpose_order->cast_vector<size_t>();
|
||||
std::vector<size_t> new_values;
|
||||
new_values.reserve(non_negative_axes.size());
|
||||
for (const auto& axis : non_negative_axes) {
|
||||
new_values.push_back(transpose_order_values[axis]);
|
||||
}
|
||||
|
||||
transpose_order_values = GetOrderAfterReduction(non_negative_axes, transpose_order_values);
|
||||
auto new_transpose_order = Constant::create(transpose_order->get_element_type(),
|
||||
{transpose_order_values.size()},
|
||||
transpose_order_values);
|
||||
|
||||
if (as_type_ptr<Reshape>(squeeze)) {
|
||||
std::vector<size_t> to_shape;
|
||||
auto success = squeeze_axes_to_shape(transpose->input_value(0), new_values, to_shape);
|
||||
if (!success) {
|
||||
return false;
|
||||
}
|
||||
new_values = to_shape;
|
||||
}
|
||||
|
||||
auto new_const = Constant::create(squeeze_axes->get_element_type(), {new_values.size()}, new_values);
|
||||
auto new_squeeze = squeeze->clone_with_new_inputs({transpose->input_value(0), new_const});
|
||||
auto new_transpose = transpose->clone_with_new_inputs({new_squeeze, new_transpose_order});
|
||||
|
||||
replace_node(squeeze, new_transpose);
|
||||
new_squeeze->set_friendly_name(transpose->get_friendly_name());
|
||||
new_transpose->set_friendly_name(squeeze->get_friendly_name());
|
||||
UpdateForwardSinkingAbility(new_transpose);
|
||||
register_new_node(new_transpose);
|
||||
copy_runtime_info({transpose, squeeze}, {new_transpose, new_squeeze});
|
||||
|
||||
return true;
|
||||
};
|
||||
|
||||
auto m = std::make_shared<pattern::Matcher>(squeeze_label, matcher_name);
|
||||
register_matcher(m, matcher_pass_callback);
|
||||
}
|
||||
|
||||
TSSqueezeBackward::TSSqueezeBackward() {
|
||||
MATCHER_SCOPE(TSSqueezeBackward);
|
||||
|
||||
auto squeeze_label = wrap_type<Squeeze, Reshape>({any_input(), wrap_type<Constant>()}, HasSameOutputTransposeNodes);
|
||||
auto transpose_label =
|
||||
wrap_type<Transpose>({squeeze_label, wrap_type<Constant>()}, [](const Output<Node>& output) -> bool {
|
||||
return has_static_rank()(output) && is_sinking_node(output);
|
||||
});
|
||||
|
||||
ov::matcher_pass_callback matcher_pass_callback = [=](Matcher& m) {
|
||||
const auto& pattern_to_output = m.get_pattern_map();
|
||||
|
||||
auto transpose = pattern_to_output.at(transpose_label);
|
||||
auto squeeze = pattern_to_output.at(squeeze_label);
|
||||
|
||||
auto transpose_order = as_type_ptr<Constant>(transpose->get_input_node_shared_ptr(1));
|
||||
auto squeeze_axes = as_type_ptr<Constant>(squeeze->get_input_node_shared_ptr(1));
|
||||
if (!transpose_order || !squeeze_axes) {
|
||||
return false;
|
||||
}
|
||||
|
||||
std::vector<size_t> non_negative_axes;
|
||||
if (as_type_ptr<Reshape>(squeeze)) {
|
||||
auto success = shape_to_squeeze_axes(squeeze, squeeze_axes, non_negative_axes);
|
||||
if (!success) {
|
||||
return false;
|
||||
}
|
||||
} else {
|
||||
auto rank = squeeze->get_input_partial_shape(0).rank();
|
||||
non_negative_axes =
|
||||
normalize_axes(squeeze->get_friendly_name(), squeeze_axes->cast_vector<int64_t>(), rank);
|
||||
}
|
||||
|
||||
bool squeeze_all_dims = false;
|
||||
if (non_negative_axes.empty()) {
|
||||
auto input_pshape = squeeze->input_value(0).get_partial_shape();
|
||||
if (input_pshape.is_dynamic()) {
|
||||
return false;
|
||||
}
|
||||
for (size_t i = 0; i < input_pshape.size(); ++i) {
|
||||
if (input_pshape[i] == 1) {
|
||||
non_negative_axes.push_back(i);
|
||||
}
|
||||
}
|
||||
squeeze_all_dims = true;
|
||||
}
|
||||
|
||||
auto transpose_order_values = transpose_order->cast_vector<size_t>();
|
||||
transpose_order_values = GetOrderBeforeReduction(non_negative_axes, transpose_order_values);
|
||||
auto reversed_order_values = ReverseTransposeOrder(transpose_order_values);
|
||||
|
||||
std::vector<size_t> new_values;
|
||||
for (const auto& axis : non_negative_axes) {
|
||||
new_values.push_back(reversed_order_values[axis]);
|
||||
}
|
||||
|
||||
auto new_transpose_order = Constant::create(transpose_order->get_element_type(),
|
||||
{transpose_order_values.size()},
|
||||
transpose_order_values);
|
||||
auto new_transpose = transpose->clone_with_new_inputs({squeeze->input_value(0), new_transpose_order});
|
||||
if (as_type_ptr<Reshape>(squeeze)) {
|
||||
std::vector<size_t> to_shape;
|
||||
auto success = squeeze_axes_to_shape(new_transpose->output(0), new_values, to_shape);
|
||||
if (!success) {
|
||||
return false;
|
||||
}
|
||||
new_values = to_shape;
|
||||
}
|
||||
|
||||
std::shared_ptr<Node> new_squeeze;
|
||||
if (squeeze_all_dims) {
|
||||
new_squeeze = squeeze->clone_with_new_inputs({new_transpose, squeeze->input_value(1)});
|
||||
} else {
|
||||
auto new_const =
|
||||
std::make_shared<Constant>(squeeze_axes->get_element_type(), squeeze_axes->get_shape(), new_values);
|
||||
new_squeeze = squeeze->clone_with_new_inputs({new_transpose, new_const});
|
||||
}
|
||||
|
||||
replace_node(transpose, new_squeeze);
|
||||
copy_runtime_info({transpose, squeeze}, {new_transpose, new_squeeze});
|
||||
new_squeeze->set_friendly_name(transpose->get_friendly_name());
|
||||
new_transpose->set_friendly_name(squeeze->get_friendly_name());
|
||||
register_new_node(new_transpose);
|
||||
return true;
|
||||
};
|
||||
|
||||
auto m = std::make_shared<pattern::Matcher>(transpose_label, matcher_name);
|
||||
register_matcher(m, matcher_pass_callback);
|
||||
}
|
@ -0,0 +1,243 @@
|
||||
// Copyright (C) 2018-2023 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#include "transformations/transpose_sinking/ts_unsqueeze.hpp"
|
||||
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
|
||||
#include "itt.hpp"
|
||||
#include "openvino/core/validation_util.hpp"
|
||||
#include "openvino/opsets/opset10.hpp"
|
||||
#include "openvino/pass/pattern/op/wrap_type.hpp"
|
||||
#include "transformations/rt_info/transpose_sinking_attr.hpp"
|
||||
#include "transformations/transpose_sinking/ts_utils.hpp"
|
||||
#include "transformations/utils/utils.hpp"
|
||||
|
||||
using namespace ov;
|
||||
using namespace opset10;
|
||||
using namespace ov::pass::pattern;
|
||||
using namespace ov::pass::transpose_sinking;
|
||||
using namespace ov::pass::transpose_sinking::utils;
|
||||
|
||||
namespace {
|
||||
|
||||
/**
|
||||
* @brief Checks that Reshape operation is equal to Unsqueeze:
|
||||
* Only 1 dims are inserted, all other dims must be the same.
|
||||
* Converts these 1 dims to axes format.
|
||||
* @arg reshape Reshape operation.
|
||||
* @arg reshape_to_shape 2nd input to Reshape op as a constant.
|
||||
* @arg result_axes contains axes which will be unsqueezed.
|
||||
*/
|
||||
bool shape_to_unsqueeze_axes(const std::shared_ptr<Node>& reshape,
|
||||
const std::shared_ptr<Constant>& reshape_to_shape,
|
||||
std::vector<size_t>& result_axes) {
|
||||
result_axes.clear();
|
||||
auto reduction_axes_values = reshape_to_shape->cast_vector<int64_t>();
|
||||
// supported the case if Reshape is equal to Unsqueeze
|
||||
const auto& new_shape = reduction_axes_values;
|
||||
const auto& input_pshape = reshape->get_input_partial_shape(0);
|
||||
// todo: support dynamic case
|
||||
if (input_pshape.is_dynamic()) {
|
||||
return false;
|
||||
}
|
||||
|
||||
const auto input_shape = input_pshape.to_shape();
|
||||
if (new_shape.size() > input_shape.size()) {
|
||||
size_t j = 0;
|
||||
for (size_t i = 0; i < new_shape.size(); ++i) {
|
||||
if (j < input_shape.size() && static_cast<int64_t>(input_shape[j]) == new_shape[i]) {
|
||||
j++;
|
||||
} else if (new_shape[i] != 1) {
|
||||
return false;
|
||||
} else {
|
||||
result_axes.push_back(i);
|
||||
}
|
||||
}
|
||||
if (j != input_shape.size()) {
|
||||
// not all input_shape values are in new_shape
|
||||
return false;
|
||||
}
|
||||
} else {
|
||||
// another reshape type, not Unsqueeze
|
||||
// todo: move this checks in the pattern
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Converts unsqueeze_axes to actual shape (2nd input) for Reshape operation
|
||||
* using the shape of the 1st input to Reshape.
|
||||
* @arg input_node 1st input to Reshape op.
|
||||
* @arg unsqueeze_axes In case of Reshape op is equal to Unsqueeze, these axes indicate the places where 1 dims have
|
||||
* to be inserted.
|
||||
*/
|
||||
bool unsqueeze_axes_to_shape(const Output<Node>& input_node,
|
||||
std::vector<size_t> unsqueeze_axes,
|
||||
std::vector<size_t>& to_shape) {
|
||||
to_shape.clear();
|
||||
const auto& input_pshape = input_node.get_partial_shape();
|
||||
if (input_pshape.is_dynamic()) {
|
||||
return false;
|
||||
}
|
||||
const auto& input_shape = input_pshape.get_shape();
|
||||
to_shape.resize(input_shape.size() + unsqueeze_axes.size());
|
||||
std::sort(unsqueeze_axes.begin(), unsqueeze_axes.end());
|
||||
for (size_t i = 0, j = 0, k = 0; i < to_shape.size(); ++i) {
|
||||
if (j < unsqueeze_axes.size() && i == unsqueeze_axes[j]) {
|
||||
to_shape[i] = 1;
|
||||
j++;
|
||||
} else if (k < input_shape.size()) {
|
||||
to_shape[i] = input_shape[k];
|
||||
k++;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
} // namespace
|
||||
|
||||
TSUnsqueezeForward::TSUnsqueezeForward() {
|
||||
MATCHER_SCOPE(TSUnsqueezeForward);
|
||||
|
||||
auto transpose_label = wrap_type<Transpose>({any_input(), wrap_type<Constant>()});
|
||||
auto unsqueeze_label = wrap_type<Unsqueeze, Reshape>({transpose_label, wrap_type<Constant>()});
|
||||
|
||||
ov::matcher_pass_callback matcher_pass_callback = [=](pattern::Matcher& m) {
|
||||
const auto& pattern_to_output = m.get_pattern_map();
|
||||
|
||||
auto transpose = pattern_to_output.at(transpose_label);
|
||||
auto unsqueeze = pattern_to_output.at(unsqueeze_label);
|
||||
|
||||
auto transpose_order = as_type_ptr<Constant>(transpose->get_input_node_shared_ptr(1));
|
||||
auto unsqueeze_axes = as_type_ptr<Constant>(unsqueeze->get_input_node_shared_ptr(1));
|
||||
if (!transpose_order || !unsqueeze_axes) {
|
||||
return false;
|
||||
}
|
||||
|
||||
std::vector<size_t> non_negative_axes;
|
||||
if (as_type_ptr<Reshape>(unsqueeze)) {
|
||||
auto success = shape_to_unsqueeze_axes(unsqueeze, unsqueeze_axes, non_negative_axes);
|
||||
if (!success) {
|
||||
return false;
|
||||
}
|
||||
} else {
|
||||
auto rank = unsqueeze->get_output_partial_shape(0).rank();
|
||||
non_negative_axes =
|
||||
normalize_axes(unsqueeze->get_friendly_name(), unsqueeze_axes->cast_vector<int64_t>(), rank);
|
||||
}
|
||||
auto ts_order_values = transpose_order->cast_vector<size_t>();
|
||||
|
||||
ts_order_values = GetOrderBeforeReduction(non_negative_axes, ts_order_values);
|
||||
auto new_transpose_order =
|
||||
Constant::create(transpose_order->get_element_type(), {ts_order_values.size()}, ts_order_values);
|
||||
|
||||
std::shared_ptr<Node> new_unsqueeze;
|
||||
if (as_type_ptr<Reshape>(unsqueeze)) {
|
||||
std::vector<size_t> new_values;
|
||||
auto success = unsqueeze_axes_to_shape(transpose->input_value(0), non_negative_axes, new_values);
|
||||
if (!success) {
|
||||
return false;
|
||||
}
|
||||
auto new_const = Constant::create(unsqueeze_axes->get_element_type(), {new_values.size()}, new_values);
|
||||
new_unsqueeze = unsqueeze->clone_with_new_inputs({transpose->input_value(0), new_const});
|
||||
} else {
|
||||
new_unsqueeze = unsqueeze->clone_with_new_inputs({transpose->input_value(0), unsqueeze->input_value(1)});
|
||||
}
|
||||
auto new_transpose = transpose->clone_with_new_inputs({new_unsqueeze, new_transpose_order});
|
||||
|
||||
replace_node(unsqueeze, new_transpose);
|
||||
new_unsqueeze->set_friendly_name(transpose->get_friendly_name());
|
||||
new_transpose->set_friendly_name(unsqueeze->get_friendly_name());
|
||||
UpdateForwardSinkingAbility(new_transpose);
|
||||
register_new_node(new_transpose);
|
||||
copy_runtime_info({transpose, unsqueeze}, {new_transpose, new_unsqueeze});
|
||||
|
||||
return true;
|
||||
};
|
||||
|
||||
auto m = std::make_shared<pattern::Matcher>(unsqueeze_label, matcher_name);
|
||||
register_matcher(m, matcher_pass_callback);
|
||||
}
|
||||
|
||||
TSUnsqueezeBackward::TSUnsqueezeBackward() {
|
||||
MATCHER_SCOPE(TSUnsqueezeBackward);
|
||||
|
||||
auto unsqueeze_label =
|
||||
wrap_type<Unsqueeze, Reshape>({any_input(), wrap_type<Constant>()}, HasSameOutputTransposeNodes);
|
||||
auto transpose_label =
|
||||
wrap_type<Transpose>({unsqueeze_label, wrap_type<Constant>()}, [](const Output<Node>& output) -> bool {
|
||||
return has_static_rank()(output) && is_sinking_node(output);
|
||||
});
|
||||
|
||||
ov::matcher_pass_callback matcher_pass_callback = [=](pattern::Matcher& m) {
|
||||
const auto& pattern_to_output = m.get_pattern_map();
|
||||
|
||||
auto transpose = pattern_to_output.at(transpose_label);
|
||||
auto unsqueeze = pattern_to_output.at(unsqueeze_label);
|
||||
|
||||
auto transpose_order = std::dynamic_pointer_cast<Constant>(transpose->get_input_node_shared_ptr(1));
|
||||
auto unsqueeze_axes = std::dynamic_pointer_cast<Constant>(unsqueeze->get_input_node_shared_ptr(1));
|
||||
if (!transpose_order || !unsqueeze_axes)
|
||||
return false;
|
||||
|
||||
std::vector<size_t> non_negative_axes;
|
||||
if (as_type_ptr<Reshape>(unsqueeze)) {
|
||||
auto success = shape_to_unsqueeze_axes(unsqueeze, unsqueeze_axes, non_negative_axes);
|
||||
if (!success) {
|
||||
return false;
|
||||
}
|
||||
} else {
|
||||
auto rank = unsqueeze->get_output_partial_shape(0).rank();
|
||||
non_negative_axes =
|
||||
normalize_axes(unsqueeze->get_friendly_name(), unsqueeze_axes->cast_vector<int64_t>(), rank);
|
||||
}
|
||||
|
||||
auto transpose_order_values = transpose_order->cast_vector<size_t>();
|
||||
auto old_transpose_order_values = transpose_order_values;
|
||||
std::vector<size_t> new_values;
|
||||
|
||||
if (non_negative_axes.size() == transpose_order_values.size()) {
|
||||
// input is a scalar, we unsqueeze all dims
|
||||
// it's enough to eliminate such Transpose
|
||||
transpose->output(0).replace(unsqueeze);
|
||||
return true;
|
||||
}
|
||||
|
||||
for (const auto& axis : non_negative_axes) {
|
||||
auto it = std::find(old_transpose_order_values.begin(), old_transpose_order_values.end(), axis);
|
||||
if (it != old_transpose_order_values.end()) {
|
||||
new_values.push_back(it - old_transpose_order_values.begin());
|
||||
}
|
||||
}
|
||||
|
||||
transpose_order_values = GetOrderAfterReduction(new_values, transpose_order_values);
|
||||
auto new_transpose_order = std::make_shared<Constant>(transpose_order->get_element_type(),
|
||||
Shape{transpose_order_values.size()},
|
||||
transpose_order_values);
|
||||
|
||||
auto new_transpose = transpose->clone_with_new_inputs({unsqueeze->input_value(0), new_transpose_order});
|
||||
if (as_type_ptr<Reshape>(unsqueeze)) {
|
||||
std::vector<size_t> to_shape;
|
||||
auto success = unsqueeze_axes_to_shape(new_transpose->output(0), new_values, to_shape);
|
||||
if (!success) {
|
||||
return false;
|
||||
}
|
||||
new_values = to_shape;
|
||||
}
|
||||
auto new_const = Constant::create(unsqueeze_axes->get_element_type(), unsqueeze_axes->get_shape(), new_values);
|
||||
auto new_unsqueeze = unsqueeze->clone_with_new_inputs({new_transpose, new_const});
|
||||
|
||||
replace_node(transpose, new_unsqueeze);
|
||||
copy_runtime_info({transpose, unsqueeze}, {new_transpose, new_unsqueeze});
|
||||
new_unsqueeze->set_friendly_name(transpose->get_friendly_name());
|
||||
new_transpose->set_friendly_name(unsqueeze->get_friendly_name());
|
||||
register_new_node(new_transpose);
|
||||
return true;
|
||||
};
|
||||
|
||||
auto m = std::make_shared<pattern::Matcher>(transpose_label, matcher_name);
|
||||
register_matcher(m, matcher_pass_callback);
|
||||
}
|
@ -379,6 +379,53 @@ void RemoveSingleOutputConsumers(const NodePtr& node) {
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<size_t> GetOrderAfterReduction(const std::vector<size_t>& axes_values,
|
||||
const std::vector<size_t>& order_values) {
|
||||
size_t buffer_size = order_values.size() - axes_values.size();
|
||||
std::vector<size_t> aligned_order(buffer_size, 0);
|
||||
std::vector<size_t> values_to_reduce(axes_values);
|
||||
for (size_t i = 0; i < values_to_reduce.size(); ++i) {
|
||||
values_to_reduce[i] = order_values[axes_values[i]];
|
||||
}
|
||||
std::sort(values_to_reduce.begin(), values_to_reduce.end());
|
||||
for (size_t i = 0, j = 0; i < order_values.size(); ++i) {
|
||||
if (std::find(axes_values.begin(), axes_values.end(), i) != axes_values.end()) {
|
||||
continue;
|
||||
}
|
||||
|
||||
auto lb = std::lower_bound(values_to_reduce.begin(), values_to_reduce.end(), order_values[i]);
|
||||
aligned_order[j] = order_values[i] - (lb - values_to_reduce.begin());
|
||||
++j;
|
||||
}
|
||||
return aligned_order;
|
||||
}
|
||||
|
||||
std::vector<size_t> GetOrderBeforeReduction(const std::vector<size_t>& axes_values,
|
||||
const std::vector<size_t>& order_values) {
|
||||
size_t buffer_size = order_values.size() + axes_values.size();
|
||||
std::vector<size_t> aligned_order(buffer_size);
|
||||
|
||||
std::vector<int64_t> cnt_deleted(buffer_size);
|
||||
int64_t cnt = 0;
|
||||
for (int64_t i = 0; i < static_cast<int64_t>(cnt_deleted.size()); ++i) {
|
||||
if (std::find(axes_values.begin(), axes_values.end(), i) != axes_values.end()) {
|
||||
cnt++;
|
||||
}
|
||||
cnt_deleted[i] = i - cnt;
|
||||
}
|
||||
|
||||
for (size_t i = 0, j = 0; i < aligned_order.size(); ++i) {
|
||||
if (std::find(axes_values.begin(), axes_values.end(), i) != axes_values.end()) {
|
||||
aligned_order[i] = i;
|
||||
continue;
|
||||
}
|
||||
|
||||
aligned_order[i] = std::find(cnt_deleted.begin(), cnt_deleted.end(), order_values[j]) - cnt_deleted.begin();
|
||||
++j;
|
||||
}
|
||||
return aligned_order;
|
||||
}
|
||||
|
||||
} // namespace utils
|
||||
} // namespace transpose_sinking
|
||||
} // namespace pass
|
||||
|
@ -11,8 +11,11 @@
|
||||
#include "transformations/transpose_sinking/ts_data_movement.hpp"
|
||||
#include "transformations/transpose_sinking/ts_interpolate.hpp"
|
||||
#include "transformations/transpose_sinking/ts_reduction.hpp"
|
||||
#include "transformations/transpose_sinking/ts_slice.hpp"
|
||||
#include "transformations/transpose_sinking/ts_split.hpp"
|
||||
#include "transformations/transpose_sinking/ts_squeeze.hpp"
|
||||
#include "transformations/transpose_sinking/ts_unary.hpp"
|
||||
#include "transformations/transpose_sinking/ts_unsqueeze.hpp"
|
||||
#include "ts_test_utils.hpp"
|
||||
|
||||
using namespace std;
|
||||
@ -183,6 +186,34 @@ private:
|
||||
FactoryPtr CreateInterpolateFactory(const std::string& type_name, bool is_reference) {
|
||||
return std::make_shared<InterpolateFactory>(type_name, is_reference);
|
||||
}
|
||||
|
||||
class SliceFactory : public IFactory {
|
||||
public:
|
||||
explicit SliceFactory(const std::string& type_name) : IFactory(type_name) {}
|
||||
NodePtr create(const OutputVector& parent_nodes) const override {
|
||||
return std::make_shared<Slice>(parent_nodes[0],
|
||||
parent_nodes[1],
|
||||
parent_nodes[2],
|
||||
parent_nodes[3],
|
||||
parent_nodes[4]);
|
||||
}
|
||||
};
|
||||
|
||||
FactoryPtr CreateSliceFactory(const std::string& type_name) {
|
||||
return std::make_shared<SliceFactory>(type_name);
|
||||
}
|
||||
|
||||
class ReshapeFactory : public IFactory {
|
||||
public:
|
||||
explicit ReshapeFactory(const std::string& type_name) : IFactory(type_name) {}
|
||||
NodePtr create(const OutputVector& parent_nodes) const override {
|
||||
return std::make_shared<Reshape>(parent_nodes[0], parent_nodes[1], false);
|
||||
}
|
||||
};
|
||||
|
||||
FactoryPtr CreateReshapeFactory(const std::string& type_name) {
|
||||
return std::make_shared<ReshapeFactory>(type_name);
|
||||
}
|
||||
// ----------------------------------------------------------------------------
|
||||
|
||||
#undef CREATE_UNARY_FACTORY
|
||||
@ -214,6 +245,12 @@ FactoryPtr CreateInterpolateFactory(const std::string& type_name, bool is_refere
|
||||
|
||||
#undef CREATE_INTERPOLATE_FACTORY
|
||||
#define CREATE_INTERPOLATE_FACTORY(type_name, reference_flag) CreateInterpolateFactory(#type_name, reference_flag)
|
||||
|
||||
#undef CREATE_SLICE_FACTORY
|
||||
#define CREATE_SLICE_FACTORY(type_name) CreateSliceFactory(#type_name)
|
||||
|
||||
#undef CREATE_RESHAPE_FACTORY
|
||||
#define CREATE_RESHAPE_FACTORY(type_name) CreateReshapeFactory(#type_name)
|
||||
// ----------------------------------------------------------------------------
|
||||
|
||||
struct Preprocessing {
|
||||
@ -651,7 +688,7 @@ auto test_forward_squeeze = []() {
|
||||
TestCase test_case;
|
||||
|
||||
// Initialize common attributes
|
||||
test_case.transformation = CREATE_PASS_FACTORY(TSReductionForward);
|
||||
test_case.transformation = CREATE_PASS_FACTORY(TSSqueezeForward);
|
||||
test_case.num_main_ops = {1};
|
||||
test_case.inputs_to_main = {
|
||||
parameter(element::f32, {32, 1, 2, 1}),
|
||||
@ -685,7 +722,7 @@ auto test_forward_unsqueeze = []() {
|
||||
TestCase test_case;
|
||||
|
||||
// Initialize common attributes
|
||||
test_case.transformation = CREATE_PASS_FACTORY(TSReductionForward);
|
||||
test_case.transformation = CREATE_PASS_FACTORY(TSUnsqueezeForward);
|
||||
test_case.num_main_ops = {1};
|
||||
test_case.inputs_to_main = {
|
||||
parameter(element::f32, {32, 3, 2, 1}),
|
||||
@ -722,6 +759,128 @@ auto test_forward_unsqueeze = []() {
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(TransposeSinkingCommonUnsqueezeForward, TransposeSinkingTestFixture, test_forward_unsqueeze());
|
||||
|
||||
auto test_forward_slice = []() {
|
||||
TestCase test_case;
|
||||
|
||||
// Initialize common attributes
|
||||
test_case.transformation = CREATE_PASS_FACTORY(TSSliceForward);
|
||||
test_case.num_main_ops = {1};
|
||||
test_case.inputs_to_main = {
|
||||
parameter(element::f32, {6, 4, 5, 3}),
|
||||
constant<int64_t>(element::i32, {3}, {1, 2, 3}),
|
||||
constant<int64_t>(element::i32, {3}, {0, 4, 11}),
|
||||
constant<int64_t>(element::i32, {3}, {1, 2, -1}),
|
||||
constant<int64_t>(element::i32, {3}, {0, 1, 2}),
|
||||
};
|
||||
|
||||
// Test model description:
|
||||
test_case.model.preprocess_inputs_to_main = {{set_transpose_for}, {{0}}};
|
||||
test_case.model.main_op = {CREATE_SLICE_FACTORY(SliceFactory)};
|
||||
test_case.model.model_template = create_model;
|
||||
|
||||
// Reference model description:
|
||||
auto set_specific_gather_for = [](const vector<size_t>& idxs, const OutputVector& out_vec) -> OutputVector {
|
||||
OutputVector result = out_vec;
|
||||
for (const auto& idx : idxs) {
|
||||
const auto& out = out_vec[idx];
|
||||
vector<int64_t> transpose_order(out_vec[0].get_shape().size());
|
||||
iota(transpose_order.begin(), transpose_order.end(), 0);
|
||||
reverse(transpose_order.begin(), transpose_order.end());
|
||||
auto data = make_shared<Constant>(element::i32, Shape{transpose_order.size()}, transpose_order);
|
||||
auto axis = make_shared<Constant>(element::i32, Shape{}, 0);
|
||||
auto transpose = make_shared<Gather>(data, out, axis);
|
||||
result[idx] = transpose;
|
||||
}
|
||||
return result;
|
||||
};
|
||||
test_case.model_ref.preprocess_inputs_to_main = {{set_specific_gather_for}, {{4}}};
|
||||
test_case.model_ref.main_op = {CREATE_SLICE_FACTORY(Slice)};
|
||||
test_case.model_ref.preprocess_outputs_of_main = {{set_transpose_for}, {{0}}};
|
||||
test_case.model_ref.model_template = create_model;
|
||||
|
||||
return wrapper(test_case);
|
||||
};
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(TransposeSinkingCommonSliceForward, TransposeSinkingTestFixture, test_forward_slice());
|
||||
|
||||
auto test_forward_reshape_squeeze = []() {
|
||||
TestCase test_case;
|
||||
|
||||
// Initialize common attributes
|
||||
test_case.transformation = CREATE_PASS_FACTORY(TSSqueezeForward);
|
||||
test_case.num_main_ops = {1};
|
||||
test_case.inputs_to_main = {
|
||||
parameter(element::f32, {6, 1, 5, 1, 4}),
|
||||
constant<int64_t>(element::i32, {3}, {4, 5, 6}),
|
||||
};
|
||||
|
||||
// Test model description:
|
||||
test_case.model.preprocess_inputs_to_main = {{set_transpose_for}, {{0}}};
|
||||
test_case.model.main_op = {CREATE_RESHAPE_FACTORY(Reshape)};
|
||||
test_case.model.model_template = create_model;
|
||||
|
||||
// Reference model description:
|
||||
auto new_constant = [](const vector<size_t>& idxs, const OutputVector& out_vec) -> OutputVector {
|
||||
OutputVector new_out_vec(out_vec.size());
|
||||
new_out_vec[0] = out_vec[0];
|
||||
new_out_vec[1] =
|
||||
make_shared<Constant>(out_vec[1].get_element_type(), out_vec[1].get_shape(), std::vector<int64_t>{6, 5, 4});
|
||||
return new_out_vec;
|
||||
};
|
||||
test_case.model_ref.preprocess_inputs_to_main = {{new_constant}, {{1}}};
|
||||
test_case.model_ref.main_op = {CREATE_RESHAPE_FACTORY(Reshape)};
|
||||
test_case.model_ref.preprocess_outputs_of_main = {{set_transpose_for}, {{0}}};
|
||||
test_case.model_ref.model_template = create_model;
|
||||
|
||||
return wrapper(test_case);
|
||||
};
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(TransposeSinkingCommonReshapeSqueezeForward,
|
||||
TransposeSinkingTestFixture,
|
||||
test_forward_reshape_squeeze());
|
||||
|
||||
auto test_forward_reshape_unsqueeze = []() {
|
||||
TestCase test_case;
|
||||
|
||||
// Initialize common attributes
|
||||
test_case.transformation = CREATE_PASS_FACTORY(TSUnsqueezeForward);
|
||||
test_case.num_main_ops = {1};
|
||||
test_case.inputs_to_main = {
|
||||
parameter(element::f32, {6, 5, 4}),
|
||||
constant<int64_t>(element::i32, {5}, {4, 1, 5, 1, 6}),
|
||||
};
|
||||
|
||||
// Test model description:
|
||||
test_case.model.preprocess_inputs_to_main = {{set_transpose_for}, {{0}}};
|
||||
test_case.model.main_op = {CREATE_RESHAPE_FACTORY(Reshape)};
|
||||
test_case.model.model_template = create_model;
|
||||
|
||||
// Reference model description:
|
||||
auto new_transpose = [](const vector<size_t>& idxs, const OutputVector& out_vec) -> OutputVector {
|
||||
OutputVector new_out_vec(out_vec.size());
|
||||
auto order = make_shared<Constant>(element::i32, Shape{5}, std::vector<int64_t>{4, 1, 2, 3, 0});
|
||||
new_out_vec[0] = make_shared<Transpose>(out_vec[0], order);
|
||||
return new_out_vec;
|
||||
};
|
||||
auto new_constant = [](const vector<size_t>& idxs, const OutputVector& out_vec) -> OutputVector {
|
||||
OutputVector new_out_vec(out_vec.size());
|
||||
new_out_vec[0] = out_vec[0];
|
||||
new_out_vec[1] = make_shared<Constant>(out_vec[1].get_element_type(),
|
||||
out_vec[1].get_shape(),
|
||||
std::vector<int64_t>{6, 1, 5, 1, 4});
|
||||
return new_out_vec;
|
||||
};
|
||||
test_case.model_ref.preprocess_inputs_to_main = {{new_constant}, {{1}}};
|
||||
test_case.model_ref.main_op = {CREATE_RESHAPE_FACTORY(Reshape)};
|
||||
test_case.model_ref.preprocess_outputs_of_main = {{new_transpose}, {{0}}};
|
||||
test_case.model_ref.model_template = create_model;
|
||||
|
||||
return wrapper(test_case);
|
||||
};
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(TransposeSinkingCommonReshapeUnsqueezeForward,
|
||||
TransposeSinkingTestFixture,
|
||||
test_forward_reshape_unsqueeze());
|
||||
// ------------------ BACKWARD --------------------
|
||||
|
||||
auto test_backward_unary = []() {
|
||||
@ -1003,7 +1162,7 @@ auto test_backward_squeeze = []() {
|
||||
TestCase test_case;
|
||||
|
||||
// Initialize common attributes
|
||||
test_case.transformation = CREATE_PASS_FACTORY(TSReductionBackward);
|
||||
test_case.transformation = CREATE_PASS_FACTORY(TSSqueezeBackward);
|
||||
test_case.num_main_ops = {1};
|
||||
test_case.inputs_to_main = {
|
||||
parameter(element::f32, {32, 1, 2, 1}),
|
||||
@ -1036,7 +1195,7 @@ auto test_backward_unsqueeze = []() {
|
||||
TestCase test_case;
|
||||
|
||||
// Initialize common attributes
|
||||
test_case.transformation = CREATE_PASS_FACTORY(TSReductionBackward);
|
||||
test_case.transformation = CREATE_PASS_FACTORY(TSUnsqueezeBackward);
|
||||
test_case.num_main_ops = {1};
|
||||
test_case.inputs_to_main = {
|
||||
parameter(element::f32, {32, 3, 2, 1}),
|
||||
@ -1066,6 +1225,126 @@ auto test_backward_unsqueeze = []() {
|
||||
INSTANTIATE_TEST_SUITE_P(TransposeSinkingCommonUnsqueezeBackward,
|
||||
TransposeSinkingTestFixture,
|
||||
test_backward_unsqueeze());
|
||||
|
||||
auto test_backward_slice = []() {
|
||||
TestCase test_case;
|
||||
|
||||
// Initialize common attributes
|
||||
test_case.transformation = CREATE_PASS_FACTORY(TSSliceBackward);
|
||||
test_case.num_main_ops = {1};
|
||||
test_case.inputs_to_main = {
|
||||
parameter(element::f32, {6, 4, 5, 3}),
|
||||
constant<int64_t>(element::i32, {3}, {1, 2, 3}),
|
||||
constant<int64_t>(element::i32, {3}, {0, 4, 11}),
|
||||
constant<int64_t>(element::i32, {3}, {1, 2, -1}),
|
||||
constant<int64_t>(element::i32, {3}, {0, 1, 2}),
|
||||
};
|
||||
|
||||
// Test model description:
|
||||
test_case.model.main_op = {CREATE_SLICE_FACTORY(Slice)};
|
||||
test_case.model.preprocess_outputs_of_main = {{set_transpose_for}, {{0}}};
|
||||
test_case.model.model_template = create_model;
|
||||
|
||||
// Reference model description:
|
||||
auto set_specific_gather_for = [](const vector<size_t>& idxs, const OutputVector& out_vec) -> OutputVector {
|
||||
OutputVector result = out_vec;
|
||||
for (const auto& idx : idxs) {
|
||||
const auto& out = out_vec[idx];
|
||||
vector<int64_t> transpose_order(out_vec[0].get_shape().size());
|
||||
iota(transpose_order.begin(), transpose_order.end(), 0);
|
||||
reverse(transpose_order.begin(), transpose_order.end());
|
||||
auto data = make_shared<Constant>(element::i32, Shape{transpose_order.size()}, transpose_order);
|
||||
auto axis = make_shared<Constant>(element::i32, Shape{}, 0);
|
||||
auto transpose = make_shared<Gather>(data, out, axis);
|
||||
result[idx] = transpose;
|
||||
}
|
||||
return result;
|
||||
};
|
||||
test_case.model_ref.preprocess_inputs_to_main = {{set_transpose_for, set_specific_gather_for}, {{0}, {4}}};
|
||||
test_case.model_ref.main_op = {CREATE_SLICE_FACTORY(SliceFactory)};
|
||||
test_case.model_ref.model_template = create_model;
|
||||
return wrapper(test_case);
|
||||
};
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(TransposeSinkingCommonSliceBackward, TransposeSinkingTestFixture, test_backward_slice());
|
||||
|
||||
auto test_backward_reshape_squeeze = []() {
|
||||
TestCase test_case;
|
||||
|
||||
// Initialize common attributes
|
||||
test_case.transformation = CREATE_PASS_FACTORY(TSSqueezeBackward);
|
||||
test_case.num_main_ops = {1};
|
||||
test_case.inputs_to_main = {
|
||||
parameter(element::f32, {4, 1, 5, 1, 6}),
|
||||
constant<int64_t>(element::i32, {3}, {4, 5, 6}),
|
||||
};
|
||||
|
||||
// Test model description:
|
||||
test_case.model.main_op = {CREATE_RESHAPE_FACTORY(Reshape)};
|
||||
test_case.model.preprocess_outputs_of_main = {{set_transpose_for}, {{0}}};
|
||||
test_case.model.model_template = create_model;
|
||||
|
||||
// Reference model description:
|
||||
auto new_transpose = [](const vector<size_t>& idxs, const OutputVector& out_vec) -> OutputVector {
|
||||
OutputVector new_out_vec(out_vec.size());
|
||||
auto order = make_shared<Constant>(element::i32, Shape{5}, std::vector<int64_t>{4, 1, 2, 3, 0});
|
||||
new_out_vec[0] = make_shared<Transpose>(out_vec[0], order);
|
||||
new_out_vec[1] = out_vec[1];
|
||||
return new_out_vec;
|
||||
};
|
||||
auto new_constant = [](const vector<size_t>& idxs, const OutputVector& out_vec) -> OutputVector {
|
||||
OutputVector new_out_vec(out_vec.size());
|
||||
new_out_vec[0] = out_vec[0];
|
||||
new_out_vec[1] =
|
||||
make_shared<Constant>(out_vec[1].get_element_type(), out_vec[1].get_shape(), std::vector<int64_t>{6, 5, 4});
|
||||
return new_out_vec;
|
||||
};
|
||||
test_case.model_ref.preprocess_inputs_to_main = {{new_transpose, new_constant}, {{0}, {1}}};
|
||||
test_case.model_ref.main_op = {CREATE_RESHAPE_FACTORY(Reshape)};
|
||||
test_case.model_ref.model_template = create_model;
|
||||
|
||||
return wrapper(test_case);
|
||||
};
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(TransposeSinkingCommonReshapeSqueezeBackward,
|
||||
TransposeSinkingTestFixture,
|
||||
test_backward_reshape_squeeze());
|
||||
|
||||
auto test_backward_reshape_unsqueeze = []() {
|
||||
TestCase test_case;
|
||||
|
||||
// Initialize common attributes
|
||||
test_case.transformation = CREATE_PASS_FACTORY(TSUnsqueezeBackward);
|
||||
test_case.num_main_ops = {1};
|
||||
test_case.inputs_to_main = {
|
||||
parameter(element::f32, {4, 5, 6}),
|
||||
constant<int64_t>(element::i32, {5}, {4, 1, 5, 1, 6}),
|
||||
};
|
||||
|
||||
// Test model description:
|
||||
test_case.model.main_op = {CREATE_RESHAPE_FACTORY(Reshape)};
|
||||
test_case.model.preprocess_outputs_of_main = {{set_transpose_for}, {{0}}};
|
||||
test_case.model.model_template = create_model;
|
||||
|
||||
// Reference model description:
|
||||
auto new_constant = [](const vector<size_t>& idxs, const OutputVector& out_vec) -> OutputVector {
|
||||
OutputVector new_out_vec(out_vec.size());
|
||||
new_out_vec[0] = out_vec[0];
|
||||
new_out_vec[1] = make_shared<Constant>(out_vec[1].get_element_type(),
|
||||
out_vec[1].get_shape(),
|
||||
std::vector<int64_t>{6, 1, 5, 1, 4});
|
||||
return new_out_vec;
|
||||
};
|
||||
test_case.model_ref.preprocess_inputs_to_main = {{set_transpose_for, new_constant}, {{0}, {1}}};
|
||||
test_case.model_ref.main_op = {CREATE_RESHAPE_FACTORY(Reshape)};
|
||||
test_case.model_ref.model_template = create_model;
|
||||
|
||||
return wrapper(test_case);
|
||||
};
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(TransposeSinkingCommonReshapeUnsqueezeBackward,
|
||||
TransposeSinkingTestFixture,
|
||||
test_backward_reshape_unsqueeze());
|
||||
} // namespace common
|
||||
} // namespace testing
|
||||
} // namespace transpose_sinking
|
||||
} // namespace transpose_sinking
|
||||
|
@ -281,7 +281,7 @@ void FrontEnd::normalize(const std::shared_ptr<ov::Model>& model) const {
|
||||
{
|
||||
// perform transpose sinking and reverse infer if the model contains only OpenVINO operations
|
||||
ov::pass::Manager manager;
|
||||
manager.register_pass<ov::pass::transpose_sinking::TSGeneral>();
|
||||
manager.register_pass<ov::pass::TransposeSinkingGeneral>();
|
||||
manager.register_pass<ov::pass::ReverseShapeAndTypeInfer>();
|
||||
manager.run_passes(model);
|
||||
}
|
||||
|
@ -268,7 +268,7 @@ void FrontEnd::normalize(const std::shared_ptr<ov::Model>& function) const {
|
||||
manager.register_pass<ov::frontend::tensorflow_lite::pass::TFLQuantizeResolver>();
|
||||
manager.register_pass<ov::frontend::tensorflow_lite::pass::Rfft2dSimplifier>();
|
||||
manager.register_pass<ov::pass::TransposeSinking>();
|
||||
manager.register_pass<ov::pass::transpose_sinking::TSGeneral>();
|
||||
manager.register_pass<ov::pass::TransposeSinkingGeneral>();
|
||||
manager.run_passes(function);
|
||||
}
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user