Delete redandant node copies in TSSqueeze, TSUnsqueeze and TSReduction transformations (#16753)
* Delete redandant node copies in TSSqueeze, TSUnsqueeze and TSReduction transformations, add new tests * codestyle * codestyle
This commit is contained in:
parent
4bb9222c6e
commit
132dceb146
@ -21,7 +21,7 @@ class TRANSFORMATIONS_API TSReductionBackward;
|
||||
|
||||
/**
|
||||
* @ingroup ie_transformation_common_api
|
||||
* @brief TransposeReductionForward transformation sinks Transpose through Reduce operations
|
||||
* @brief TSReductionForward transformation sinks Transpose through Reduce operations
|
||||
* in the forward direction.
|
||||
*/
|
||||
class ov::pass::transpose_sinking::TSReductionForward : public ov::pass::MatcherPass {
|
||||
@ -32,7 +32,7 @@ public:
|
||||
|
||||
/**
|
||||
* @ingroup ie_transformation_common_api
|
||||
* @brief TransposeReductionBackward transformation sinks Transpose through Reduce operations
|
||||
* @brief TSReductionBackward transformation sinks Transpose through Reduce operations
|
||||
* in the backward direction.
|
||||
*/
|
||||
class ov::pass::transpose_sinking::TSReductionBackward : public ov::pass::MatcherPass {
|
||||
|
@ -64,7 +64,9 @@ namespace sink_forward {
|
||||
* @brief Inserts reversed transposed on @args main_node inputs. Removes input transpose specified in @arg
|
||||
* transpose_input_info
|
||||
*/
|
||||
bool UpdateInputTransposes(const std::shared_ptr<ov::Node>& main_node, const TransposeInputsInfo& transpose_input_info);
|
||||
bool UpdateInputTransposes(const std::shared_ptr<ov::Node>& main_node,
|
||||
const TransposeInputsInfo& transpose_input_info,
|
||||
std::vector<size_t> input_indexes = {});
|
||||
|
||||
/**
|
||||
* @brief Removes @arg input node
|
||||
@ -86,7 +88,7 @@ namespace sink_backward {
|
||||
*/
|
||||
ov::NodeVector InsertTransposeBeforeNode(const std::shared_ptr<ov::Node>& main_node,
|
||||
const std::shared_ptr<ov::opset10::Constant>& transpose_const,
|
||||
std::vector<int> input_indexes = {});
|
||||
std::vector<size_t> input_indexes = {});
|
||||
} // namespace sink_backward
|
||||
|
||||
void UpdateForwardSinkingAbility(const std::shared_ptr<ov::Node>&);
|
||||
|
@ -13,6 +13,7 @@
|
||||
#include "openvino/op/util/logical_reduction_keep_dims.hpp"
|
||||
#include "openvino/opsets/opset10.hpp"
|
||||
#include "openvino/pass/pattern/op/wrap_type.hpp"
|
||||
#include "transformations/rt_info/transpose_sinking_attr.hpp"
|
||||
#include "transformations/transpose_sinking/ts_utils.hpp"
|
||||
#include "transformations/utils/utils.hpp"
|
||||
|
||||
@ -24,9 +25,9 @@ using namespace ov::pass::transpose_sinking::utils;
|
||||
|
||||
namespace {
|
||||
|
||||
bool get_keep_dims(const std::shared_ptr<Node>& reduction) {
|
||||
auto arithmetic_reduce = as_type_ptr<ov::op::util::ArithmeticReductionKeepDims>(reduction);
|
||||
auto logical_reduce = as_type_ptr<ov::op::util::LogicalReductionKeepDims>(reduction);
|
||||
bool get_keep_dims(const std::shared_ptr<Node>& main_node) {
|
||||
auto arithmetic_reduce = as_type_ptr<ov::op::util::ArithmeticReductionKeepDims>(main_node);
|
||||
auto logical_reduce = as_type_ptr<ov::op::util::LogicalReductionKeepDims>(main_node);
|
||||
|
||||
bool keep_dims = false; // squeeze/unsqueeze always reduces number of output dimensions
|
||||
if (logical_reduce)
|
||||
@ -47,24 +48,22 @@ TSReductionForward::TSReductionForward() {
|
||||
|
||||
ov::matcher_pass_callback matcher_pass_callback = [=](pattern::Matcher& m) {
|
||||
const auto& pattern_to_output = m.get_pattern_map();
|
||||
|
||||
auto transpose = pattern_to_output.at(transpose_label);
|
||||
auto reduction = pattern_to_output.at(reduce_label);
|
||||
if (transformation_callback(reduction)) {
|
||||
auto transpose = as_type_ptr<Transpose>(pattern_to_output.at(transpose_label));
|
||||
auto main_node = pattern_to_output.at(reduce_label);
|
||||
if (!transpose || transformation_callback(main_node)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
auto keep_dims = get_keep_dims(reduction);
|
||||
|
||||
auto keep_dims = get_keep_dims(main_node);
|
||||
auto transpose_order = as_type_ptr<Constant>(transpose->get_input_node_shared_ptr(1));
|
||||
auto reduction_axes = as_type_ptr<Constant>(reduction->get_input_node_shared_ptr(1));
|
||||
auto reduction_axes = as_type_ptr<Constant>(main_node->get_input_node_shared_ptr(1));
|
||||
if (!transpose_order || !reduction_axes)
|
||||
return false;
|
||||
|
||||
auto rank = reduction->get_input_partial_shape(0).rank();
|
||||
auto rank = main_node->get_input_partial_shape(0).rank();
|
||||
OPENVINO_SUPPRESS_DEPRECATED_START
|
||||
auto non_negative_axes =
|
||||
normalize_axes(reduction->get_friendly_name(), reduction_axes->cast_vector<int64_t>(), rank);
|
||||
normalize_axes(main_node->get_friendly_name(), reduction_axes->cast_vector<int64_t>(), rank);
|
||||
OPENVINO_SUPPRESS_DEPRECATED_END
|
||||
|
||||
auto transpose_order_values = transpose_order->cast_vector<size_t>();
|
||||
@ -82,16 +81,21 @@ TSReductionForward::TSReductionForward() {
|
||||
{transpose_order_values.size()},
|
||||
transpose_order_values);
|
||||
|
||||
auto new_const = Constant::create(reduction_axes->get_element_type(), reduction_axes->get_shape(), new_values);
|
||||
auto new_reduction = reduction->clone_with_new_inputs({transpose->input_value(0), new_const});
|
||||
auto new_transpose = transpose->clone_with_new_inputs({new_reduction, new_transpose_order});
|
||||
auto new_const = Constant::create(reduction_axes->get_element_type(), {new_values.size()}, new_values);
|
||||
main_node->input(1).replace_source_output(new_const);
|
||||
TransposeInputsInfo transpose_input_info = {transpose, new_transpose_order, 0};
|
||||
// deletes Transpose from 0 input
|
||||
auto success = sink_forward::UpdateInputTransposes(main_node, transpose_input_info, {0});
|
||||
if (!success) {
|
||||
return false;
|
||||
}
|
||||
|
||||
replace_node(reduction, new_transpose);
|
||||
new_reduction->set_friendly_name(transpose->get_friendly_name());
|
||||
new_transpose->set_friendly_name(reduction->get_friendly_name());
|
||||
UpdateForwardSinkingAbility(new_transpose);
|
||||
register_new_node(new_transpose);
|
||||
copy_runtime_info({transpose, reduction}, {new_transpose, new_reduction});
|
||||
copy_runtime_info(reduction_axes, new_const);
|
||||
main_node->validate_and_infer_types();
|
||||
for (auto& new_node : sink_forward::InsertOutputTransposes(main_node, transpose_input_info)) {
|
||||
register_new_node(new_node);
|
||||
UpdateForwardSinkingAbility(new_node);
|
||||
}
|
||||
return true;
|
||||
};
|
||||
|
||||
@ -105,28 +109,32 @@ TSReductionBackward::TSReductionBackward() {
|
||||
auto reduce_label = wrap_type<op::util::ArithmeticReductionKeepDims, op::util::LogicalReductionKeepDims>(
|
||||
{any_input(), wrap_type<Constant>()},
|
||||
HasSameOutputTransposeNodes);
|
||||
auto transpose_label = wrap_type<Transpose>({reduce_label, wrap_type<Constant>()});
|
||||
auto transpose_label =
|
||||
wrap_type<Transpose>({reduce_label, wrap_type<Constant>()}, [](const Output<Node>& output) -> bool {
|
||||
return has_static_rank()(output) && is_sinking_node(output);
|
||||
});
|
||||
|
||||
ov::matcher_pass_callback matcher_pass_callback = [=](pattern::Matcher& m) {
|
||||
const auto& pattern_to_output = m.get_pattern_map();
|
||||
auto transpose = pattern_to_output.at(transpose_label);
|
||||
auto reduction = pattern_to_output.at(reduce_label);
|
||||
if (transformation_callback(reduction)) {
|
||||
auto main_node = pattern_to_output.at(reduce_label);
|
||||
if (transformation_callback(main_node)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
auto keep_dims = get_keep_dims(reduction);
|
||||
auto keep_dims = get_keep_dims(main_node);
|
||||
|
||||
auto transpose_order = as_type_ptr<Constant>(transpose->get_input_node_shared_ptr(1));
|
||||
auto reduction_axes = as_type_ptr<Constant>(reduction->get_input_node_shared_ptr(1));
|
||||
auto reduction_axes = as_type_ptr<Constant>(main_node->get_input_node_shared_ptr(1));
|
||||
if (!transpose_order || !reduction_axes)
|
||||
return false;
|
||||
|
||||
auto rank = reduction->get_input_partial_shape(0).rank();
|
||||
auto rank = main_node->get_input_partial_shape(0).rank();
|
||||
OPENVINO_SUPPRESS_DEPRECATED_START
|
||||
auto non_negative_axes =
|
||||
normalize_axes(reduction->get_friendly_name(), reduction_axes->cast_vector<int64_t>(), rank);
|
||||
normalize_axes(main_node->get_friendly_name(), reduction_axes->cast_vector<int64_t>(), rank);
|
||||
OPENVINO_SUPPRESS_DEPRECATED_END
|
||||
|
||||
auto transpose_order_values = transpose_order->cast_vector<size_t>();
|
||||
if (!keep_dims) {
|
||||
transpose_order_values = GetOrderBeforeReduction(non_negative_axes, transpose_order_values);
|
||||
@ -141,16 +149,15 @@ TSReductionBackward::TSReductionBackward() {
|
||||
new_values.push_back(reversed_order_values[axis]);
|
||||
}
|
||||
|
||||
auto new_const = Constant::create(reduction_axes->get_element_type(), reduction_axes->get_shape(), new_values);
|
||||
auto new_transpose = transpose->clone_with_new_inputs({reduction->input_value(0), new_transpose_order});
|
||||
auto new_reduction = reduction->clone_with_new_inputs({new_transpose, new_const});
|
||||
|
||||
replace_node(transpose, new_reduction);
|
||||
copy_runtime_info({transpose, reduction}, {new_transpose, new_reduction});
|
||||
UpdateForwardSinkingAbility(new_transpose);
|
||||
new_reduction->set_friendly_name(transpose->get_friendly_name());
|
||||
new_transpose->set_friendly_name(reduction->get_friendly_name());
|
||||
register_new_node(new_transpose);
|
||||
auto new_const = Constant::create(reduction_axes->get_element_type(), {new_values.size()}, new_values);
|
||||
main_node->input(1).replace_source_output(new_const);
|
||||
for (auto& new_node : sink_backward::InsertTransposeBeforeNode(main_node, new_transpose_order, {0})) {
|
||||
register_new_node(new_node);
|
||||
}
|
||||
main_node->validate_and_infer_types();
|
||||
RemoveSingleOutputConsumers(main_node);
|
||||
SwapNames(transpose, main_node);
|
||||
copy_runtime_info(reduction_axes, new_const);
|
||||
return true;
|
||||
};
|
||||
|
||||
|
@ -10,6 +10,7 @@
|
||||
#include "itt.hpp"
|
||||
#include "openvino/core/validation_util.hpp"
|
||||
#include "openvino/opsets/opset10.hpp"
|
||||
#include "openvino/pass/pattern/op/or.hpp"
|
||||
#include "openvino/pass/pattern/op/wrap_type.hpp"
|
||||
#include "transformations/rt_info/transpose_sinking_attr.hpp"
|
||||
#include "transformations/transpose_sinking/ts_utils.hpp"
|
||||
@ -102,38 +103,52 @@ TSSqueezeForward::TSSqueezeForward() {
|
||||
MATCHER_SCOPE(TSSqueezeForward);
|
||||
|
||||
auto transpose_label = wrap_type<Transpose>({any_input(), wrap_type<Constant>()});
|
||||
auto squeeze_with_1_input = wrap_type<Squeeze>({transpose_label});
|
||||
auto squeeze_label = wrap_type<Squeeze, Reshape>({transpose_label, wrap_type<Constant>()});
|
||||
auto pattern = std::make_shared<pattern::op::Or>(OutputVector{squeeze_with_1_input, squeeze_label});
|
||||
|
||||
ov::matcher_pass_callback matcher_pass_callback = [=](Matcher& m) {
|
||||
const auto& pattern_to_output = m.get_pattern_map();
|
||||
|
||||
auto transpose = pattern_to_output.at(transpose_label);
|
||||
auto squeeze = pattern_to_output.at(squeeze_label);
|
||||
if (transformation_callback(squeeze)) {
|
||||
auto transpose = as_type_ptr<Transpose>(pattern_to_output.at(transpose_label));
|
||||
std::shared_ptr<Node> main_node;
|
||||
if (pattern_to_output.count(squeeze_label)) {
|
||||
main_node = pattern_to_output.at(squeeze_label);
|
||||
} else {
|
||||
main_node = pattern_to_output.at(squeeze_with_1_input);
|
||||
}
|
||||
if (!transpose || transformation_callback(main_node)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
auto transpose_order = as_type_ptr<Constant>(transpose->get_input_node_shared_ptr(1));
|
||||
auto squeeze_axes = as_type_ptr<Constant>(squeeze->get_input_node_shared_ptr(1));
|
||||
if (!transpose_order || !squeeze_axes) {
|
||||
|
||||
if (!transpose_order) {
|
||||
return false;
|
||||
}
|
||||
|
||||
std::vector<size_t> non_negative_axes;
|
||||
if (as_type_ptr<Reshape>(squeeze)) {
|
||||
auto success = shape_to_squeeze_axes(squeeze, squeeze_axes, non_negative_axes);
|
||||
if (!success) {
|
||||
std::shared_ptr<Constant> squeeze_axes;
|
||||
if (main_node->get_input_size() > 1) {
|
||||
squeeze_axes = as_type_ptr<Constant>(main_node->get_input_node_shared_ptr(1));
|
||||
if (!squeeze_axes) {
|
||||
return false;
|
||||
}
|
||||
} else {
|
||||
auto rank = squeeze->get_input_partial_shape(0).rank();
|
||||
OPENVINO_SUPPRESS_DEPRECATED_START
|
||||
non_negative_axes =
|
||||
normalize_axes(squeeze->get_friendly_name(), squeeze_axes->cast_vector<int64_t>(), rank);
|
||||
OPENVINO_SUPPRESS_DEPRECATED_END
|
||||
if (as_type_ptr<Reshape>(main_node)) {
|
||||
auto success = shape_to_squeeze_axes(main_node, squeeze_axes, non_negative_axes);
|
||||
if (!success) {
|
||||
return false;
|
||||
}
|
||||
} else {
|
||||
auto rank = main_node->get_input_partial_shape(0).rank();
|
||||
OPENVINO_SUPPRESS_DEPRECATED_START
|
||||
non_negative_axes =
|
||||
normalize_axes(main_node->get_friendly_name(), squeeze_axes->cast_vector<int64_t>(), rank);
|
||||
OPENVINO_SUPPRESS_DEPRECATED_END
|
||||
}
|
||||
}
|
||||
|
||||
// if 2nd input to squeeze is empty then all '1' dims will be deleted.
|
||||
// if 2nd input to main_node is empty then all '1' dims will be deleted.
|
||||
if (non_negative_axes.empty()) {
|
||||
auto input_pshape = transpose->output(0).get_partial_shape();
|
||||
if (input_pshape.is_dynamic()) {
|
||||
@ -158,7 +173,7 @@ TSSqueezeForward::TSSqueezeForward() {
|
||||
{transpose_order_values.size()},
|
||||
transpose_order_values);
|
||||
|
||||
if (as_type_ptr<Reshape>(squeeze)) {
|
||||
if (as_type_ptr<Reshape>(main_node)) {
|
||||
std::vector<size_t> to_shape;
|
||||
auto success = squeeze_axes_to_shape(transpose->input_value(0), new_values, to_shape);
|
||||
if (!success) {
|
||||
@ -167,30 +182,39 @@ TSSqueezeForward::TSSqueezeForward() {
|
||||
new_values = to_shape;
|
||||
}
|
||||
|
||||
auto new_const = Constant::create(squeeze_axes->get_element_type(), {new_values.size()}, new_values);
|
||||
auto new_squeeze = squeeze->clone_with_new_inputs({transpose->input_value(0), new_const});
|
||||
auto new_transpose = transpose->clone_with_new_inputs({new_squeeze, new_transpose_order});
|
||||
if (squeeze_axes) {
|
||||
auto new_const = Constant::create(squeeze_axes->get_element_type(), {new_values.size()}, new_values);
|
||||
main_node->input(1).replace_source_output(new_const);
|
||||
copy_runtime_info(squeeze_axes, new_const);
|
||||
}
|
||||
|
||||
replace_node(squeeze, new_transpose);
|
||||
new_squeeze->set_friendly_name(transpose->get_friendly_name());
|
||||
new_transpose->set_friendly_name(squeeze->get_friendly_name());
|
||||
UpdateForwardSinkingAbility(new_transpose);
|
||||
register_new_node(new_transpose);
|
||||
copy_runtime_info({transpose, squeeze}, {new_transpose, new_squeeze});
|
||||
TransposeInputsInfo transpose_input_info = {transpose, new_transpose_order, 0};
|
||||
// deletes Transpose from 0 input
|
||||
auto success = sink_forward::UpdateInputTransposes(main_node, transpose_input_info, {0});
|
||||
if (!success) {
|
||||
return false;
|
||||
}
|
||||
|
||||
main_node->validate_and_infer_types();
|
||||
for (auto& new_node : sink_forward::InsertOutputTransposes(main_node, transpose_input_info)) {
|
||||
register_new_node(new_node);
|
||||
UpdateForwardSinkingAbility(new_node);
|
||||
}
|
||||
|
||||
return true;
|
||||
};
|
||||
|
||||
auto m = std::make_shared<pattern::Matcher>(squeeze_label, matcher_name);
|
||||
auto m = std::make_shared<pattern::Matcher>(pattern, matcher_name);
|
||||
register_matcher(m, matcher_pass_callback);
|
||||
}
|
||||
|
||||
TSSqueezeBackward::TSSqueezeBackward() {
|
||||
MATCHER_SCOPE(TSSqueezeBackward);
|
||||
|
||||
auto squeeze_with_1_input = wrap_type<Squeeze>({any_input()}, HasSameOutputTransposeNodes);
|
||||
auto squeeze_label = wrap_type<Squeeze, Reshape>({any_input(), wrap_type<Constant>()}, HasSameOutputTransposeNodes);
|
||||
auto pattern = std::make_shared<pattern::op::Or>(OutputVector{squeeze_with_1_input, squeeze_label});
|
||||
auto transpose_label =
|
||||
wrap_type<Transpose>({squeeze_label, wrap_type<Constant>()}, [](const Output<Node>& output) -> bool {
|
||||
wrap_type<Transpose>({pattern, wrap_type<Constant>()}, [](const Output<Node>& output) -> bool {
|
||||
return has_static_rank()(output) && is_sinking_node(output);
|
||||
});
|
||||
|
||||
@ -198,34 +222,47 @@ TSSqueezeBackward::TSSqueezeBackward() {
|
||||
const auto& pattern_to_output = m.get_pattern_map();
|
||||
|
||||
auto transpose = pattern_to_output.at(transpose_label);
|
||||
auto squeeze = pattern_to_output.at(squeeze_label);
|
||||
if (transformation_callback(squeeze)) {
|
||||
std::shared_ptr<Node> main_node;
|
||||
if (pattern_to_output.count(squeeze_label)) {
|
||||
main_node = pattern_to_output.at(squeeze_label);
|
||||
} else {
|
||||
main_node = pattern_to_output.at(squeeze_with_1_input);
|
||||
}
|
||||
|
||||
if (transformation_callback(main_node)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
auto transpose_order = as_type_ptr<Constant>(transpose->get_input_node_shared_ptr(1));
|
||||
auto squeeze_axes = as_type_ptr<Constant>(squeeze->get_input_node_shared_ptr(1));
|
||||
if (!transpose_order || !squeeze_axes) {
|
||||
|
||||
if (!transpose_order) {
|
||||
return false;
|
||||
}
|
||||
|
||||
std::vector<size_t> non_negative_axes;
|
||||
if (as_type_ptr<Reshape>(squeeze)) {
|
||||
auto success = shape_to_squeeze_axes(squeeze, squeeze_axes, non_negative_axes);
|
||||
if (!success) {
|
||||
std::shared_ptr<Constant> squeeze_axes;
|
||||
if (main_node->get_input_size() > 1) {
|
||||
squeeze_axes = as_type_ptr<Constant>(main_node->get_input_node_shared_ptr(1));
|
||||
if (!squeeze_axes) {
|
||||
return false;
|
||||
}
|
||||
} else {
|
||||
auto rank = squeeze->get_input_partial_shape(0).rank();
|
||||
OPENVINO_SUPPRESS_DEPRECATED_START
|
||||
non_negative_axes =
|
||||
normalize_axes(squeeze->get_friendly_name(), squeeze_axes->cast_vector<int64_t>(), rank);
|
||||
OPENVINO_SUPPRESS_DEPRECATED_END
|
||||
if (as_type_ptr<Reshape>(main_node)) {
|
||||
auto success = shape_to_squeeze_axes(main_node, squeeze_axes, non_negative_axes);
|
||||
if (!success) {
|
||||
return false;
|
||||
}
|
||||
} else {
|
||||
auto rank = main_node->get_input_partial_shape(0).rank();
|
||||
OPENVINO_SUPPRESS_DEPRECATED_START
|
||||
non_negative_axes =
|
||||
normalize_axes(main_node->get_friendly_name(), squeeze_axes->cast_vector<int64_t>(), rank);
|
||||
OPENVINO_SUPPRESS_DEPRECATED_END
|
||||
}
|
||||
}
|
||||
|
||||
bool squeeze_all_dims = false;
|
||||
if (non_negative_axes.empty()) {
|
||||
auto input_pshape = squeeze->input_value(0).get_partial_shape();
|
||||
auto input_pshape = main_node->input_value(0).get_partial_shape();
|
||||
if (input_pshape.is_dynamic()) {
|
||||
return false;
|
||||
}
|
||||
@ -249,8 +286,8 @@ TSSqueezeBackward::TSSqueezeBackward() {
|
||||
auto new_transpose_order = Constant::create(transpose_order->get_element_type(),
|
||||
{transpose_order_values.size()},
|
||||
transpose_order_values);
|
||||
auto new_transpose = transpose->clone_with_new_inputs({squeeze->input_value(0), new_transpose_order});
|
||||
if (as_type_ptr<Reshape>(squeeze)) {
|
||||
auto new_transpose = transpose->clone_with_new_inputs({main_node->input_value(0), new_transpose_order});
|
||||
if (as_type_ptr<Reshape>(main_node)) {
|
||||
std::vector<size_t> to_shape;
|
||||
auto success = squeeze_axes_to_shape(new_transpose->output(0), new_values, to_shape);
|
||||
if (!success) {
|
||||
@ -260,19 +297,19 @@ TSSqueezeBackward::TSSqueezeBackward() {
|
||||
}
|
||||
|
||||
std::shared_ptr<Node> new_squeeze;
|
||||
if (squeeze_all_dims) {
|
||||
new_squeeze = squeeze->clone_with_new_inputs({new_transpose, squeeze->input_value(1)});
|
||||
} else {
|
||||
auto new_const =
|
||||
std::make_shared<Constant>(squeeze_axes->get_element_type(), squeeze_axes->get_shape(), new_values);
|
||||
new_squeeze = squeeze->clone_with_new_inputs({new_transpose, new_const});
|
||||
if (!squeeze_all_dims) {
|
||||
auto new_const = Constant::create(squeeze_axes->get_element_type(), {new_values.size()}, new_values);
|
||||
main_node->input(1).replace_source_output(new_const);
|
||||
copy_runtime_info(squeeze_axes, new_const);
|
||||
}
|
||||
|
||||
replace_node(transpose, new_squeeze);
|
||||
copy_runtime_info({transpose, squeeze}, {new_transpose, new_squeeze});
|
||||
new_squeeze->set_friendly_name(transpose->get_friendly_name());
|
||||
new_transpose->set_friendly_name(squeeze->get_friendly_name());
|
||||
register_new_node(new_transpose);
|
||||
for (auto& new_node : sink_backward::InsertTransposeBeforeNode(main_node, new_transpose_order, {0})) {
|
||||
register_new_node(new_node);
|
||||
}
|
||||
main_node->validate_and_infer_types();
|
||||
RemoveSingleOutputConsumers(main_node);
|
||||
SwapNames(transpose, main_node);
|
||||
|
||||
return true;
|
||||
};
|
||||
|
||||
|
@ -108,29 +108,29 @@ TSUnsqueezeForward::TSUnsqueezeForward() {
|
||||
ov::matcher_pass_callback matcher_pass_callback = [=](pattern::Matcher& m) {
|
||||
const auto& pattern_to_output = m.get_pattern_map();
|
||||
|
||||
auto transpose = pattern_to_output.at(transpose_label);
|
||||
auto unsqueeze = pattern_to_output.at(unsqueeze_label);
|
||||
if (transformation_callback(unsqueeze)) {
|
||||
auto transpose = as_type_ptr<Transpose>(pattern_to_output.at(transpose_label));
|
||||
auto main_node = pattern_to_output.at(unsqueeze_label);
|
||||
if (!transpose || transformation_callback(main_node)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
auto transpose_order = as_type_ptr<Constant>(transpose->get_input_node_shared_ptr(1));
|
||||
auto unsqueeze_axes = as_type_ptr<Constant>(unsqueeze->get_input_node_shared_ptr(1));
|
||||
auto unsqueeze_axes = as_type_ptr<Constant>(main_node->get_input_node_shared_ptr(1));
|
||||
if (!transpose_order || !unsqueeze_axes) {
|
||||
return false;
|
||||
}
|
||||
|
||||
std::vector<size_t> non_negative_axes;
|
||||
if (as_type_ptr<Reshape>(unsqueeze)) {
|
||||
auto success = shape_to_unsqueeze_axes(unsqueeze, unsqueeze_axes, non_negative_axes);
|
||||
if (as_type_ptr<Reshape>(main_node)) {
|
||||
auto success = shape_to_unsqueeze_axes(main_node, unsqueeze_axes, non_negative_axes);
|
||||
if (!success) {
|
||||
return false;
|
||||
}
|
||||
} else {
|
||||
auto rank = unsqueeze->get_output_partial_shape(0).rank();
|
||||
auto rank = main_node->get_output_partial_shape(0).rank();
|
||||
OPENVINO_SUPPRESS_DEPRECATED_START
|
||||
non_negative_axes =
|
||||
normalize_axes(unsqueeze->get_friendly_name(), unsqueeze_axes->cast_vector<int64_t>(), rank);
|
||||
normalize_axes(main_node->get_friendly_name(), unsqueeze_axes->cast_vector<int64_t>(), rank);
|
||||
OPENVINO_SUPPRESS_DEPRECATED_END
|
||||
}
|
||||
auto ts_order_values = transpose_order->cast_vector<size_t>();
|
||||
@ -140,25 +140,29 @@ TSUnsqueezeForward::TSUnsqueezeForward() {
|
||||
Constant::create(transpose_order->get_element_type(), {ts_order_values.size()}, ts_order_values);
|
||||
|
||||
std::shared_ptr<Node> new_unsqueeze;
|
||||
if (as_type_ptr<Reshape>(unsqueeze)) {
|
||||
if (as_type_ptr<Reshape>(main_node)) {
|
||||
std::vector<size_t> new_values;
|
||||
auto success = unsqueeze_axes_to_shape(transpose->input_value(0), non_negative_axes, new_values);
|
||||
if (!success) {
|
||||
return false;
|
||||
}
|
||||
auto new_const = Constant::create(unsqueeze_axes->get_element_type(), {new_values.size()}, new_values);
|
||||
new_unsqueeze = unsqueeze->clone_with_new_inputs({transpose->input_value(0), new_const});
|
||||
} else {
|
||||
new_unsqueeze = unsqueeze->clone_with_new_inputs({transpose->input_value(0), unsqueeze->input_value(1)});
|
||||
main_node->input(1).replace_source_output(new_const);
|
||||
copy_runtime_info(unsqueeze_axes, new_const);
|
||||
}
|
||||
auto new_transpose = transpose->clone_with_new_inputs({new_unsqueeze, new_transpose_order});
|
||||
|
||||
replace_node(unsqueeze, new_transpose);
|
||||
new_unsqueeze->set_friendly_name(transpose->get_friendly_name());
|
||||
new_transpose->set_friendly_name(unsqueeze->get_friendly_name());
|
||||
UpdateForwardSinkingAbility(new_transpose);
|
||||
register_new_node(new_transpose);
|
||||
copy_runtime_info({transpose, unsqueeze}, {new_transpose, new_unsqueeze});
|
||||
TransposeInputsInfo transpose_input_info = {transpose, new_transpose_order, 0};
|
||||
// deletes Transpose from 0 input
|
||||
auto success = sink_forward::UpdateInputTransposes(main_node, transpose_input_info, {0});
|
||||
if (!success) {
|
||||
return false;
|
||||
}
|
||||
|
||||
main_node->validate_and_infer_types();
|
||||
for (auto& new_node : sink_forward::InsertOutputTransposes(main_node, transpose_input_info)) {
|
||||
register_new_node(new_node);
|
||||
UpdateForwardSinkingAbility(new_node);
|
||||
}
|
||||
|
||||
return true;
|
||||
};
|
||||
@ -181,27 +185,27 @@ TSUnsqueezeBackward::TSUnsqueezeBackward() {
|
||||
const auto& pattern_to_output = m.get_pattern_map();
|
||||
|
||||
auto transpose = pattern_to_output.at(transpose_label);
|
||||
auto unsqueeze = pattern_to_output.at(unsqueeze_label);
|
||||
if (transformation_callback(unsqueeze)) {
|
||||
auto main_node = pattern_to_output.at(unsqueeze_label);
|
||||
if (transformation_callback(main_node)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
auto transpose_order = std::dynamic_pointer_cast<Constant>(transpose->get_input_node_shared_ptr(1));
|
||||
auto unsqueeze_axes = std::dynamic_pointer_cast<Constant>(unsqueeze->get_input_node_shared_ptr(1));
|
||||
auto unsqueeze_axes = std::dynamic_pointer_cast<Constant>(main_node->get_input_node_shared_ptr(1));
|
||||
if (!transpose_order || !unsqueeze_axes)
|
||||
return false;
|
||||
|
||||
std::vector<size_t> non_negative_axes;
|
||||
if (as_type_ptr<Reshape>(unsqueeze)) {
|
||||
auto success = shape_to_unsqueeze_axes(unsqueeze, unsqueeze_axes, non_negative_axes);
|
||||
if (as_type_ptr<Reshape>(main_node)) {
|
||||
auto success = shape_to_unsqueeze_axes(main_node, unsqueeze_axes, non_negative_axes);
|
||||
if (!success) {
|
||||
return false;
|
||||
}
|
||||
} else {
|
||||
auto rank = unsqueeze->get_output_partial_shape(0).rank();
|
||||
auto rank = main_node->get_output_partial_shape(0).rank();
|
||||
OPENVINO_SUPPRESS_DEPRECATED_START
|
||||
non_negative_axes =
|
||||
normalize_axes(unsqueeze->get_friendly_name(), unsqueeze_axes->cast_vector<int64_t>(), rank);
|
||||
normalize_axes(main_node->get_friendly_name(), unsqueeze_axes->cast_vector<int64_t>(), rank);
|
||||
OPENVINO_SUPPRESS_DEPRECATED_END
|
||||
}
|
||||
|
||||
@ -210,9 +214,9 @@ TSUnsqueezeBackward::TSUnsqueezeBackward() {
|
||||
std::vector<size_t> new_values;
|
||||
|
||||
if (non_negative_axes.size() == transpose_order_values.size()) {
|
||||
// input is a scalar, we unsqueeze all dims
|
||||
// input is a scalar, we main_node all dims
|
||||
// it's enough to eliminate such Transpose
|
||||
transpose->output(0).replace(unsqueeze);
|
||||
transpose->output(0).replace(main_node);
|
||||
return true;
|
||||
}
|
||||
|
||||
@ -228,23 +232,24 @@ TSUnsqueezeBackward::TSUnsqueezeBackward() {
|
||||
Shape{transpose_order_values.size()},
|
||||
transpose_order_values);
|
||||
|
||||
auto new_transpose = transpose->clone_with_new_inputs({unsqueeze->input_value(0), new_transpose_order});
|
||||
if (as_type_ptr<Reshape>(unsqueeze)) {
|
||||
for (auto& new_node : sink_backward::InsertTransposeBeforeNode(main_node, new_transpose_order, {0})) {
|
||||
register_new_node(new_node);
|
||||
}
|
||||
if (as_type_ptr<Reshape>(main_node)) {
|
||||
std::vector<size_t> to_shape;
|
||||
auto success = unsqueeze_axes_to_shape(new_transpose->output(0), new_values, to_shape);
|
||||
auto success = unsqueeze_axes_to_shape(main_node->input_value(0), new_values, to_shape);
|
||||
if (!success) {
|
||||
return false;
|
||||
}
|
||||
new_values = to_shape;
|
||||
}
|
||||
auto new_const = Constant::create(unsqueeze_axes->get_element_type(), unsqueeze_axes->get_shape(), new_values);
|
||||
auto new_unsqueeze = unsqueeze->clone_with_new_inputs({new_transpose, new_const});
|
||||
auto new_const = Constant::create(unsqueeze_axes->get_element_type(), {new_values.size()}, new_values);
|
||||
main_node->input(1).replace_source_output(new_const);
|
||||
|
||||
replace_node(transpose, new_unsqueeze);
|
||||
copy_runtime_info({transpose, unsqueeze}, {new_transpose, new_unsqueeze});
|
||||
new_unsqueeze->set_friendly_name(transpose->get_friendly_name());
|
||||
new_transpose->set_friendly_name(unsqueeze->get_friendly_name());
|
||||
register_new_node(new_transpose);
|
||||
main_node->validate_and_infer_types();
|
||||
RemoveSingleOutputConsumers(main_node);
|
||||
SwapNames(transpose, main_node);
|
||||
copy_runtime_info(unsqueeze_axes, new_const);
|
||||
return true;
|
||||
};
|
||||
|
||||
|
@ -150,7 +150,13 @@ AxisVector AlignTransposeOrder(const Output<Node>& output, const TransposeInputs
|
||||
return new_transpose_order;
|
||||
}
|
||||
|
||||
bool UpdateInputTransposes(const NodePtr& main_node, const TransposeInputsInfo& transpose_input_info) {
|
||||
bool UpdateInputTransposes(const NodePtr& main_node,
|
||||
const TransposeInputsInfo& transpose_input_info,
|
||||
std::vector<size_t> input_indexes) {
|
||||
if (input_indexes.empty()) {
|
||||
input_indexes.resize(main_node->get_input_size());
|
||||
std::iota(input_indexes.begin(), input_indexes.end(), 0);
|
||||
}
|
||||
if (transpose_input_info.isEmpty() || HasDynamicRankInput(main_node))
|
||||
return false;
|
||||
|
||||
@ -161,7 +167,7 @@ bool UpdateInputTransposes(const NodePtr& main_node, const TransposeInputsInfo&
|
||||
const size_t transpose_input_index = transpose_input_info.input_idx;
|
||||
const auto transpose_element_type = transpose_input_info.transpose_const->get_element_type();
|
||||
|
||||
for (size_t i = 0; i < main_node->get_input_size(); ++i) {
|
||||
for (const auto& i : input_indexes) {
|
||||
auto input_node = main_node->input_value(i);
|
||||
if (i == transpose_input_index) {
|
||||
auto transpose_parent = input_node.get_node()->input_value(0);
|
||||
@ -230,7 +236,7 @@ namespace sink_backward {
|
||||
|
||||
NodeVector InsertTransposeBeforeNode(const NodePtr& main_node,
|
||||
const std::shared_ptr<Constant>& transpose_const,
|
||||
std::vector<int> input_indexes) {
|
||||
std::vector<size_t> input_indexes) {
|
||||
if (input_indexes.empty()) {
|
||||
input_indexes.resize(main_node->get_input_size());
|
||||
std::iota(input_indexes.begin(), input_indexes.end(), 0);
|
||||
|
@ -11,11 +11,13 @@
|
||||
#include "openvino/frontend/manager.hpp"
|
||||
#include "openvino/opsets/opset10.hpp"
|
||||
#include "openvino/pass/manager.hpp"
|
||||
#include "ts_test_case.hpp"
|
||||
#include "ts_test_utils.hpp"
|
||||
|
||||
using namespace ov;
|
||||
using namespace ov::opset10;
|
||||
using namespace ov::pass::transpose_sinking;
|
||||
using namespace transpose_sinking::testing;
|
||||
using namespace transpose_sinking::testing::utils;
|
||||
|
||||
namespace {
|
||||
|
@ -16,12 +16,14 @@
|
||||
#include "transformations/transpose_sinking/ts_squeeze.hpp"
|
||||
#include "transformations/transpose_sinking/ts_unary.hpp"
|
||||
#include "transformations/transpose_sinking/ts_unsqueeze.hpp"
|
||||
#include "ts_test_case.hpp"
|
||||
#include "ts_test_utils.hpp"
|
||||
|
||||
using namespace std;
|
||||
using namespace ov;
|
||||
using namespace ov::opset10;
|
||||
using namespace ov::pass::transpose_sinking;
|
||||
using namespace transpose_sinking::testing;
|
||||
using namespace transpose_sinking::testing::utils;
|
||||
|
||||
namespace transpose_sinking {
|
||||
@ -214,6 +216,7 @@ public:
|
||||
FactoryPtr CreateReshapeFactory(const std::string& type_name) {
|
||||
return std::make_shared<ReshapeFactory>(type_name);
|
||||
}
|
||||
|
||||
// ----------------------------------------------------------------------------
|
||||
|
||||
#undef CREATE_UNARY_FACTORY
|
||||
@ -251,72 +254,9 @@ FactoryPtr CreateReshapeFactory(const std::string& type_name) {
|
||||
|
||||
#undef CREATE_RESHAPE_FACTORY
|
||||
#define CREATE_RESHAPE_FACTORY(type_name) CreateReshapeFactory(#type_name)
|
||||
|
||||
// ----------------------------------------------------------------------------
|
||||
|
||||
struct Preprocessing {
|
||||
vector<function<OutputVector(vector<size_t>, OutputVector)>> preprocessing;
|
||||
vector<vector<size_t>> indices;
|
||||
|
||||
OutputVector apply(const OutputVector& inputs) const {
|
||||
OutputVector new_inputs = inputs;
|
||||
for (size_t i = 0; i < preprocessing.size(); ++i) {
|
||||
new_inputs = preprocessing[i](indices[i], new_inputs);
|
||||
}
|
||||
return new_inputs;
|
||||
}
|
||||
};
|
||||
|
||||
struct TestCase;
|
||||
struct ModelDescription;
|
||||
using TestParams = tuple<size_t /* idx num_main_ops */, size_t /* idx main_op */, TestCase>;
|
||||
using CreateGraphF =
|
||||
function<shared_ptr<ov::Model>(size_t main_op_idx, const ModelDescription&, size_t, const OutputVector&)>;
|
||||
|
||||
// Describes a model to test.
|
||||
// Expects to be used in such a scenario:
|
||||
// 1st Preprocessing inserts Transpose/Gather to the inputs
|
||||
// of the main node.
|
||||
// Factory contains the rules how to create the main testing node.
|
||||
// 2nd Preprocessing inserts Transpose/Gather to the outputs
|
||||
// of the main node.
|
||||
// model_template is a function which uses the arguments above.
|
||||
// Examples of the scenarios:
|
||||
// ModelDescription model: Param -> (Transpose inserted by 1st Preprocessing) -> Abs (main_node) -> Result
|
||||
// ModelDescription reference: Param -> Abs (main_node) -> (Transpose inserted by 2nd Preprocessing) -> Result
|
||||
struct ModelDescription {
|
||||
Preprocessing preprocess_inputs_to_main;
|
||||
// @parameterized with multiple values
|
||||
vector<FactoryPtr> main_op;
|
||||
Preprocessing preprocess_outputs_of_main;
|
||||
CreateGraphF model_template;
|
||||
};
|
||||
|
||||
struct TestCase {
|
||||
OutputVector inputs_to_main;
|
||||
// @parameterized with multiple values
|
||||
vector<size_t> num_main_ops;
|
||||
|
||||
ModelDescription model;
|
||||
ModelDescription model_ref;
|
||||
PassFactoryPtr transformation;
|
||||
};
|
||||
|
||||
class TransposeSinkingTestFixture : public ::testing::WithParamInterface<TestParams>, public TransformationTestsF {
|
||||
public:
|
||||
static string get_test_name(const ::testing::TestParamInfo<TestParams>& obj) {
|
||||
size_t num_main_ops_idx;
|
||||
size_t main_op_idx;
|
||||
TestCase test_case;
|
||||
tie(num_main_ops_idx, main_op_idx, test_case) = obj.param;
|
||||
|
||||
ostringstream test_name;
|
||||
test_name << "Factory=" << test_case.model.main_op[main_op_idx]->getTypeName() << "/";
|
||||
test_name << "NumOps=" << test_case.num_main_ops[num_main_ops_idx] << "/";
|
||||
test_name << "Transformation=" << test_case.transformation->getTypeName() << "/";
|
||||
return test_name.str();
|
||||
}
|
||||
};
|
||||
|
||||
vector<FactoryPtr> unary_factories = {
|
||||
CREATE_UNARY_FACTORY(Abs), CREATE_UNARY_FACTORY(Acos), CREATE_UNARY_FACTORY(Acosh),
|
||||
CREATE_UNARY_FACTORY(Asin), CREATE_UNARY_FACTORY(Asinh), CREATE_UNARY_FACTORY(Atan),
|
||||
@ -355,7 +295,16 @@ std::vector<FactoryPtr> reduction_factories = {
|
||||
CREATE_REDUCTION_FACTORY(ReduceL2),
|
||||
};
|
||||
|
||||
TEST_P(TransposeSinkingTestFixture, CompareFunctions) {
|
||||
auto wrapper = [](const TestCase& test_case) {
|
||||
OPENVINO_ASSERT(test_case.model.main_op.size() == test_case.model_ref.main_op.size(),
|
||||
"The number of main op (testing op) creator have to be the same for the testing model and for"
|
||||
"the reference model.");
|
||||
return ::testing::Combine(::testing::Range<size_t>(0, test_case.num_main_ops.size()),
|
||||
::testing::Range<size_t>(0, test_case.model.main_op.size()),
|
||||
::testing::Values(test_case));
|
||||
};
|
||||
|
||||
TEST_P(TSTestFixture, CompareFunctions) {
|
||||
size_t num_main_ops_idx;
|
||||
size_t main_op_idx;
|
||||
TestCase test_case;
|
||||
@ -387,15 +336,6 @@ shared_ptr<ov::Model> create_model(size_t main_node_idx,
|
||||
return make_shared<ov::Model>(outputs, filter_parameters(inputs_to_main));
|
||||
}
|
||||
|
||||
auto wrapper = [](const TestCase& test_case) {
|
||||
OPENVINO_ASSERT(test_case.model.main_op.size() == test_case.model_ref.main_op.size(),
|
||||
"The number of main op (testing op) creator have to be the same for the testing model and for"
|
||||
"the reference model.");
|
||||
return ::testing::Combine(::testing::Range<size_t>(0, test_case.num_main_ops.size()),
|
||||
::testing::Range<size_t>(0, test_case.model.main_op.size()),
|
||||
::testing::Values(test_case));
|
||||
};
|
||||
|
||||
auto test_forward_unary = [](const vector<FactoryPtr>& factories, const vector<size_t>& num_main_ops) {
|
||||
TestCase test_case;
|
||||
|
||||
@ -420,10 +360,10 @@ auto test_forward_unary = [](const vector<FactoryPtr>& factories, const vector<s
|
||||
};
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(TransposeSinkingCommonUnaryForward,
|
||||
TransposeSinkingTestFixture,
|
||||
TSTestFixture,
|
||||
test_forward_unary(unary_factories, {1, 10}));
|
||||
INSTANTIATE_TEST_SUITE_P(TransposeSinkingCommonLogicalUnaryForward,
|
||||
TransposeSinkingTestFixture,
|
||||
TSTestFixture,
|
||||
test_forward_unary(logical_unary_factories, {1}));
|
||||
|
||||
auto test_forward_binary = []() {
|
||||
@ -451,7 +391,7 @@ auto test_forward_binary = []() {
|
||||
return wrapper(test_case);
|
||||
};
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(TransposeSinkingCommonBinaryForward, TransposeSinkingTestFixture, test_forward_binary());
|
||||
INSTANTIATE_TEST_SUITE_P(TransposeSinkingCommonBinaryForward, TSTestFixture, test_forward_binary());
|
||||
|
||||
auto test_forward_concat = []() {
|
||||
TestCase test_case;
|
||||
@ -479,7 +419,7 @@ auto test_forward_concat = []() {
|
||||
return wrapper(test_case);
|
||||
};
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(TransposeSinkingCommonConcatForward, TransposeSinkingTestFixture, test_forward_concat());
|
||||
INSTANTIATE_TEST_SUITE_P(TransposeSinkingCommonConcatForward, TSTestFixture, test_forward_concat());
|
||||
|
||||
auto test_forward_split = []() {
|
||||
TestCase test_case;
|
||||
@ -513,7 +453,7 @@ auto test_forward_split = []() {
|
||||
return wrapper(test_case);
|
||||
};
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(TransposeSinkingCommonSplitForward, TransposeSinkingTestFixture, test_forward_split());
|
||||
INSTANTIATE_TEST_SUITE_P(TransposeSinkingCommonSplitForward, TSTestFixture, test_forward_split());
|
||||
|
||||
auto test_forward_pad = []() {
|
||||
TestCase test_case;
|
||||
@ -541,7 +481,7 @@ auto test_forward_pad = []() {
|
||||
return wrapper(test_case);
|
||||
};
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(TransposeSinkingCommonPadForward, TransposeSinkingTestFixture, test_forward_pad());
|
||||
INSTANTIATE_TEST_SUITE_P(TransposeSinkingCommonPadForward, TSTestFixture, test_forward_pad());
|
||||
|
||||
auto test_forward_batch_to_space = []() {
|
||||
TestCase test_case;
|
||||
@ -570,9 +510,7 @@ auto test_forward_batch_to_space = []() {
|
||||
return wrapper(test_case);
|
||||
};
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(TransposeSinkingCommonBatchToSpaceForward,
|
||||
TransposeSinkingTestFixture,
|
||||
test_forward_batch_to_space());
|
||||
INSTANTIATE_TEST_SUITE_P(TransposeSinkingCommonBatchToSpaceForward, TSTestFixture, test_forward_batch_to_space());
|
||||
|
||||
auto test_forward_space_to_batch = []() {
|
||||
TestCase test_case;
|
||||
@ -601,9 +539,7 @@ auto test_forward_space_to_batch = []() {
|
||||
return wrapper(test_case);
|
||||
};
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(TransposeSinkingCommonSpaceToBatchForward,
|
||||
TransposeSinkingTestFixture,
|
||||
test_forward_space_to_batch());
|
||||
INSTANTIATE_TEST_SUITE_P(TransposeSinkingCommonSpaceToBatchForward, TSTestFixture, test_forward_space_to_batch());
|
||||
|
||||
auto test_forward_reduction = []() {
|
||||
TestCase test_case;
|
||||
@ -637,7 +573,7 @@ auto test_forward_reduction = []() {
|
||||
return wrapper(test_case);
|
||||
};
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(TransposeSinkingCommonReductionForward, TransposeSinkingTestFixture, test_forward_reduction());
|
||||
INSTANTIATE_TEST_SUITE_P(TransposeSinkingCommonReductionForward, TSTestFixture, test_forward_reduction());
|
||||
|
||||
auto test_forward_interpolate = []() {
|
||||
TestCase test_case;
|
||||
@ -680,9 +616,7 @@ auto test_forward_interpolate = []() {
|
||||
return wrapper(test_case);
|
||||
};
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(TransposeSinkingCommonInterpolateForward,
|
||||
TransposeSinkingTestFixture,
|
||||
test_forward_interpolate());
|
||||
INSTANTIATE_TEST_SUITE_P(TransposeSinkingCommonInterpolateForward, TSTestFixture, test_forward_interpolate());
|
||||
|
||||
auto test_forward_squeeze = []() {
|
||||
TestCase test_case;
|
||||
@ -716,7 +650,7 @@ auto test_forward_squeeze = []() {
|
||||
return wrapper(test_case);
|
||||
};
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(TransposeSinkingCommonSqueezeForward, TransposeSinkingTestFixture, test_forward_squeeze());
|
||||
INSTANTIATE_TEST_SUITE_P(TransposeSinkingCommonSqueezeForward, TSTestFixture, test_forward_squeeze());
|
||||
|
||||
auto test_forward_unsqueeze = []() {
|
||||
TestCase test_case;
|
||||
@ -757,7 +691,7 @@ auto test_forward_unsqueeze = []() {
|
||||
return wrapper(test_case);
|
||||
};
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(TransposeSinkingCommonUnsqueezeForward, TransposeSinkingTestFixture, test_forward_unsqueeze());
|
||||
INSTANTIATE_TEST_SUITE_P(TransposeSinkingCommonUnsqueezeForward, TSTestFixture, test_forward_unsqueeze());
|
||||
|
||||
auto test_forward_slice = []() {
|
||||
TestCase test_case;
|
||||
@ -801,7 +735,7 @@ auto test_forward_slice = []() {
|
||||
return wrapper(test_case);
|
||||
};
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(TransposeSinkingCommonSliceForward, TransposeSinkingTestFixture, test_forward_slice());
|
||||
INSTANTIATE_TEST_SUITE_P(TransposeSinkingCommonSliceForward, TSTestFixture, test_forward_slice());
|
||||
|
||||
auto test_forward_reshape_squeeze = []() {
|
||||
TestCase test_case;
|
||||
@ -835,9 +769,7 @@ auto test_forward_reshape_squeeze = []() {
|
||||
return wrapper(test_case);
|
||||
};
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(TransposeSinkingCommonReshapeSqueezeForward,
|
||||
TransposeSinkingTestFixture,
|
||||
test_forward_reshape_squeeze());
|
||||
INSTANTIATE_TEST_SUITE_P(TransposeSinkingCommonReshapeSqueezeForward, TSTestFixture, test_forward_reshape_squeeze());
|
||||
|
||||
auto test_forward_reshape_unsqueeze = []() {
|
||||
TestCase test_case;
|
||||
@ -879,8 +811,9 @@ auto test_forward_reshape_unsqueeze = []() {
|
||||
};
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(TransposeSinkingCommonReshapeUnsqueezeForward,
|
||||
TransposeSinkingTestFixture,
|
||||
TSTestFixture,
|
||||
test_forward_reshape_unsqueeze());
|
||||
|
||||
// ------------------ BACKWARD --------------------
|
||||
|
||||
auto test_backward_unary = []() {
|
||||
@ -906,7 +839,7 @@ auto test_backward_unary = []() {
|
||||
return wrapper(test_case);
|
||||
};
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(TransposeSinkingCommonUnaryBackward, TransposeSinkingTestFixture, test_backward_unary());
|
||||
INSTANTIATE_TEST_SUITE_P(TransposeSinkingCommonUnaryBackward, TSTestFixture, test_backward_unary());
|
||||
|
||||
auto test_backward_binary = []() {
|
||||
TestCase test_case;
|
||||
@ -932,7 +865,7 @@ auto test_backward_binary = []() {
|
||||
return wrapper(test_case);
|
||||
};
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(TransposeSinkingCommonBinaryBackward, TransposeSinkingTestFixture, test_backward_binary());
|
||||
INSTANTIATE_TEST_SUITE_P(TransposeSinkingCommonBinaryBackward, TSTestFixture, test_backward_binary());
|
||||
|
||||
auto test_backward_concat = []() {
|
||||
TestCase test_case;
|
||||
@ -959,7 +892,7 @@ auto test_backward_concat = []() {
|
||||
return wrapper(test_case);
|
||||
};
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(TransposeSinkingCommonConcatBackward, TransposeSinkingTestFixture, test_backward_concat());
|
||||
INSTANTIATE_TEST_SUITE_P(TransposeSinkingCommonConcatBackward, TSTestFixture, test_backward_concat());
|
||||
|
||||
auto test_backward_split = []() {
|
||||
TestCase test_case;
|
||||
@ -991,7 +924,7 @@ auto test_backward_split = []() {
|
||||
return wrapper(test_case);
|
||||
};
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(TransposeSinkingCommonSplitBackward, TransposeSinkingTestFixture, test_backward_split());
|
||||
INSTANTIATE_TEST_SUITE_P(TransposeSinkingCommonSplitBackward, TSTestFixture, test_backward_split());
|
||||
|
||||
auto test_backward_pad = []() {
|
||||
TestCase test_case;
|
||||
@ -1018,7 +951,7 @@ auto test_backward_pad = []() {
|
||||
return wrapper(test_case);
|
||||
};
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(TransposeSinkingCommonPadBackward, TransposeSinkingTestFixture, test_backward_pad());
|
||||
INSTANTIATE_TEST_SUITE_P(TransposeSinkingCommonPadBackward, TSTestFixture, test_backward_pad());
|
||||
|
||||
auto test_backward_batch_to_space = []() {
|
||||
TestCase test_case;
|
||||
@ -1046,9 +979,7 @@ auto test_backward_batch_to_space = []() {
|
||||
return wrapper(test_case);
|
||||
};
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(TransposeSinkingCommonBatchToSpaceBackward,
|
||||
TransposeSinkingTestFixture,
|
||||
test_backward_batch_to_space());
|
||||
INSTANTIATE_TEST_SUITE_P(TransposeSinkingCommonBatchToSpaceBackward, TSTestFixture, test_backward_batch_to_space());
|
||||
|
||||
auto test_backward_space_to_batch = []() {
|
||||
TestCase test_case;
|
||||
@ -1075,9 +1006,7 @@ auto test_backward_space_to_batch = []() {
|
||||
return wrapper(test_case);
|
||||
};
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(TransposeSinkingCommonSpaceToBatchBackward,
|
||||
TransposeSinkingTestFixture,
|
||||
test_backward_space_to_batch());
|
||||
INSTANTIATE_TEST_SUITE_P(TransposeSinkingCommonSpaceToBatchBackward, TSTestFixture, test_backward_space_to_batch());
|
||||
|
||||
auto test_backward_reduction = []() {
|
||||
TestCase test_case;
|
||||
@ -1110,9 +1039,7 @@ auto test_backward_reduction = []() {
|
||||
return wrapper(test_case);
|
||||
};
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(TransposeSinkingCommonReductionBackward,
|
||||
TransposeSinkingTestFixture,
|
||||
test_backward_reduction());
|
||||
INSTANTIATE_TEST_SUITE_P(TransposeSinkingCommonReductionBackward, TSTestFixture, test_backward_reduction());
|
||||
|
||||
auto test_backward_interpolate = []() {
|
||||
TestCase test_case;
|
||||
@ -1154,42 +1081,7 @@ auto test_backward_interpolate = []() {
|
||||
return wrapper(test_case);
|
||||
};
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(TransposeSinkingCommonInterpolateBackward,
|
||||
TransposeSinkingTestFixture,
|
||||
test_backward_interpolate());
|
||||
|
||||
auto test_backward_squeeze = []() {
|
||||
TestCase test_case;
|
||||
|
||||
// Initialize common attributes
|
||||
test_case.transformation = CREATE_PASS_FACTORY(TSSqueezeBackward);
|
||||
test_case.num_main_ops = {1};
|
||||
test_case.inputs_to_main = {
|
||||
parameter(element::f32, {32, 1, 2, 1}),
|
||||
constant<int64_t>(element::i32, {2}, {1, 3}),
|
||||
};
|
||||
|
||||
// Test model description:
|
||||
test_case.model.main_op = {CREATE_BINARY_FACTORY(Squeeze)};
|
||||
test_case.model.preprocess_outputs_of_main = {{set_transpose_for}, {{0}}};
|
||||
test_case.model.model_template = create_model;
|
||||
|
||||
// Reference model description:
|
||||
auto new_transpose = [](const vector<size_t>& idxs, const OutputVector& out_vec) -> OutputVector {
|
||||
OutputVector new_out_vec(out_vec.size());
|
||||
auto order = make_shared<Constant>(element::i32, Shape{4}, std::vector<int64_t>{2, 1, 0, 3});
|
||||
new_out_vec[0] = make_shared<Transpose>(out_vec[0], order);
|
||||
new_out_vec[1] = out_vec[1];
|
||||
return new_out_vec;
|
||||
};
|
||||
test_case.model_ref.preprocess_inputs_to_main = {{new_transpose}, {{0}}};
|
||||
test_case.model_ref.main_op = {CREATE_BINARY_FACTORY(Squeeze)};
|
||||
test_case.model_ref.model_template = create_model;
|
||||
|
||||
return wrapper(test_case);
|
||||
};
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(TransposeSinkingCommonSqueezeBackward, TransposeSinkingTestFixture, test_backward_squeeze());
|
||||
INSTANTIATE_TEST_SUITE_P(TransposeSinkingCommonInterpolateBackward, TSTestFixture, test_backward_interpolate());
|
||||
|
||||
auto test_backward_unsqueeze = []() {
|
||||
TestCase test_case;
|
||||
@ -1222,9 +1114,7 @@ auto test_backward_unsqueeze = []() {
|
||||
return wrapper(test_case);
|
||||
};
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(TransposeSinkingCommonUnsqueezeBackward,
|
||||
TransposeSinkingTestFixture,
|
||||
test_backward_unsqueeze());
|
||||
INSTANTIATE_TEST_SUITE_P(TransposeSinkingCommonUnsqueezeBackward, TSTestFixture, test_backward_unsqueeze());
|
||||
|
||||
auto test_backward_slice = []() {
|
||||
TestCase test_case;
|
||||
@ -1266,7 +1156,7 @@ auto test_backward_slice = []() {
|
||||
return wrapper(test_case);
|
||||
};
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(TransposeSinkingCommonSliceBackward, TransposeSinkingTestFixture, test_backward_slice());
|
||||
INSTANTIATE_TEST_SUITE_P(TransposeSinkingCommonSliceBackward, TSTestFixture, test_backward_slice());
|
||||
|
||||
auto test_backward_reshape_squeeze = []() {
|
||||
TestCase test_case;
|
||||
@ -1306,9 +1196,7 @@ auto test_backward_reshape_squeeze = []() {
|
||||
return wrapper(test_case);
|
||||
};
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(TransposeSinkingCommonReshapeSqueezeBackward,
|
||||
TransposeSinkingTestFixture,
|
||||
test_backward_reshape_squeeze());
|
||||
INSTANTIATE_TEST_SUITE_P(TransposeSinkingCommonReshapeSqueezeBackward, TSTestFixture, test_backward_reshape_squeeze());
|
||||
|
||||
auto test_backward_reshape_unsqueeze = []() {
|
||||
TestCase test_case;
|
||||
@ -1343,8 +1231,9 @@ auto test_backward_reshape_unsqueeze = []() {
|
||||
};
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(TransposeSinkingCommonReshapeUnsqueezeBackward,
|
||||
TransposeSinkingTestFixture,
|
||||
TSTestFixture,
|
||||
test_backward_reshape_unsqueeze());
|
||||
|
||||
} // namespace common
|
||||
} // namespace testing
|
||||
} // namespace transpose_sinking
|
||||
|
@ -13,11 +13,13 @@
|
||||
#include "openvino/pass/manager.hpp"
|
||||
#include "transformations/init_node_info.hpp"
|
||||
#include "transformations/transpose_sinking/ts_utils.hpp"
|
||||
#include "ts_test_case.hpp"
|
||||
#include "ts_test_utils.hpp"
|
||||
|
||||
using namespace ov;
|
||||
using namespace ov::opset10;
|
||||
using namespace ov::pass::transpose_sinking;
|
||||
using namespace transpose_sinking::testing;
|
||||
using namespace transpose_sinking::testing::utils;
|
||||
|
||||
namespace {
|
||||
|
@ -0,0 +1,160 @@
|
||||
// Copyright (C) 2022-2023 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#include "transformations/transpose_sinking/ts_squeeze.hpp"
|
||||
|
||||
#include "gtest/gtest.h"
|
||||
#include "openvino/opsets/opset10.hpp"
|
||||
#include "openvino/pass/manager.hpp"
|
||||
#include "ts_test_case.hpp"
|
||||
#include "ts_test_utils.hpp"
|
||||
|
||||
using namespace std;
|
||||
using namespace ov;
|
||||
using namespace ov::element;
|
||||
using namespace ov::opset10;
|
||||
using namespace ov::pass::transpose_sinking;
|
||||
using namespace transpose_sinking::testing::utils;
|
||||
|
||||
namespace transpose_sinking {
|
||||
namespace testing {
|
||||
namespace squeeze {
|
||||
|
||||
class SqueezeFactory : public IFactory {
|
||||
public:
|
||||
explicit SqueezeFactory(const std::string& type_name) : IFactory(type_name) {}
|
||||
NodePtr create(const OutputVector& parent_nodes) const override {
|
||||
if (parent_nodes.size() == 2) {
|
||||
return std::make_shared<Squeeze>(parent_nodes[0], parent_nodes[1]);
|
||||
} else if (parent_nodes.size() == 1) {
|
||||
return std::make_shared<Squeeze>(parent_nodes[0]);
|
||||
}
|
||||
OPENVINO_ASSERT(false, "Unexpected number of inputs to Squeeze operation.");
|
||||
}
|
||||
};
|
||||
|
||||
FactoryPtr CreateSqueezeFactory(const std::string& type_name) {
|
||||
return std::make_shared<SqueezeFactory>(type_name);
|
||||
}
|
||||
// ----------------------------------------------------------------------------
|
||||
|
||||
#undef CREATE_SQUEEZE_FACTORY
|
||||
#define CREATE_SQUEEZE_FACTORY(type_name) CreateSqueezeFactory(#type_name)
|
||||
// ----------------------------------------------------------------------------
|
||||
|
||||
shared_ptr<ov::Model> create_model(size_t main_node_idx,
|
||||
const ModelDescription& model_desc,
|
||||
size_t num_ops,
|
||||
const OutputVector& inputs_to_main) {
|
||||
auto new_inputs = model_desc.preprocess_inputs_to_main.apply(inputs_to_main);
|
||||
auto main_node = create_main_node(new_inputs, num_ops, model_desc.main_op[main_node_idx]);
|
||||
auto outputs = model_desc.preprocess_outputs_of_main.apply(main_node->outputs());
|
||||
return make_shared<ov::Model>(outputs, filter_parameters(inputs_to_main));
|
||||
}
|
||||
|
||||
auto wrapper = [](const TestCase& test_case) {
|
||||
OPENVINO_ASSERT(test_case.model.main_op.size() == test_case.model_ref.main_op.size(),
|
||||
"The number of main op (testing op) creator have to be the same for the testing model and for"
|
||||
"the reference model.");
|
||||
return ::testing::Combine(::testing::Range<size_t>(0, test_case.num_main_ops.size()),
|
||||
::testing::Range<size_t>(0, test_case.model.main_op.size()),
|
||||
::testing::Values(test_case));
|
||||
};
|
||||
|
||||
struct SqueezeArguments {
|
||||
OutputVector inputs_to_main;
|
||||
Output<Node> new_constant;
|
||||
};
|
||||
|
||||
auto test_forward_squeeze = [](const SqueezeArguments& test_arguments) {
|
||||
TestCase test_case;
|
||||
|
||||
// Initialize common attributes
|
||||
test_case.transformation = CREATE_PASS_FACTORY(TSSqueezeForward);
|
||||
test_case.num_main_ops = {1};
|
||||
test_case.inputs_to_main = test_arguments.inputs_to_main;
|
||||
|
||||
// Test model description:
|
||||
test_case.model.preprocess_inputs_to_main = {{set_transpose_for}, {{0}}};
|
||||
test_case.model.main_op = {CREATE_SQUEEZE_FACTORY(Squeeze)};
|
||||
test_case.model.model_template = create_model;
|
||||
|
||||
// Reference model description:
|
||||
auto new_constant = [&](const vector<size_t>& idxs, const OutputVector& out_vec) -> OutputVector {
|
||||
OutputVector new_out_vec(out_vec.size());
|
||||
new_out_vec[0] = out_vec[0];
|
||||
if (out_vec.size() > 1) {
|
||||
new_out_vec[1] = test_arguments.new_constant;
|
||||
}
|
||||
return new_out_vec;
|
||||
};
|
||||
test_case.model_ref.preprocess_inputs_to_main = {{new_constant}, {{1}}};
|
||||
test_case.model_ref.main_op = {CREATE_SQUEEZE_FACTORY(Squeeze)};
|
||||
test_case.model_ref.preprocess_outputs_of_main = {{set_transpose_for}, {{0}}};
|
||||
test_case.model_ref.model_template = create_model;
|
||||
|
||||
return wrapper(test_case);
|
||||
};
|
||||
|
||||
vector<SqueezeArguments> tests_forward_arguments{
|
||||
{{parameter(f32, {1, 2}), constant<int64_t>(i32, {1}, {1})}, constant<int64_t>(i32, {1}, {0})},
|
||||
{{parameter(f32, {1, 2, 1}), constant<int64_t>(i32, {2}, {0, 2})}, constant<int64_t>(i32, {2}, {2, 0})},
|
||||
{{parameter(f32, {1, 1, 2, 1}), constant<int64_t>(i32, {3}, {0, 2, 3})}, constant<int64_t>(i32, {3}, {3, 1, 0})},
|
||||
{{constant<int64_t>(i32, {1, 2}, {1}), constant<int64_t>(i32, {1}, {1})}, constant<int64_t>(i32, {1}, {0})},
|
||||
{{parameter(f32, {1, 2})}},
|
||||
};
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(TSCommonSqueezeForward_0, TSTestFixture, test_forward_squeeze(tests_forward_arguments[0]));
|
||||
INSTANTIATE_TEST_SUITE_P(TSCommonSqueezeForward_1, TSTestFixture, test_forward_squeeze(tests_forward_arguments[1]));
|
||||
INSTANTIATE_TEST_SUITE_P(TSCommonSqueezeForward_2, TSTestFixture, test_forward_squeeze(tests_forward_arguments[2]));
|
||||
INSTANTIATE_TEST_SUITE_P(TSCommonSqueezeForward_3, TSTestFixture, test_forward_squeeze(tests_forward_arguments[3]));
|
||||
INSTANTIATE_TEST_SUITE_P(TSCommonSqueezeForward_4, TSTestFixture, test_forward_squeeze(tests_forward_arguments[4]));
|
||||
|
||||
auto test_backward_squeeze = [](const SqueezeArguments& test_arguments) {
|
||||
TestCase test_case;
|
||||
|
||||
// Initialize common attributes
|
||||
test_case.transformation = CREATE_PASS_FACTORY(TSSqueezeBackward);
|
||||
test_case.num_main_ops = {1};
|
||||
test_case.inputs_to_main = test_arguments.inputs_to_main;
|
||||
|
||||
// Test model description:
|
||||
test_case.model.main_op = {CREATE_SQUEEZE_FACTORY(Squeeze)};
|
||||
test_case.model.preprocess_outputs_of_main = {{set_transpose_for}, {{0}}};
|
||||
test_case.model.model_template = create_model;
|
||||
|
||||
// Reference model description:
|
||||
auto new_transpose = [&](const vector<size_t>& idxs, const OutputVector& out_vec) -> OutputVector {
|
||||
OutputVector new_out_vec(out_vec.size());
|
||||
|
||||
auto order = test_arguments.new_constant;
|
||||
new_out_vec[0] = make_shared<Transpose>(out_vec[0], order);
|
||||
if (out_vec.size() > 1) {
|
||||
new_out_vec[1] = out_vec[1];
|
||||
}
|
||||
return new_out_vec;
|
||||
};
|
||||
test_case.model_ref.preprocess_inputs_to_main = {{new_transpose}, {{0}}};
|
||||
test_case.model_ref.main_op = {CREATE_SQUEEZE_FACTORY(Squeeze)};
|
||||
test_case.model_ref.model_template = create_model;
|
||||
|
||||
return wrapper(test_case);
|
||||
};
|
||||
|
||||
vector<SqueezeArguments> tests_backward_arguments{
|
||||
{{parameter(f32, {1, 2}), constant<int64_t>(i32, {1}, {0})}, constant<int64_t>(i32, {2}, {0, 1})},
|
||||
{{parameter(f32, {1, 2, 1}), constant<int64_t>(i32, {2}, {0, 2})}, constant<int64_t>(i32, {3}, {0, 1, 2})},
|
||||
{{parameter(f32, {1, 1, 2, 1}), constant<int64_t>(i32, {3}, {0, 1, 3})}, constant<int64_t>(i32, {4}, {0, 1, 2, 3})},
|
||||
{{constant<int64_t>(i32, {1, 2}, {1}), constant<int64_t>(i32, {1}, {0})}, constant<int64_t>(i32, {2}, {0, 1})},
|
||||
{{parameter(f32, {1, 2})}, constant<int64_t>(i32, {2}, {0, 1})},
|
||||
};
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(TSCommonSqueezeBackward_0, TSTestFixture, test_backward_squeeze(tests_backward_arguments[0]));
|
||||
INSTANTIATE_TEST_SUITE_P(TSCommonSqueezeBackward_1, TSTestFixture, test_backward_squeeze(tests_backward_arguments[1]));
|
||||
INSTANTIATE_TEST_SUITE_P(TSCommonSqueezeBackward_2, TSTestFixture, test_backward_squeeze(tests_backward_arguments[2]));
|
||||
INSTANTIATE_TEST_SUITE_P(TSCommonSqueezeBackward_3, TSTestFixture, test_backward_squeeze(tests_backward_arguments[3]));
|
||||
INSTANTIATE_TEST_SUITE_P(TSCommonSqueezeBackward_4, TSTestFixture, test_backward_squeeze(tests_backward_arguments[4]));
|
||||
} // namespace squeeze
|
||||
} // namespace testing
|
||||
} // namespace transpose_sinking
|
@ -0,0 +1,20 @@
|
||||
// Copyright (C) 2022-2023 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#include "ts_test_case.hpp"
|
||||
|
||||
using namespace transpose_sinking::testing;
|
||||
|
||||
std::string TSTestFixture::get_test_name(const ::testing::TestParamInfo<TestParams>& obj) {
|
||||
size_t num_main_ops_idx;
|
||||
size_t main_op_idx;
|
||||
TestCase test_case;
|
||||
std::tie(num_main_ops_idx, main_op_idx, test_case) = obj.param;
|
||||
|
||||
std::ostringstream test_name;
|
||||
test_name << "Factory=" << test_case.model.main_op[main_op_idx]->getTypeName() << "/";
|
||||
test_name << "NumOps=" << test_case.num_main_ops[num_main_ops_idx] << "/";
|
||||
test_name << "Transformation=" << test_case.transformation->getTypeName() << "/";
|
||||
return test_name.str();
|
||||
}
|
@ -0,0 +1,109 @@
|
||||
// Copyright (C) 2022-2023 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "common_test_utils/ngraph_test_utils.hpp"
|
||||
#include "gtest/gtest.h"
|
||||
#include "openvino/frontend/manager.hpp"
|
||||
#include "openvino/opsets/opset10.hpp"
|
||||
#include "openvino/pass/manager.hpp"
|
||||
|
||||
namespace transpose_sinking {
|
||||
namespace testing {
|
||||
|
||||
struct Preprocessing {
|
||||
std::vector<std::function<ov::OutputVector(std::vector<size_t>, ov::OutputVector)>> preprocessing;
|
||||
std::vector<std::vector<size_t>> indices;
|
||||
|
||||
ov::OutputVector apply(const ov::OutputVector& inputs) const {
|
||||
ov::OutputVector new_inputs = inputs;
|
||||
for (size_t i = 0; i < preprocessing.size(); ++i) {
|
||||
new_inputs = preprocessing[i](indices[i], new_inputs);
|
||||
}
|
||||
return new_inputs;
|
||||
}
|
||||
};
|
||||
|
||||
class IFactory {
|
||||
public:
|
||||
explicit IFactory(const std::string& type_name) : type_name_(type_name) {}
|
||||
virtual ~IFactory() = default;
|
||||
virtual std::shared_ptr<ov::Node> create(const ov::OutputVector& parent_nodes) const = 0;
|
||||
|
||||
const std::string& getTypeName() const {
|
||||
return type_name_;
|
||||
}
|
||||
|
||||
private:
|
||||
const std::string type_name_;
|
||||
};
|
||||
using FactoryPtr = std::shared_ptr<IFactory>;
|
||||
|
||||
class IPassFactory {
|
||||
public:
|
||||
explicit IPassFactory(const std::string& type_name) : type_name_(type_name) {}
|
||||
virtual ~IPassFactory() = default;
|
||||
virtual void registerPass(ov::pass::Manager& pass_manager) const = 0;
|
||||
const std::string& getTypeName() const {
|
||||
return type_name_;
|
||||
}
|
||||
|
||||
private:
|
||||
const std::string type_name_;
|
||||
};
|
||||
|
||||
template <typename PassT>
|
||||
class PassFactory : public IPassFactory {
|
||||
public:
|
||||
explicit PassFactory(const std::string& type_name) : IPassFactory(type_name) {}
|
||||
void registerPass(ov::pass::Manager& pass_manager) const override {
|
||||
pass_manager.register_pass<PassT>();
|
||||
}
|
||||
};
|
||||
using PassFactoryPtr = std::shared_ptr<IPassFactory>;
|
||||
#define CREATE_PASS_FACTORY(pass_name) std::make_shared<PassFactory<ov::pass::transpose_sinking::pass_name>>(#pass_name)
|
||||
|
||||
struct TestCase;
|
||||
struct ModelDescription;
|
||||
using TestParams = std::tuple<size_t /* idx num_main_ops */, size_t /* idx main_op */, TestCase>;
|
||||
using CreateGraphF = std::function<
|
||||
std::shared_ptr<ov::Model>(size_t main_op_idx, const ModelDescription&, size_t, const ov::OutputVector&)>;
|
||||
|
||||
// Describes a model to test.
|
||||
// Expects to be used in such a scenario:
|
||||
// 1st Preprocessing inserts Transpose/Gather to the inputs
|
||||
// of the main node.
|
||||
// Factory contains the rules how to create the main testing node.
|
||||
// 2nd Preprocessing inserts Transpose/Gather to the outputs
|
||||
// of the main node.
|
||||
// model_template is a function which uses the arguments above.
|
||||
// Examples of the scenarios:
|
||||
// ModelDescription model: Param -> (Transpose inserted by 1st Preprocessing) -> Abs (main_node) -> Result
|
||||
// ModelDescription reference: Param -> Abs (main_node) -> (Transpose inserted by 2nd Preprocessing) -> Result
|
||||
struct ModelDescription {
|
||||
Preprocessing preprocess_inputs_to_main;
|
||||
// @parameterized with multiple values
|
||||
std::vector<FactoryPtr> main_op;
|
||||
Preprocessing preprocess_outputs_of_main;
|
||||
CreateGraphF model_template;
|
||||
};
|
||||
|
||||
struct TestCase {
|
||||
ov::OutputVector inputs_to_main;
|
||||
// @parameterized with multiple values
|
||||
std::vector<size_t> num_main_ops;
|
||||
|
||||
ModelDescription model;
|
||||
ModelDescription model_ref;
|
||||
PassFactoryPtr transformation;
|
||||
};
|
||||
|
||||
class TSTestFixture : public ::testing::WithParamInterface<TestParams>, public TransformationTestsF {
|
||||
public:
|
||||
static std::string get_test_name(const ::testing::TestParamInfo<TestParams>& obj);
|
||||
};
|
||||
|
||||
} // namespace testing
|
||||
} // namespace transpose_sinking
|
@ -79,7 +79,7 @@ std::string to_string(const Shape& shape) {
|
||||
return result.str();
|
||||
}
|
||||
|
||||
std::shared_ptr<ov::Node> parameter(ov::element::Type el_type, const PartialShape& ps) {
|
||||
Output<Node> parameter(ov::element::Type el_type, const PartialShape& ps) {
|
||||
return std::make_shared<Parameter>(el_type, ps);
|
||||
}
|
||||
|
||||
|
@ -9,6 +9,7 @@
|
||||
#include "openvino/frontend/manager.hpp"
|
||||
#include "openvino/opsets/opset10.hpp"
|
||||
#include "openvino/pass/manager.hpp"
|
||||
#include "ts_test_case.hpp"
|
||||
|
||||
namespace transpose_sinking {
|
||||
namespace testing {
|
||||
@ -16,54 +17,16 @@ namespace utils {
|
||||
|
||||
using NodePtr = std::shared_ptr<ov::Node>;
|
||||
|
||||
class IFactory {
|
||||
public:
|
||||
explicit IFactory(const std::string& type_name) : type_name_(type_name) {}
|
||||
virtual ~IFactory() = default;
|
||||
virtual NodePtr create(const ov::OutputVector& parent_nodes) const = 0;
|
||||
|
||||
const std::string& getTypeName() const {
|
||||
return type_name_;
|
||||
}
|
||||
|
||||
private:
|
||||
const std::string type_name_;
|
||||
};
|
||||
using FactoryPtr = std::shared_ptr<IFactory>;
|
||||
|
||||
class IPassFactory {
|
||||
public:
|
||||
explicit IPassFactory(const std::string& type_name) : type_name_(type_name) {}
|
||||
virtual ~IPassFactory() = default;
|
||||
virtual void registerPass(ov::pass::Manager& pass_manager) const = 0;
|
||||
const std::string& getTypeName() const {
|
||||
return type_name_;
|
||||
}
|
||||
|
||||
private:
|
||||
const std::string type_name_;
|
||||
};
|
||||
|
||||
template <typename PassT>
|
||||
class PassFactory : public IPassFactory {
|
||||
public:
|
||||
explicit PassFactory(const std::string& type_name) : IPassFactory(type_name) {}
|
||||
void registerPass(ov::pass::Manager& pass_manager) const override {
|
||||
pass_manager.register_pass<PassT>();
|
||||
}
|
||||
};
|
||||
using PassFactoryPtr = std::shared_ptr<IPassFactory>;
|
||||
#define CREATE_PASS_FACTORY(pass_name) std::make_shared<PassFactory<ov::pass::transpose_sinking::pass_name>>(#pass_name)
|
||||
|
||||
std::string to_string(const ov::Shape& shape);
|
||||
ov::ParameterVector filter_parameters(const ov::OutputVector& out_vec);
|
||||
|
||||
ov::OutputVector set_transpose_for(const std::vector<size_t>& idxs, const ov::OutputVector& out_vec);
|
||||
ov::OutputVector set_gather_for(const std::vector<size_t>& idxs, const ov::OutputVector& out_vec);
|
||||
std::shared_ptr<ov::Node> create_main_node(const ov::OutputVector& inputs, size_t num_ops, const FactoryPtr& creator);
|
||||
ov::ParameterVector filter_parameters(const ov::OutputVector& out_vec);
|
||||
|
||||
std::shared_ptr<ov::Node> parameter(ov::element::Type el_type, const ov::PartialShape& ps);
|
||||
ov::Output<ov::Node> parameter(ov::element::Type el_type, const ov::PartialShape& ps);
|
||||
template <class T>
|
||||
std::shared_ptr<ov::Node> constant(ov::element::Type el_type, const ov::Shape& shape, const std::vector<T>& value) {
|
||||
ov::Output<ov::Node> constant(ov::element::Type el_type, const ov::Shape& shape, const std::vector<T>& value) {
|
||||
return ov::opset10::Constant::create<T>(el_type, shape, value);
|
||||
}
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user