Support new operations in TS: Selu, Swish, HSwish, Tile, CumSum, HardSigmoid (#19990)

* add new operations as unary

* get unary as input(0) instead of iterating pattern map

* add CumSum + unit tests

* add Tile + unit tests

* add tile

* fix ts_tile

* code review fix: use ADD_MATCHER

* fix bug CI tests
This commit is contained in:
Evgeny Kotov 2023-10-10 13:06:10 +02:00 committed by GitHub
parent a5b6606132
commit 4d9f2f3cd7
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 605 additions and 49 deletions

View File

@ -0,0 +1,41 @@
// Copyright (C) 2023 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#pragma once
#include "openvino/pass/graph_rewrite.hpp"
#include "openvino/pass/pass.hpp"
#include "transformations/transpose_sinking/ts_base.hpp"
#include "transformations_visibility.hpp"
namespace ov {
namespace pass {
namespace transpose_sinking {
class TRANSFORMATIONS_API TSCumSumForward;
class TRANSFORMATIONS_API TSCumSumBackward;
} // namespace transpose_sinking
} // namespace pass
} // namespace ov
/**
* @ingroup ie_transformation_common_api
* @brief TSCumSumForward transformation sinks Transpose through CumSum in the forward direction.
*/
class ov::pass::transpose_sinking::TSCumSumForward : public ov::pass::transpose_sinking::TSForwardBase {
public:
OPENVINO_RTTI("ov::pass::TSBinaryForward", "0");
TSCumSumForward();
};
/**
* @ingroup ie_transformation_common_api
* @brief TSCumSumBackward transformation sinks Transpose through CumSum in the backward direction.
*/
class ov::pass::transpose_sinking::TSCumSumBackward : public ov::pass::MatcherPass {
public:
OPENVINO_RTTI("ov::pass::TSBinaryBackward", "0");
TSCumSumBackward();
};

View File

@ -0,0 +1,41 @@
// Copyright (C) 2023 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#pragma once
#include "openvino/pass/graph_rewrite.hpp"
#include "openvino/pass/pass.hpp"
#include "transformations/transpose_sinking/ts_base.hpp"
#include "transformations_visibility.hpp"
namespace ov {
namespace pass {
namespace transpose_sinking {
class TRANSFORMATIONS_API TSTileForward;
class TRANSFORMATIONS_API TSTileBackward;
} // namespace transpose_sinking
} // namespace pass
} // namespace ov
/**
* @ingroup ie_transformation_common_api
* @brief TSTileForward transformation sinks Transpose through Tile in the forward direction.
*/
class ov::pass::transpose_sinking::TSTileForward : public ov::pass::transpose_sinking::TSForwardBase {
public:
OPENVINO_RTTI("ov::pass::TSBinaryForward", "0");
TSTileForward();
};
/**
* @ingroup ie_transformation_common_api
* @brief TSTileBackward transformation sinks Transpose through Tile in the backward direction.
*/
class ov::pass::transpose_sinking::TSTileBackward : public ov::pass::MatcherPass {
public:
OPENVINO_RTTI("ov::pass::TSBinaryBackward", "0");
TSTileBackward();
};

View File

@ -0,0 +1,92 @@
// Copyright (C) 2023 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include "transformations/transpose_sinking/ts_cumsum.hpp"
#include "itt.hpp"
#include "openvino/op/constant.hpp"
#include "openvino/op/cum_sum.hpp"
#include "openvino/op/fake_quantize.hpp"
#include "openvino/op/transpose.hpp"
#include "openvino/op/util/op_types.hpp"
#include "openvino/pass/pattern/op/wrap_type.hpp"
#include "transformations/rt_info/transpose_sinking_attr.hpp"
#include "transformations/transpose_sinking/ts_utils.hpp"
using namespace ov;
using namespace ov::pass::pattern;
using namespace ov::pass::transpose_sinking;
using namespace ov::pass::transpose_sinking::utils;
#undef CUMSUM_AXIS_INPUT_IDX
#define CUMSUM_AXIS_INPUT_IDX 1
TSCumSumForward::TSCumSumForward() {
MATCHER_SCOPE(TSCumSumForward);
create_pattern<ov::op::v0::CumSum>(true, {0});
auto sinking_transformation = [=](const std::shared_ptr<Node>& main_node,
const TransposeInputsInfo& transpose_info) -> bool {
if (transformation_callback(main_node)) {
return false;
}
bool res = utils::sink_forward::UpdateInputTransposes(main_node, transpose_info, /* input_indexes= */ {0});
if (!res)
return res;
const auto transpose_axis_order = transpose_info.transpose_const->get_axis_vector_val();
auto axis = std::make_shared<ov::op::v0::Constant>(element::i32, Shape{}, 0);
const auto& new_axes = ChangeAxes(main_node->input_value(CUMSUM_AXIS_INPUT_IDX), transpose_axis_order, axis);
main_node->input(CUMSUM_AXIS_INPUT_IDX).replace_source_output(new_axes);
default_outputs_update(main_node, transpose_info);
return true;
};
transpose_sinking(matcher_name, sinking_transformation);
}
TSCumSumBackward::TSCumSumBackward() {
MATCHER_SCOPE(TSCumSumBackward);
auto main_node_label = wrap_type<ov::op::v0::CumSum>([](const Output<Node>& output) -> bool {
return has_static_rank()(output) && CheckTransposeConsumers(output);
});
auto transpose_const_label = wrap_type<ov::op::v0::Constant>();
auto transpose_label = wrap_type<ov::op::v1::Transpose>({main_node_label, transpose_const_label},
[](const Output<Node>& output) -> bool {
return has_static_rank()(output);
});
matcher_pass_callback matcher_pass_callback = [=](Matcher& m) {
const auto& pattern_to_output = m.get_pattern_value_map();
auto transpose_const =
as_type_ptr<ov::op::v0::Constant>(pattern_to_output.at(transpose_const_label).get_node_shared_ptr());
auto transpose = pattern_to_output.at(transpose_label).get_node_shared_ptr();
auto main_node = pattern_to_output.at(main_node_label).get_node_shared_ptr();
if (transformation_callback(main_node)) {
return false;
}
for (auto& new_node : sink_backward::InsertTransposeBeforeNode(main_node,
transpose_const,
/* input_indexes= */ {0})) {
register_new_node(new_node);
}
RemoveTransposeConsumers(main_node);
const auto transpose_axis_order = transpose_const->get_axis_vector_val();
const auto reversed_transpose_order = ReverseTransposeOrder(transpose_axis_order);
auto axis = std::make_shared<ov::op::v0::Constant>(element::i32, Shape{}, 0);
auto new_axes = ChangeAxes(main_node->input_value(CUMSUM_AXIS_INPUT_IDX), reversed_transpose_order, axis);
main_node->input(CUMSUM_AXIS_INPUT_IDX).replace_source_output(new_axes);
main_node->validate_and_infer_types();
return true;
};
auto m = std::make_shared<Matcher>(transpose_label, matcher_name);
register_matcher(m, matcher_pass_callback);
}

View File

@ -13,6 +13,7 @@
#include "transformations/common_optimizations/enable_shapeof_constant_folding.hpp"
#include "transformations/transpose_sinking/ts_binary.hpp"
#include "transformations/transpose_sinking/ts_concat.hpp"
#include "transformations/transpose_sinking/ts_cumsum.hpp"
#include "transformations/transpose_sinking/ts_data_movement.hpp"
#include "transformations/transpose_sinking/ts_fuse.hpp"
#include "transformations/transpose_sinking/ts_gather.hpp"
@ -23,6 +24,7 @@
#include "transformations/transpose_sinking/ts_slice.hpp"
#include "transformations/transpose_sinking/ts_split.hpp"
#include "transformations/transpose_sinking/ts_squeeze.hpp"
#include "transformations/transpose_sinking/ts_tile.hpp"
#include "transformations/transpose_sinking/ts_unary.hpp"
#include "transformations/transpose_sinking/ts_unsqueeze.hpp"
#include "transformations/utils/utils.hpp"
@ -31,35 +33,40 @@ using namespace ov::pass::transpose_sinking;
TSGeneralForward::TSGeneralForward() {
MATCHER_SCOPE(TSGeneralForward);
add_matcher<TSUnaryForward>();
add_matcher<TSBinaryForward>();
add_matcher<TSConcatForward>();
add_matcher<TSSplitForward>();
add_matcher<TSDataMovementForward>();
add_matcher<TSReductionForward>();
add_matcher<TSSqueezeForward>();
add_matcher<TSUnsqueezeForward>();
add_matcher<TSInterpolateForward>();
add_matcher<TSSliceForward>();
add_matcher<TSGatherForward>();
add_matcher<TSShapeOfForward>();
add_matcher<TSFuse>();
ADD_MATCHER(this, TSUnaryForward);
ADD_MATCHER(this, TSBinaryForward);
ADD_MATCHER(this, TSConcatForward);
ADD_MATCHER(this, TSSplitForward);
ADD_MATCHER(this, TSDataMovementForward);
ADD_MATCHER(this, TSReductionForward);
ADD_MATCHER(this, TSSqueezeForward);
ADD_MATCHER(this, TSUnsqueezeForward);
ADD_MATCHER(this, TSInterpolateForward);
ADD_MATCHER(this, TSSliceForward);
ADD_MATCHER(this, TSGatherForward);
ADD_MATCHER(this, TSShapeOfForward);
ADD_MATCHER(this, TSCumSumForward);
ADD_MATCHER(this, TSTileForward);
ADD_MATCHER(this, TSFuse);
}
TSGeneralBackward::TSGeneralBackward() {
MATCHER_SCOPE(TSGeneralBackward);
add_matcher<TSUnaryBackward>();
add_matcher<TSBinaryBackward>();
add_matcher<TSConcatBackward>();
add_matcher<TSSplitBackward>();
add_matcher<TSDataMovementBackward>();
add_matcher<TSReductionBackward>();
add_matcher<TSSqueezeBackward>();
add_matcher<TSUnsqueezeBackward>();
add_matcher<TSInterpolateBackward>();
add_matcher<TSSliceBackward>();
add_matcher<TSGatherBackward>();
add_matcher<TSFuse>();
ADD_MATCHER(this, TSUnaryBackward);
ADD_MATCHER(this, TSUnaryBackward);
ADD_MATCHER(this, TSBinaryBackward);
ADD_MATCHER(this, TSConcatBackward);
ADD_MATCHER(this, TSSplitBackward);
ADD_MATCHER(this, TSDataMovementBackward);
ADD_MATCHER(this, TSReductionBackward);
ADD_MATCHER(this, TSSqueezeBackward);
ADD_MATCHER(this, TSUnsqueezeBackward);
ADD_MATCHER(this, TSInterpolateBackward);
ADD_MATCHER(this, TSSliceBackward);
ADD_MATCHER(this, TSGatherBackward);
ADD_MATCHER(this, TSCumSumBackward);
ADD_MATCHER(this, TSTileBackward);
ADD_MATCHER(this, TSFuse);
}
bool TSGeneral::run_on_model(const std::shared_ptr<ov::Model>& f) {

View File

@ -0,0 +1,93 @@
// Copyright (C) 2023 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include "transformations/transpose_sinking/ts_tile.hpp"
#include "itt.hpp"
#include "openvino/op/constant.hpp"
#include "openvino/op/fake_quantize.hpp"
#include "openvino/op/tile.hpp"
#include "openvino/op/transpose.hpp"
#include "openvino/op/util/op_types.hpp"
#include "openvino/pass/pattern/op/wrap_type.hpp"
#include "transformations/rt_info/transpose_sinking_attr.hpp"
#include "transformations/transpose_sinking/ts_utils.hpp"
using namespace ov;
using namespace ov::pass::pattern;
using namespace ov::pass::transpose_sinking;
using namespace ov::pass::transpose_sinking::utils;
#undef TILE_REPEATS_INPUT_IDX
#define TILE_REPEATS_INPUT_IDX 1
TSTileForward::TSTileForward() {
MATCHER_SCOPE(TSTileForward);
create_pattern<ov::op::v0::Tile>(true, {0});
auto sinking_transformation = [=](const std::shared_ptr<Node>& main_node,
const TransposeInputsInfo& transpose_info) -> bool {
if (transformation_callback(main_node)) {
return false;
}
bool res = utils::sink_forward::UpdateInputTransposes(main_node, transpose_info, /* input_indexes= */ {0});
if (!res)
return res;
const auto transpose_axis_order = transpose_info.transpose_const->get_axis_vector_val();
auto repeats = std::make_shared<ov::op::v0::Constant>(element::i32, Shape{}, 0);
const auto& new_repeats =
ChangeValuesOrder(main_node->input_value(TILE_REPEATS_INPUT_IDX), transpose_axis_order, repeats);
main_node->input(TILE_REPEATS_INPUT_IDX).replace_source_output(new_repeats);
default_outputs_update(main_node, transpose_info);
return true;
};
transpose_sinking(matcher_name, sinking_transformation);
}
TSTileBackward::TSTileBackward() {
MATCHER_SCOPE(TSTileBackward);
auto main_node_label = wrap_type<ov::op::v0::Tile>([](const Output<Node>& output) -> bool {
return has_static_rank()(output) && CheckTransposeConsumers(output);
});
auto transpose_const_label = wrap_type<ov::op::v0::Constant>();
auto transpose_label = wrap_type<ov::op::v1::Transpose>({main_node_label, transpose_const_label},
[](const Output<Node>& output) -> bool {
return has_static_rank()(output);
});
matcher_pass_callback matcher_pass_callback = [=](Matcher& m) {
const auto& pattern_to_output = m.get_pattern_value_map();
auto transpose_const =
as_type_ptr<ov::op::v0::Constant>(pattern_to_output.at(transpose_const_label).get_node_shared_ptr());
auto transpose = pattern_to_output.at(transpose_label).get_node_shared_ptr();
auto main_node = pattern_to_output.at(main_node_label).get_node_shared_ptr();
if (transformation_callback(main_node)) {
return false;
}
for (auto& new_node : sink_backward::InsertTransposeBeforeNode(main_node,
transpose_const,
/* input_indexes= */ {0})) {
register_new_node(new_node);
}
RemoveTransposeConsumers(main_node);
const auto transpose_axis_order = transpose_const->get_axis_vector_val();
auto repeats = std::make_shared<ov::op::v0::Constant>(element::i32, Shape{}, 0);
auto new_repeats =
ChangeValuesOrder(main_node->input_value(TILE_REPEATS_INPUT_IDX), transpose_axis_order, repeats);
main_node->input(TILE_REPEATS_INPUT_IDX).replace_source_output(new_repeats);
main_node->validate_and_infer_types();
return true;
};
auto m = std::make_shared<Matcher>(transpose_label, matcher_name);
register_matcher(m, matcher_pass_callback);
}

View File

@ -16,6 +16,7 @@
#include "openvino/op/logical_not.hpp"
#include "openvino/op/softplus.hpp"
#include "openvino/op/transpose.hpp"
#include "openvino/pass/pattern/op/or.hpp"
#include "openvino/pass/pattern/op/wrap_type.hpp"
#include "transformations/rt_info/transpose_sinking_attr.hpp"
#include "transformations/transpose_sinking/ts_utils.hpp"
@ -45,8 +46,21 @@ TSUnaryForward::TSUnaryForward() {
ov::op::v0::Convert,
ov::op::v10::IsInf,
ov::op::v10::IsNaN,
ov::op::v10::IsFinite>(true);
transpose_sinking(matcher_name);
ov::op::v10::IsFinite,
ov::op::v0::Selu,
ov::op::v4::Swish,
ov::op::v0::HardSigmoid,
ov::op::v5::LogSoftmax,
ov::op::v1::ConvertLike>(true);
auto ts_unary_sinking_function = [this](const std::shared_ptr<Node>& main_node,
const utils::TransposeInputsInfo& transpose_info) -> bool {
bool res = utils::sink_forward::UpdateInputTransposes(main_node, transpose_info, {0});
if (!res)
return res;
default_outputs_update(main_node, transpose_info);
return true;
};
transpose_sinking(matcher_name, ts_unary_sinking_function);
}
TSUnaryBackward::TSUnaryBackward() {
@ -56,15 +70,25 @@ TSUnaryBackward::TSUnaryBackward() {
return CheckTransposeConsumers(output);
};
auto unary_label = wrap_type<UnaryElementwiseArithmetic,
ov::op::v0::Clamp,
ov::op::v0::Elu,
ov::op::v4::SoftPlus,
ov::op::v1::LogicalNot,
ov::op::v0::Convert,
ov::op::v10::IsInf,
ov::op::v10::IsNaN,
ov::op::v10::IsFinite>({any_input()}, unary_restrictions);
auto unary_with_1_input_label = wrap_type<UnaryElementwiseArithmetic,
ov::op::v0::Clamp,
ov::op::v0::Elu,
ov::op::v4::SoftPlus,
ov::op::v1::LogicalNot,
ov::op::v0::Convert,
ov::op::v10::IsInf,
ov::op::v10::IsNaN,
ov::op::v10::IsFinite,
ov::op::v5::LogSoftmax>({any_input()}, unary_restrictions);
auto unary_with_2_inputs_label =
wrap_type<ov::op::v4::Swish, ov::op::v1::ConvertLike>({any_input(), any_input()}, unary_restrictions);
auto unary_with_3_inputs_label =
wrap_type<ov::op::v0::Selu, ov::op::v0::HardSigmoid>({any_input(), any_input(), any_input()},
unary_restrictions);
auto unary_label = std::make_shared<pattern::op::Or>(
ov::OutputVector{unary_with_1_input_label, unary_with_2_inputs_label, unary_with_3_inputs_label});
auto transpose_const_label = wrap_type<ov::op::v0::Constant>();
@ -75,12 +99,12 @@ TSUnaryBackward::TSUnaryBackward() {
auto transpose_const =
as_type_ptr<ov::op::v0::Constant>(pattern_to_output.at(transpose_const_label).get_node_shared_ptr());
auto transpose = pattern_to_output.at(transpose_label).get_node_shared_ptr();
auto unary = pattern_to_output.at(unary_label).get_node_shared_ptr();
auto unary = transpose->get_input_node_shared_ptr(0);
if (transformation_callback(unary)) {
return false;
}
for (auto& new_node : sink_backward::InsertTransposeBeforeNode(unary, transpose_const)) {
for (auto& new_node : sink_backward::InsertTransposeBeforeNode(unary, transpose_const, {0})) {
register_new_node(new_node);
}
unary->validate_and_infer_types();

View File

@ -59,6 +59,7 @@ Output<Node> ChangeAxes(const Output<Node>& indices,
copy_runtime_info(indices.get_node_shared_ptr(), gather);
return gather;
}
Output<Node> ChangeAxes(const Output<Node>& indices,
const AxisVector& transpose_axis_order,
const std::shared_ptr<ov::op::v0::Constant>& axis) {

View File

@ -8,12 +8,14 @@
#include "openvino/pass/manager.hpp"
#include "transformations/transpose_sinking/ts_binary.hpp"
#include "transformations/transpose_sinking/ts_concat.hpp"
#include "transformations/transpose_sinking/ts_cumsum.hpp"
#include "transformations/transpose_sinking/ts_data_movement.hpp"
#include "transformations/transpose_sinking/ts_interpolate.hpp"
#include "transformations/transpose_sinking/ts_reduction.hpp"
#include "transformations/transpose_sinking/ts_slice.hpp"
#include "transformations/transpose_sinking/ts_split.hpp"
#include "transformations/transpose_sinking/ts_squeeze.hpp"
#include "transformations/transpose_sinking/ts_tile.hpp"
#include "transformations/transpose_sinking/ts_unary.hpp"
#include "transformations/transpose_sinking/ts_unsqueeze.hpp"
#include "ts_test_case.hpp"
@ -206,6 +208,30 @@ FactoryPtr CreateInterpolateFactory(const std::string& type_name, bool is_refere
return std::make_shared<InterpolateFactory>(type_name, is_reference);
}
class CumSumFactory : public IFactory {
public:
explicit CumSumFactory(const std::string& type_name) : IFactory(type_name) {}
NodePtr create(const OutputVector& parent_nodes) const override {
return std::make_shared<CumSum>(parent_nodes[0], parent_nodes[1]);
}
};
FactoryPtr CreateCumSumFactory(const std::string& type_name) {
return std::make_shared<CumSumFactory>(type_name);
}
class TileFactory : public IFactory {
public:
explicit TileFactory(const std::string& type_name) : IFactory(type_name) {}
NodePtr create(const OutputVector& parent_nodes) const override {
return std::make_shared<Tile>(parent_nodes[0], parent_nodes[1]);
}
};
FactoryPtr CreateTileFactory(const std::string& type_name) {
return std::make_shared<TileFactory>(type_name);
}
class SliceFactory : public IFactory {
public:
explicit SliceFactory(const std::string& type_name) : IFactory(type_name) {}
@ -285,6 +311,12 @@ FactoryPtr CreateFakeQuantizeFactory(const std::string& type_name) {
#undef CREATE_INTERPOLATE_FACTORY
#define CREATE_INTERPOLATE_FACTORY(type_name, reference_flag) CreateInterpolateFactory(#type_name, reference_flag)
#undef CREATE_CUMSUM_FACTORY
#define CREATE_CUMSUM_FACTORY(type_name) CreateCumSumFactory(#type_name)
#undef CREATE_TILE_FACTORY
#define CREATE_TILE_FACTORY(type_name) CreateTileFactory(#type_name)
#undef CREATE_SLICE_FACTORY
#define CREATE_SLICE_FACTORY(type_name) CreateSliceFactory(#type_name)
@ -761,6 +793,84 @@ auto test_forward_interpolate = []() {
INSTANTIATE_TEST_SUITE_P(TransposeSinkingCommonInterpolateForward, TSTestFixture, test_forward_interpolate());
auto test_forward_cumsum = []() {
TestCase test_case;
// Initialize common attributes
test_case.transformation = CREATE_PASS_FACTORY(TSCumSumForward);
test_case.num_main_ops = {1};
test_case.inputs_to_main = {parameter(element::f32, {1, 2, 48, 80}),
constant<int64_t>(element::i64, {}, std::vector<int64_t>{0})};
// Test model description:
test_case.model.preprocess_inputs_to_main = {{set_transpose_for}, {{0}}};
test_case.model.main_op = {CREATE_CUMSUM_FACTORY(CumSum)};
test_case.model.model_template = create_model;
// Reference model description:
auto set_specific_gather_for = [](const vector<size_t>& idxs, const OutputVector& out_vec) -> OutputVector {
OutputVector result = out_vec;
for (const auto& idx : idxs) {
const auto& out = out_vec[idx];
vector<int64_t> transpose_order(out_vec[0].get_shape().size());
iota(transpose_order.begin(), transpose_order.end(), 0);
reverse(transpose_order.begin(), transpose_order.end());
auto data = make_shared<Constant>(element::i32, Shape{transpose_order.size()}, transpose_order);
auto axis = make_shared<Constant>(element::i32, Shape{}, 0);
auto transpose = make_shared<Gather>(data, out, axis);
result[idx] = transpose;
}
return result;
};
test_case.model_ref.preprocess_inputs_to_main = {{set_specific_gather_for}, {{1}}};
test_case.model_ref.main_op = {CREATE_CUMSUM_FACTORY(CumSum)};
test_case.model_ref.preprocess_outputs_of_main = {{set_transpose_for}, {{0}}};
test_case.model_ref.model_template = create_model;
return wrapper(test_case);
};
INSTANTIATE_TEST_SUITE_P(TransposeSinkingCommonCumSumForward, TSTestFixture, test_forward_cumsum());
auto test_forward_tile = []() {
TestCase test_case;
// Initialize common attributes
test_case.transformation = CREATE_PASS_FACTORY(TSTileForward);
test_case.num_main_ops = {1};
test_case.inputs_to_main = {parameter(element::f32, {1, 2, 48, 80}),
constant<int64_t>(element::i64, {4}, std::vector<int64_t>{1, 2, 3, 4})};
// Test model description:
test_case.model.preprocess_inputs_to_main = {{set_transpose_for}, {{0}}};
test_case.model.main_op = {CREATE_TILE_FACTORY(Tile)};
test_case.model.model_template = create_model;
// Reference model description:
auto set_specific_gather_for = [](const vector<size_t>& idxs, const OutputVector& out_vec) -> OutputVector {
OutputVector result = out_vec;
for (const auto& idx : idxs) {
const auto& out = out_vec[idx];
vector<int64_t> transpose_order(out_vec[0].get_shape().size());
iota(transpose_order.begin(), transpose_order.end(), 0);
reverse(transpose_order.begin(), transpose_order.end());
auto data = make_shared<Constant>(element::i32, Shape{transpose_order.size()}, transpose_order);
auto axis = make_shared<Constant>(element::i32, Shape{}, 0);
auto transpose = make_shared<Gather>(out, data, axis);
result[idx] = transpose;
}
return result;
};
test_case.model_ref.preprocess_inputs_to_main = {{set_specific_gather_for}, {{1}}};
test_case.model_ref.main_op = {CREATE_TILE_FACTORY(Tile)};
test_case.model_ref.preprocess_outputs_of_main = {{set_transpose_for}, {{0}}};
test_case.model_ref.model_template = create_model;
return wrapper(test_case);
};
INSTANTIATE_TEST_SUITE_P(TransposeSinkingCommonTileForward, TSTestFixture, test_forward_tile());
auto test_forward_squeeze = []() {
TestCase test_case;
@ -1262,6 +1372,120 @@ auto test_backward_interpolate = []() {
INSTANTIATE_TEST_SUITE_P(TransposeSinkingCommonInterpolateBackward, TSTestFixture, test_backward_interpolate());
auto test_backward_cumsum = []() {
TestCase test_case;
// Initialize common attributes
test_case.transformation = CREATE_PASS_FACTORY(TSCumSumBackward);
test_case.num_main_ops = {1};
test_case.inputs_to_main = {parameter(element::f32, {1, 2, 48, 80}),
constant<int64_t>(element::i64, {}, std::vector<int64_t>{0})};
// Test model description:
test_case.model.main_op = {CREATE_CUMSUM_FACTORY(CumSum)};
test_case.model.preprocess_outputs_of_main = {{set_transpose_for}, {{0}}};
test_case.model.model_template = create_model;
// Reference model description:
auto set_specific_gather_for = [](const vector<size_t>& idxs, const OutputVector& out_vec) -> OutputVector {
OutputVector result = out_vec;
for (const auto& idx : idxs) {
const auto& out = out_vec[idx];
vector<int64_t> transpose_order(out_vec[0].get_shape().size());
iota(transpose_order.begin(), transpose_order.end(), 0);
reverse(transpose_order.begin(), transpose_order.end());
auto data = make_shared<Constant>(element::i32, Shape{transpose_order.size()}, transpose_order);
auto axis = make_shared<Constant>(element::i32, Shape{}, 0);
auto transpose = make_shared<Gather>(data, out, axis);
result[idx] = transpose;
}
return result;
};
test_case.model_ref.preprocess_inputs_to_main = {{set_transpose_for, set_specific_gather_for}, {{0}, {1}}};
test_case.model_ref.main_op = {CREATE_CUMSUM_FACTORY(CumSum)};
test_case.model_ref.model_template = create_model;
return wrapper(test_case);
};
INSTANTIATE_TEST_SUITE_P(TransposeSinkingCommonCumSumBackward, TSTestFixture, test_backward_cumsum());
auto test_backward_tile = []() {
TestCase test_case;
// Initialize common attributes
test_case.transformation = CREATE_PASS_FACTORY(TSTileBackward);
test_case.num_main_ops = {1};
test_case.inputs_to_main = {parameter(element::f32, {1, 2, 48, 80}),
constant<int64_t>(element::i64, {4}, std::vector<int64_t>{1, 2, 3, 4})};
// Test model description:
test_case.model.main_op = {CREATE_TILE_FACTORY(Tile)};
test_case.model.preprocess_outputs_of_main = {{set_transpose_for}, {{0}}};
test_case.model.model_template = create_model;
// Reference model description:
auto set_specific_gather_for = [](const vector<size_t>& idxs, const OutputVector& out_vec) -> OutputVector {
OutputVector result = out_vec;
for (const auto& idx : idxs) {
const auto& out = out_vec[idx];
vector<int64_t> transpose_order(out_vec[0].get_shape().size());
iota(transpose_order.begin(), transpose_order.end(), 0);
reverse(transpose_order.begin(), transpose_order.end());
auto data = make_shared<Constant>(element::i32, Shape{transpose_order.size()}, transpose_order);
auto axis = make_shared<Constant>(element::i32, Shape{}, 0);
auto transpose = make_shared<Gather>(out, data, axis);
result[idx] = transpose;
}
return result;
};
test_case.model_ref.preprocess_inputs_to_main = {{set_transpose_for, set_specific_gather_for}, {{0}, {1}}};
test_case.model_ref.main_op = {CREATE_TILE_FACTORY(Tile)};
test_case.model_ref.model_template = create_model;
return wrapper(test_case);
};
INSTANTIATE_TEST_SUITE_P(TransposeSinkingCommonTileBackward, TSTestFixture, test_backward_tile());
auto test_backward_tile_tf_case = []() {
TestCase test_case;
// Initialize common attributes
test_case.transformation = CREATE_PASS_FACTORY(TSTileBackward);
test_case.num_main_ops = {1};
test_case.inputs_to_main = {parameter(element::f32, {2, 1, 1, 128}),
constant<int64_t>(element::i64, {4}, std::vector<int64_t>{1, 1, 88, 1})};
// Test model description:
test_case.model.main_op = {CREATE_TILE_FACTORY(Tile)};
test_case.model.preprocess_outputs_of_main = {{set_transpose_for}, {{0}}};
test_case.model.model_template = create_model;
// Reference model description:
auto set_specific_gather_for = [](const vector<size_t>& idxs, const OutputVector& out_vec) -> OutputVector {
OutputVector result = out_vec;
for (const auto& idx : idxs) {
const auto& out = out_vec[idx];
vector<int64_t> transpose_order(out_vec[0].get_shape().size());
iota(transpose_order.begin(), transpose_order.end(), 0);
reverse(transpose_order.begin(), transpose_order.end());
auto data = make_shared<Constant>(element::i32, Shape{transpose_order.size()}, transpose_order);
auto axis = make_shared<Constant>(element::i32, Shape{}, 0);
auto transpose = make_shared<Gather>(out, data, axis);
result[idx] = transpose;
}
return result;
};
test_case.model_ref.preprocess_inputs_to_main = {{set_transpose_for, set_specific_gather_for}, {{0}, {1}}};
test_case.model_ref.main_op = {CREATE_TILE_FACTORY(Tile)};
test_case.model_ref.model_template = create_model;
return wrapper(test_case);
};
INSTANTIATE_TEST_SUITE_P(TransposeSinkingCommonTileBackwardTfCase, TSTestFixture, test_backward_tile_tf_case());
auto test_backward_unsqueeze = []() {
TestCase test_case;

View File

@ -7,7 +7,7 @@
#include "common_test_utils/ov_test_utils.hpp"
#include "gtest/gtest.h"
#include "openvino/frontend/manager.hpp"
#include "openvino/opsets/opset10.hpp"
#include "openvino/opsets/opset12.hpp"
#include "openvino/pass/manager.hpp"
#include "ts_test_utils.hpp"
@ -85,6 +85,37 @@ NodePtr UnaryFactory<Convert>::create(const OutputVector& inputs) const {
return std::make_shared<Convert>(inputs[0], element::f64);
}
template <>
NodePtr UnaryFactory<Selu>::create(const OutputVector& inputs) const {
auto alpha = std::make_shared<Constant>(element::f32, Shape{}, 2.0);
auto lambda = std::make_shared<Constant>(element::f32, Shape{}, 3.0);
return std::make_shared<Selu>(inputs[0], alpha, lambda);
}
template <>
NodePtr UnaryFactory<Swish>::create(const OutputVector& inputs) const {
auto beta = std::make_shared<Constant>(element::f32, Shape{}, 0.9);
return std::make_shared<Swish>(inputs[0], beta);
}
template <>
NodePtr UnaryFactory<HardSigmoid>::create(const OutputVector& inputs) const {
auto alpha = std::make_shared<Constant>(element::f32, Shape{}, 2.0);
auto beta = std::make_shared<Constant>(element::f32, Shape{}, 3.0);
return std::make_shared<HardSigmoid>(inputs[0], alpha, beta);
}
template <>
NodePtr UnaryFactory<LogSoftmax>::create(const OutputVector& inputs) const {
return std::make_shared<LogSoftmax>(inputs[0], 2);
}
template <>
NodePtr UnaryFactory<ConvertLike>::create(const OutputVector& inputs) const {
auto like = std::make_shared<Constant>(element::f64, Shape{}, 1);
return std::make_shared<ConvertLike>(inputs[0], like);
}
template <typename UnaryT>
FactoryPtr CreateUnaryFactory(const std::string& type_name) {
return std::make_shared<UnaryFactory<UnaryT>>(type_name);
@ -352,16 +383,18 @@ std::shared_ptr<ov::Model> CreateReferenceFunction(const FactoryPtr& unary_facto
} // namespace mult_consumers_first_node
std::vector<FactoryPtr> unary_factories = {
CREATE_UNARY_FACTORY(Clamp), CREATE_UNARY_FACTORY(Elu), CREATE_UNARY_FACTORY(SoftPlus),
CREATE_UNARY_FACTORY(LogicalNot), CREATE_UNARY_FACTORY(Convert), CREATE_UNARY_FACTORY(Abs),
CREATE_UNARY_FACTORY(Acos), CREATE_UNARY_FACTORY(Asin), CREATE_UNARY_FACTORY(Asinh),
CREATE_UNARY_FACTORY(Atan), CREATE_UNARY_FACTORY(Ceiling), CREATE_UNARY_FACTORY(Cos),
CREATE_UNARY_FACTORY(Cosh), CREATE_UNARY_FACTORY(Erf), CREATE_UNARY_FACTORY(Exp),
CREATE_UNARY_FACTORY(Gelu), CREATE_UNARY_FACTORY(HSigmoid), CREATE_UNARY_FACTORY(HSwish),
CREATE_UNARY_FACTORY(Log), CREATE_UNARY_FACTORY(Negative), CREATE_UNARY_FACTORY(Relu),
CREATE_UNARY_FACTORY(Sigmoid), CREATE_UNARY_FACTORY(Sign), CREATE_UNARY_FACTORY(Sin),
CREATE_UNARY_FACTORY(Sinh), CREATE_UNARY_FACTORY(SoftSign), CREATE_UNARY_FACTORY(Sqrt),
CREATE_UNARY_FACTORY(Tan), CREATE_UNARY_FACTORY(Tanh)};
CREATE_UNARY_FACTORY(Clamp), CREATE_UNARY_FACTORY(Elu), CREATE_UNARY_FACTORY(SoftPlus),
CREATE_UNARY_FACTORY(LogicalNot), CREATE_UNARY_FACTORY(Convert), CREATE_UNARY_FACTORY(Abs),
CREATE_UNARY_FACTORY(Acos), CREATE_UNARY_FACTORY(Asin), CREATE_UNARY_FACTORY(Asinh),
CREATE_UNARY_FACTORY(Atan), CREATE_UNARY_FACTORY(Ceiling), CREATE_UNARY_FACTORY(Cos),
CREATE_UNARY_FACTORY(Cosh), CREATE_UNARY_FACTORY(Erf), CREATE_UNARY_FACTORY(Exp),
CREATE_UNARY_FACTORY(Gelu), CREATE_UNARY_FACTORY(HSigmoid), CREATE_UNARY_FACTORY(HSwish),
CREATE_UNARY_FACTORY(Log), CREATE_UNARY_FACTORY(Negative), CREATE_UNARY_FACTORY(Relu),
CREATE_UNARY_FACTORY(Sigmoid), CREATE_UNARY_FACTORY(Sign), CREATE_UNARY_FACTORY(Sin),
CREATE_UNARY_FACTORY(Sinh), CREATE_UNARY_FACTORY(SoftSign), CREATE_UNARY_FACTORY(Sqrt),
CREATE_UNARY_FACTORY(Tan), CREATE_UNARY_FACTORY(Tanh), CREATE_UNARY_FACTORY(Selu),
CREATE_UNARY_FACTORY(Swish), CREATE_UNARY_FACTORY(HardSigmoid), CREATE_UNARY_FACTORY(LogSoftmax),
CREATE_UNARY_FACTORY(ConvertLike)};
TEST_P(TransposeSinkingUnaryTestFixture, CompareFunctions) {
FactoryPtr unary_factory;