Delete redandant node copies in TSSqueeze, TSUnsqueeze and TSReduction transformations (#16753)

* Delete redandant node copies in TSSqueeze, TSUnsqueeze and TSReduction transformations, add new tests

* codestyle

* codestyle
This commit is contained in:
Ivan Tikhonov 2023-04-12 11:30:48 +04:00 committed by GitHub
parent 4bb9222c6e
commit 132dceb146
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
14 changed files with 538 additions and 336 deletions

View File

@ -21,7 +21,7 @@ class TRANSFORMATIONS_API TSReductionBackward;
/**
* @ingroup ie_transformation_common_api
* @brief TransposeReductionForward transformation sinks Transpose through Reduce operations
* @brief TSReductionForward transformation sinks Transpose through Reduce operations
* in the forward direction.
*/
class ov::pass::transpose_sinking::TSReductionForward : public ov::pass::MatcherPass {
@ -32,7 +32,7 @@ public:
/**
* @ingroup ie_transformation_common_api
* @brief TransposeReductionBackward transformation sinks Transpose through Reduce operations
* @brief TSReductionBackward transformation sinks Transpose through Reduce operations
* in the backward direction.
*/
class ov::pass::transpose_sinking::TSReductionBackward : public ov::pass::MatcherPass {

View File

@ -64,7 +64,9 @@ namespace sink_forward {
* @brief Inserts reversed transposed on @args main_node inputs. Removes input transpose specified in @arg
* transpose_input_info
*/
bool UpdateInputTransposes(const std::shared_ptr<ov::Node>& main_node, const TransposeInputsInfo& transpose_input_info);
bool UpdateInputTransposes(const std::shared_ptr<ov::Node>& main_node,
const TransposeInputsInfo& transpose_input_info,
std::vector<size_t> input_indexes = {});
/**
* @brief Removes @arg input node
@ -86,7 +88,7 @@ namespace sink_backward {
*/
ov::NodeVector InsertTransposeBeforeNode(const std::shared_ptr<ov::Node>& main_node,
const std::shared_ptr<ov::opset10::Constant>& transpose_const,
std::vector<int> input_indexes = {});
std::vector<size_t> input_indexes = {});
} // namespace sink_backward
void UpdateForwardSinkingAbility(const std::shared_ptr<ov::Node>&);

View File

@ -13,6 +13,7 @@
#include "openvino/op/util/logical_reduction_keep_dims.hpp"
#include "openvino/opsets/opset10.hpp"
#include "openvino/pass/pattern/op/wrap_type.hpp"
#include "transformations/rt_info/transpose_sinking_attr.hpp"
#include "transformations/transpose_sinking/ts_utils.hpp"
#include "transformations/utils/utils.hpp"
@ -24,9 +25,9 @@ using namespace ov::pass::transpose_sinking::utils;
namespace {
bool get_keep_dims(const std::shared_ptr<Node>& reduction) {
auto arithmetic_reduce = as_type_ptr<ov::op::util::ArithmeticReductionKeepDims>(reduction);
auto logical_reduce = as_type_ptr<ov::op::util::LogicalReductionKeepDims>(reduction);
bool get_keep_dims(const std::shared_ptr<Node>& main_node) {
auto arithmetic_reduce = as_type_ptr<ov::op::util::ArithmeticReductionKeepDims>(main_node);
auto logical_reduce = as_type_ptr<ov::op::util::LogicalReductionKeepDims>(main_node);
bool keep_dims = false; // squeeze/unsqueeze always reduces number of output dimensions
if (logical_reduce)
@ -47,24 +48,22 @@ TSReductionForward::TSReductionForward() {
ov::matcher_pass_callback matcher_pass_callback = [=](pattern::Matcher& m) {
const auto& pattern_to_output = m.get_pattern_map();
auto transpose = pattern_to_output.at(transpose_label);
auto reduction = pattern_to_output.at(reduce_label);
if (transformation_callback(reduction)) {
auto transpose = as_type_ptr<Transpose>(pattern_to_output.at(transpose_label));
auto main_node = pattern_to_output.at(reduce_label);
if (!transpose || transformation_callback(main_node)) {
return false;
}
auto keep_dims = get_keep_dims(reduction);
auto keep_dims = get_keep_dims(main_node);
auto transpose_order = as_type_ptr<Constant>(transpose->get_input_node_shared_ptr(1));
auto reduction_axes = as_type_ptr<Constant>(reduction->get_input_node_shared_ptr(1));
auto reduction_axes = as_type_ptr<Constant>(main_node->get_input_node_shared_ptr(1));
if (!transpose_order || !reduction_axes)
return false;
auto rank = reduction->get_input_partial_shape(0).rank();
auto rank = main_node->get_input_partial_shape(0).rank();
OPENVINO_SUPPRESS_DEPRECATED_START
auto non_negative_axes =
normalize_axes(reduction->get_friendly_name(), reduction_axes->cast_vector<int64_t>(), rank);
normalize_axes(main_node->get_friendly_name(), reduction_axes->cast_vector<int64_t>(), rank);
OPENVINO_SUPPRESS_DEPRECATED_END
auto transpose_order_values = transpose_order->cast_vector<size_t>();
@ -82,16 +81,21 @@ TSReductionForward::TSReductionForward() {
{transpose_order_values.size()},
transpose_order_values);
auto new_const = Constant::create(reduction_axes->get_element_type(), reduction_axes->get_shape(), new_values);
auto new_reduction = reduction->clone_with_new_inputs({transpose->input_value(0), new_const});
auto new_transpose = transpose->clone_with_new_inputs({new_reduction, new_transpose_order});
auto new_const = Constant::create(reduction_axes->get_element_type(), {new_values.size()}, new_values);
main_node->input(1).replace_source_output(new_const);
TransposeInputsInfo transpose_input_info = {transpose, new_transpose_order, 0};
// deletes Transpose from 0 input
auto success = sink_forward::UpdateInputTransposes(main_node, transpose_input_info, {0});
if (!success) {
return false;
}
replace_node(reduction, new_transpose);
new_reduction->set_friendly_name(transpose->get_friendly_name());
new_transpose->set_friendly_name(reduction->get_friendly_name());
UpdateForwardSinkingAbility(new_transpose);
register_new_node(new_transpose);
copy_runtime_info({transpose, reduction}, {new_transpose, new_reduction});
copy_runtime_info(reduction_axes, new_const);
main_node->validate_and_infer_types();
for (auto& new_node : sink_forward::InsertOutputTransposes(main_node, transpose_input_info)) {
register_new_node(new_node);
UpdateForwardSinkingAbility(new_node);
}
return true;
};
@ -105,28 +109,32 @@ TSReductionBackward::TSReductionBackward() {
auto reduce_label = wrap_type<op::util::ArithmeticReductionKeepDims, op::util::LogicalReductionKeepDims>(
{any_input(), wrap_type<Constant>()},
HasSameOutputTransposeNodes);
auto transpose_label = wrap_type<Transpose>({reduce_label, wrap_type<Constant>()});
auto transpose_label =
wrap_type<Transpose>({reduce_label, wrap_type<Constant>()}, [](const Output<Node>& output) -> bool {
return has_static_rank()(output) && is_sinking_node(output);
});
ov::matcher_pass_callback matcher_pass_callback = [=](pattern::Matcher& m) {
const auto& pattern_to_output = m.get_pattern_map();
auto transpose = pattern_to_output.at(transpose_label);
auto reduction = pattern_to_output.at(reduce_label);
if (transformation_callback(reduction)) {
auto main_node = pattern_to_output.at(reduce_label);
if (transformation_callback(main_node)) {
return false;
}
auto keep_dims = get_keep_dims(reduction);
auto keep_dims = get_keep_dims(main_node);
auto transpose_order = as_type_ptr<Constant>(transpose->get_input_node_shared_ptr(1));
auto reduction_axes = as_type_ptr<Constant>(reduction->get_input_node_shared_ptr(1));
auto reduction_axes = as_type_ptr<Constant>(main_node->get_input_node_shared_ptr(1));
if (!transpose_order || !reduction_axes)
return false;
auto rank = reduction->get_input_partial_shape(0).rank();
auto rank = main_node->get_input_partial_shape(0).rank();
OPENVINO_SUPPRESS_DEPRECATED_START
auto non_negative_axes =
normalize_axes(reduction->get_friendly_name(), reduction_axes->cast_vector<int64_t>(), rank);
normalize_axes(main_node->get_friendly_name(), reduction_axes->cast_vector<int64_t>(), rank);
OPENVINO_SUPPRESS_DEPRECATED_END
auto transpose_order_values = transpose_order->cast_vector<size_t>();
if (!keep_dims) {
transpose_order_values = GetOrderBeforeReduction(non_negative_axes, transpose_order_values);
@ -141,16 +149,15 @@ TSReductionBackward::TSReductionBackward() {
new_values.push_back(reversed_order_values[axis]);
}
auto new_const = Constant::create(reduction_axes->get_element_type(), reduction_axes->get_shape(), new_values);
auto new_transpose = transpose->clone_with_new_inputs({reduction->input_value(0), new_transpose_order});
auto new_reduction = reduction->clone_with_new_inputs({new_transpose, new_const});
replace_node(transpose, new_reduction);
copy_runtime_info({transpose, reduction}, {new_transpose, new_reduction});
UpdateForwardSinkingAbility(new_transpose);
new_reduction->set_friendly_name(transpose->get_friendly_name());
new_transpose->set_friendly_name(reduction->get_friendly_name());
register_new_node(new_transpose);
auto new_const = Constant::create(reduction_axes->get_element_type(), {new_values.size()}, new_values);
main_node->input(1).replace_source_output(new_const);
for (auto& new_node : sink_backward::InsertTransposeBeforeNode(main_node, new_transpose_order, {0})) {
register_new_node(new_node);
}
main_node->validate_and_infer_types();
RemoveSingleOutputConsumers(main_node);
SwapNames(transpose, main_node);
copy_runtime_info(reduction_axes, new_const);
return true;
};

View File

@ -10,6 +10,7 @@
#include "itt.hpp"
#include "openvino/core/validation_util.hpp"
#include "openvino/opsets/opset10.hpp"
#include "openvino/pass/pattern/op/or.hpp"
#include "openvino/pass/pattern/op/wrap_type.hpp"
#include "transformations/rt_info/transpose_sinking_attr.hpp"
#include "transformations/transpose_sinking/ts_utils.hpp"
@ -102,38 +103,52 @@ TSSqueezeForward::TSSqueezeForward() {
MATCHER_SCOPE(TSSqueezeForward);
auto transpose_label = wrap_type<Transpose>({any_input(), wrap_type<Constant>()});
auto squeeze_with_1_input = wrap_type<Squeeze>({transpose_label});
auto squeeze_label = wrap_type<Squeeze, Reshape>({transpose_label, wrap_type<Constant>()});
auto pattern = std::make_shared<pattern::op::Or>(OutputVector{squeeze_with_1_input, squeeze_label});
ov::matcher_pass_callback matcher_pass_callback = [=](Matcher& m) {
const auto& pattern_to_output = m.get_pattern_map();
auto transpose = pattern_to_output.at(transpose_label);
auto squeeze = pattern_to_output.at(squeeze_label);
if (transformation_callback(squeeze)) {
auto transpose = as_type_ptr<Transpose>(pattern_to_output.at(transpose_label));
std::shared_ptr<Node> main_node;
if (pattern_to_output.count(squeeze_label)) {
main_node = pattern_to_output.at(squeeze_label);
} else {
main_node = pattern_to_output.at(squeeze_with_1_input);
}
if (!transpose || transformation_callback(main_node)) {
return false;
}
auto transpose_order = as_type_ptr<Constant>(transpose->get_input_node_shared_ptr(1));
auto squeeze_axes = as_type_ptr<Constant>(squeeze->get_input_node_shared_ptr(1));
if (!transpose_order || !squeeze_axes) {
if (!transpose_order) {
return false;
}
std::vector<size_t> non_negative_axes;
if (as_type_ptr<Reshape>(squeeze)) {
auto success = shape_to_squeeze_axes(squeeze, squeeze_axes, non_negative_axes);
if (!success) {
std::shared_ptr<Constant> squeeze_axes;
if (main_node->get_input_size() > 1) {
squeeze_axes = as_type_ptr<Constant>(main_node->get_input_node_shared_ptr(1));
if (!squeeze_axes) {
return false;
}
} else {
auto rank = squeeze->get_input_partial_shape(0).rank();
OPENVINO_SUPPRESS_DEPRECATED_START
non_negative_axes =
normalize_axes(squeeze->get_friendly_name(), squeeze_axes->cast_vector<int64_t>(), rank);
OPENVINO_SUPPRESS_DEPRECATED_END
if (as_type_ptr<Reshape>(main_node)) {
auto success = shape_to_squeeze_axes(main_node, squeeze_axes, non_negative_axes);
if (!success) {
return false;
}
} else {
auto rank = main_node->get_input_partial_shape(0).rank();
OPENVINO_SUPPRESS_DEPRECATED_START
non_negative_axes =
normalize_axes(main_node->get_friendly_name(), squeeze_axes->cast_vector<int64_t>(), rank);
OPENVINO_SUPPRESS_DEPRECATED_END
}
}
// if 2nd input to squeeze is empty then all '1' dims will be deleted.
// if 2nd input to main_node is empty then all '1' dims will be deleted.
if (non_negative_axes.empty()) {
auto input_pshape = transpose->output(0).get_partial_shape();
if (input_pshape.is_dynamic()) {
@ -158,7 +173,7 @@ TSSqueezeForward::TSSqueezeForward() {
{transpose_order_values.size()},
transpose_order_values);
if (as_type_ptr<Reshape>(squeeze)) {
if (as_type_ptr<Reshape>(main_node)) {
std::vector<size_t> to_shape;
auto success = squeeze_axes_to_shape(transpose->input_value(0), new_values, to_shape);
if (!success) {
@ -167,30 +182,39 @@ TSSqueezeForward::TSSqueezeForward() {
new_values = to_shape;
}
auto new_const = Constant::create(squeeze_axes->get_element_type(), {new_values.size()}, new_values);
auto new_squeeze = squeeze->clone_with_new_inputs({transpose->input_value(0), new_const});
auto new_transpose = transpose->clone_with_new_inputs({new_squeeze, new_transpose_order});
if (squeeze_axes) {
auto new_const = Constant::create(squeeze_axes->get_element_type(), {new_values.size()}, new_values);
main_node->input(1).replace_source_output(new_const);
copy_runtime_info(squeeze_axes, new_const);
}
replace_node(squeeze, new_transpose);
new_squeeze->set_friendly_name(transpose->get_friendly_name());
new_transpose->set_friendly_name(squeeze->get_friendly_name());
UpdateForwardSinkingAbility(new_transpose);
register_new_node(new_transpose);
copy_runtime_info({transpose, squeeze}, {new_transpose, new_squeeze});
TransposeInputsInfo transpose_input_info = {transpose, new_transpose_order, 0};
// deletes Transpose from 0 input
auto success = sink_forward::UpdateInputTransposes(main_node, transpose_input_info, {0});
if (!success) {
return false;
}
main_node->validate_and_infer_types();
for (auto& new_node : sink_forward::InsertOutputTransposes(main_node, transpose_input_info)) {
register_new_node(new_node);
UpdateForwardSinkingAbility(new_node);
}
return true;
};
auto m = std::make_shared<pattern::Matcher>(squeeze_label, matcher_name);
auto m = std::make_shared<pattern::Matcher>(pattern, matcher_name);
register_matcher(m, matcher_pass_callback);
}
TSSqueezeBackward::TSSqueezeBackward() {
MATCHER_SCOPE(TSSqueezeBackward);
auto squeeze_with_1_input = wrap_type<Squeeze>({any_input()}, HasSameOutputTransposeNodes);
auto squeeze_label = wrap_type<Squeeze, Reshape>({any_input(), wrap_type<Constant>()}, HasSameOutputTransposeNodes);
auto pattern = std::make_shared<pattern::op::Or>(OutputVector{squeeze_with_1_input, squeeze_label});
auto transpose_label =
wrap_type<Transpose>({squeeze_label, wrap_type<Constant>()}, [](const Output<Node>& output) -> bool {
wrap_type<Transpose>({pattern, wrap_type<Constant>()}, [](const Output<Node>& output) -> bool {
return has_static_rank()(output) && is_sinking_node(output);
});
@ -198,34 +222,47 @@ TSSqueezeBackward::TSSqueezeBackward() {
const auto& pattern_to_output = m.get_pattern_map();
auto transpose = pattern_to_output.at(transpose_label);
auto squeeze = pattern_to_output.at(squeeze_label);
if (transformation_callback(squeeze)) {
std::shared_ptr<Node> main_node;
if (pattern_to_output.count(squeeze_label)) {
main_node = pattern_to_output.at(squeeze_label);
} else {
main_node = pattern_to_output.at(squeeze_with_1_input);
}
if (transformation_callback(main_node)) {
return false;
}
auto transpose_order = as_type_ptr<Constant>(transpose->get_input_node_shared_ptr(1));
auto squeeze_axes = as_type_ptr<Constant>(squeeze->get_input_node_shared_ptr(1));
if (!transpose_order || !squeeze_axes) {
if (!transpose_order) {
return false;
}
std::vector<size_t> non_negative_axes;
if (as_type_ptr<Reshape>(squeeze)) {
auto success = shape_to_squeeze_axes(squeeze, squeeze_axes, non_negative_axes);
if (!success) {
std::shared_ptr<Constant> squeeze_axes;
if (main_node->get_input_size() > 1) {
squeeze_axes = as_type_ptr<Constant>(main_node->get_input_node_shared_ptr(1));
if (!squeeze_axes) {
return false;
}
} else {
auto rank = squeeze->get_input_partial_shape(0).rank();
OPENVINO_SUPPRESS_DEPRECATED_START
non_negative_axes =
normalize_axes(squeeze->get_friendly_name(), squeeze_axes->cast_vector<int64_t>(), rank);
OPENVINO_SUPPRESS_DEPRECATED_END
if (as_type_ptr<Reshape>(main_node)) {
auto success = shape_to_squeeze_axes(main_node, squeeze_axes, non_negative_axes);
if (!success) {
return false;
}
} else {
auto rank = main_node->get_input_partial_shape(0).rank();
OPENVINO_SUPPRESS_DEPRECATED_START
non_negative_axes =
normalize_axes(main_node->get_friendly_name(), squeeze_axes->cast_vector<int64_t>(), rank);
OPENVINO_SUPPRESS_DEPRECATED_END
}
}
bool squeeze_all_dims = false;
if (non_negative_axes.empty()) {
auto input_pshape = squeeze->input_value(0).get_partial_shape();
auto input_pshape = main_node->input_value(0).get_partial_shape();
if (input_pshape.is_dynamic()) {
return false;
}
@ -249,8 +286,8 @@ TSSqueezeBackward::TSSqueezeBackward() {
auto new_transpose_order = Constant::create(transpose_order->get_element_type(),
{transpose_order_values.size()},
transpose_order_values);
auto new_transpose = transpose->clone_with_new_inputs({squeeze->input_value(0), new_transpose_order});
if (as_type_ptr<Reshape>(squeeze)) {
auto new_transpose = transpose->clone_with_new_inputs({main_node->input_value(0), new_transpose_order});
if (as_type_ptr<Reshape>(main_node)) {
std::vector<size_t> to_shape;
auto success = squeeze_axes_to_shape(new_transpose->output(0), new_values, to_shape);
if (!success) {
@ -260,19 +297,19 @@ TSSqueezeBackward::TSSqueezeBackward() {
}
std::shared_ptr<Node> new_squeeze;
if (squeeze_all_dims) {
new_squeeze = squeeze->clone_with_new_inputs({new_transpose, squeeze->input_value(1)});
} else {
auto new_const =
std::make_shared<Constant>(squeeze_axes->get_element_type(), squeeze_axes->get_shape(), new_values);
new_squeeze = squeeze->clone_with_new_inputs({new_transpose, new_const});
if (!squeeze_all_dims) {
auto new_const = Constant::create(squeeze_axes->get_element_type(), {new_values.size()}, new_values);
main_node->input(1).replace_source_output(new_const);
copy_runtime_info(squeeze_axes, new_const);
}
replace_node(transpose, new_squeeze);
copy_runtime_info({transpose, squeeze}, {new_transpose, new_squeeze});
new_squeeze->set_friendly_name(transpose->get_friendly_name());
new_transpose->set_friendly_name(squeeze->get_friendly_name());
register_new_node(new_transpose);
for (auto& new_node : sink_backward::InsertTransposeBeforeNode(main_node, new_transpose_order, {0})) {
register_new_node(new_node);
}
main_node->validate_and_infer_types();
RemoveSingleOutputConsumers(main_node);
SwapNames(transpose, main_node);
return true;
};

View File

@ -108,29 +108,29 @@ TSUnsqueezeForward::TSUnsqueezeForward() {
ov::matcher_pass_callback matcher_pass_callback = [=](pattern::Matcher& m) {
const auto& pattern_to_output = m.get_pattern_map();
auto transpose = pattern_to_output.at(transpose_label);
auto unsqueeze = pattern_to_output.at(unsqueeze_label);
if (transformation_callback(unsqueeze)) {
auto transpose = as_type_ptr<Transpose>(pattern_to_output.at(transpose_label));
auto main_node = pattern_to_output.at(unsqueeze_label);
if (!transpose || transformation_callback(main_node)) {
return false;
}
auto transpose_order = as_type_ptr<Constant>(transpose->get_input_node_shared_ptr(1));
auto unsqueeze_axes = as_type_ptr<Constant>(unsqueeze->get_input_node_shared_ptr(1));
auto unsqueeze_axes = as_type_ptr<Constant>(main_node->get_input_node_shared_ptr(1));
if (!transpose_order || !unsqueeze_axes) {
return false;
}
std::vector<size_t> non_negative_axes;
if (as_type_ptr<Reshape>(unsqueeze)) {
auto success = shape_to_unsqueeze_axes(unsqueeze, unsqueeze_axes, non_negative_axes);
if (as_type_ptr<Reshape>(main_node)) {
auto success = shape_to_unsqueeze_axes(main_node, unsqueeze_axes, non_negative_axes);
if (!success) {
return false;
}
} else {
auto rank = unsqueeze->get_output_partial_shape(0).rank();
auto rank = main_node->get_output_partial_shape(0).rank();
OPENVINO_SUPPRESS_DEPRECATED_START
non_negative_axes =
normalize_axes(unsqueeze->get_friendly_name(), unsqueeze_axes->cast_vector<int64_t>(), rank);
normalize_axes(main_node->get_friendly_name(), unsqueeze_axes->cast_vector<int64_t>(), rank);
OPENVINO_SUPPRESS_DEPRECATED_END
}
auto ts_order_values = transpose_order->cast_vector<size_t>();
@ -140,25 +140,29 @@ TSUnsqueezeForward::TSUnsqueezeForward() {
Constant::create(transpose_order->get_element_type(), {ts_order_values.size()}, ts_order_values);
std::shared_ptr<Node> new_unsqueeze;
if (as_type_ptr<Reshape>(unsqueeze)) {
if (as_type_ptr<Reshape>(main_node)) {
std::vector<size_t> new_values;
auto success = unsqueeze_axes_to_shape(transpose->input_value(0), non_negative_axes, new_values);
if (!success) {
return false;
}
auto new_const = Constant::create(unsqueeze_axes->get_element_type(), {new_values.size()}, new_values);
new_unsqueeze = unsqueeze->clone_with_new_inputs({transpose->input_value(0), new_const});
} else {
new_unsqueeze = unsqueeze->clone_with_new_inputs({transpose->input_value(0), unsqueeze->input_value(1)});
main_node->input(1).replace_source_output(new_const);
copy_runtime_info(unsqueeze_axes, new_const);
}
auto new_transpose = transpose->clone_with_new_inputs({new_unsqueeze, new_transpose_order});
replace_node(unsqueeze, new_transpose);
new_unsqueeze->set_friendly_name(transpose->get_friendly_name());
new_transpose->set_friendly_name(unsqueeze->get_friendly_name());
UpdateForwardSinkingAbility(new_transpose);
register_new_node(new_transpose);
copy_runtime_info({transpose, unsqueeze}, {new_transpose, new_unsqueeze});
TransposeInputsInfo transpose_input_info = {transpose, new_transpose_order, 0};
// deletes Transpose from 0 input
auto success = sink_forward::UpdateInputTransposes(main_node, transpose_input_info, {0});
if (!success) {
return false;
}
main_node->validate_and_infer_types();
for (auto& new_node : sink_forward::InsertOutputTransposes(main_node, transpose_input_info)) {
register_new_node(new_node);
UpdateForwardSinkingAbility(new_node);
}
return true;
};
@ -181,27 +185,27 @@ TSUnsqueezeBackward::TSUnsqueezeBackward() {
const auto& pattern_to_output = m.get_pattern_map();
auto transpose = pattern_to_output.at(transpose_label);
auto unsqueeze = pattern_to_output.at(unsqueeze_label);
if (transformation_callback(unsqueeze)) {
auto main_node = pattern_to_output.at(unsqueeze_label);
if (transformation_callback(main_node)) {
return false;
}
auto transpose_order = std::dynamic_pointer_cast<Constant>(transpose->get_input_node_shared_ptr(1));
auto unsqueeze_axes = std::dynamic_pointer_cast<Constant>(unsqueeze->get_input_node_shared_ptr(1));
auto unsqueeze_axes = std::dynamic_pointer_cast<Constant>(main_node->get_input_node_shared_ptr(1));
if (!transpose_order || !unsqueeze_axes)
return false;
std::vector<size_t> non_negative_axes;
if (as_type_ptr<Reshape>(unsqueeze)) {
auto success = shape_to_unsqueeze_axes(unsqueeze, unsqueeze_axes, non_negative_axes);
if (as_type_ptr<Reshape>(main_node)) {
auto success = shape_to_unsqueeze_axes(main_node, unsqueeze_axes, non_negative_axes);
if (!success) {
return false;
}
} else {
auto rank = unsqueeze->get_output_partial_shape(0).rank();
auto rank = main_node->get_output_partial_shape(0).rank();
OPENVINO_SUPPRESS_DEPRECATED_START
non_negative_axes =
normalize_axes(unsqueeze->get_friendly_name(), unsqueeze_axes->cast_vector<int64_t>(), rank);
normalize_axes(main_node->get_friendly_name(), unsqueeze_axes->cast_vector<int64_t>(), rank);
OPENVINO_SUPPRESS_DEPRECATED_END
}
@ -210,9 +214,9 @@ TSUnsqueezeBackward::TSUnsqueezeBackward() {
std::vector<size_t> new_values;
if (non_negative_axes.size() == transpose_order_values.size()) {
// input is a scalar, we unsqueeze all dims
// input is a scalar, we main_node all dims
// it's enough to eliminate such Transpose
transpose->output(0).replace(unsqueeze);
transpose->output(0).replace(main_node);
return true;
}
@ -228,23 +232,24 @@ TSUnsqueezeBackward::TSUnsqueezeBackward() {
Shape{transpose_order_values.size()},
transpose_order_values);
auto new_transpose = transpose->clone_with_new_inputs({unsqueeze->input_value(0), new_transpose_order});
if (as_type_ptr<Reshape>(unsqueeze)) {
for (auto& new_node : sink_backward::InsertTransposeBeforeNode(main_node, new_transpose_order, {0})) {
register_new_node(new_node);
}
if (as_type_ptr<Reshape>(main_node)) {
std::vector<size_t> to_shape;
auto success = unsqueeze_axes_to_shape(new_transpose->output(0), new_values, to_shape);
auto success = unsqueeze_axes_to_shape(main_node->input_value(0), new_values, to_shape);
if (!success) {
return false;
}
new_values = to_shape;
}
auto new_const = Constant::create(unsqueeze_axes->get_element_type(), unsqueeze_axes->get_shape(), new_values);
auto new_unsqueeze = unsqueeze->clone_with_new_inputs({new_transpose, new_const});
auto new_const = Constant::create(unsqueeze_axes->get_element_type(), {new_values.size()}, new_values);
main_node->input(1).replace_source_output(new_const);
replace_node(transpose, new_unsqueeze);
copy_runtime_info({transpose, unsqueeze}, {new_transpose, new_unsqueeze});
new_unsqueeze->set_friendly_name(transpose->get_friendly_name());
new_transpose->set_friendly_name(unsqueeze->get_friendly_name());
register_new_node(new_transpose);
main_node->validate_and_infer_types();
RemoveSingleOutputConsumers(main_node);
SwapNames(transpose, main_node);
copy_runtime_info(unsqueeze_axes, new_const);
return true;
};

View File

@ -150,7 +150,13 @@ AxisVector AlignTransposeOrder(const Output<Node>& output, const TransposeInputs
return new_transpose_order;
}
bool UpdateInputTransposes(const NodePtr& main_node, const TransposeInputsInfo& transpose_input_info) {
bool UpdateInputTransposes(const NodePtr& main_node,
const TransposeInputsInfo& transpose_input_info,
std::vector<size_t> input_indexes) {
if (input_indexes.empty()) {
input_indexes.resize(main_node->get_input_size());
std::iota(input_indexes.begin(), input_indexes.end(), 0);
}
if (transpose_input_info.isEmpty() || HasDynamicRankInput(main_node))
return false;
@ -161,7 +167,7 @@ bool UpdateInputTransposes(const NodePtr& main_node, const TransposeInputsInfo&
const size_t transpose_input_index = transpose_input_info.input_idx;
const auto transpose_element_type = transpose_input_info.transpose_const->get_element_type();
for (size_t i = 0; i < main_node->get_input_size(); ++i) {
for (const auto& i : input_indexes) {
auto input_node = main_node->input_value(i);
if (i == transpose_input_index) {
auto transpose_parent = input_node.get_node()->input_value(0);
@ -230,7 +236,7 @@ namespace sink_backward {
NodeVector InsertTransposeBeforeNode(const NodePtr& main_node,
const std::shared_ptr<Constant>& transpose_const,
std::vector<int> input_indexes) {
std::vector<size_t> input_indexes) {
if (input_indexes.empty()) {
input_indexes.resize(main_node->get_input_size());
std::iota(input_indexes.begin(), input_indexes.end(), 0);

View File

@ -11,11 +11,13 @@
#include "openvino/frontend/manager.hpp"
#include "openvino/opsets/opset10.hpp"
#include "openvino/pass/manager.hpp"
#include "ts_test_case.hpp"
#include "ts_test_utils.hpp"
using namespace ov;
using namespace ov::opset10;
using namespace ov::pass::transpose_sinking;
using namespace transpose_sinking::testing;
using namespace transpose_sinking::testing::utils;
namespace {

View File

@ -16,12 +16,14 @@
#include "transformations/transpose_sinking/ts_squeeze.hpp"
#include "transformations/transpose_sinking/ts_unary.hpp"
#include "transformations/transpose_sinking/ts_unsqueeze.hpp"
#include "ts_test_case.hpp"
#include "ts_test_utils.hpp"
using namespace std;
using namespace ov;
using namespace ov::opset10;
using namespace ov::pass::transpose_sinking;
using namespace transpose_sinking::testing;
using namespace transpose_sinking::testing::utils;
namespace transpose_sinking {
@ -214,6 +216,7 @@ public:
FactoryPtr CreateReshapeFactory(const std::string& type_name) {
return std::make_shared<ReshapeFactory>(type_name);
}
// ----------------------------------------------------------------------------
#undef CREATE_UNARY_FACTORY
@ -251,72 +254,9 @@ FactoryPtr CreateReshapeFactory(const std::string& type_name) {
#undef CREATE_RESHAPE_FACTORY
#define CREATE_RESHAPE_FACTORY(type_name) CreateReshapeFactory(#type_name)
// ----------------------------------------------------------------------------
struct Preprocessing {
vector<function<OutputVector(vector<size_t>, OutputVector)>> preprocessing;
vector<vector<size_t>> indices;
OutputVector apply(const OutputVector& inputs) const {
OutputVector new_inputs = inputs;
for (size_t i = 0; i < preprocessing.size(); ++i) {
new_inputs = preprocessing[i](indices[i], new_inputs);
}
return new_inputs;
}
};
struct TestCase;
struct ModelDescription;
using TestParams = tuple<size_t /* idx num_main_ops */, size_t /* idx main_op */, TestCase>;
using CreateGraphF =
function<shared_ptr<ov::Model>(size_t main_op_idx, const ModelDescription&, size_t, const OutputVector&)>;
// Describes a model to test.
// Expects to be used in such a scenario:
// 1st Preprocessing inserts Transpose/Gather to the inputs
// of the main node.
// Factory contains the rules how to create the main testing node.
// 2nd Preprocessing inserts Transpose/Gather to the outputs
// of the main node.
// model_template is a function which uses the arguments above.
// Examples of the scenarios:
// ModelDescription model: Param -> (Transpose inserted by 1st Preprocessing) -> Abs (main_node) -> Result
// ModelDescription reference: Param -> Abs (main_node) -> (Transpose inserted by 2nd Preprocessing) -> Result
struct ModelDescription {
Preprocessing preprocess_inputs_to_main;
// @parameterized with multiple values
vector<FactoryPtr> main_op;
Preprocessing preprocess_outputs_of_main;
CreateGraphF model_template;
};
struct TestCase {
OutputVector inputs_to_main;
// @parameterized with multiple values
vector<size_t> num_main_ops;
ModelDescription model;
ModelDescription model_ref;
PassFactoryPtr transformation;
};
class TransposeSinkingTestFixture : public ::testing::WithParamInterface<TestParams>, public TransformationTestsF {
public:
static string get_test_name(const ::testing::TestParamInfo<TestParams>& obj) {
size_t num_main_ops_idx;
size_t main_op_idx;
TestCase test_case;
tie(num_main_ops_idx, main_op_idx, test_case) = obj.param;
ostringstream test_name;
test_name << "Factory=" << test_case.model.main_op[main_op_idx]->getTypeName() << "/";
test_name << "NumOps=" << test_case.num_main_ops[num_main_ops_idx] << "/";
test_name << "Transformation=" << test_case.transformation->getTypeName() << "/";
return test_name.str();
}
};
vector<FactoryPtr> unary_factories = {
CREATE_UNARY_FACTORY(Abs), CREATE_UNARY_FACTORY(Acos), CREATE_UNARY_FACTORY(Acosh),
CREATE_UNARY_FACTORY(Asin), CREATE_UNARY_FACTORY(Asinh), CREATE_UNARY_FACTORY(Atan),
@ -355,7 +295,16 @@ std::vector<FactoryPtr> reduction_factories = {
CREATE_REDUCTION_FACTORY(ReduceL2),
};
TEST_P(TransposeSinkingTestFixture, CompareFunctions) {
auto wrapper = [](const TestCase& test_case) {
OPENVINO_ASSERT(test_case.model.main_op.size() == test_case.model_ref.main_op.size(),
"The number of main op (testing op) creator have to be the same for the testing model and for"
"the reference model.");
return ::testing::Combine(::testing::Range<size_t>(0, test_case.num_main_ops.size()),
::testing::Range<size_t>(0, test_case.model.main_op.size()),
::testing::Values(test_case));
};
TEST_P(TSTestFixture, CompareFunctions) {
size_t num_main_ops_idx;
size_t main_op_idx;
TestCase test_case;
@ -387,15 +336,6 @@ shared_ptr<ov::Model> create_model(size_t main_node_idx,
return make_shared<ov::Model>(outputs, filter_parameters(inputs_to_main));
}
auto wrapper = [](const TestCase& test_case) {
OPENVINO_ASSERT(test_case.model.main_op.size() == test_case.model_ref.main_op.size(),
"The number of main op (testing op) creator have to be the same for the testing model and for"
"the reference model.");
return ::testing::Combine(::testing::Range<size_t>(0, test_case.num_main_ops.size()),
::testing::Range<size_t>(0, test_case.model.main_op.size()),
::testing::Values(test_case));
};
auto test_forward_unary = [](const vector<FactoryPtr>& factories, const vector<size_t>& num_main_ops) {
TestCase test_case;
@ -420,10 +360,10 @@ auto test_forward_unary = [](const vector<FactoryPtr>& factories, const vector<s
};
INSTANTIATE_TEST_SUITE_P(TransposeSinkingCommonUnaryForward,
TransposeSinkingTestFixture,
TSTestFixture,
test_forward_unary(unary_factories, {1, 10}));
INSTANTIATE_TEST_SUITE_P(TransposeSinkingCommonLogicalUnaryForward,
TransposeSinkingTestFixture,
TSTestFixture,
test_forward_unary(logical_unary_factories, {1}));
auto test_forward_binary = []() {
@ -451,7 +391,7 @@ auto test_forward_binary = []() {
return wrapper(test_case);
};
INSTANTIATE_TEST_SUITE_P(TransposeSinkingCommonBinaryForward, TransposeSinkingTestFixture, test_forward_binary());
INSTANTIATE_TEST_SUITE_P(TransposeSinkingCommonBinaryForward, TSTestFixture, test_forward_binary());
auto test_forward_concat = []() {
TestCase test_case;
@ -479,7 +419,7 @@ auto test_forward_concat = []() {
return wrapper(test_case);
};
INSTANTIATE_TEST_SUITE_P(TransposeSinkingCommonConcatForward, TransposeSinkingTestFixture, test_forward_concat());
INSTANTIATE_TEST_SUITE_P(TransposeSinkingCommonConcatForward, TSTestFixture, test_forward_concat());
auto test_forward_split = []() {
TestCase test_case;
@ -513,7 +453,7 @@ auto test_forward_split = []() {
return wrapper(test_case);
};
INSTANTIATE_TEST_SUITE_P(TransposeSinkingCommonSplitForward, TransposeSinkingTestFixture, test_forward_split());
INSTANTIATE_TEST_SUITE_P(TransposeSinkingCommonSplitForward, TSTestFixture, test_forward_split());
auto test_forward_pad = []() {
TestCase test_case;
@ -541,7 +481,7 @@ auto test_forward_pad = []() {
return wrapper(test_case);
};
INSTANTIATE_TEST_SUITE_P(TransposeSinkingCommonPadForward, TransposeSinkingTestFixture, test_forward_pad());
INSTANTIATE_TEST_SUITE_P(TransposeSinkingCommonPadForward, TSTestFixture, test_forward_pad());
auto test_forward_batch_to_space = []() {
TestCase test_case;
@ -570,9 +510,7 @@ auto test_forward_batch_to_space = []() {
return wrapper(test_case);
};
INSTANTIATE_TEST_SUITE_P(TransposeSinkingCommonBatchToSpaceForward,
TransposeSinkingTestFixture,
test_forward_batch_to_space());
INSTANTIATE_TEST_SUITE_P(TransposeSinkingCommonBatchToSpaceForward, TSTestFixture, test_forward_batch_to_space());
auto test_forward_space_to_batch = []() {
TestCase test_case;
@ -601,9 +539,7 @@ auto test_forward_space_to_batch = []() {
return wrapper(test_case);
};
INSTANTIATE_TEST_SUITE_P(TransposeSinkingCommonSpaceToBatchForward,
TransposeSinkingTestFixture,
test_forward_space_to_batch());
INSTANTIATE_TEST_SUITE_P(TransposeSinkingCommonSpaceToBatchForward, TSTestFixture, test_forward_space_to_batch());
auto test_forward_reduction = []() {
TestCase test_case;
@ -637,7 +573,7 @@ auto test_forward_reduction = []() {
return wrapper(test_case);
};
INSTANTIATE_TEST_SUITE_P(TransposeSinkingCommonReductionForward, TransposeSinkingTestFixture, test_forward_reduction());
INSTANTIATE_TEST_SUITE_P(TransposeSinkingCommonReductionForward, TSTestFixture, test_forward_reduction());
auto test_forward_interpolate = []() {
TestCase test_case;
@ -680,9 +616,7 @@ auto test_forward_interpolate = []() {
return wrapper(test_case);
};
INSTANTIATE_TEST_SUITE_P(TransposeSinkingCommonInterpolateForward,
TransposeSinkingTestFixture,
test_forward_interpolate());
INSTANTIATE_TEST_SUITE_P(TransposeSinkingCommonInterpolateForward, TSTestFixture, test_forward_interpolate());
auto test_forward_squeeze = []() {
TestCase test_case;
@ -716,7 +650,7 @@ auto test_forward_squeeze = []() {
return wrapper(test_case);
};
INSTANTIATE_TEST_SUITE_P(TransposeSinkingCommonSqueezeForward, TransposeSinkingTestFixture, test_forward_squeeze());
INSTANTIATE_TEST_SUITE_P(TransposeSinkingCommonSqueezeForward, TSTestFixture, test_forward_squeeze());
auto test_forward_unsqueeze = []() {
TestCase test_case;
@ -757,7 +691,7 @@ auto test_forward_unsqueeze = []() {
return wrapper(test_case);
};
INSTANTIATE_TEST_SUITE_P(TransposeSinkingCommonUnsqueezeForward, TransposeSinkingTestFixture, test_forward_unsqueeze());
INSTANTIATE_TEST_SUITE_P(TransposeSinkingCommonUnsqueezeForward, TSTestFixture, test_forward_unsqueeze());
auto test_forward_slice = []() {
TestCase test_case;
@ -801,7 +735,7 @@ auto test_forward_slice = []() {
return wrapper(test_case);
};
INSTANTIATE_TEST_SUITE_P(TransposeSinkingCommonSliceForward, TransposeSinkingTestFixture, test_forward_slice());
INSTANTIATE_TEST_SUITE_P(TransposeSinkingCommonSliceForward, TSTestFixture, test_forward_slice());
auto test_forward_reshape_squeeze = []() {
TestCase test_case;
@ -835,9 +769,7 @@ auto test_forward_reshape_squeeze = []() {
return wrapper(test_case);
};
INSTANTIATE_TEST_SUITE_P(TransposeSinkingCommonReshapeSqueezeForward,
TransposeSinkingTestFixture,
test_forward_reshape_squeeze());
INSTANTIATE_TEST_SUITE_P(TransposeSinkingCommonReshapeSqueezeForward, TSTestFixture, test_forward_reshape_squeeze());
auto test_forward_reshape_unsqueeze = []() {
TestCase test_case;
@ -879,8 +811,9 @@ auto test_forward_reshape_unsqueeze = []() {
};
INSTANTIATE_TEST_SUITE_P(TransposeSinkingCommonReshapeUnsqueezeForward,
TransposeSinkingTestFixture,
TSTestFixture,
test_forward_reshape_unsqueeze());
// ------------------ BACKWARD --------------------
auto test_backward_unary = []() {
@ -906,7 +839,7 @@ auto test_backward_unary = []() {
return wrapper(test_case);
};
INSTANTIATE_TEST_SUITE_P(TransposeSinkingCommonUnaryBackward, TransposeSinkingTestFixture, test_backward_unary());
INSTANTIATE_TEST_SUITE_P(TransposeSinkingCommonUnaryBackward, TSTestFixture, test_backward_unary());
auto test_backward_binary = []() {
TestCase test_case;
@ -932,7 +865,7 @@ auto test_backward_binary = []() {
return wrapper(test_case);
};
INSTANTIATE_TEST_SUITE_P(TransposeSinkingCommonBinaryBackward, TransposeSinkingTestFixture, test_backward_binary());
INSTANTIATE_TEST_SUITE_P(TransposeSinkingCommonBinaryBackward, TSTestFixture, test_backward_binary());
auto test_backward_concat = []() {
TestCase test_case;
@ -959,7 +892,7 @@ auto test_backward_concat = []() {
return wrapper(test_case);
};
INSTANTIATE_TEST_SUITE_P(TransposeSinkingCommonConcatBackward, TransposeSinkingTestFixture, test_backward_concat());
INSTANTIATE_TEST_SUITE_P(TransposeSinkingCommonConcatBackward, TSTestFixture, test_backward_concat());
auto test_backward_split = []() {
TestCase test_case;
@ -991,7 +924,7 @@ auto test_backward_split = []() {
return wrapper(test_case);
};
INSTANTIATE_TEST_SUITE_P(TransposeSinkingCommonSplitBackward, TransposeSinkingTestFixture, test_backward_split());
INSTANTIATE_TEST_SUITE_P(TransposeSinkingCommonSplitBackward, TSTestFixture, test_backward_split());
auto test_backward_pad = []() {
TestCase test_case;
@ -1018,7 +951,7 @@ auto test_backward_pad = []() {
return wrapper(test_case);
};
INSTANTIATE_TEST_SUITE_P(TransposeSinkingCommonPadBackward, TransposeSinkingTestFixture, test_backward_pad());
INSTANTIATE_TEST_SUITE_P(TransposeSinkingCommonPadBackward, TSTestFixture, test_backward_pad());
auto test_backward_batch_to_space = []() {
TestCase test_case;
@ -1046,9 +979,7 @@ auto test_backward_batch_to_space = []() {
return wrapper(test_case);
};
INSTANTIATE_TEST_SUITE_P(TransposeSinkingCommonBatchToSpaceBackward,
TransposeSinkingTestFixture,
test_backward_batch_to_space());
INSTANTIATE_TEST_SUITE_P(TransposeSinkingCommonBatchToSpaceBackward, TSTestFixture, test_backward_batch_to_space());
auto test_backward_space_to_batch = []() {
TestCase test_case;
@ -1075,9 +1006,7 @@ auto test_backward_space_to_batch = []() {
return wrapper(test_case);
};
INSTANTIATE_TEST_SUITE_P(TransposeSinkingCommonSpaceToBatchBackward,
TransposeSinkingTestFixture,
test_backward_space_to_batch());
INSTANTIATE_TEST_SUITE_P(TransposeSinkingCommonSpaceToBatchBackward, TSTestFixture, test_backward_space_to_batch());
auto test_backward_reduction = []() {
TestCase test_case;
@ -1110,9 +1039,7 @@ auto test_backward_reduction = []() {
return wrapper(test_case);
};
INSTANTIATE_TEST_SUITE_P(TransposeSinkingCommonReductionBackward,
TransposeSinkingTestFixture,
test_backward_reduction());
INSTANTIATE_TEST_SUITE_P(TransposeSinkingCommonReductionBackward, TSTestFixture, test_backward_reduction());
auto test_backward_interpolate = []() {
TestCase test_case;
@ -1154,42 +1081,7 @@ auto test_backward_interpolate = []() {
return wrapper(test_case);
};
INSTANTIATE_TEST_SUITE_P(TransposeSinkingCommonInterpolateBackward,
TransposeSinkingTestFixture,
test_backward_interpolate());
auto test_backward_squeeze = []() {
TestCase test_case;
// Initialize common attributes
test_case.transformation = CREATE_PASS_FACTORY(TSSqueezeBackward);
test_case.num_main_ops = {1};
test_case.inputs_to_main = {
parameter(element::f32, {32, 1, 2, 1}),
constant<int64_t>(element::i32, {2}, {1, 3}),
};
// Test model description:
test_case.model.main_op = {CREATE_BINARY_FACTORY(Squeeze)};
test_case.model.preprocess_outputs_of_main = {{set_transpose_for}, {{0}}};
test_case.model.model_template = create_model;
// Reference model description:
auto new_transpose = [](const vector<size_t>& idxs, const OutputVector& out_vec) -> OutputVector {
OutputVector new_out_vec(out_vec.size());
auto order = make_shared<Constant>(element::i32, Shape{4}, std::vector<int64_t>{2, 1, 0, 3});
new_out_vec[0] = make_shared<Transpose>(out_vec[0], order);
new_out_vec[1] = out_vec[1];
return new_out_vec;
};
test_case.model_ref.preprocess_inputs_to_main = {{new_transpose}, {{0}}};
test_case.model_ref.main_op = {CREATE_BINARY_FACTORY(Squeeze)};
test_case.model_ref.model_template = create_model;
return wrapper(test_case);
};
INSTANTIATE_TEST_SUITE_P(TransposeSinkingCommonSqueezeBackward, TransposeSinkingTestFixture, test_backward_squeeze());
INSTANTIATE_TEST_SUITE_P(TransposeSinkingCommonInterpolateBackward, TSTestFixture, test_backward_interpolate());
auto test_backward_unsqueeze = []() {
TestCase test_case;
@ -1222,9 +1114,7 @@ auto test_backward_unsqueeze = []() {
return wrapper(test_case);
};
INSTANTIATE_TEST_SUITE_P(TransposeSinkingCommonUnsqueezeBackward,
TransposeSinkingTestFixture,
test_backward_unsqueeze());
INSTANTIATE_TEST_SUITE_P(TransposeSinkingCommonUnsqueezeBackward, TSTestFixture, test_backward_unsqueeze());
auto test_backward_slice = []() {
TestCase test_case;
@ -1266,7 +1156,7 @@ auto test_backward_slice = []() {
return wrapper(test_case);
};
INSTANTIATE_TEST_SUITE_P(TransposeSinkingCommonSliceBackward, TransposeSinkingTestFixture, test_backward_slice());
INSTANTIATE_TEST_SUITE_P(TransposeSinkingCommonSliceBackward, TSTestFixture, test_backward_slice());
auto test_backward_reshape_squeeze = []() {
TestCase test_case;
@ -1306,9 +1196,7 @@ auto test_backward_reshape_squeeze = []() {
return wrapper(test_case);
};
INSTANTIATE_TEST_SUITE_P(TransposeSinkingCommonReshapeSqueezeBackward,
TransposeSinkingTestFixture,
test_backward_reshape_squeeze());
INSTANTIATE_TEST_SUITE_P(TransposeSinkingCommonReshapeSqueezeBackward, TSTestFixture, test_backward_reshape_squeeze());
auto test_backward_reshape_unsqueeze = []() {
TestCase test_case;
@ -1343,8 +1231,9 @@ auto test_backward_reshape_unsqueeze = []() {
};
INSTANTIATE_TEST_SUITE_P(TransposeSinkingCommonReshapeUnsqueezeBackward,
TransposeSinkingTestFixture,
TSTestFixture,
test_backward_reshape_unsqueeze());
} // namespace common
} // namespace testing
} // namespace transpose_sinking

View File

@ -13,11 +13,13 @@
#include "openvino/pass/manager.hpp"
#include "transformations/init_node_info.hpp"
#include "transformations/transpose_sinking/ts_utils.hpp"
#include "ts_test_case.hpp"
#include "ts_test_utils.hpp"
using namespace ov;
using namespace ov::opset10;
using namespace ov::pass::transpose_sinking;
using namespace transpose_sinking::testing;
using namespace transpose_sinking::testing::utils;
namespace {

View File

@ -0,0 +1,160 @@
// Copyright (C) 2022-2023 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include "transformations/transpose_sinking/ts_squeeze.hpp"
#include "gtest/gtest.h"
#include "openvino/opsets/opset10.hpp"
#include "openvino/pass/manager.hpp"
#include "ts_test_case.hpp"
#include "ts_test_utils.hpp"
using namespace std;
using namespace ov;
using namespace ov::element;
using namespace ov::opset10;
using namespace ov::pass::transpose_sinking;
using namespace transpose_sinking::testing::utils;
namespace transpose_sinking {
namespace testing {
namespace squeeze {
class SqueezeFactory : public IFactory {
public:
explicit SqueezeFactory(const std::string& type_name) : IFactory(type_name) {}
NodePtr create(const OutputVector& parent_nodes) const override {
if (parent_nodes.size() == 2) {
return std::make_shared<Squeeze>(parent_nodes[0], parent_nodes[1]);
} else if (parent_nodes.size() == 1) {
return std::make_shared<Squeeze>(parent_nodes[0]);
}
OPENVINO_ASSERT(false, "Unexpected number of inputs to Squeeze operation.");
}
};
FactoryPtr CreateSqueezeFactory(const std::string& type_name) {
return std::make_shared<SqueezeFactory>(type_name);
}
// ----------------------------------------------------------------------------
#undef CREATE_SQUEEZE_FACTORY
#define CREATE_SQUEEZE_FACTORY(type_name) CreateSqueezeFactory(#type_name)
// ----------------------------------------------------------------------------
shared_ptr<ov::Model> create_model(size_t main_node_idx,
const ModelDescription& model_desc,
size_t num_ops,
const OutputVector& inputs_to_main) {
auto new_inputs = model_desc.preprocess_inputs_to_main.apply(inputs_to_main);
auto main_node = create_main_node(new_inputs, num_ops, model_desc.main_op[main_node_idx]);
auto outputs = model_desc.preprocess_outputs_of_main.apply(main_node->outputs());
return make_shared<ov::Model>(outputs, filter_parameters(inputs_to_main));
}
auto wrapper = [](const TestCase& test_case) {
OPENVINO_ASSERT(test_case.model.main_op.size() == test_case.model_ref.main_op.size(),
"The number of main op (testing op) creator have to be the same for the testing model and for"
"the reference model.");
return ::testing::Combine(::testing::Range<size_t>(0, test_case.num_main_ops.size()),
::testing::Range<size_t>(0, test_case.model.main_op.size()),
::testing::Values(test_case));
};
struct SqueezeArguments {
OutputVector inputs_to_main;
Output<Node> new_constant;
};
auto test_forward_squeeze = [](const SqueezeArguments& test_arguments) {
TestCase test_case;
// Initialize common attributes
test_case.transformation = CREATE_PASS_FACTORY(TSSqueezeForward);
test_case.num_main_ops = {1};
test_case.inputs_to_main = test_arguments.inputs_to_main;
// Test model description:
test_case.model.preprocess_inputs_to_main = {{set_transpose_for}, {{0}}};
test_case.model.main_op = {CREATE_SQUEEZE_FACTORY(Squeeze)};
test_case.model.model_template = create_model;
// Reference model description:
auto new_constant = [&](const vector<size_t>& idxs, const OutputVector& out_vec) -> OutputVector {
OutputVector new_out_vec(out_vec.size());
new_out_vec[0] = out_vec[0];
if (out_vec.size() > 1) {
new_out_vec[1] = test_arguments.new_constant;
}
return new_out_vec;
};
test_case.model_ref.preprocess_inputs_to_main = {{new_constant}, {{1}}};
test_case.model_ref.main_op = {CREATE_SQUEEZE_FACTORY(Squeeze)};
test_case.model_ref.preprocess_outputs_of_main = {{set_transpose_for}, {{0}}};
test_case.model_ref.model_template = create_model;
return wrapper(test_case);
};
vector<SqueezeArguments> tests_forward_arguments{
{{parameter(f32, {1, 2}), constant<int64_t>(i32, {1}, {1})}, constant<int64_t>(i32, {1}, {0})},
{{parameter(f32, {1, 2, 1}), constant<int64_t>(i32, {2}, {0, 2})}, constant<int64_t>(i32, {2}, {2, 0})},
{{parameter(f32, {1, 1, 2, 1}), constant<int64_t>(i32, {3}, {0, 2, 3})}, constant<int64_t>(i32, {3}, {3, 1, 0})},
{{constant<int64_t>(i32, {1, 2}, {1}), constant<int64_t>(i32, {1}, {1})}, constant<int64_t>(i32, {1}, {0})},
{{parameter(f32, {1, 2})}},
};
INSTANTIATE_TEST_SUITE_P(TSCommonSqueezeForward_0, TSTestFixture, test_forward_squeeze(tests_forward_arguments[0]));
INSTANTIATE_TEST_SUITE_P(TSCommonSqueezeForward_1, TSTestFixture, test_forward_squeeze(tests_forward_arguments[1]));
INSTANTIATE_TEST_SUITE_P(TSCommonSqueezeForward_2, TSTestFixture, test_forward_squeeze(tests_forward_arguments[2]));
INSTANTIATE_TEST_SUITE_P(TSCommonSqueezeForward_3, TSTestFixture, test_forward_squeeze(tests_forward_arguments[3]));
INSTANTIATE_TEST_SUITE_P(TSCommonSqueezeForward_4, TSTestFixture, test_forward_squeeze(tests_forward_arguments[4]));
auto test_backward_squeeze = [](const SqueezeArguments& test_arguments) {
TestCase test_case;
// Initialize common attributes
test_case.transformation = CREATE_PASS_FACTORY(TSSqueezeBackward);
test_case.num_main_ops = {1};
test_case.inputs_to_main = test_arguments.inputs_to_main;
// Test model description:
test_case.model.main_op = {CREATE_SQUEEZE_FACTORY(Squeeze)};
test_case.model.preprocess_outputs_of_main = {{set_transpose_for}, {{0}}};
test_case.model.model_template = create_model;
// Reference model description:
auto new_transpose = [&](const vector<size_t>& idxs, const OutputVector& out_vec) -> OutputVector {
OutputVector new_out_vec(out_vec.size());
auto order = test_arguments.new_constant;
new_out_vec[0] = make_shared<Transpose>(out_vec[0], order);
if (out_vec.size() > 1) {
new_out_vec[1] = out_vec[1];
}
return new_out_vec;
};
test_case.model_ref.preprocess_inputs_to_main = {{new_transpose}, {{0}}};
test_case.model_ref.main_op = {CREATE_SQUEEZE_FACTORY(Squeeze)};
test_case.model_ref.model_template = create_model;
return wrapper(test_case);
};
vector<SqueezeArguments> tests_backward_arguments{
{{parameter(f32, {1, 2}), constant<int64_t>(i32, {1}, {0})}, constant<int64_t>(i32, {2}, {0, 1})},
{{parameter(f32, {1, 2, 1}), constant<int64_t>(i32, {2}, {0, 2})}, constant<int64_t>(i32, {3}, {0, 1, 2})},
{{parameter(f32, {1, 1, 2, 1}), constant<int64_t>(i32, {3}, {0, 1, 3})}, constant<int64_t>(i32, {4}, {0, 1, 2, 3})},
{{constant<int64_t>(i32, {1, 2}, {1}), constant<int64_t>(i32, {1}, {0})}, constant<int64_t>(i32, {2}, {0, 1})},
{{parameter(f32, {1, 2})}, constant<int64_t>(i32, {2}, {0, 1})},
};
INSTANTIATE_TEST_SUITE_P(TSCommonSqueezeBackward_0, TSTestFixture, test_backward_squeeze(tests_backward_arguments[0]));
INSTANTIATE_TEST_SUITE_P(TSCommonSqueezeBackward_1, TSTestFixture, test_backward_squeeze(tests_backward_arguments[1]));
INSTANTIATE_TEST_SUITE_P(TSCommonSqueezeBackward_2, TSTestFixture, test_backward_squeeze(tests_backward_arguments[2]));
INSTANTIATE_TEST_SUITE_P(TSCommonSqueezeBackward_3, TSTestFixture, test_backward_squeeze(tests_backward_arguments[3]));
INSTANTIATE_TEST_SUITE_P(TSCommonSqueezeBackward_4, TSTestFixture, test_backward_squeeze(tests_backward_arguments[4]));
} // namespace squeeze
} // namespace testing
} // namespace transpose_sinking

View File

@ -0,0 +1,20 @@
// Copyright (C) 2022-2023 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include "ts_test_case.hpp"
using namespace transpose_sinking::testing;
std::string TSTestFixture::get_test_name(const ::testing::TestParamInfo<TestParams>& obj) {
size_t num_main_ops_idx;
size_t main_op_idx;
TestCase test_case;
std::tie(num_main_ops_idx, main_op_idx, test_case) = obj.param;
std::ostringstream test_name;
test_name << "Factory=" << test_case.model.main_op[main_op_idx]->getTypeName() << "/";
test_name << "NumOps=" << test_case.num_main_ops[num_main_ops_idx] << "/";
test_name << "Transformation=" << test_case.transformation->getTypeName() << "/";
return test_name.str();
}

View File

@ -0,0 +1,109 @@
// Copyright (C) 2022-2023 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#pragma once
#include "common_test_utils/ngraph_test_utils.hpp"
#include "gtest/gtest.h"
#include "openvino/frontend/manager.hpp"
#include "openvino/opsets/opset10.hpp"
#include "openvino/pass/manager.hpp"
namespace transpose_sinking {
namespace testing {
struct Preprocessing {
std::vector<std::function<ov::OutputVector(std::vector<size_t>, ov::OutputVector)>> preprocessing;
std::vector<std::vector<size_t>> indices;
ov::OutputVector apply(const ov::OutputVector& inputs) const {
ov::OutputVector new_inputs = inputs;
for (size_t i = 0; i < preprocessing.size(); ++i) {
new_inputs = preprocessing[i](indices[i], new_inputs);
}
return new_inputs;
}
};
class IFactory {
public:
explicit IFactory(const std::string& type_name) : type_name_(type_name) {}
virtual ~IFactory() = default;
virtual std::shared_ptr<ov::Node> create(const ov::OutputVector& parent_nodes) const = 0;
const std::string& getTypeName() const {
return type_name_;
}
private:
const std::string type_name_;
};
using FactoryPtr = std::shared_ptr<IFactory>;
class IPassFactory {
public:
explicit IPassFactory(const std::string& type_name) : type_name_(type_name) {}
virtual ~IPassFactory() = default;
virtual void registerPass(ov::pass::Manager& pass_manager) const = 0;
const std::string& getTypeName() const {
return type_name_;
}
private:
const std::string type_name_;
};
template <typename PassT>
class PassFactory : public IPassFactory {
public:
explicit PassFactory(const std::string& type_name) : IPassFactory(type_name) {}
void registerPass(ov::pass::Manager& pass_manager) const override {
pass_manager.register_pass<PassT>();
}
};
using PassFactoryPtr = std::shared_ptr<IPassFactory>;
#define CREATE_PASS_FACTORY(pass_name) std::make_shared<PassFactory<ov::pass::transpose_sinking::pass_name>>(#pass_name)
struct TestCase;
struct ModelDescription;
using TestParams = std::tuple<size_t /* idx num_main_ops */, size_t /* idx main_op */, TestCase>;
using CreateGraphF = std::function<
std::shared_ptr<ov::Model>(size_t main_op_idx, const ModelDescription&, size_t, const ov::OutputVector&)>;
// Describes a model to test.
// Expects to be used in such a scenario:
// 1st Preprocessing inserts Transpose/Gather to the inputs
// of the main node.
// Factory contains the rules how to create the main testing node.
// 2nd Preprocessing inserts Transpose/Gather to the outputs
// of the main node.
// model_template is a function which uses the arguments above.
// Examples of the scenarios:
// ModelDescription model: Param -> (Transpose inserted by 1st Preprocessing) -> Abs (main_node) -> Result
// ModelDescription reference: Param -> Abs (main_node) -> (Transpose inserted by 2nd Preprocessing) -> Result
struct ModelDescription {
Preprocessing preprocess_inputs_to_main;
// @parameterized with multiple values
std::vector<FactoryPtr> main_op;
Preprocessing preprocess_outputs_of_main;
CreateGraphF model_template;
};
struct TestCase {
ov::OutputVector inputs_to_main;
// @parameterized with multiple values
std::vector<size_t> num_main_ops;
ModelDescription model;
ModelDescription model_ref;
PassFactoryPtr transformation;
};
class TSTestFixture : public ::testing::WithParamInterface<TestParams>, public TransformationTestsF {
public:
static std::string get_test_name(const ::testing::TestParamInfo<TestParams>& obj);
};
} // namespace testing
} // namespace transpose_sinking

View File

@ -79,7 +79,7 @@ std::string to_string(const Shape& shape) {
return result.str();
}
std::shared_ptr<ov::Node> parameter(ov::element::Type el_type, const PartialShape& ps) {
Output<Node> parameter(ov::element::Type el_type, const PartialShape& ps) {
return std::make_shared<Parameter>(el_type, ps);
}

View File

@ -9,6 +9,7 @@
#include "openvino/frontend/manager.hpp"
#include "openvino/opsets/opset10.hpp"
#include "openvino/pass/manager.hpp"
#include "ts_test_case.hpp"
namespace transpose_sinking {
namespace testing {
@ -16,54 +17,16 @@ namespace utils {
using NodePtr = std::shared_ptr<ov::Node>;
class IFactory {
public:
explicit IFactory(const std::string& type_name) : type_name_(type_name) {}
virtual ~IFactory() = default;
virtual NodePtr create(const ov::OutputVector& parent_nodes) const = 0;
const std::string& getTypeName() const {
return type_name_;
}
private:
const std::string type_name_;
};
using FactoryPtr = std::shared_ptr<IFactory>;
class IPassFactory {
public:
explicit IPassFactory(const std::string& type_name) : type_name_(type_name) {}
virtual ~IPassFactory() = default;
virtual void registerPass(ov::pass::Manager& pass_manager) const = 0;
const std::string& getTypeName() const {
return type_name_;
}
private:
const std::string type_name_;
};
template <typename PassT>
class PassFactory : public IPassFactory {
public:
explicit PassFactory(const std::string& type_name) : IPassFactory(type_name) {}
void registerPass(ov::pass::Manager& pass_manager) const override {
pass_manager.register_pass<PassT>();
}
};
using PassFactoryPtr = std::shared_ptr<IPassFactory>;
#define CREATE_PASS_FACTORY(pass_name) std::make_shared<PassFactory<ov::pass::transpose_sinking::pass_name>>(#pass_name)
std::string to_string(const ov::Shape& shape);
ov::ParameterVector filter_parameters(const ov::OutputVector& out_vec);
ov::OutputVector set_transpose_for(const std::vector<size_t>& idxs, const ov::OutputVector& out_vec);
ov::OutputVector set_gather_for(const std::vector<size_t>& idxs, const ov::OutputVector& out_vec);
std::shared_ptr<ov::Node> create_main_node(const ov::OutputVector& inputs, size_t num_ops, const FactoryPtr& creator);
ov::ParameterVector filter_parameters(const ov::OutputVector& out_vec);
std::shared_ptr<ov::Node> parameter(ov::element::Type el_type, const ov::PartialShape& ps);
ov::Output<ov::Node> parameter(ov::element::Type el_type, const ov::PartialShape& ps);
template <class T>
std::shared_ptr<ov::Node> constant(ov::element::Type el_type, const ov::Shape& shape, const std::vector<T>& value) {
ov::Output<ov::Node> constant(ov::element::Type el_type, const ov::Shape& shape, const std::vector<T>& value) {
return ov::opset10::Constant::create<T>(el_type, shape, value);
}