TSReduction refactoring, move Unsqueeze/Squeeze transformations to separate files, added limited support for Reshape op + tests
This commit is contained in:
parent
2bc1334f65
commit
f1dc3702f1
@ -1,30 +0,0 @@
|
||||
// Copyright (C) 2022-2023 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "openvino/pass/graph_rewrite.hpp"
|
||||
#include "openvino/pass/pass.hpp"
|
||||
#include "transformations_visibility.hpp"
|
||||
|
||||
namespace ov {
|
||||
namespace pass {
|
||||
|
||||
class TRANSFORMATIONS_API TransposeSinkingSliceForward;
|
||||
class TRANSFORMATIONS_API TransposeSinkingSliceBackward;
|
||||
|
||||
} // namespace pass
|
||||
} // namespace ov
|
||||
|
||||
class ov::pass::TransposeSinkingSliceForward : public ov::pass::MatcherPass {
|
||||
public:
|
||||
OPENVINO_RTTI("ov::pass::TransposeSinkingSliceForward", "0");
|
||||
TransposeSinkingSliceForward();
|
||||
};
|
||||
|
||||
class ov::pass::TransposeSinkingSliceBackward : public ov::pass::MatcherPass {
|
||||
public:
|
||||
OPENVINO_RTTI("ov::pass::TransposeSinkingSliceBackward", "0");
|
||||
TransposeSinkingSliceBackward();
|
||||
};
|
@ -21,7 +21,7 @@ class TRANSFORMATIONS_API TSReductionBackward;
|
||||
|
||||
/**
|
||||
* @ingroup ie_transformation_common_api
|
||||
* @brief TransposeReductionForward transformation sinks Transpose through Reduce, Squeeze, Unsqueeze operations
|
||||
* @brief TransposeReductionForward transformation sinks Transpose through Reduce, Squeeze operations
|
||||
* in the forward direction.
|
||||
*/
|
||||
class ov::pass::transpose_sinking::TSReductionForward : public ov::pass::MatcherPass {
|
||||
@ -32,7 +32,7 @@ public:
|
||||
|
||||
/**
|
||||
* @ingroup ie_transformation_common_api
|
||||
* @brief TransposeReductionBackward transformation sinks Transpose through Reduce, Squeeze, Unsqueeze operations
|
||||
* @brief TransposeReductionBackward transformation sinks Transpose through Reduce, Squeeze operations
|
||||
* in the backward direction.
|
||||
*/
|
||||
class ov::pass::transpose_sinking::TSReductionBackward : public ov::pass::MatcherPass {
|
||||
|
@ -0,0 +1,32 @@
|
||||
// Copyright (C) 2022-2023 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "openvino/pass/graph_rewrite.hpp"
|
||||
#include "openvino/pass/pass.hpp"
|
||||
#include "../../transformations_visibility.hpp"
|
||||
|
||||
namespace ov {
|
||||
namespace pass {
|
||||
namespace transpose_sinking {
|
||||
|
||||
class TRANSFORMATIONS_API TSSliceForward;
|
||||
class TRANSFORMATIONS_API TSSliceBackward;
|
||||
|
||||
}
|
||||
} // namespace pass
|
||||
} // namespace ov
|
||||
|
||||
class ov::pass::transpose_sinking::TSSliceForward : public ov::pass::MatcherPass {
|
||||
public:
|
||||
OPENVINO_RTTI("ov::pass::TSSliceForward", "0");
|
||||
TSSliceForward();
|
||||
};
|
||||
|
||||
class ov::pass::transpose_sinking::TSSliceBackward : public ov::pass::MatcherPass {
|
||||
public:
|
||||
OPENVINO_RTTI("ov::pass::TSSliceBackward", "0");
|
||||
TSSliceBackward();
|
||||
};
|
@ -0,0 +1,42 @@
|
||||
// Copyright (C) 2022-2023 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "openvino/pass/graph_rewrite.hpp"
|
||||
#include "openvino/pass/pass.hpp"
|
||||
#include "transformations_visibility.hpp"
|
||||
|
||||
namespace ov {
|
||||
namespace pass {
|
||||
namespace transpose_sinking {
|
||||
|
||||
class TRANSFORMATIONS_API TSSqueezeForward;
|
||||
class TRANSFORMATIONS_API TSSqueezeBackward;
|
||||
|
||||
} // namespace transpose_sinking
|
||||
} // namespace pass
|
||||
} // namespace ov
|
||||
|
||||
/**
|
||||
* @ingroup ie_transformation_common_api
|
||||
* @brief TSSqueezeForward transformation sinks Transpose through Reshape, Squeeze operations
|
||||
* in the forward direction.
|
||||
*/
|
||||
class ov::pass::transpose_sinking::TSSqueezeForward : public ov::pass::MatcherPass {
|
||||
public:
|
||||
OPENVINO_RTTI("ov::pass::TSSqueezeForward", "0");
|
||||
TSSqueezeForward();
|
||||
};
|
||||
|
||||
/**
|
||||
* @ingroup ie_transformation_common_api
|
||||
* @brief TSSqueezeBackward transformation sinks Transpose through Reshape, Squeeze operations
|
||||
* in the backward direction.
|
||||
*/
|
||||
class ov::pass::transpose_sinking::TSSqueezeBackward : public ov::pass::MatcherPass {
|
||||
public:
|
||||
OPENVINO_RTTI("ov::pass::TSSqueezeBackward", "0");
|
||||
TSSqueezeBackward();
|
||||
};
|
@ -0,0 +1,42 @@
|
||||
// Copyright (C) 2022-2023 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "openvino/pass/graph_rewrite.hpp"
|
||||
#include "openvino/pass/pass.hpp"
|
||||
#include "transformations_visibility.hpp"
|
||||
|
||||
namespace ov {
|
||||
namespace pass {
|
||||
namespace transpose_sinking {
|
||||
|
||||
class TRANSFORMATIONS_API TSUnsqueezeForward;
|
||||
class TRANSFORMATIONS_API TSUnsqueezeBackward;
|
||||
|
||||
} // namespace transpose_sinking
|
||||
} // namespace pass
|
||||
} // namespace ov
|
||||
|
||||
/**
|
||||
* @ingroup ie_transformation_common_api
|
||||
* @brief TSUnsqueezeForward transformation sinks Transpose through Unsqueeze operation
|
||||
* in the forward direction.
|
||||
*/
|
||||
class ov::pass::transpose_sinking::TSUnsqueezeForward : public ov::pass::MatcherPass {
|
||||
public:
|
||||
OPENVINO_RTTI("ov::pass::TSUnsqueezeForward", "0");
|
||||
TSUnsqueezeForward();
|
||||
};
|
||||
|
||||
/**
|
||||
* @ingroup ie_transformation_common_api
|
||||
* @brief TSUnsqueezeBackward transformation sinks Transpose through Unsqueeze operation
|
||||
* in the backward direction.
|
||||
*/
|
||||
class ov::pass::transpose_sinking::TSUnsqueezeBackward : public ov::pass::MatcherPass {
|
||||
public:
|
||||
OPENVINO_RTTI("ov::pass::TSUnsqueezeBackward", "0");
|
||||
TSUnsqueezeBackward();
|
||||
};
|
@ -98,17 +98,43 @@ void UpdateForwardSinkingAbility(const std::shared_ptr<ov::Node>&);
|
||||
bool HasSameOutputTransposeNodes(const ov::Output<ov::Node>&);
|
||||
|
||||
/**
|
||||
* Removes all direct node consumers that have one output
|
||||
* @brief Removes all direct node consumers that have one output
|
||||
*/
|
||||
void RemoveSingleOutputConsumers(const std::shared_ptr<ov::Node>&);
|
||||
|
||||
/**
|
||||
* Changes the order of values in @arg input according to @arg transpose_axis_order along @arg axis
|
||||
* @brief Changes the order of values in @arg input according to @arg transpose_axis_order along @arg axis
|
||||
*/
|
||||
ov::Output<ov::Node> ChangeValuesOrder(const ov::Output<ov::Node>& input,
|
||||
const ov::AxisVector& transpose_axis_order,
|
||||
const std::shared_ptr<ov::opset10::Constant>& axis);
|
||||
|
||||
/**
|
||||
* @brief Returns the updated axes order for case when the initial axes order has more elements
|
||||
* than after TransposeSinking, e.g.:
|
||||
*
|
||||
* before: Transpose(the initial axes order) -> ReduceMax
|
||||
* after : ReduceMax -> Transpose (the updated axes order)
|
||||
*
|
||||
* before: Unsqueeze -> Transpose (the initial axes order)
|
||||
* after : Transpose (the updated axes order) -> Unsqueeze
|
||||
*/
|
||||
std::vector<size_t> GetOrderAfterReduction(const std::vector<size_t>& axes_values,
|
||||
const std::vector<size_t>& order_values);
|
||||
|
||||
/**
|
||||
* @brief Returns the updated axes order for case when the initial axes order has less elements
|
||||
* than after TransposeSinking, e.g.:
|
||||
*
|
||||
* before : ReduceMax -> Transpose (the updated axes order)
|
||||
* after: Transpose(the initial axes order) -> ReduceMax
|
||||
*
|
||||
* before: Transpose (the updated axes order) -> Unsqueeze
|
||||
* after : Unsqueeze -> Transpose (the initial axes order)
|
||||
*/
|
||||
std::vector<size_t> GetOrderBeforeReduction(const std::vector<size_t>& axes_values,
|
||||
const std::vector<size_t>& order_values);
|
||||
|
||||
} // namespace utils
|
||||
} // namespace transpose_sinking
|
||||
} // namespace pass
|
||||
|
@ -15,6 +15,8 @@
|
||||
#include "transformations/transpose_sinking/ts_fuse.hpp"
|
||||
#include "transformations/transpose_sinking/ts_interpolate.hpp"
|
||||
#include "transformations/transpose_sinking/ts_reduction.hpp"
|
||||
#include "transformations/transpose_sinking/ts_squeeze.hpp"
|
||||
#include "transformations/transpose_sinking/ts_unsqueeze.hpp"
|
||||
#include "transformations/transpose_sinking/ts_split.hpp"
|
||||
#include "transformations/transpose_sinking/ts_unary.hpp"
|
||||
#include "transformations/utils/utils.hpp"
|
||||
@ -29,6 +31,8 @@ TSGeneralForward::TSGeneralForward() {
|
||||
add_matcher<TSSplitForward>();
|
||||
add_matcher<TSDataMovementForward>();
|
||||
add_matcher<TSReductionForward>();
|
||||
add_matcher<TSSqueezeForward>();
|
||||
add_matcher<TSUnsqueezeForward>();
|
||||
add_matcher<TSInterpolateForward>();
|
||||
add_matcher<TSFuse>();
|
||||
}
|
||||
@ -41,6 +45,8 @@ TSGeneralBackward::TSGeneralBackward() {
|
||||
add_matcher<TSSplitBackward>();
|
||||
add_matcher<TSDataMovementBackward>();
|
||||
add_matcher<TSReductionBackward>();
|
||||
add_matcher<TSSqueezeBackward>();
|
||||
add_matcher<TSUnsqueezeBackward>();
|
||||
add_matcher<TSInterpolateBackward>();
|
||||
add_matcher<TSFuse>();
|
||||
}
|
||||
|
@ -18,60 +18,15 @@
|
||||
|
||||
using namespace ov;
|
||||
using namespace opset10;
|
||||
using namespace ov::pass::pattern;
|
||||
using namespace ov::pass::transpose_sinking;
|
||||
using namespace ov::pass::transpose_sinking::utils;
|
||||
|
||||
namespace {
|
||||
std::vector<size_t> get_updated_order_forward(const std::vector<size_t>& axes_values,
|
||||
const std::vector<size_t>& order_values) {
|
||||
size_t buffer_size = order_values.size() - axes_values.size();
|
||||
std::vector<size_t> aligned_order(buffer_size, 0);
|
||||
std::vector<size_t> values_to_reduce(axes_values);
|
||||
for (size_t i = 0; i < values_to_reduce.size(); ++i) {
|
||||
values_to_reduce[i] = order_values[axes_values[i]];
|
||||
}
|
||||
std::sort(values_to_reduce.begin(), values_to_reduce.end());
|
||||
for (size_t i = 0, j = 0; i < order_values.size(); ++i) {
|
||||
if (std::find(axes_values.begin(), axes_values.end(), i) != axes_values.end()) {
|
||||
continue;
|
||||
}
|
||||
|
||||
auto lb = std::lower_bound(values_to_reduce.begin(), values_to_reduce.end(), order_values[i]);
|
||||
aligned_order[j] = order_values[i] - (lb - values_to_reduce.begin());
|
||||
++j;
|
||||
}
|
||||
return aligned_order;
|
||||
}
|
||||
|
||||
std::vector<size_t> get_updated_order_backward(const std::vector<size_t>& axes_values,
|
||||
const std::vector<size_t>& order_values) {
|
||||
size_t buffer_size = order_values.size() + axes_values.size();
|
||||
std::vector<size_t> aligned_order(buffer_size);
|
||||
|
||||
std::vector<int64_t> cnt_deleted(buffer_size);
|
||||
int64_t cnt = 0;
|
||||
for (int64_t i = 0; i < static_cast<int64_t>(cnt_deleted.size()); ++i) {
|
||||
if (std::find(axes_values.begin(), axes_values.end(), i) != axes_values.end()) {
|
||||
cnt++;
|
||||
}
|
||||
cnt_deleted[i] = i - cnt;
|
||||
}
|
||||
|
||||
for (size_t i = 0, j = 0; i < aligned_order.size(); ++i) {
|
||||
if (std::find(axes_values.begin(), axes_values.end(), i) != axes_values.end()) {
|
||||
aligned_order[i] = i;
|
||||
continue;
|
||||
}
|
||||
|
||||
aligned_order[i] = std::find(cnt_deleted.begin(), cnt_deleted.end(), order_values[j]) - cnt_deleted.begin();
|
||||
++j;
|
||||
}
|
||||
return aligned_order;
|
||||
}
|
||||
|
||||
bool get_keep_dims(const std::shared_ptr<Node>& reduction) {
|
||||
auto arithmetic_reduce = std::dynamic_pointer_cast<ov::op::util::ArithmeticReductionKeepDims>(reduction);
|
||||
auto logical_reduce = std::dynamic_pointer_cast<ov::op::util::LogicalReductionKeepDims>(reduction);
|
||||
bool get_keep_dims(const std::shared_ptr<Node> &reduction) {
|
||||
auto arithmetic_reduce = as_type_ptr<ov::op::util::ArithmeticReductionKeepDims>(reduction);
|
||||
auto logical_reduce = as_type_ptr<ov::op::util::LogicalReductionKeepDims>(reduction);
|
||||
|
||||
bool keep_dims = false; // squeeze/unsqueeze always reduces number of output dimensions
|
||||
if (logical_reduce)
|
||||
@ -80,32 +35,29 @@ bool get_keep_dims(const std::shared_ptr<Node>& reduction) {
|
||||
keep_dims = arithmetic_reduce->get_keep_dims();
|
||||
return keep_dims;
|
||||
}
|
||||
} // namespace
|
||||
|
||||
}
|
||||
|
||||
TSReductionForward::TSReductionForward() {
|
||||
MATCHER_SCOPE(TSReductionForward);
|
||||
|
||||
auto transpose_label = pattern::wrap_type<Transpose>({pattern::any_input(), pattern::wrap_type<Constant>()},
|
||||
pattern::consumers_count(1));
|
||||
auto reduce_or_squeeze_label = pattern::
|
||||
wrap_type<op::util::ArithmeticReductionKeepDims, op::util::LogicalReductionKeepDims, Squeeze, Unsqueeze>(
|
||||
{transpose_label, pattern::wrap_type<Constant>()});
|
||||
auto transpose_label = wrap_type<Transpose>({any_input(), wrap_type<Constant>()});
|
||||
auto reduce_label = wrap_type<op::util::ArithmeticReductionKeepDims, op::util::LogicalReductionKeepDims>(
|
||||
{transpose_label, wrap_type<Constant>()});
|
||||
|
||||
ov::matcher_pass_callback matcher_pass_callback = [=](pattern::Matcher& m) {
|
||||
const auto& pattern_to_output = m.get_pattern_value_map();
|
||||
const auto& pattern_to_output = m.get_pattern_map();
|
||||
|
||||
auto transpose = pattern_to_output.at(transpose_label).get_node_shared_ptr();
|
||||
auto reduction = pattern_to_output.at(reduce_or_squeeze_label).get_node_shared_ptr();
|
||||
auto transpose = pattern_to_output.at(transpose_label);
|
||||
auto reduction = pattern_to_output.at(reduce_label);
|
||||
auto keep_dims = get_keep_dims(reduction);
|
||||
|
||||
auto transpose_order = std::dynamic_pointer_cast<Constant>(transpose->get_input_node_shared_ptr(1));
|
||||
auto reduction_axes = std::dynamic_pointer_cast<Constant>(reduction->get_input_node_shared_ptr(1));
|
||||
auto transpose_order = as_type_ptr<Constant>(transpose->get_input_node_shared_ptr(1));
|
||||
auto reduction_axes = as_type_ptr<Constant>(reduction->get_input_node_shared_ptr(1));
|
||||
if (!transpose_order || !reduction_axes)
|
||||
return false;
|
||||
|
||||
auto unsqueeze = std::dynamic_pointer_cast<Unsqueeze>(reduction);
|
||||
auto rank =
|
||||
unsqueeze ? reduction->get_output_partial_shape(0).rank() : reduction->get_input_partial_shape(0).rank();
|
||||
auto rank = reduction->get_input_partial_shape(0).rank();
|
||||
auto non_negative_axes =
|
||||
normalize_axes(reduction->get_friendly_name(), reduction_axes->cast_vector<int64_t>(), rank);
|
||||
|
||||
@ -117,38 +69,16 @@ TSReductionForward::TSReductionForward() {
|
||||
}
|
||||
|
||||
if (!keep_dims) {
|
||||
if (non_negative_axes.empty()) {
|
||||
auto input_pshape = transpose->input_value(0).get_partial_shape();
|
||||
|
||||
if (input_pshape.is_static()) {
|
||||
for (size_t i = 0; i < input_pshape.size(); ++i) {
|
||||
if (input_pshape[i] == 1) {
|
||||
non_negative_axes.push_back(i);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
if (unsqueeze) {
|
||||
transpose_order_values = get_updated_order_backward(non_negative_axes, transpose_order_values);
|
||||
} else {
|
||||
transpose_order_values = get_updated_order_forward(non_negative_axes, transpose_order_values);
|
||||
}
|
||||
transpose_order_values = GetOrderAfterReduction(non_negative_axes, transpose_order_values);
|
||||
}
|
||||
auto new_transpose_order = std::make_shared<Constant>(transpose_order->get_element_type(),
|
||||
Shape{transpose_order_values.size()},
|
||||
transpose_order_values);
|
||||
|
||||
std::shared_ptr<Node> new_reduction;
|
||||
if (!unsqueeze) {
|
||||
auto new_const =
|
||||
std::make_shared<Constant>(reduction_axes->get_element_type(), reduction_axes->get_shape(), new_values);
|
||||
new_reduction = reduction->clone_with_new_inputs({transpose->input_value(0), new_const});
|
||||
} else {
|
||||
new_reduction = reduction->clone_with_new_inputs({transpose->input_value(0), reduction->input_value(1)});
|
||||
}
|
||||
auto new_transpose_order = Constant::create(transpose_order->get_element_type(), {transpose_order_values.size()},
|
||||
transpose_order_values);
|
||||
|
||||
auto new_const = Constant::create(reduction_axes->get_element_type(), reduction_axes->get_shape(), new_values);
|
||||
auto new_reduction = reduction->clone_with_new_inputs({transpose->input_value(0), new_const});
|
||||
auto new_transpose = transpose->clone_with_new_inputs({new_reduction, new_transpose_order});
|
||||
|
||||
replace_node(reduction, new_transpose);
|
||||
new_reduction->set_friendly_name(transpose->get_friendly_name());
|
||||
new_transpose->set_friendly_name(reduction->get_friendly_name());
|
||||
@ -158,98 +88,53 @@ TSReductionForward::TSReductionForward() {
|
||||
return true;
|
||||
};
|
||||
|
||||
auto m = std::make_shared<pattern::Matcher>(reduce_or_squeeze_label, matcher_name);
|
||||
auto m = std::make_shared<pattern::Matcher>(reduce_label, matcher_name);
|
||||
register_matcher(m, matcher_pass_callback);
|
||||
}
|
||||
|
||||
TSReductionBackward::TSReductionBackward() {
|
||||
MATCHER_SCOPE(TSReductionBackward);
|
||||
|
||||
auto reduce_or_squeeze_label = pattern::
|
||||
wrap_type<op::util::ArithmeticReductionKeepDims, op::util::LogicalReductionKeepDims, Squeeze, Unsqueeze>(
|
||||
{pattern::any_input(), pattern::wrap_type<Constant>()},
|
||||
HasSameOutputTransposeNodes);
|
||||
auto transpose_label = pattern::wrap_type<Transpose>({reduce_or_squeeze_label, pattern::wrap_type<Constant>()});
|
||||
ov::matcher_pass_callback matcher_pass_callback = [=](pattern::Matcher& m) {
|
||||
const auto& pattern_to_output = m.get_pattern_value_map();
|
||||
auto reduce_label = wrap_type<op::util::ArithmeticReductionKeepDims, op::util::LogicalReductionKeepDims>(
|
||||
{any_input(), wrap_type<Constant>()}, HasSameOutputTransposeNodes);
|
||||
auto transpose_label = wrap_type<Transpose>({reduce_label, wrap_type<Constant>()});
|
||||
|
||||
auto transpose = pattern_to_output.at(transpose_label).get_node_shared_ptr();
|
||||
auto reduction = pattern_to_output.at(reduce_or_squeeze_label).get_node_shared_ptr();
|
||||
ov::matcher_pass_callback matcher_pass_callback = [=](pattern::Matcher& m) {
|
||||
const auto& pattern_to_output = m.get_pattern_map();
|
||||
auto transpose = pattern_to_output.at(transpose_label);
|
||||
auto reduction = pattern_to_output.at(reduce_label);
|
||||
auto keep_dims = get_keep_dims(reduction);
|
||||
auto transpose_order = std::dynamic_pointer_cast<Constant>(transpose->get_input_node_shared_ptr(1));
|
||||
auto reduction_axes = std::dynamic_pointer_cast<Constant>(reduction->get_input_node_shared_ptr(1));
|
||||
|
||||
auto transpose_order = as_type_ptr<Constant>(transpose->get_input_node_shared_ptr(1));
|
||||
auto reduction_axes = as_type_ptr<Constant>(reduction->get_input_node_shared_ptr(1));
|
||||
if (!transpose_order || !reduction_axes)
|
||||
return false;
|
||||
|
||||
auto unsqueeze = std::dynamic_pointer_cast<Unsqueeze>(reduction);
|
||||
auto rank =
|
||||
unsqueeze ? reduction->get_output_partial_shape(0).rank() : reduction->get_input_partial_shape(0).rank();
|
||||
auto rank = reduction->get_input_partial_shape(0).rank();
|
||||
auto non_negative_axes =
|
||||
normalize_axes(reduction->get_friendly_name(), reduction_axes->cast_vector<int64_t>(), rank);
|
||||
|
||||
normalize_axes(reduction->get_friendly_name(), reduction_axes->cast_vector<int64_t>(), rank);
|
||||
auto transpose_order_values = transpose_order->cast_vector<size_t>();
|
||||
auto old_transpose_order_values = transpose_order_values;
|
||||
std::vector<size_t> new_values;
|
||||
if (unsqueeze) {
|
||||
if (non_negative_axes.size() == transpose_order_values.size()) {
|
||||
// input is a scalar, we unsqueeze all dims
|
||||
// it's enough to eliminate such Transpose
|
||||
transpose->output(0).replace(reduction);
|
||||
return true;
|
||||
}
|
||||
for (const auto& axis : non_negative_axes) {
|
||||
auto it = std::find(old_transpose_order_values.begin(), old_transpose_order_values.end(), axis);
|
||||
if (it != old_transpose_order_values.end()) {
|
||||
new_values.push_back(it - old_transpose_order_values.begin());
|
||||
}
|
||||
}
|
||||
}
|
||||
bool squeeze_all_dims = false;
|
||||
if (!keep_dims) {
|
||||
if (non_negative_axes.empty()) {
|
||||
auto input_pshape = reduction->input_value(0).get_partial_shape();
|
||||
if (input_pshape.is_static()) {
|
||||
for (size_t i = 0; i < input_pshape.size(); ++i) {
|
||||
if (input_pshape[i] == 1) {
|
||||
non_negative_axes.push_back(i);
|
||||
}
|
||||
}
|
||||
squeeze_all_dims = true;
|
||||
} else {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
if (unsqueeze) {
|
||||
transpose_order_values = get_updated_order_forward(new_values, transpose_order_values);
|
||||
} else {
|
||||
transpose_order_values = get_updated_order_backward(non_negative_axes, transpose_order_values);
|
||||
}
|
||||
transpose_order_values = GetOrderBeforeReduction(non_negative_axes, transpose_order_values);
|
||||
}
|
||||
auto reversed_order_values = ReverseTransposeOrder(transpose_order_values);
|
||||
auto new_transpose_order = Constant::create(transpose_order->get_element_type(), {transpose_order_values.size()},
|
||||
transpose_order_values);
|
||||
|
||||
std::vector<size_t> new_values;
|
||||
for (const auto& axis : non_negative_axes) {
|
||||
new_values.push_back(reversed_order_values[axis]);
|
||||
}
|
||||
|
||||
if (!unsqueeze) {
|
||||
auto reversed_order_values = ReverseTransposeOrder(transpose_order_values);
|
||||
for (const auto& axis : non_negative_axes) {
|
||||
new_values.push_back(reversed_order_values[axis]);
|
||||
}
|
||||
}
|
||||
auto new_const = Constant::create(reduction_axes->get_element_type(), reduction_axes->get_shape(), new_values);
|
||||
auto new_transpose = transpose->clone_with_new_inputs({reduction->input_value(0), new_transpose_order});
|
||||
auto new_reduction = reduction->clone_with_new_inputs({new_transpose, new_const});
|
||||
|
||||
auto new_transpose_order = std::make_shared<Constant>(transpose_order->get_element_type(),
|
||||
Shape{transpose_order_values.size()},
|
||||
transpose_order_values);
|
||||
std::shared_ptr<Node> new_transpose, new_reduction;
|
||||
if (squeeze_all_dims) {
|
||||
new_transpose = transpose->clone_with_new_inputs({reduction->input_value(0), new_transpose_order});
|
||||
new_reduction = reduction->clone_with_new_inputs({new_transpose, reduction->input_value(1)});
|
||||
} else {
|
||||
auto new_const =
|
||||
std::make_shared<Constant>(reduction_axes->get_element_type(), reduction_axes->get_shape(), new_values);
|
||||
new_transpose = transpose->clone_with_new_inputs({reduction->input_value(0), new_transpose_order});
|
||||
new_reduction = reduction->clone_with_new_inputs({new_transpose, new_const});
|
||||
}
|
||||
replace_node(transpose, new_reduction);
|
||||
copy_runtime_info({transpose, reduction}, {new_transpose, new_reduction});
|
||||
UpdateForwardSinkingAbility(new_transpose);
|
||||
new_reduction->set_friendly_name(transpose->get_friendly_name());
|
||||
new_transpose->set_friendly_name(reduction->get_friendly_name());
|
||||
register_new_node(new_transpose);
|
||||
return true;
|
||||
};
|
||||
|
@ -2,25 +2,26 @@
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#include "transformations/common_optimizations/transpose_sinking_slice.hpp"
|
||||
#include "transformations/transpose_sinking/ts_slice.hpp"
|
||||
|
||||
#include <openvino/pass/pattern/op/or.hpp>
|
||||
#include "openvino/pass/pattern/op/or.hpp"
|
||||
|
||||
#include "itt.hpp"
|
||||
#include "openvino/op/util/op_types.hpp"
|
||||
#include "openvino/opsets/opset10.hpp"
|
||||
#include "openvino/pass/pattern/op/wrap_type.hpp"
|
||||
#include "openvino/util/common_util.hpp"
|
||||
#include "transformations/common_optimizations/transpose_sinking_utils.hpp"
|
||||
#include "transformations/transpose_sinking/ts_utils.hpp"
|
||||
#include "transformations/rt_info/transpose_sinking_attr.hpp"
|
||||
|
||||
using namespace ov;
|
||||
using namespace ov::opset10;
|
||||
using namespace ov::pass::pattern;
|
||||
using namespace transpose_sinking;
|
||||
using namespace ov::pass::transpose_sinking;
|
||||
using namespace ov::pass::transpose_sinking::utils;
|
||||
|
||||
ov::pass::TransposeSinkingSliceForward::TransposeSinkingSliceForward() {
|
||||
MATCHER_SCOPE(TransposeSinkingSliceForward);
|
||||
TSSliceForward::TSSliceForward() {
|
||||
MATCHER_SCOPE(TSSliceForward);
|
||||
auto const_label = wrap_type<Constant>();
|
||||
auto transpose_label = wrap_type<Transpose>({any_input(), const_label});
|
||||
auto main_node_label = wrap_type<Slice>({transpose_label, any_input(), any_input(), any_input(), any_input()});
|
||||
@ -56,7 +57,7 @@ ov::pass::TransposeSinkingSliceForward::TransposeSinkingSliceForward() {
|
||||
TransposeInputsInfo transpose_input_info = {transpose, transpose_const, 0};
|
||||
for (auto& new_node : sink_forward::InsertOutputTransposes(main_node, transpose_input_info)) {
|
||||
register_new_node(new_node);
|
||||
transpose_sinking::UpdateForwardSinkingAbility(new_node);
|
||||
UpdateForwardSinkingAbility(new_node);
|
||||
}
|
||||
return true;
|
||||
};
|
||||
@ -65,8 +66,8 @@ ov::pass::TransposeSinkingSliceForward::TransposeSinkingSliceForward() {
|
||||
register_matcher(m, matcher_pass_callback);
|
||||
}
|
||||
|
||||
ov::pass::TransposeSinkingSliceBackward::TransposeSinkingSliceBackward() {
|
||||
MATCHER_SCOPE(TransposeSinkingSliceBackward);
|
||||
TSSliceBackward::TSSliceBackward() {
|
||||
MATCHER_SCOPE(TSSliceBackward);
|
||||
|
||||
auto main_node_label = wrap_type<Slice>([](const Output<Node>& output) -> bool {
|
||||
return has_static_rank()(output) && HasSameOutputTransposeNodes(output);
|
@ -0,0 +1,228 @@
|
||||
// Copyright (C) 2018-2023 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#include "transformations/transpose_sinking/ts_squeeze.hpp"
|
||||
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
|
||||
#include "itt.hpp"
|
||||
#include "openvino/core/validation_util.hpp"
|
||||
#include "openvino/opsets/opset10.hpp"
|
||||
#include "openvino/pass/pattern/op/wrap_type.hpp"
|
||||
#include "transformations/transpose_sinking/ts_utils.hpp"
|
||||
#include "transformations/utils/utils.hpp"
|
||||
|
||||
using namespace ov;
|
||||
using namespace opset10;
|
||||
using namespace ov::pass::pattern;
|
||||
using namespace ov::pass::transpose_sinking;
|
||||
using namespace ov::pass::transpose_sinking::utils;
|
||||
|
||||
namespace {
|
||||
|
||||
bool shape_to_squeeze_axes(const std::shared_ptr<Node>& reshape,
|
||||
const std::shared_ptr<Constant>& reshape_to_shape,
|
||||
std::vector<size_t>& result_axes) {
|
||||
result_axes.clear();
|
||||
auto reduction_axes_values = reshape_to_shape->cast_vector<int64_t>();
|
||||
// supported the case if Reshape is equal to Squeeze
|
||||
const auto &new_shape = reduction_axes_values;
|
||||
const auto &input_pshape = reshape->get_input_partial_shape(0);
|
||||
// todo: support dynamic case
|
||||
if (input_pshape.is_dynamic()) {
|
||||
return false;
|
||||
}
|
||||
|
||||
const auto input_shape = input_pshape.to_shape();
|
||||
if (new_shape.size() < input_shape.size()) {
|
||||
for (size_t i = 0, j = 0; i < new_shape.size(); j++) {
|
||||
if (new_shape[i] == input_shape[j]) {
|
||||
i++;
|
||||
} else if (new_shape[i] != input_shape[j] && input_shape[j] != 1) {
|
||||
return false;
|
||||
} else {
|
||||
result_axes.push_back(j);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// another reshape type, not Squeeze
|
||||
// todo: move this checks in the pattern
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
std::vector<size_t> squeeze_axes_to_shape(const std::shared_ptr<Node>& input_node, std::vector<size_t> squeeze_axes) {
|
||||
std::vector<size_t> to_shape;
|
||||
std::sort(squeeze_axes.begin(), squeeze_axes.end());
|
||||
const auto& input_shape = input_node->input(0).get_shape(); // check is static
|
||||
for (size_t i = 0, j = 0; i < input_shape.size(); ++i) {
|
||||
if (j < squeeze_axes.size() && i == squeeze_axes[j]) {
|
||||
++j;
|
||||
continue;
|
||||
}
|
||||
to_shape.push_back(input_shape[i]);
|
||||
}
|
||||
return to_shape;
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
TSSqueezeForward::TSSqueezeForward() {
|
||||
MATCHER_SCOPE(TSSqueezeForward);
|
||||
|
||||
auto transpose_label = wrap_type<Transpose>({any_input(), wrap_type<Constant>()});
|
||||
auto squeeze_label = wrap_type<Squeeze, Reshape>({transpose_label, wrap_type<Constant>()});
|
||||
|
||||
ov::matcher_pass_callback matcher_pass_callback = [=](Matcher& m) {
|
||||
const auto& pattern_to_output = m.get_pattern_map();
|
||||
|
||||
auto transpose = pattern_to_output.at(transpose_label);
|
||||
auto squeeze = pattern_to_output.at(squeeze_label);
|
||||
|
||||
auto transpose_order = as_type_ptr<Constant>(transpose->get_input_node_shared_ptr(1));
|
||||
auto squeeze_axes = as_type_ptr<Constant>(squeeze->get_input_node_shared_ptr(1));
|
||||
if (!transpose_order || !squeeze_axes) {
|
||||
return false;
|
||||
}
|
||||
|
||||
std::vector<size_t> non_negative_axes;
|
||||
if (as_type_ptr<Reshape>(squeeze)) {
|
||||
auto success = shape_to_squeeze_axes(squeeze, squeeze_axes, non_negative_axes);
|
||||
if (!success) {
|
||||
return false;
|
||||
}
|
||||
} else {
|
||||
auto rank = squeeze->get_input_partial_shape(0).rank();
|
||||
non_negative_axes = normalize_axes(squeeze->get_friendly_name(), squeeze_axes->cast_vector<int64_t>(), rank);
|
||||
}
|
||||
|
||||
// if 2nd input to squeeze is empty then all '1' dims will be deleted.
|
||||
if (non_negative_axes.empty()) {
|
||||
auto input_pshape = transpose->input_value(0).get_partial_shape();
|
||||
if (input_pshape.is_dynamic()) {
|
||||
return false;
|
||||
}
|
||||
for (size_t i = 0; i < input_pshape.size(); ++i) {
|
||||
if (input_pshape[i].get_length() == 1) {
|
||||
non_negative_axes.push_back(i);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
auto transpose_order_values = transpose_order->cast_vector<size_t>();
|
||||
std::vector<size_t> new_values;
|
||||
new_values.reserve(non_negative_axes.size());
|
||||
for (const auto& axis : non_negative_axes) {
|
||||
new_values.push_back(transpose_order_values[axis]);
|
||||
}
|
||||
|
||||
transpose_order_values = GetOrderAfterReduction(non_negative_axes, transpose_order_values);
|
||||
auto new_transpose_order = Constant::create(transpose_order->get_element_type(),
|
||||
{transpose_order_values.size()},
|
||||
transpose_order_values);
|
||||
|
||||
if (as_type_ptr<Reshape>(squeeze)) {
|
||||
new_values = squeeze_axes_to_shape(transpose, new_values);
|
||||
}
|
||||
|
||||
auto new_const = Constant::create(squeeze_axes->get_element_type(), squeeze_axes->get_shape(), new_values);
|
||||
auto new_squeeze = squeeze->clone_with_new_inputs({transpose->input_value(0), new_const});
|
||||
auto new_transpose = transpose->clone_with_new_inputs({new_squeeze, new_transpose_order});
|
||||
|
||||
replace_node(squeeze, new_transpose);
|
||||
new_squeeze->set_friendly_name(transpose->get_friendly_name());
|
||||
new_transpose->set_friendly_name(squeeze->get_friendly_name());
|
||||
UpdateForwardSinkingAbility(new_transpose);
|
||||
register_new_node(new_transpose);
|
||||
copy_runtime_info({transpose, squeeze}, {new_transpose, new_squeeze});
|
||||
|
||||
return true;
|
||||
};
|
||||
|
||||
auto m = std::make_shared<pattern::Matcher>(squeeze_label, matcher_name);
|
||||
register_matcher(m, matcher_pass_callback);
|
||||
}
|
||||
|
||||
TSSqueezeBackward::TSSqueezeBackward() {
|
||||
MATCHER_SCOPE(TSSqueezeBackward);
|
||||
|
||||
auto squeeze_label = wrap_type<Squeeze, Reshape>({any_input(), wrap_type<Constant>()}, HasSameOutputTransposeNodes);
|
||||
auto transpose_label = wrap_type<Transpose>({squeeze_label, wrap_type<Constant>()});
|
||||
|
||||
ov::matcher_pass_callback matcher_pass_callback = [=](Matcher& m) {
|
||||
const auto& pattern_to_output = m.get_pattern_map();
|
||||
|
||||
auto transpose = pattern_to_output.at(transpose_label);
|
||||
auto squeeze = pattern_to_output.at(squeeze_label);
|
||||
|
||||
auto transpose_order = as_type_ptr<Constant>(transpose->get_input_node_shared_ptr(1));
|
||||
auto squeeze_axes = as_type_ptr<Constant>(squeeze->get_input_node_shared_ptr(1));
|
||||
if (!transpose_order || !squeeze_axes) {
|
||||
return false;
|
||||
}
|
||||
|
||||
std::vector<size_t> non_negative_axes;
|
||||
if (as_type_ptr<Reshape>(squeeze)) {
|
||||
auto success = shape_to_squeeze_axes(squeeze, squeeze_axes, non_negative_axes);
|
||||
if (!success) {
|
||||
return false;
|
||||
}
|
||||
} else {
|
||||
auto rank = squeeze->get_input_partial_shape(0).rank();
|
||||
non_negative_axes = normalize_axes(squeeze->get_friendly_name(), squeeze_axes->cast_vector<int64_t>(), rank);
|
||||
}
|
||||
|
||||
bool squeeze_all_dims = false;
|
||||
if (non_negative_axes.empty()) {
|
||||
auto input_pshape = squeeze->input_value(0).get_partial_shape();
|
||||
if (input_pshape.is_dynamic()) {
|
||||
return false;
|
||||
}
|
||||
for (size_t i = 0; i < input_pshape.size(); ++i) {
|
||||
if (input_pshape[i] == 1) {
|
||||
non_negative_axes.push_back(i);
|
||||
}
|
||||
}
|
||||
squeeze_all_dims = true;
|
||||
}
|
||||
|
||||
auto transpose_order_values = transpose_order->cast_vector<size_t>();
|
||||
transpose_order_values = GetOrderBeforeReduction(non_negative_axes, transpose_order_values);
|
||||
auto reversed_order_values = ReverseTransposeOrder(transpose_order_values);
|
||||
|
||||
std::vector<size_t> new_values;
|
||||
for (const auto& axis : non_negative_axes) {
|
||||
new_values.push_back(reversed_order_values[axis]);
|
||||
}
|
||||
|
||||
auto new_transpose_order = Constant::create(transpose_order->get_element_type(),
|
||||
{transpose_order_values.size()},
|
||||
transpose_order_values);
|
||||
if (as_type_ptr<Reshape>(squeeze)) {
|
||||
new_values = squeeze_axes_to_shape(squeeze, new_values);
|
||||
}
|
||||
|
||||
std::shared_ptr<Node> new_squeeze;
|
||||
auto new_transpose = transpose->clone_with_new_inputs({squeeze->input_value(0), new_transpose_order});
|
||||
if (squeeze_all_dims) {
|
||||
new_squeeze = squeeze->clone_with_new_inputs({new_transpose, squeeze->input_value(1)});
|
||||
} else {
|
||||
auto new_const = std::make_shared<Constant>(squeeze_axes->get_element_type(), squeeze_axes->get_shape(), new_values);
|
||||
new_squeeze = squeeze->clone_with_new_inputs({new_transpose, new_const});
|
||||
}
|
||||
|
||||
replace_node(transpose, new_squeeze);
|
||||
copy_runtime_info({transpose, squeeze}, {new_transpose, new_squeeze});
|
||||
UpdateForwardSinkingAbility(new_transpose);
|
||||
new_squeeze->set_friendly_name(transpose->get_friendly_name());
|
||||
new_transpose->set_friendly_name(squeeze->get_friendly_name());
|
||||
register_new_node(new_transpose);
|
||||
return true;
|
||||
};
|
||||
|
||||
auto m = std::make_shared<pattern::Matcher>(transpose_label, matcher_name);
|
||||
register_matcher(m, matcher_pass_callback);
|
||||
}
|
@ -0,0 +1,204 @@
|
||||
// Copyright (C) 2018-2023 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#include "transformations/transpose_sinking/ts_unsqueeze.hpp"
|
||||
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
|
||||
#include "itt.hpp"
|
||||
#include "openvino/core/validation_util.hpp"
|
||||
#include "openvino/opsets/opset10.hpp"
|
||||
#include "openvino/pass/pattern/op/wrap_type.hpp"
|
||||
#include "transformations/transpose_sinking/ts_utils.hpp"
|
||||
#include "transformations/utils/utils.hpp"
|
||||
|
||||
using namespace ov;
|
||||
using namespace opset10;
|
||||
using namespace ov::pass::pattern;
|
||||
using namespace ov::pass::transpose_sinking;
|
||||
using namespace ov::pass::transpose_sinking::utils;
|
||||
|
||||
namespace {
|
||||
|
||||
bool shape_to_unsqueeze_axes(const std::shared_ptr<Node>& reshape,
|
||||
const std::shared_ptr<Constant>& reshape_to_shape,
|
||||
std::vector<size_t>& result_axes) {
|
||||
result_axes.clear();
|
||||
auto reduction_axes_values = reshape_to_shape->cast_vector<int64_t>();
|
||||
// supported the case if Reshape is equal to Unsqueeze
|
||||
const auto &new_shape = reduction_axes_values;
|
||||
const auto &input_pshape = reshape->get_input_partial_shape(0);
|
||||
// todo: support dynamic case
|
||||
if (input_pshape.is_dynamic()) {
|
||||
return false;
|
||||
}
|
||||
|
||||
const auto input_shape = input_pshape.to_shape();
|
||||
if (new_shape.size() > input_shape.size()) {
|
||||
for (size_t i = 0, j = 0; i < input_shape.size();j++) {
|
||||
if (input_shape[i] == new_shape[j]) {
|
||||
i++;
|
||||
} else if (input_shape[i] != new_shape[j] && new_shape[j] != 1) {
|
||||
return false;
|
||||
} else {
|
||||
result_axes.push_back(j);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// another reshape type, not Unsqueeze
|
||||
// todo: move this checks in the pattern
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
std::vector<size_t> unsqueeze_axes_to_shape(const std::shared_ptr<Node>& input_node, std::vector<size_t> unsqueeze_axes) {
|
||||
const auto& input_shape = input_node->input(0).get_shape(); // check is static
|
||||
std::vector<size_t> to_shape(input_shape.size() + unsqueeze_axes.size());
|
||||
std::sort(unsqueeze_axes.begin(), unsqueeze_axes.end());
|
||||
std::stack<size_t, std::vector<size_t>> shape_to_add(input_shape);
|
||||
for (size_t i = 0, j = 0; i < to_shape.size(); ++i) {
|
||||
if (j < unsqueeze_axes.size() && i == unsqueeze_axes[j]) {
|
||||
to_shape[i] = 1;
|
||||
j++;
|
||||
continue;
|
||||
}
|
||||
to_shape[i] = shape_to_add.top();
|
||||
shape_to_add.pop();
|
||||
}
|
||||
return to_shape;
|
||||
}
|
||||
} // namespace
|
||||
|
||||
TSUnsqueezeForward::TSUnsqueezeForward() {
|
||||
MATCHER_SCOPE(TSUnsqueezeForward);
|
||||
|
||||
auto transpose_label = wrap_type<Transpose>({any_input(), wrap_type<Constant>()});
|
||||
auto unsqueeze_label = wrap_type<Unsqueeze, Reshape>({transpose_label, wrap_type<Constant>()});
|
||||
|
||||
ov::matcher_pass_callback matcher_pass_callback = [=](pattern::Matcher& m) {
|
||||
const auto& pattern_to_output = m.get_pattern_map();
|
||||
|
||||
auto transpose = pattern_to_output.at(transpose_label);
|
||||
auto unsqueeze = pattern_to_output.at(unsqueeze_label);
|
||||
|
||||
auto transpose_order = as_type_ptr<Constant>(transpose->get_input_node_shared_ptr(1));
|
||||
auto unsqueeze_axes = as_type_ptr<Constant>(unsqueeze->get_input_node_shared_ptr(1));
|
||||
if (!transpose_order || !unsqueeze_axes) {
|
||||
return false;
|
||||
}
|
||||
|
||||
std::vector<size_t> non_negative_axes;
|
||||
if (as_type_ptr<Reshape>(unsqueeze)) {
|
||||
auto success = shape_to_unsqueeze_axes(unsqueeze, unsqueeze_axes, non_negative_axes);
|
||||
if (!success) {
|
||||
return false;
|
||||
}
|
||||
} else {
|
||||
auto rank = unsqueeze->get_output_partial_shape(0).rank();
|
||||
non_negative_axes = normalize_axes(unsqueeze->get_friendly_name(), unsqueeze_axes->cast_vector<int64_t>(), rank);
|
||||
}
|
||||
auto ts_order_values = transpose_order->cast_vector<size_t>();
|
||||
|
||||
/* std::vector<size_t> new_values;
|
||||
new_values.reserve(non_negative_axes.size());
|
||||
for (const auto& axis : non_negative_axes) {
|
||||
new_values.push_back(ts_order_values[axis]);
|
||||
}*/
|
||||
|
||||
ts_order_values = GetOrderBeforeReduction(non_negative_axes, ts_order_values);
|
||||
auto new_transpose_order = Constant::create(transpose_order->get_element_type(),
|
||||
{ts_order_values.size()},
|
||||
ts_order_values);
|
||||
|
||||
/*if (as_type_ptr<Reshape>(unsqueeze)) {
|
||||
new_values = unsqueeze_axes_to_shape(unsqueeze, new_values);
|
||||
}*/
|
||||
auto new_unsqueeze = unsqueeze->clone_with_new_inputs({transpose->input_value(0), unsqueeze->input_value(1)});
|
||||
auto new_transpose = transpose->clone_with_new_inputs({new_unsqueeze, new_transpose_order});
|
||||
|
||||
replace_node(unsqueeze, new_transpose);
|
||||
new_unsqueeze->set_friendly_name(transpose->get_friendly_name());
|
||||
new_transpose->set_friendly_name(unsqueeze->get_friendly_name());
|
||||
UpdateForwardSinkingAbility(new_transpose);
|
||||
register_new_node(new_transpose);
|
||||
copy_runtime_info({transpose, unsqueeze}, {new_transpose, new_unsqueeze});
|
||||
|
||||
return true;
|
||||
};
|
||||
|
||||
auto m = std::make_shared<pattern::Matcher>(unsqueeze_label, matcher_name);
|
||||
register_matcher(m, matcher_pass_callback);
|
||||
}
|
||||
|
||||
TSUnsqueezeBackward::TSUnsqueezeBackward() {
|
||||
MATCHER_SCOPE(TSUnsqueezeBackward);
|
||||
|
||||
auto unsqueeze_label = wrap_type<Unsqueeze, Reshape>({any_input(), wrap_type<Constant>()}, HasSameOutputTransposeNodes);
|
||||
auto transpose_label = wrap_type<Transpose>({unsqueeze_label, wrap_type<Constant>()});
|
||||
|
||||
ov::matcher_pass_callback matcher_pass_callback = [=](pattern::Matcher& m) {
|
||||
const auto& pattern_to_output = m.get_pattern_map();
|
||||
|
||||
auto transpose = pattern_to_output.at(transpose_label);
|
||||
auto unsqueeze = pattern_to_output.at(unsqueeze_label);
|
||||
|
||||
auto transpose_order = std::dynamic_pointer_cast<Constant>(transpose->get_input_node_shared_ptr(1));
|
||||
auto unsqueeze_axes = std::dynamic_pointer_cast<Constant>(unsqueeze->get_input_node_shared_ptr(1));
|
||||
if (!transpose_order || !unsqueeze_axes)
|
||||
return false;
|
||||
|
||||
std::vector<size_t> non_negative_axes;
|
||||
if (as_type_ptr<Reshape>(unsqueeze)) {
|
||||
auto success = shape_to_unsqueeze_axes(unsqueeze, unsqueeze_axes, non_negative_axes);
|
||||
if (!success) {
|
||||
return false;
|
||||
}
|
||||
} else {
|
||||
auto rank = unsqueeze->get_output_partial_shape(0).rank();
|
||||
non_negative_axes = normalize_axes(unsqueeze->get_friendly_name(), unsqueeze_axes->cast_vector<int64_t>(), rank);
|
||||
}
|
||||
|
||||
auto transpose_order_values = transpose_order->cast_vector<size_t>();
|
||||
auto old_transpose_order_values = transpose_order_values;
|
||||
std::vector<size_t> new_values;
|
||||
|
||||
if (non_negative_axes.size() == transpose_order_values.size()) {
|
||||
// input is a scalar, we unsqueeze all dims
|
||||
// it's enough to eliminate such Transpose
|
||||
transpose->output(0).replace(unsqueeze);
|
||||
return true;
|
||||
}
|
||||
|
||||
for (const auto& axis : non_negative_axes) {
|
||||
auto it = std::find(old_transpose_order_values.begin(), old_transpose_order_values.end(), axis);
|
||||
if (it != old_transpose_order_values.end()) {
|
||||
new_values.push_back(it - old_transpose_order_values.begin());
|
||||
}
|
||||
}
|
||||
|
||||
transpose_order_values = GetOrderAfterReduction(new_values, transpose_order_values);
|
||||
auto new_transpose_order = std::make_shared<Constant>(transpose_order->get_element_type(),
|
||||
Shape{transpose_order_values.size()},
|
||||
transpose_order_values);
|
||||
if (as_type_ptr<Reshape>(unsqueeze)) {
|
||||
new_values = unsqueeze_axes_to_shape(unsqueeze, new_values);
|
||||
}
|
||||
auto new_const = Constant::create(unsqueeze_axes->get_element_type(), unsqueeze_axes->get_shape(), new_values);
|
||||
auto new_transpose = transpose->clone_with_new_inputs({unsqueeze->input_value(0), new_transpose_order});
|
||||
auto new_unsqueeze = unsqueeze->clone_with_new_inputs({new_transpose, new_const});
|
||||
|
||||
replace_node(transpose, new_unsqueeze);
|
||||
copy_runtime_info({transpose, unsqueeze}, {new_transpose, new_unsqueeze});
|
||||
UpdateForwardSinkingAbility(new_transpose);
|
||||
new_unsqueeze->set_friendly_name(transpose->get_friendly_name());
|
||||
new_transpose->set_friendly_name(unsqueeze->get_friendly_name());
|
||||
register_new_node(new_transpose);
|
||||
return true;
|
||||
};
|
||||
|
||||
auto m = std::make_shared<pattern::Matcher>(transpose_label, matcher_name);
|
||||
register_matcher(m, matcher_pass_callback);
|
||||
}
|
@ -379,6 +379,53 @@ void RemoveSingleOutputConsumers(const NodePtr& node) {
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<size_t> GetOrderAfterReduction(const std::vector<size_t>& axes_values,
|
||||
const std::vector<size_t>& order_values) {
|
||||
size_t buffer_size = order_values.size() - axes_values.size();
|
||||
std::vector<size_t> aligned_order(buffer_size, 0);
|
||||
std::vector<size_t> values_to_reduce(axes_values);
|
||||
for (size_t i = 0; i < values_to_reduce.size(); ++i) {
|
||||
values_to_reduce[i] = order_values[axes_values[i]];
|
||||
}
|
||||
std::sort(values_to_reduce.begin(), values_to_reduce.end());
|
||||
for (size_t i = 0, j = 0; i < order_values.size(); ++i) {
|
||||
if (std::find(axes_values.begin(), axes_values.end(), i) != axes_values.end()) {
|
||||
continue;
|
||||
}
|
||||
|
||||
auto lb = std::lower_bound(values_to_reduce.begin(), values_to_reduce.end(), order_values[i]);
|
||||
aligned_order[j] = order_values[i] - (lb - values_to_reduce.begin());
|
||||
++j;
|
||||
}
|
||||
return aligned_order;
|
||||
}
|
||||
|
||||
std::vector<size_t> GetOrderBeforeReduction(const std::vector<size_t>& axes_values,
|
||||
const std::vector<size_t>& order_values) {
|
||||
size_t buffer_size = order_values.size() + axes_values.size();
|
||||
std::vector<size_t> aligned_order(buffer_size);
|
||||
|
||||
std::vector<int64_t> cnt_deleted(buffer_size);
|
||||
int64_t cnt = 0;
|
||||
for (int64_t i = 0; i < static_cast<int64_t>(cnt_deleted.size()); ++i) {
|
||||
if (std::find(axes_values.begin(), axes_values.end(), i) != axes_values.end()) {
|
||||
cnt++;
|
||||
}
|
||||
cnt_deleted[i] = i - cnt;
|
||||
}
|
||||
|
||||
for (size_t i = 0, j = 0; i < aligned_order.size(); ++i) {
|
||||
if (std::find(axes_values.begin(), axes_values.end(), i) != axes_values.end()) {
|
||||
aligned_order[i] = i;
|
||||
continue;
|
||||
}
|
||||
|
||||
aligned_order[i] = std::find(cnt_deleted.begin(), cnt_deleted.end(), order_values[j]) - cnt_deleted.begin();
|
||||
++j;
|
||||
}
|
||||
return aligned_order;
|
||||
}
|
||||
|
||||
} // namespace utils
|
||||
} // namespace transpose_sinking
|
||||
} // namespace pass
|
||||
|
@ -11,8 +11,11 @@
|
||||
#include "transformations/transpose_sinking/ts_data_movement.hpp"
|
||||
#include "transformations/transpose_sinking/ts_interpolate.hpp"
|
||||
#include "transformations/transpose_sinking/ts_reduction.hpp"
|
||||
#include "transformations/transpose_sinking/ts_slice.hpp"
|
||||
#include "transformations/transpose_sinking/ts_split.hpp"
|
||||
#include "transformations/transpose_sinking/ts_squeeze.hpp"
|
||||
#include "transformations/transpose_sinking/ts_unary.hpp"
|
||||
#include "transformations/transpose_sinking/ts_unsqueeze.hpp"
|
||||
#include "ts_test_utils.hpp"
|
||||
|
||||
using namespace std;
|
||||
@ -195,6 +198,18 @@ public:
|
||||
FactoryPtr CreateSliceFactory(const std::string& type_name) {
|
||||
return std::make_shared<SliceFactory>(type_name);
|
||||
}
|
||||
|
||||
class ReshapeFactory : public IFactory {
|
||||
public:
|
||||
explicit ReshapeFactory(const std::string& type_name) : IFactory(type_name) {}
|
||||
NodePtr create(const OutputVector& parent_nodes) const override {
|
||||
return std::make_shared<Reshape>(parent_nodes[0], parent_nodes[1], false);
|
||||
}
|
||||
};
|
||||
|
||||
FactoryPtr CreateReshapeFactory(const std::string& type_name) {
|
||||
return std::make_shared<ReshapeFactory>(type_name);
|
||||
}
|
||||
// ----------------------------------------------------------------------------
|
||||
|
||||
#undef CREATE_UNARY_FACTORY
|
||||
@ -229,6 +244,9 @@ FactoryPtr CreateSliceFactory(const std::string& type_name) {
|
||||
|
||||
#undef CREATE_SLICE_FACTORY
|
||||
#define CREATE_SLICE_FACTORY(type_name) CreateSliceFactory(#type_name)
|
||||
|
||||
#undef CREATE_RESHAPE_FACTORY
|
||||
#define CREATE_RESHAPE_FACTORY(type_name) CreateReshapeFactory(#type_name)
|
||||
// ----------------------------------------------------------------------------
|
||||
|
||||
struct Preprocessing {
|
||||
@ -666,7 +684,7 @@ auto test_forward_squeeze = []() {
|
||||
TestCase test_case;
|
||||
|
||||
// Initialize common attributes
|
||||
test_case.transformation = CREATE_PASS_FACTORY(TSReductionForward);
|
||||
test_case.transformation = CREATE_PASS_FACTORY(TSSqueezeForward);
|
||||
test_case.num_main_ops = {1};
|
||||
test_case.inputs_to_main = {
|
||||
parameter(element::f32, {32, 1, 2, 1}),
|
||||
@ -700,7 +718,7 @@ auto test_forward_unsqueeze = []() {
|
||||
TestCase test_case;
|
||||
|
||||
// Initialize common attributes
|
||||
test_case.transformation = CREATE_PASS_FACTORY(TSReductionForward);
|
||||
test_case.transformation = CREATE_PASS_FACTORY(TSUnsqueezeForward);
|
||||
test_case.num_main_ops = {1};
|
||||
test_case.inputs_to_main = {
|
||||
parameter(element::f32, {32, 3, 2, 1}),
|
||||
@ -741,7 +759,7 @@ auto test_forward_slice = []() {
|
||||
TestCase test_case;
|
||||
|
||||
// Initialize common attributes
|
||||
test_case.transformation = CREATE_PASS_FACTORY(TransposeSinkingSliceForward);
|
||||
test_case.transformation = CREATE_PASS_FACTORY(TSSliceForward);
|
||||
test_case.num_main_ops = {1};
|
||||
test_case.inputs_to_main = {
|
||||
parameter(element::f32, {6, 4, 5, 3}),
|
||||
@ -754,7 +772,7 @@ auto test_forward_slice = []() {
|
||||
// Test model description:
|
||||
test_case.model.preprocess_inputs_to_main = {{set_transpose_for}, {{0}}};
|
||||
test_case.model.main_op = {CREATE_SLICE_FACTORY(SliceFactory)};
|
||||
test_case.model.model_template = transpose_sinking::common::create_model;
|
||||
test_case.model.model_template = create_model;
|
||||
|
||||
// Reference model description:
|
||||
auto set_specific_gather_for = [](const vector<size_t>& idxs, const OutputVector& out_vec) -> OutputVector {
|
||||
@ -774,12 +792,78 @@ auto test_forward_slice = []() {
|
||||
test_case.model_ref.preprocess_inputs_to_main = {{set_specific_gather_for}, {{4}}};
|
||||
test_case.model_ref.main_op = {CREATE_SLICE_FACTORY(Slice)};
|
||||
test_case.model_ref.preprocess_outputs_of_main = {{set_transpose_for}, {{0}}};
|
||||
test_case.model_ref.model_template = transpose_sinking::common::create_model;
|
||||
test_case.model_ref.model_template = create_model;
|
||||
|
||||
return wrapper(test_case);
|
||||
};
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(TransposeSinkingCommonSliceForward, TransposeSinkingTestFixture, test_forward_slice());
|
||||
|
||||
auto test_forward_reshape_squeeze = []() {
|
||||
TestCase test_case;
|
||||
|
||||
// Initialize common attributes
|
||||
test_case.transformation = CREATE_PASS_FACTORY(TSSqueezeForward);
|
||||
test_case.num_main_ops = {1};
|
||||
test_case.inputs_to_main = {
|
||||
parameter(element::f32, {6, 1, 5, 1, 4}),
|
||||
constant<int64_t>(element::i32, {3}, {4, 5, 6}),
|
||||
};
|
||||
|
||||
// Test model description:
|
||||
test_case.model.preprocess_inputs_to_main = {{set_transpose_for}, {{0}}};
|
||||
test_case.model.main_op = {CREATE_RESHAPE_FACTORY(Reshape)};
|
||||
test_case.model.model_template = create_model;
|
||||
|
||||
// Reference model description:
|
||||
auto new_constant = [](const vector<size_t>& idxs, const OutputVector& out_vec) -> OutputVector {
|
||||
OutputVector new_out_vec(out_vec.size());
|
||||
new_out_vec[0] = out_vec[0];
|
||||
new_out_vec[1] =
|
||||
make_shared<Constant>(out_vec[1].get_element_type(), out_vec[1].get_shape(), std::vector<int64_t>{6, 5, 4});
|
||||
return new_out_vec;
|
||||
};
|
||||
test_case.model_ref.preprocess_inputs_to_main = {{new_constant}, {{1}}};
|
||||
test_case.model_ref.main_op = {CREATE_RESHAPE_FACTORY(Reshape)};
|
||||
test_case.model_ref.preprocess_outputs_of_main = {{set_transpose_for}, {{0}}};
|
||||
test_case.model_ref.model_template = create_model;
|
||||
|
||||
return wrapper(test_case);
|
||||
};
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(TransposeSinkingCommonReshapeSqueezeForward, TransposeSinkingTestFixture, test_forward_reshape_squeeze());
|
||||
|
||||
auto test_forward_reshape_unsqueeze = []() {
|
||||
TestCase test_case;
|
||||
|
||||
// Initialize common attributes
|
||||
test_case.transformation = CREATE_PASS_FACTORY(TSUnsqueezeForward);
|
||||
test_case.num_main_ops = {1};
|
||||
test_case.inputs_to_main = {
|
||||
parameter(element::f32, {6, 5, 4}),
|
||||
constant<int64_t>(element::i32, {5}, {4, 1, 5, 1, 6}),
|
||||
};
|
||||
|
||||
// Test model description:
|
||||
test_case.model.preprocess_inputs_to_main = {{set_transpose_for}, {{0}}};
|
||||
test_case.model.main_op = {CREATE_RESHAPE_FACTORY(Reshape)};
|
||||
test_case.model.model_template = create_model;
|
||||
|
||||
// Reference model description:
|
||||
auto new_transpose = [](const vector<size_t>& idxs, const OutputVector& out_vec) -> OutputVector {
|
||||
OutputVector new_out_vec(out_vec.size());
|
||||
auto order = make_shared<Constant>(element::i32, Shape{5}, std::vector<int64_t>{4, 1, 2, 3, 0});
|
||||
new_out_vec[0] = make_shared<Transpose>(out_vec[0], order);
|
||||
return new_out_vec;
|
||||
};
|
||||
test_case.model_ref.main_op = {CREATE_RESHAPE_FACTORY(Reshape)};
|
||||
test_case.model_ref.preprocess_outputs_of_main = {{new_transpose}, {{0}}};
|
||||
test_case.model_ref.model_template = create_model;
|
||||
|
||||
return wrapper(test_case);
|
||||
};
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(TransposeSinkingCommonReshapeUnsqueezeForward, TransposeSinkingTestFixture, test_forward_reshape_unsqueeze());
|
||||
// ------------------ BACKWARD --------------------
|
||||
|
||||
auto test_backward_unary = []() {
|
||||
@ -1062,7 +1146,7 @@ auto test_backward_squeeze = []() {
|
||||
TestCase test_case;
|
||||
|
||||
// Initialize common attributes
|
||||
test_case.transformation = CREATE_PASS_FACTORY(TSReductionBackward);
|
||||
test_case.transformation = CREATE_PASS_FACTORY(TSSqueezeBackward);
|
||||
test_case.num_main_ops = {1};
|
||||
test_case.inputs_to_main = {
|
||||
parameter(element::f32, {32, 1, 2, 1}),
|
||||
@ -1095,7 +1179,7 @@ auto test_backward_unsqueeze = []() {
|
||||
TestCase test_case;
|
||||
|
||||
// Initialize common attributes
|
||||
test_case.transformation = CREATE_PASS_FACTORY(TSReductionBackward);
|
||||
test_case.transformation = CREATE_PASS_FACTORY(TSUnsqueezeBackward);
|
||||
test_case.num_main_ops = {1};
|
||||
test_case.inputs_to_main = {
|
||||
parameter(element::f32, {32, 3, 2, 1}),
|
||||
@ -1128,7 +1212,7 @@ auto test_backward_slice = []() {
|
||||
TestCase test_case;
|
||||
|
||||
// Initialize common attributes
|
||||
test_case.transformation = CREATE_PASS_FACTORY(TransposeSinkingSliceBackward);
|
||||
test_case.transformation = CREATE_PASS_FACTORY(TSSliceBackward);
|
||||
test_case.num_main_ops = {1};
|
||||
test_case.inputs_to_main = {
|
||||
parameter(element::f32, {6, 4, 5, 3}),
|
||||
@ -1141,7 +1225,7 @@ auto test_backward_slice = []() {
|
||||
// Test model description:
|
||||
test_case.model.main_op = {CREATE_SLICE_FACTORY(Slice)};
|
||||
test_case.model.preprocess_outputs_of_main = {{set_transpose_for}, {{0}}};
|
||||
test_case.model.model_template = transpose_sinking::common::create_model;
|
||||
test_case.model.model_template = create_model;
|
||||
|
||||
// Reference model description:
|
||||
auto set_specific_gather_for = [](const vector<size_t>& idxs, const OutputVector& out_vec) -> OutputVector {
|
||||
@ -1160,11 +1244,77 @@ auto test_backward_slice = []() {
|
||||
};
|
||||
test_case.model_ref.preprocess_inputs_to_main = {{set_transpose_for, set_specific_gather_for}, {{0}, {4}}};
|
||||
test_case.model_ref.main_op = {CREATE_SLICE_FACTORY(SliceFactory)};
|
||||
test_case.model_ref.model_template = transpose_sinking::common::create_model;
|
||||
test_case.model_ref.model_template = create_model;
|
||||
return wrapper(test_case);
|
||||
};
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(TransposeSinkingCommonSliceBackward, TransposeSinkingTestFixture, test_backward_slice());
|
||||
|
||||
auto test_backward_reshape_squeeze = []() {
|
||||
TestCase test_case;
|
||||
|
||||
// Initialize common attributes
|
||||
test_case.transformation = CREATE_PASS_FACTORY(TSSqueezeBackward);
|
||||
test_case.num_main_ops = {1};
|
||||
test_case.inputs_to_main = {
|
||||
parameter(element::f32, {4, 1, 5, 1, 6}),
|
||||
constant<int64_t>(element::i32, {3}, {4, 5, 6}),
|
||||
};
|
||||
|
||||
// Test model description:
|
||||
test_case.model.main_op = {CREATE_RESHAPE_FACTORY(Reshape)};
|
||||
test_case.model.preprocess_outputs_of_main = {{set_transpose_for}, {{0}}};
|
||||
test_case.model.model_template = create_model;
|
||||
|
||||
// Reference model description:
|
||||
auto new_transpose = [](const vector<size_t>& idxs, const OutputVector& out_vec) -> OutputVector {
|
||||
OutputVector new_out_vec(out_vec.size());
|
||||
auto order = make_shared<Constant>(element::i32, Shape{5}, std::vector<int64_t>{4, 1, 2, 3, 0});
|
||||
new_out_vec[0] = make_shared<Transpose>(out_vec[0], order);
|
||||
new_out_vec[1] = out_vec[1];
|
||||
return new_out_vec;
|
||||
};
|
||||
test_case.model_ref.preprocess_inputs_to_main = {{new_transpose}, {{0}}};
|
||||
test_case.model_ref.main_op = {CREATE_RESHAPE_FACTORY(Reshape)};
|
||||
test_case.model_ref.model_template = create_model;
|
||||
|
||||
return wrapper(test_case);
|
||||
};
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(TransposeSinkingCommonReshapeSqueezeBackward, TransposeSinkingTestFixture, test_backward_reshape_squeeze());
|
||||
|
||||
auto test_backward_reshape_unsqueeze = []() {
|
||||
TestCase test_case;
|
||||
|
||||
// Initialize common attributes
|
||||
test_case.transformation = CREATE_PASS_FACTORY(TSUnsqueezeBackward);
|
||||
test_case.num_main_ops = {1};
|
||||
test_case.inputs_to_main = {
|
||||
parameter(element::f32, {4, 5, 6}),
|
||||
constant<int64_t>(element::i32, {5}, {4, 1, 5, 1, 6}),
|
||||
};
|
||||
|
||||
// Test model description:
|
||||
test_case.model.main_op = {CREATE_RESHAPE_FACTORY(Reshape)};
|
||||
test_case.model.preprocess_outputs_of_main = {{set_transpose_for}, {{0}}};
|
||||
test_case.model.model_template = create_model;
|
||||
|
||||
// Reference model description:
|
||||
auto new_constant = [](const vector<size_t>& idxs, const OutputVector& out_vec) -> OutputVector {
|
||||
OutputVector new_out_vec(out_vec.size());
|
||||
new_out_vec[0] = out_vec[0];
|
||||
new_out_vec[1] =
|
||||
make_shared<Constant>(out_vec[1].get_element_type(), out_vec[1].get_shape(), std::vector<int64_t>{6, 1, 5, 1, 4});
|
||||
return new_out_vec;
|
||||
};
|
||||
test_case.model_ref.preprocess_inputs_to_main = {{set_transpose_for, new_constant}, {{0}, {1}}};
|
||||
test_case.model_ref.main_op = {CREATE_RESHAPE_FACTORY(Reshape)};
|
||||
test_case.model_ref.model_template = create_model;
|
||||
|
||||
return wrapper(test_case);
|
||||
};
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(TransposeSinkingCommonReshapeUnsqueezeBackward, TransposeSinkingTestFixture, test_backward_reshape_unsqueeze());
|
||||
} // namespace common
|
||||
} // namespace testing
|
||||
} // namespace transpose_sinking
|
Loading…
Reference in New Issue
Block a user