Support new operations in TS: Selu, Swish, HSwish, Tile, CumSum, HardSigmoid (#19990)
* add new operations as unary * get unary as input(0) instead of iterating pattern map * add CumSum + unit tests * add Tile + unit tests * add tile * fix ts_tile * code review fix: use ADD_MATCHER * fix bug CI tests
This commit is contained in:
parent
a5b6606132
commit
4d9f2f3cd7
@ -0,0 +1,41 @@
|
||||
// Copyright (C) 2023 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "openvino/pass/graph_rewrite.hpp"
|
||||
#include "openvino/pass/pass.hpp"
|
||||
#include "transformations/transpose_sinking/ts_base.hpp"
|
||||
#include "transformations_visibility.hpp"
|
||||
|
||||
namespace ov {
|
||||
namespace pass {
|
||||
namespace transpose_sinking {
|
||||
|
||||
class TRANSFORMATIONS_API TSCumSumForward;
|
||||
class TRANSFORMATIONS_API TSCumSumBackward;
|
||||
|
||||
} // namespace transpose_sinking
|
||||
} // namespace pass
|
||||
} // namespace ov
|
||||
|
||||
/**
|
||||
* @ingroup ie_transformation_common_api
|
||||
* @brief TSCumSumForward transformation sinks Transpose through CumSum in the forward direction.
|
||||
*/
|
||||
class ov::pass::transpose_sinking::TSCumSumForward : public ov::pass::transpose_sinking::TSForwardBase {
|
||||
public:
|
||||
OPENVINO_RTTI("ov::pass::TSBinaryForward", "0");
|
||||
TSCumSumForward();
|
||||
};
|
||||
|
||||
/**
|
||||
* @ingroup ie_transformation_common_api
|
||||
* @brief TSCumSumBackward transformation sinks Transpose through CumSum in the backward direction.
|
||||
*/
|
||||
class ov::pass::transpose_sinking::TSCumSumBackward : public ov::pass::MatcherPass {
|
||||
public:
|
||||
OPENVINO_RTTI("ov::pass::TSBinaryBackward", "0");
|
||||
TSCumSumBackward();
|
||||
};
|
@ -0,0 +1,41 @@
|
||||
// Copyright (C) 2023 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "openvino/pass/graph_rewrite.hpp"
|
||||
#include "openvino/pass/pass.hpp"
|
||||
#include "transformations/transpose_sinking/ts_base.hpp"
|
||||
#include "transformations_visibility.hpp"
|
||||
|
||||
namespace ov {
|
||||
namespace pass {
|
||||
namespace transpose_sinking {
|
||||
|
||||
class TRANSFORMATIONS_API TSTileForward;
|
||||
class TRANSFORMATIONS_API TSTileBackward;
|
||||
|
||||
} // namespace transpose_sinking
|
||||
} // namespace pass
|
||||
} // namespace ov
|
||||
|
||||
/**
|
||||
* @ingroup ie_transformation_common_api
|
||||
* @brief TSTileForward transformation sinks Transpose through Tile in the forward direction.
|
||||
*/
|
||||
class ov::pass::transpose_sinking::TSTileForward : public ov::pass::transpose_sinking::TSForwardBase {
|
||||
public:
|
||||
OPENVINO_RTTI("ov::pass::TSBinaryForward", "0");
|
||||
TSTileForward();
|
||||
};
|
||||
|
||||
/**
|
||||
* @ingroup ie_transformation_common_api
|
||||
* @brief TSTileBackward transformation sinks Transpose through Tile in the backward direction.
|
||||
*/
|
||||
class ov::pass::transpose_sinking::TSTileBackward : public ov::pass::MatcherPass {
|
||||
public:
|
||||
OPENVINO_RTTI("ov::pass::TSBinaryBackward", "0");
|
||||
TSTileBackward();
|
||||
};
|
@ -0,0 +1,92 @@
|
||||
// Copyright (C) 2023 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#include "transformations/transpose_sinking/ts_cumsum.hpp"
|
||||
|
||||
#include "itt.hpp"
|
||||
#include "openvino/op/constant.hpp"
|
||||
#include "openvino/op/cum_sum.hpp"
|
||||
#include "openvino/op/fake_quantize.hpp"
|
||||
#include "openvino/op/transpose.hpp"
|
||||
#include "openvino/op/util/op_types.hpp"
|
||||
#include "openvino/pass/pattern/op/wrap_type.hpp"
|
||||
#include "transformations/rt_info/transpose_sinking_attr.hpp"
|
||||
#include "transformations/transpose_sinking/ts_utils.hpp"
|
||||
|
||||
using namespace ov;
|
||||
using namespace ov::pass::pattern;
|
||||
using namespace ov::pass::transpose_sinking;
|
||||
using namespace ov::pass::transpose_sinking::utils;
|
||||
|
||||
#undef CUMSUM_AXIS_INPUT_IDX
|
||||
#define CUMSUM_AXIS_INPUT_IDX 1
|
||||
|
||||
TSCumSumForward::TSCumSumForward() {
|
||||
MATCHER_SCOPE(TSCumSumForward);
|
||||
|
||||
create_pattern<ov::op::v0::CumSum>(true, {0});
|
||||
|
||||
auto sinking_transformation = [=](const std::shared_ptr<Node>& main_node,
|
||||
const TransposeInputsInfo& transpose_info) -> bool {
|
||||
if (transformation_callback(main_node)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
bool res = utils::sink_forward::UpdateInputTransposes(main_node, transpose_info, /* input_indexes= */ {0});
|
||||
if (!res)
|
||||
return res;
|
||||
|
||||
const auto transpose_axis_order = transpose_info.transpose_const->get_axis_vector_val();
|
||||
auto axis = std::make_shared<ov::op::v0::Constant>(element::i32, Shape{}, 0);
|
||||
const auto& new_axes = ChangeAxes(main_node->input_value(CUMSUM_AXIS_INPUT_IDX), transpose_axis_order, axis);
|
||||
main_node->input(CUMSUM_AXIS_INPUT_IDX).replace_source_output(new_axes);
|
||||
|
||||
default_outputs_update(main_node, transpose_info);
|
||||
return true;
|
||||
};
|
||||
transpose_sinking(matcher_name, sinking_transformation);
|
||||
}
|
||||
|
||||
TSCumSumBackward::TSCumSumBackward() {
|
||||
MATCHER_SCOPE(TSCumSumBackward);
|
||||
auto main_node_label = wrap_type<ov::op::v0::CumSum>([](const Output<Node>& output) -> bool {
|
||||
return has_static_rank()(output) && CheckTransposeConsumers(output);
|
||||
});
|
||||
|
||||
auto transpose_const_label = wrap_type<ov::op::v0::Constant>();
|
||||
|
||||
auto transpose_label = wrap_type<ov::op::v1::Transpose>({main_node_label, transpose_const_label},
|
||||
[](const Output<Node>& output) -> bool {
|
||||
return has_static_rank()(output);
|
||||
});
|
||||
matcher_pass_callback matcher_pass_callback = [=](Matcher& m) {
|
||||
const auto& pattern_to_output = m.get_pattern_value_map();
|
||||
auto transpose_const =
|
||||
as_type_ptr<ov::op::v0::Constant>(pattern_to_output.at(transpose_const_label).get_node_shared_ptr());
|
||||
auto transpose = pattern_to_output.at(transpose_label).get_node_shared_ptr();
|
||||
auto main_node = pattern_to_output.at(main_node_label).get_node_shared_ptr();
|
||||
|
||||
if (transformation_callback(main_node)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
for (auto& new_node : sink_backward::InsertTransposeBeforeNode(main_node,
|
||||
transpose_const,
|
||||
/* input_indexes= */ {0})) {
|
||||
register_new_node(new_node);
|
||||
}
|
||||
|
||||
RemoveTransposeConsumers(main_node);
|
||||
const auto transpose_axis_order = transpose_const->get_axis_vector_val();
|
||||
const auto reversed_transpose_order = ReverseTransposeOrder(transpose_axis_order);
|
||||
auto axis = std::make_shared<ov::op::v0::Constant>(element::i32, Shape{}, 0);
|
||||
auto new_axes = ChangeAxes(main_node->input_value(CUMSUM_AXIS_INPUT_IDX), reversed_transpose_order, axis);
|
||||
main_node->input(CUMSUM_AXIS_INPUT_IDX).replace_source_output(new_axes);
|
||||
|
||||
main_node->validate_and_infer_types();
|
||||
return true;
|
||||
};
|
||||
auto m = std::make_shared<Matcher>(transpose_label, matcher_name);
|
||||
register_matcher(m, matcher_pass_callback);
|
||||
}
|
@ -13,6 +13,7 @@
|
||||
#include "transformations/common_optimizations/enable_shapeof_constant_folding.hpp"
|
||||
#include "transformations/transpose_sinking/ts_binary.hpp"
|
||||
#include "transformations/transpose_sinking/ts_concat.hpp"
|
||||
#include "transformations/transpose_sinking/ts_cumsum.hpp"
|
||||
#include "transformations/transpose_sinking/ts_data_movement.hpp"
|
||||
#include "transformations/transpose_sinking/ts_fuse.hpp"
|
||||
#include "transformations/transpose_sinking/ts_gather.hpp"
|
||||
@ -23,6 +24,7 @@
|
||||
#include "transformations/transpose_sinking/ts_slice.hpp"
|
||||
#include "transformations/transpose_sinking/ts_split.hpp"
|
||||
#include "transformations/transpose_sinking/ts_squeeze.hpp"
|
||||
#include "transformations/transpose_sinking/ts_tile.hpp"
|
||||
#include "transformations/transpose_sinking/ts_unary.hpp"
|
||||
#include "transformations/transpose_sinking/ts_unsqueeze.hpp"
|
||||
#include "transformations/utils/utils.hpp"
|
||||
@ -31,35 +33,40 @@ using namespace ov::pass::transpose_sinking;
|
||||
|
||||
TSGeneralForward::TSGeneralForward() {
|
||||
MATCHER_SCOPE(TSGeneralForward);
|
||||
add_matcher<TSUnaryForward>();
|
||||
add_matcher<TSBinaryForward>();
|
||||
add_matcher<TSConcatForward>();
|
||||
add_matcher<TSSplitForward>();
|
||||
add_matcher<TSDataMovementForward>();
|
||||
add_matcher<TSReductionForward>();
|
||||
add_matcher<TSSqueezeForward>();
|
||||
add_matcher<TSUnsqueezeForward>();
|
||||
add_matcher<TSInterpolateForward>();
|
||||
add_matcher<TSSliceForward>();
|
||||
add_matcher<TSGatherForward>();
|
||||
add_matcher<TSShapeOfForward>();
|
||||
add_matcher<TSFuse>();
|
||||
ADD_MATCHER(this, TSUnaryForward);
|
||||
ADD_MATCHER(this, TSBinaryForward);
|
||||
ADD_MATCHER(this, TSConcatForward);
|
||||
ADD_MATCHER(this, TSSplitForward);
|
||||
ADD_MATCHER(this, TSDataMovementForward);
|
||||
ADD_MATCHER(this, TSReductionForward);
|
||||
ADD_MATCHER(this, TSSqueezeForward);
|
||||
ADD_MATCHER(this, TSUnsqueezeForward);
|
||||
ADD_MATCHER(this, TSInterpolateForward);
|
||||
ADD_MATCHER(this, TSSliceForward);
|
||||
ADD_MATCHER(this, TSGatherForward);
|
||||
ADD_MATCHER(this, TSShapeOfForward);
|
||||
ADD_MATCHER(this, TSCumSumForward);
|
||||
ADD_MATCHER(this, TSTileForward);
|
||||
ADD_MATCHER(this, TSFuse);
|
||||
}
|
||||
|
||||
TSGeneralBackward::TSGeneralBackward() {
|
||||
MATCHER_SCOPE(TSGeneralBackward);
|
||||
add_matcher<TSUnaryBackward>();
|
||||
add_matcher<TSBinaryBackward>();
|
||||
add_matcher<TSConcatBackward>();
|
||||
add_matcher<TSSplitBackward>();
|
||||
add_matcher<TSDataMovementBackward>();
|
||||
add_matcher<TSReductionBackward>();
|
||||
add_matcher<TSSqueezeBackward>();
|
||||
add_matcher<TSUnsqueezeBackward>();
|
||||
add_matcher<TSInterpolateBackward>();
|
||||
add_matcher<TSSliceBackward>();
|
||||
add_matcher<TSGatherBackward>();
|
||||
add_matcher<TSFuse>();
|
||||
ADD_MATCHER(this, TSUnaryBackward);
|
||||
ADD_MATCHER(this, TSUnaryBackward);
|
||||
ADD_MATCHER(this, TSBinaryBackward);
|
||||
ADD_MATCHER(this, TSConcatBackward);
|
||||
ADD_MATCHER(this, TSSplitBackward);
|
||||
ADD_MATCHER(this, TSDataMovementBackward);
|
||||
ADD_MATCHER(this, TSReductionBackward);
|
||||
ADD_MATCHER(this, TSSqueezeBackward);
|
||||
ADD_MATCHER(this, TSUnsqueezeBackward);
|
||||
ADD_MATCHER(this, TSInterpolateBackward);
|
||||
ADD_MATCHER(this, TSSliceBackward);
|
||||
ADD_MATCHER(this, TSGatherBackward);
|
||||
ADD_MATCHER(this, TSCumSumBackward);
|
||||
ADD_MATCHER(this, TSTileBackward);
|
||||
ADD_MATCHER(this, TSFuse);
|
||||
}
|
||||
|
||||
bool TSGeneral::run_on_model(const std::shared_ptr<ov::Model>& f) {
|
||||
|
@ -0,0 +1,93 @@
|
||||
// Copyright (C) 2023 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#include "transformations/transpose_sinking/ts_tile.hpp"
|
||||
|
||||
#include "itt.hpp"
|
||||
#include "openvino/op/constant.hpp"
|
||||
#include "openvino/op/fake_quantize.hpp"
|
||||
#include "openvino/op/tile.hpp"
|
||||
#include "openvino/op/transpose.hpp"
|
||||
#include "openvino/op/util/op_types.hpp"
|
||||
#include "openvino/pass/pattern/op/wrap_type.hpp"
|
||||
#include "transformations/rt_info/transpose_sinking_attr.hpp"
|
||||
#include "transformations/transpose_sinking/ts_utils.hpp"
|
||||
|
||||
using namespace ov;
|
||||
using namespace ov::pass::pattern;
|
||||
using namespace ov::pass::transpose_sinking;
|
||||
using namespace ov::pass::transpose_sinking::utils;
|
||||
|
||||
#undef TILE_REPEATS_INPUT_IDX
|
||||
#define TILE_REPEATS_INPUT_IDX 1
|
||||
|
||||
TSTileForward::TSTileForward() {
|
||||
MATCHER_SCOPE(TSTileForward);
|
||||
|
||||
create_pattern<ov::op::v0::Tile>(true, {0});
|
||||
|
||||
auto sinking_transformation = [=](const std::shared_ptr<Node>& main_node,
|
||||
const TransposeInputsInfo& transpose_info) -> bool {
|
||||
if (transformation_callback(main_node)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
bool res = utils::sink_forward::UpdateInputTransposes(main_node, transpose_info, /* input_indexes= */ {0});
|
||||
if (!res)
|
||||
return res;
|
||||
|
||||
const auto transpose_axis_order = transpose_info.transpose_const->get_axis_vector_val();
|
||||
auto repeats = std::make_shared<ov::op::v0::Constant>(element::i32, Shape{}, 0);
|
||||
const auto& new_repeats =
|
||||
ChangeValuesOrder(main_node->input_value(TILE_REPEATS_INPUT_IDX), transpose_axis_order, repeats);
|
||||
main_node->input(TILE_REPEATS_INPUT_IDX).replace_source_output(new_repeats);
|
||||
|
||||
default_outputs_update(main_node, transpose_info);
|
||||
return true;
|
||||
};
|
||||
transpose_sinking(matcher_name, sinking_transformation);
|
||||
}
|
||||
|
||||
TSTileBackward::TSTileBackward() {
|
||||
MATCHER_SCOPE(TSTileBackward);
|
||||
auto main_node_label = wrap_type<ov::op::v0::Tile>([](const Output<Node>& output) -> bool {
|
||||
return has_static_rank()(output) && CheckTransposeConsumers(output);
|
||||
});
|
||||
|
||||
auto transpose_const_label = wrap_type<ov::op::v0::Constant>();
|
||||
|
||||
auto transpose_label = wrap_type<ov::op::v1::Transpose>({main_node_label, transpose_const_label},
|
||||
[](const Output<Node>& output) -> bool {
|
||||
return has_static_rank()(output);
|
||||
});
|
||||
matcher_pass_callback matcher_pass_callback = [=](Matcher& m) {
|
||||
const auto& pattern_to_output = m.get_pattern_value_map();
|
||||
auto transpose_const =
|
||||
as_type_ptr<ov::op::v0::Constant>(pattern_to_output.at(transpose_const_label).get_node_shared_ptr());
|
||||
auto transpose = pattern_to_output.at(transpose_label).get_node_shared_ptr();
|
||||
auto main_node = pattern_to_output.at(main_node_label).get_node_shared_ptr();
|
||||
|
||||
if (transformation_callback(main_node)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
for (auto& new_node : sink_backward::InsertTransposeBeforeNode(main_node,
|
||||
transpose_const,
|
||||
/* input_indexes= */ {0})) {
|
||||
register_new_node(new_node);
|
||||
}
|
||||
|
||||
RemoveTransposeConsumers(main_node);
|
||||
const auto transpose_axis_order = transpose_const->get_axis_vector_val();
|
||||
auto repeats = std::make_shared<ov::op::v0::Constant>(element::i32, Shape{}, 0);
|
||||
auto new_repeats =
|
||||
ChangeValuesOrder(main_node->input_value(TILE_REPEATS_INPUT_IDX), transpose_axis_order, repeats);
|
||||
main_node->input(TILE_REPEATS_INPUT_IDX).replace_source_output(new_repeats);
|
||||
|
||||
main_node->validate_and_infer_types();
|
||||
return true;
|
||||
};
|
||||
auto m = std::make_shared<Matcher>(transpose_label, matcher_name);
|
||||
register_matcher(m, matcher_pass_callback);
|
||||
}
|
@ -16,6 +16,7 @@
|
||||
#include "openvino/op/logical_not.hpp"
|
||||
#include "openvino/op/softplus.hpp"
|
||||
#include "openvino/op/transpose.hpp"
|
||||
#include "openvino/pass/pattern/op/or.hpp"
|
||||
#include "openvino/pass/pattern/op/wrap_type.hpp"
|
||||
#include "transformations/rt_info/transpose_sinking_attr.hpp"
|
||||
#include "transformations/transpose_sinking/ts_utils.hpp"
|
||||
@ -45,8 +46,21 @@ TSUnaryForward::TSUnaryForward() {
|
||||
ov::op::v0::Convert,
|
||||
ov::op::v10::IsInf,
|
||||
ov::op::v10::IsNaN,
|
||||
ov::op::v10::IsFinite>(true);
|
||||
transpose_sinking(matcher_name);
|
||||
ov::op::v10::IsFinite,
|
||||
ov::op::v0::Selu,
|
||||
ov::op::v4::Swish,
|
||||
ov::op::v0::HardSigmoid,
|
||||
ov::op::v5::LogSoftmax,
|
||||
ov::op::v1::ConvertLike>(true);
|
||||
auto ts_unary_sinking_function = [this](const std::shared_ptr<Node>& main_node,
|
||||
const utils::TransposeInputsInfo& transpose_info) -> bool {
|
||||
bool res = utils::sink_forward::UpdateInputTransposes(main_node, transpose_info, {0});
|
||||
if (!res)
|
||||
return res;
|
||||
default_outputs_update(main_node, transpose_info);
|
||||
return true;
|
||||
};
|
||||
transpose_sinking(matcher_name, ts_unary_sinking_function);
|
||||
}
|
||||
|
||||
TSUnaryBackward::TSUnaryBackward() {
|
||||
@ -56,15 +70,25 @@ TSUnaryBackward::TSUnaryBackward() {
|
||||
return CheckTransposeConsumers(output);
|
||||
};
|
||||
|
||||
auto unary_label = wrap_type<UnaryElementwiseArithmetic,
|
||||
ov::op::v0::Clamp,
|
||||
ov::op::v0::Elu,
|
||||
ov::op::v4::SoftPlus,
|
||||
ov::op::v1::LogicalNot,
|
||||
ov::op::v0::Convert,
|
||||
ov::op::v10::IsInf,
|
||||
ov::op::v10::IsNaN,
|
||||
ov::op::v10::IsFinite>({any_input()}, unary_restrictions);
|
||||
auto unary_with_1_input_label = wrap_type<UnaryElementwiseArithmetic,
|
||||
ov::op::v0::Clamp,
|
||||
ov::op::v0::Elu,
|
||||
ov::op::v4::SoftPlus,
|
||||
ov::op::v1::LogicalNot,
|
||||
ov::op::v0::Convert,
|
||||
ov::op::v10::IsInf,
|
||||
ov::op::v10::IsNaN,
|
||||
ov::op::v10::IsFinite,
|
||||
ov::op::v5::LogSoftmax>({any_input()}, unary_restrictions);
|
||||
|
||||
auto unary_with_2_inputs_label =
|
||||
wrap_type<ov::op::v4::Swish, ov::op::v1::ConvertLike>({any_input(), any_input()}, unary_restrictions);
|
||||
auto unary_with_3_inputs_label =
|
||||
wrap_type<ov::op::v0::Selu, ov::op::v0::HardSigmoid>({any_input(), any_input(), any_input()},
|
||||
unary_restrictions);
|
||||
|
||||
auto unary_label = std::make_shared<pattern::op::Or>(
|
||||
ov::OutputVector{unary_with_1_input_label, unary_with_2_inputs_label, unary_with_3_inputs_label});
|
||||
|
||||
auto transpose_const_label = wrap_type<ov::op::v0::Constant>();
|
||||
|
||||
@ -75,12 +99,12 @@ TSUnaryBackward::TSUnaryBackward() {
|
||||
auto transpose_const =
|
||||
as_type_ptr<ov::op::v0::Constant>(pattern_to_output.at(transpose_const_label).get_node_shared_ptr());
|
||||
auto transpose = pattern_to_output.at(transpose_label).get_node_shared_ptr();
|
||||
auto unary = pattern_to_output.at(unary_label).get_node_shared_ptr();
|
||||
auto unary = transpose->get_input_node_shared_ptr(0);
|
||||
if (transformation_callback(unary)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
for (auto& new_node : sink_backward::InsertTransposeBeforeNode(unary, transpose_const)) {
|
||||
for (auto& new_node : sink_backward::InsertTransposeBeforeNode(unary, transpose_const, {0})) {
|
||||
register_new_node(new_node);
|
||||
}
|
||||
unary->validate_and_infer_types();
|
||||
|
@ -59,6 +59,7 @@ Output<Node> ChangeAxes(const Output<Node>& indices,
|
||||
copy_runtime_info(indices.get_node_shared_ptr(), gather);
|
||||
return gather;
|
||||
}
|
||||
|
||||
Output<Node> ChangeAxes(const Output<Node>& indices,
|
||||
const AxisVector& transpose_axis_order,
|
||||
const std::shared_ptr<ov::op::v0::Constant>& axis) {
|
||||
|
@ -8,12 +8,14 @@
|
||||
#include "openvino/pass/manager.hpp"
|
||||
#include "transformations/transpose_sinking/ts_binary.hpp"
|
||||
#include "transformations/transpose_sinking/ts_concat.hpp"
|
||||
#include "transformations/transpose_sinking/ts_cumsum.hpp"
|
||||
#include "transformations/transpose_sinking/ts_data_movement.hpp"
|
||||
#include "transformations/transpose_sinking/ts_interpolate.hpp"
|
||||
#include "transformations/transpose_sinking/ts_reduction.hpp"
|
||||
#include "transformations/transpose_sinking/ts_slice.hpp"
|
||||
#include "transformations/transpose_sinking/ts_split.hpp"
|
||||
#include "transformations/transpose_sinking/ts_squeeze.hpp"
|
||||
#include "transformations/transpose_sinking/ts_tile.hpp"
|
||||
#include "transformations/transpose_sinking/ts_unary.hpp"
|
||||
#include "transformations/transpose_sinking/ts_unsqueeze.hpp"
|
||||
#include "ts_test_case.hpp"
|
||||
@ -206,6 +208,30 @@ FactoryPtr CreateInterpolateFactory(const std::string& type_name, bool is_refere
|
||||
return std::make_shared<InterpolateFactory>(type_name, is_reference);
|
||||
}
|
||||
|
||||
class CumSumFactory : public IFactory {
|
||||
public:
|
||||
explicit CumSumFactory(const std::string& type_name) : IFactory(type_name) {}
|
||||
NodePtr create(const OutputVector& parent_nodes) const override {
|
||||
return std::make_shared<CumSum>(parent_nodes[0], parent_nodes[1]);
|
||||
}
|
||||
};
|
||||
|
||||
FactoryPtr CreateCumSumFactory(const std::string& type_name) {
|
||||
return std::make_shared<CumSumFactory>(type_name);
|
||||
}
|
||||
|
||||
class TileFactory : public IFactory {
|
||||
public:
|
||||
explicit TileFactory(const std::string& type_name) : IFactory(type_name) {}
|
||||
NodePtr create(const OutputVector& parent_nodes) const override {
|
||||
return std::make_shared<Tile>(parent_nodes[0], parent_nodes[1]);
|
||||
}
|
||||
};
|
||||
|
||||
FactoryPtr CreateTileFactory(const std::string& type_name) {
|
||||
return std::make_shared<TileFactory>(type_name);
|
||||
}
|
||||
|
||||
class SliceFactory : public IFactory {
|
||||
public:
|
||||
explicit SliceFactory(const std::string& type_name) : IFactory(type_name) {}
|
||||
@ -285,6 +311,12 @@ FactoryPtr CreateFakeQuantizeFactory(const std::string& type_name) {
|
||||
#undef CREATE_INTERPOLATE_FACTORY
|
||||
#define CREATE_INTERPOLATE_FACTORY(type_name, reference_flag) CreateInterpolateFactory(#type_name, reference_flag)
|
||||
|
||||
#undef CREATE_CUMSUM_FACTORY
|
||||
#define CREATE_CUMSUM_FACTORY(type_name) CreateCumSumFactory(#type_name)
|
||||
|
||||
#undef CREATE_TILE_FACTORY
|
||||
#define CREATE_TILE_FACTORY(type_name) CreateTileFactory(#type_name)
|
||||
|
||||
#undef CREATE_SLICE_FACTORY
|
||||
#define CREATE_SLICE_FACTORY(type_name) CreateSliceFactory(#type_name)
|
||||
|
||||
@ -761,6 +793,84 @@ auto test_forward_interpolate = []() {
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(TransposeSinkingCommonInterpolateForward, TSTestFixture, test_forward_interpolate());
|
||||
|
||||
auto test_forward_cumsum = []() {
|
||||
TestCase test_case;
|
||||
|
||||
// Initialize common attributes
|
||||
test_case.transformation = CREATE_PASS_FACTORY(TSCumSumForward);
|
||||
test_case.num_main_ops = {1};
|
||||
test_case.inputs_to_main = {parameter(element::f32, {1, 2, 48, 80}),
|
||||
constant<int64_t>(element::i64, {}, std::vector<int64_t>{0})};
|
||||
|
||||
// Test model description:
|
||||
test_case.model.preprocess_inputs_to_main = {{set_transpose_for}, {{0}}};
|
||||
test_case.model.main_op = {CREATE_CUMSUM_FACTORY(CumSum)};
|
||||
test_case.model.model_template = create_model;
|
||||
|
||||
// Reference model description:
|
||||
auto set_specific_gather_for = [](const vector<size_t>& idxs, const OutputVector& out_vec) -> OutputVector {
|
||||
OutputVector result = out_vec;
|
||||
for (const auto& idx : idxs) {
|
||||
const auto& out = out_vec[idx];
|
||||
vector<int64_t> transpose_order(out_vec[0].get_shape().size());
|
||||
iota(transpose_order.begin(), transpose_order.end(), 0);
|
||||
reverse(transpose_order.begin(), transpose_order.end());
|
||||
auto data = make_shared<Constant>(element::i32, Shape{transpose_order.size()}, transpose_order);
|
||||
auto axis = make_shared<Constant>(element::i32, Shape{}, 0);
|
||||
auto transpose = make_shared<Gather>(data, out, axis);
|
||||
result[idx] = transpose;
|
||||
}
|
||||
return result;
|
||||
};
|
||||
test_case.model_ref.preprocess_inputs_to_main = {{set_specific_gather_for}, {{1}}};
|
||||
test_case.model_ref.main_op = {CREATE_CUMSUM_FACTORY(CumSum)};
|
||||
test_case.model_ref.preprocess_outputs_of_main = {{set_transpose_for}, {{0}}};
|
||||
test_case.model_ref.model_template = create_model;
|
||||
|
||||
return wrapper(test_case);
|
||||
};
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(TransposeSinkingCommonCumSumForward, TSTestFixture, test_forward_cumsum());
|
||||
|
||||
auto test_forward_tile = []() {
|
||||
TestCase test_case;
|
||||
|
||||
// Initialize common attributes
|
||||
test_case.transformation = CREATE_PASS_FACTORY(TSTileForward);
|
||||
test_case.num_main_ops = {1};
|
||||
test_case.inputs_to_main = {parameter(element::f32, {1, 2, 48, 80}),
|
||||
constant<int64_t>(element::i64, {4}, std::vector<int64_t>{1, 2, 3, 4})};
|
||||
|
||||
// Test model description:
|
||||
test_case.model.preprocess_inputs_to_main = {{set_transpose_for}, {{0}}};
|
||||
test_case.model.main_op = {CREATE_TILE_FACTORY(Tile)};
|
||||
test_case.model.model_template = create_model;
|
||||
|
||||
// Reference model description:
|
||||
auto set_specific_gather_for = [](const vector<size_t>& idxs, const OutputVector& out_vec) -> OutputVector {
|
||||
OutputVector result = out_vec;
|
||||
for (const auto& idx : idxs) {
|
||||
const auto& out = out_vec[idx];
|
||||
vector<int64_t> transpose_order(out_vec[0].get_shape().size());
|
||||
iota(transpose_order.begin(), transpose_order.end(), 0);
|
||||
reverse(transpose_order.begin(), transpose_order.end());
|
||||
auto data = make_shared<Constant>(element::i32, Shape{transpose_order.size()}, transpose_order);
|
||||
auto axis = make_shared<Constant>(element::i32, Shape{}, 0);
|
||||
auto transpose = make_shared<Gather>(out, data, axis);
|
||||
result[idx] = transpose;
|
||||
}
|
||||
return result;
|
||||
};
|
||||
test_case.model_ref.preprocess_inputs_to_main = {{set_specific_gather_for}, {{1}}};
|
||||
test_case.model_ref.main_op = {CREATE_TILE_FACTORY(Tile)};
|
||||
test_case.model_ref.preprocess_outputs_of_main = {{set_transpose_for}, {{0}}};
|
||||
test_case.model_ref.model_template = create_model;
|
||||
|
||||
return wrapper(test_case);
|
||||
};
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(TransposeSinkingCommonTileForward, TSTestFixture, test_forward_tile());
|
||||
|
||||
auto test_forward_squeeze = []() {
|
||||
TestCase test_case;
|
||||
|
||||
@ -1262,6 +1372,120 @@ auto test_backward_interpolate = []() {
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(TransposeSinkingCommonInterpolateBackward, TSTestFixture, test_backward_interpolate());
|
||||
|
||||
auto test_backward_cumsum = []() {
|
||||
TestCase test_case;
|
||||
|
||||
// Initialize common attributes
|
||||
test_case.transformation = CREATE_PASS_FACTORY(TSCumSumBackward);
|
||||
test_case.num_main_ops = {1};
|
||||
test_case.inputs_to_main = {parameter(element::f32, {1, 2, 48, 80}),
|
||||
constant<int64_t>(element::i64, {}, std::vector<int64_t>{0})};
|
||||
|
||||
// Test model description:
|
||||
test_case.model.main_op = {CREATE_CUMSUM_FACTORY(CumSum)};
|
||||
test_case.model.preprocess_outputs_of_main = {{set_transpose_for}, {{0}}};
|
||||
test_case.model.model_template = create_model;
|
||||
|
||||
// Reference model description:
|
||||
auto set_specific_gather_for = [](const vector<size_t>& idxs, const OutputVector& out_vec) -> OutputVector {
|
||||
OutputVector result = out_vec;
|
||||
for (const auto& idx : idxs) {
|
||||
const auto& out = out_vec[idx];
|
||||
vector<int64_t> transpose_order(out_vec[0].get_shape().size());
|
||||
iota(transpose_order.begin(), transpose_order.end(), 0);
|
||||
reverse(transpose_order.begin(), transpose_order.end());
|
||||
auto data = make_shared<Constant>(element::i32, Shape{transpose_order.size()}, transpose_order);
|
||||
auto axis = make_shared<Constant>(element::i32, Shape{}, 0);
|
||||
auto transpose = make_shared<Gather>(data, out, axis);
|
||||
result[idx] = transpose;
|
||||
}
|
||||
return result;
|
||||
};
|
||||
test_case.model_ref.preprocess_inputs_to_main = {{set_transpose_for, set_specific_gather_for}, {{0}, {1}}};
|
||||
test_case.model_ref.main_op = {CREATE_CUMSUM_FACTORY(CumSum)};
|
||||
test_case.model_ref.model_template = create_model;
|
||||
|
||||
return wrapper(test_case);
|
||||
};
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(TransposeSinkingCommonCumSumBackward, TSTestFixture, test_backward_cumsum());
|
||||
|
||||
auto test_backward_tile = []() {
|
||||
TestCase test_case;
|
||||
|
||||
// Initialize common attributes
|
||||
test_case.transformation = CREATE_PASS_FACTORY(TSTileBackward);
|
||||
test_case.num_main_ops = {1};
|
||||
test_case.inputs_to_main = {parameter(element::f32, {1, 2, 48, 80}),
|
||||
constant<int64_t>(element::i64, {4}, std::vector<int64_t>{1, 2, 3, 4})};
|
||||
|
||||
// Test model description:
|
||||
test_case.model.main_op = {CREATE_TILE_FACTORY(Tile)};
|
||||
test_case.model.preprocess_outputs_of_main = {{set_transpose_for}, {{0}}};
|
||||
test_case.model.model_template = create_model;
|
||||
|
||||
// Reference model description:
|
||||
auto set_specific_gather_for = [](const vector<size_t>& idxs, const OutputVector& out_vec) -> OutputVector {
|
||||
OutputVector result = out_vec;
|
||||
for (const auto& idx : idxs) {
|
||||
const auto& out = out_vec[idx];
|
||||
vector<int64_t> transpose_order(out_vec[0].get_shape().size());
|
||||
iota(transpose_order.begin(), transpose_order.end(), 0);
|
||||
reverse(transpose_order.begin(), transpose_order.end());
|
||||
auto data = make_shared<Constant>(element::i32, Shape{transpose_order.size()}, transpose_order);
|
||||
auto axis = make_shared<Constant>(element::i32, Shape{}, 0);
|
||||
auto transpose = make_shared<Gather>(out, data, axis);
|
||||
result[idx] = transpose;
|
||||
}
|
||||
return result;
|
||||
};
|
||||
test_case.model_ref.preprocess_inputs_to_main = {{set_transpose_for, set_specific_gather_for}, {{0}, {1}}};
|
||||
test_case.model_ref.main_op = {CREATE_TILE_FACTORY(Tile)};
|
||||
test_case.model_ref.model_template = create_model;
|
||||
|
||||
return wrapper(test_case);
|
||||
};
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(TransposeSinkingCommonTileBackward, TSTestFixture, test_backward_tile());
|
||||
|
||||
auto test_backward_tile_tf_case = []() {
|
||||
TestCase test_case;
|
||||
|
||||
// Initialize common attributes
|
||||
test_case.transformation = CREATE_PASS_FACTORY(TSTileBackward);
|
||||
test_case.num_main_ops = {1};
|
||||
test_case.inputs_to_main = {parameter(element::f32, {2, 1, 1, 128}),
|
||||
constant<int64_t>(element::i64, {4}, std::vector<int64_t>{1, 1, 88, 1})};
|
||||
|
||||
// Test model description:
|
||||
test_case.model.main_op = {CREATE_TILE_FACTORY(Tile)};
|
||||
test_case.model.preprocess_outputs_of_main = {{set_transpose_for}, {{0}}};
|
||||
test_case.model.model_template = create_model;
|
||||
|
||||
// Reference model description:
|
||||
auto set_specific_gather_for = [](const vector<size_t>& idxs, const OutputVector& out_vec) -> OutputVector {
|
||||
OutputVector result = out_vec;
|
||||
for (const auto& idx : idxs) {
|
||||
const auto& out = out_vec[idx];
|
||||
vector<int64_t> transpose_order(out_vec[0].get_shape().size());
|
||||
iota(transpose_order.begin(), transpose_order.end(), 0);
|
||||
reverse(transpose_order.begin(), transpose_order.end());
|
||||
auto data = make_shared<Constant>(element::i32, Shape{transpose_order.size()}, transpose_order);
|
||||
auto axis = make_shared<Constant>(element::i32, Shape{}, 0);
|
||||
auto transpose = make_shared<Gather>(out, data, axis);
|
||||
result[idx] = transpose;
|
||||
}
|
||||
return result;
|
||||
};
|
||||
test_case.model_ref.preprocess_inputs_to_main = {{set_transpose_for, set_specific_gather_for}, {{0}, {1}}};
|
||||
test_case.model_ref.main_op = {CREATE_TILE_FACTORY(Tile)};
|
||||
test_case.model_ref.model_template = create_model;
|
||||
|
||||
return wrapper(test_case);
|
||||
};
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(TransposeSinkingCommonTileBackwardTfCase, TSTestFixture, test_backward_tile_tf_case());
|
||||
|
||||
auto test_backward_unsqueeze = []() {
|
||||
TestCase test_case;
|
||||
|
||||
|
@ -7,7 +7,7 @@
|
||||
#include "common_test_utils/ov_test_utils.hpp"
|
||||
#include "gtest/gtest.h"
|
||||
#include "openvino/frontend/manager.hpp"
|
||||
#include "openvino/opsets/opset10.hpp"
|
||||
#include "openvino/opsets/opset12.hpp"
|
||||
#include "openvino/pass/manager.hpp"
|
||||
#include "ts_test_utils.hpp"
|
||||
|
||||
@ -85,6 +85,37 @@ NodePtr UnaryFactory<Convert>::create(const OutputVector& inputs) const {
|
||||
return std::make_shared<Convert>(inputs[0], element::f64);
|
||||
}
|
||||
|
||||
template <>
|
||||
NodePtr UnaryFactory<Selu>::create(const OutputVector& inputs) const {
|
||||
auto alpha = std::make_shared<Constant>(element::f32, Shape{}, 2.0);
|
||||
auto lambda = std::make_shared<Constant>(element::f32, Shape{}, 3.0);
|
||||
return std::make_shared<Selu>(inputs[0], alpha, lambda);
|
||||
}
|
||||
|
||||
template <>
|
||||
NodePtr UnaryFactory<Swish>::create(const OutputVector& inputs) const {
|
||||
auto beta = std::make_shared<Constant>(element::f32, Shape{}, 0.9);
|
||||
return std::make_shared<Swish>(inputs[0], beta);
|
||||
}
|
||||
|
||||
template <>
|
||||
NodePtr UnaryFactory<HardSigmoid>::create(const OutputVector& inputs) const {
|
||||
auto alpha = std::make_shared<Constant>(element::f32, Shape{}, 2.0);
|
||||
auto beta = std::make_shared<Constant>(element::f32, Shape{}, 3.0);
|
||||
return std::make_shared<HardSigmoid>(inputs[0], alpha, beta);
|
||||
}
|
||||
|
||||
template <>
|
||||
NodePtr UnaryFactory<LogSoftmax>::create(const OutputVector& inputs) const {
|
||||
return std::make_shared<LogSoftmax>(inputs[0], 2);
|
||||
}
|
||||
|
||||
template <>
|
||||
NodePtr UnaryFactory<ConvertLike>::create(const OutputVector& inputs) const {
|
||||
auto like = std::make_shared<Constant>(element::f64, Shape{}, 1);
|
||||
return std::make_shared<ConvertLike>(inputs[0], like);
|
||||
}
|
||||
|
||||
template <typename UnaryT>
|
||||
FactoryPtr CreateUnaryFactory(const std::string& type_name) {
|
||||
return std::make_shared<UnaryFactory<UnaryT>>(type_name);
|
||||
@ -352,16 +383,18 @@ std::shared_ptr<ov::Model> CreateReferenceFunction(const FactoryPtr& unary_facto
|
||||
} // namespace mult_consumers_first_node
|
||||
|
||||
std::vector<FactoryPtr> unary_factories = {
|
||||
CREATE_UNARY_FACTORY(Clamp), CREATE_UNARY_FACTORY(Elu), CREATE_UNARY_FACTORY(SoftPlus),
|
||||
CREATE_UNARY_FACTORY(LogicalNot), CREATE_UNARY_FACTORY(Convert), CREATE_UNARY_FACTORY(Abs),
|
||||
CREATE_UNARY_FACTORY(Acos), CREATE_UNARY_FACTORY(Asin), CREATE_UNARY_FACTORY(Asinh),
|
||||
CREATE_UNARY_FACTORY(Atan), CREATE_UNARY_FACTORY(Ceiling), CREATE_UNARY_FACTORY(Cos),
|
||||
CREATE_UNARY_FACTORY(Cosh), CREATE_UNARY_FACTORY(Erf), CREATE_UNARY_FACTORY(Exp),
|
||||
CREATE_UNARY_FACTORY(Gelu), CREATE_UNARY_FACTORY(HSigmoid), CREATE_UNARY_FACTORY(HSwish),
|
||||
CREATE_UNARY_FACTORY(Log), CREATE_UNARY_FACTORY(Negative), CREATE_UNARY_FACTORY(Relu),
|
||||
CREATE_UNARY_FACTORY(Sigmoid), CREATE_UNARY_FACTORY(Sign), CREATE_UNARY_FACTORY(Sin),
|
||||
CREATE_UNARY_FACTORY(Sinh), CREATE_UNARY_FACTORY(SoftSign), CREATE_UNARY_FACTORY(Sqrt),
|
||||
CREATE_UNARY_FACTORY(Tan), CREATE_UNARY_FACTORY(Tanh)};
|
||||
CREATE_UNARY_FACTORY(Clamp), CREATE_UNARY_FACTORY(Elu), CREATE_UNARY_FACTORY(SoftPlus),
|
||||
CREATE_UNARY_FACTORY(LogicalNot), CREATE_UNARY_FACTORY(Convert), CREATE_UNARY_FACTORY(Abs),
|
||||
CREATE_UNARY_FACTORY(Acos), CREATE_UNARY_FACTORY(Asin), CREATE_UNARY_FACTORY(Asinh),
|
||||
CREATE_UNARY_FACTORY(Atan), CREATE_UNARY_FACTORY(Ceiling), CREATE_UNARY_FACTORY(Cos),
|
||||
CREATE_UNARY_FACTORY(Cosh), CREATE_UNARY_FACTORY(Erf), CREATE_UNARY_FACTORY(Exp),
|
||||
CREATE_UNARY_FACTORY(Gelu), CREATE_UNARY_FACTORY(HSigmoid), CREATE_UNARY_FACTORY(HSwish),
|
||||
CREATE_UNARY_FACTORY(Log), CREATE_UNARY_FACTORY(Negative), CREATE_UNARY_FACTORY(Relu),
|
||||
CREATE_UNARY_FACTORY(Sigmoid), CREATE_UNARY_FACTORY(Sign), CREATE_UNARY_FACTORY(Sin),
|
||||
CREATE_UNARY_FACTORY(Sinh), CREATE_UNARY_FACTORY(SoftSign), CREATE_UNARY_FACTORY(Sqrt),
|
||||
CREATE_UNARY_FACTORY(Tan), CREATE_UNARY_FACTORY(Tanh), CREATE_UNARY_FACTORY(Selu),
|
||||
CREATE_UNARY_FACTORY(Swish), CREATE_UNARY_FACTORY(HardSigmoid), CREATE_UNARY_FACTORY(LogSoftmax),
|
||||
CREATE_UNARY_FACTORY(ConvertLike)};
|
||||
|
||||
TEST_P(TransposeSinkingUnaryTestFixture, CompareFunctions) {
|
||||
FactoryPtr unary_factory;
|
||||
|
Loading…
Reference in New Issue
Block a user