Transpose sinking tests refactoring: part 3. + Revert changes in MOC.
This commit is contained in:
parent
d71949fd09
commit
ef0e89551d
@ -7,8 +7,6 @@
|
||||
#include <memory>
|
||||
#include <openvino/pass/graph_rewrite.hpp>
|
||||
#include <openvino/pass/pattern/matcher.hpp>
|
||||
#include <transformations/common_optimizations/transpose_sinking_fuse.hpp>
|
||||
#include <transformations/common_optimizations/transpose_sinking_reduction.hpp>
|
||||
#include <transformations_visibility.hpp>
|
||||
#include <vector>
|
||||
|
||||
@ -18,11 +16,23 @@ namespace pass {
|
||||
class TRANSFORMATIONS_API TransposeSinking;
|
||||
class TRANSFORMATIONS_API TransposeConvert;
|
||||
class TRANSFORMATIONS_API TransposeEltwise;
|
||||
class TRANSFORMATIONS_API TransposeReduction;
|
||||
class TRANSFORMATIONS_API TransposeFQReduction;
|
||||
class TRANSFORMATIONS_API TransposeFuse;
|
||||
|
||||
} // namespace pass
|
||||
} // namespace ov
|
||||
|
||||
/**
|
||||
* @ingroup ie_transformation_common_api
|
||||
* @brief TransposeReduction transformation sinks Transpose through Reduce operations
|
||||
*/
|
||||
class ov::pass::TransposeReduction : public ov::pass::MatcherPass {
|
||||
public:
|
||||
OPENVINO_RTTI("TransposeReduction", "0");
|
||||
TransposeReduction();
|
||||
};
|
||||
|
||||
/**
|
||||
* @ingroup ie_transformation_common_api
|
||||
* @brief TransposeFQReduction transformation sinks Transpose through FakeQuantize in case it is followed by reduction
|
||||
@ -54,6 +64,17 @@ public:
|
||||
TransposeEltwise();
|
||||
};
|
||||
|
||||
/**
|
||||
* @ingroup ie_transformation_common_api
|
||||
* @brief TransposeFuse transformation eliminates 2 consequtive Transposes if they result in no changes to input or
|
||||
* fuses them to single Transpose if input gets changed
|
||||
*/
|
||||
class ov::pass::TransposeFuse : public ov::pass::MatcherPass {
|
||||
public:
|
||||
OPENVINO_RTTI("TransposeFuse", "0");
|
||||
TransposeFuse();
|
||||
};
|
||||
|
||||
/**
|
||||
* @ingroup ie_transformation_common_api
|
||||
* @brief TransposeSinking transformation sinks Transposes through known operations
|
||||
@ -63,7 +84,7 @@ public:
|
||||
OPENVINO_RTTI("TransposeSinking", "0");
|
||||
TransposeSinking() {
|
||||
add_matcher<ov::pass::TransposeFQReduction>();
|
||||
add_matcher<ov::pass::TransposeSinkingReductionForward>();
|
||||
add_matcher<ov::pass::TransposeReduction>();
|
||||
add_matcher<ov::pass::TransposeConvert>();
|
||||
add_matcher<ov::pass::TransposeEltwise>();
|
||||
add_matcher<ov::pass::TransposeFuse>();
|
||||
|
@ -11,18 +11,18 @@
|
||||
namespace ov {
|
||||
namespace pass {
|
||||
|
||||
class TRANSFORMATIONS_API TransposeFuse;
|
||||
class TRANSFORMATIONS_API TransposeSinkingFuse;
|
||||
|
||||
} // namespace pass
|
||||
} // namespace ov
|
||||
|
||||
/**
|
||||
* @ingroup ie_transformation_common_api
|
||||
* @brief TransposeFuse transformation eliminates 2 consecutive Transposes if they result in no changes to input or
|
||||
* @brief TransposeSinkingFuse transformation eliminates 2 consecutive Transposes if they result in no changes to input or
|
||||
* fuses them to single Transpose if input gets changed
|
||||
*/
|
||||
class ov::pass::TransposeFuse : public ov::pass::MatcherPass {
|
||||
class ov::pass::TransposeSinkingFuse : public ov::pass::MatcherPass {
|
||||
public:
|
||||
OPENVINO_RTTI("TransposeFuse", "0");
|
||||
TransposeFuse();
|
||||
OPENVINO_RTTI("TransposeSinkingFuse", "0");
|
||||
TransposeSinkingFuse();
|
||||
};
|
@ -21,6 +21,27 @@ using namespace ov;
|
||||
|
||||
namespace {
|
||||
|
||||
std::shared_ptr<opset6::Constant> get_reduced_order_constant(const std::shared_ptr<opset6::Constant>& axes_const,
|
||||
const std::shared_ptr<opset6::Constant>& order_const) {
|
||||
auto order = order_const->cast_vector<int64_t>();
|
||||
|
||||
auto axes = axes_const->cast_vector<int64_t>();
|
||||
std::sort(axes.rbegin(), axes.rend());
|
||||
for (const auto& i : axes)
|
||||
order.erase(order.begin() + i);
|
||||
|
||||
const auto& updated_order_size = static_cast<int64_t>(order.size());
|
||||
|
||||
auto order_sorted = order;
|
||||
sort(order_sorted.begin(), order_sorted.end());
|
||||
for (int64_t i = 0; i < updated_order_size; ++i) {
|
||||
auto lowest_greater_eq_i = std::lower_bound(order_sorted.begin(), order_sorted.end(), i);
|
||||
std::replace(order.begin(), order.end(), *lowest_greater_eq_i, i);
|
||||
std::replace(order_sorted.begin(), order_sorted.end(), *lowest_greater_eq_i, i);
|
||||
}
|
||||
return std::make_shared<opset6::Constant>(ngraph::element::i64, ngraph::Shape{order.size()}, order);
|
||||
}
|
||||
|
||||
std::shared_ptr<opset6::Constant> get_reversed_order_constant(const std::shared_ptr<opset6::Constant>& order_const) {
|
||||
const auto& order = order_const->cast_vector<size_t>();
|
||||
const auto& rank = order.size();
|
||||
@ -109,6 +130,71 @@ ov::pass::TransposeConvert::TransposeConvert() {
|
||||
register_matcher(m, matcher_pass_callback);
|
||||
}
|
||||
|
||||
ov::pass::TransposeReduction::TransposeReduction() {
|
||||
MATCHER_SCOPE(TransposeReduction);
|
||||
|
||||
auto transpose_label =
|
||||
pattern::wrap_type<opset6::Transpose>({pattern::any_input(), pattern::wrap_type<opset6::Constant>()},
|
||||
pattern::consumers_count(1));
|
||||
auto reduce_or_squeeze_label =
|
||||
pattern::wrap_type<op::util::ArithmeticReductionKeepDims, op::util::LogicalReductionKeepDims, opset6::Squeeze>(
|
||||
{transpose_label, pattern::wrap_type<opset6::Constant>()});
|
||||
|
||||
ov::matcher_pass_callback matcher_pass_callback = [=](ngraph::pattern::Matcher& m) {
|
||||
const auto& pattern_to_output = m.get_pattern_value_map();
|
||||
|
||||
auto transpose = pattern_to_output.at(transpose_label).get_node_shared_ptr();
|
||||
auto reduction = pattern_to_output.at(reduce_or_squeeze_label).get_node_shared_ptr();
|
||||
auto arithmetic_reduce = std::dynamic_pointer_cast<op::util::ArithmeticReductionKeepDims>(reduction);
|
||||
auto logical_reduce = std::dynamic_pointer_cast<op::util::LogicalReductionKeepDims>(reduction);
|
||||
auto squeeze = std::dynamic_pointer_cast<opset6::Squeeze>(reduction);
|
||||
if (!transpose || !(arithmetic_reduce || logical_reduce || squeeze))
|
||||
return false;
|
||||
|
||||
bool keep_dims = false; // squeeze always reduces number of output dimensions
|
||||
if (logical_reduce)
|
||||
keep_dims = logical_reduce->get_keep_dims();
|
||||
else if (arithmetic_reduce)
|
||||
keep_dims = arithmetic_reduce->get_keep_dims();
|
||||
|
||||
auto transpose_order = std::dynamic_pointer_cast<opset6::Constant>(transpose->get_input_node_shared_ptr(1));
|
||||
auto reduction_axes = std::dynamic_pointer_cast<opset6::Constant>(reduction->get_input_node_shared_ptr(1));
|
||||
if (!transpose_order || !reduction_axes)
|
||||
return false;
|
||||
|
||||
const auto& non_negative_axes = normalize_axes(reduction->get_friendly_name(),
|
||||
reduction_axes->cast_vector<int64_t>(),
|
||||
reduction->get_input_partial_shape(0).rank());
|
||||
reduction_axes = opset6::Constant::create(ngraph::element::i64, {non_negative_axes.size()}, non_negative_axes);
|
||||
|
||||
ngraph::NodeVector new_ops;
|
||||
auto new_axes =
|
||||
ov::op::util::make_try_fold<opset6::Gather>(transpose_order,
|
||||
reduction_axes,
|
||||
opset6::Constant::create(ngraph::element::i64, {}, {0}));
|
||||
new_ops.push_back(new_axes);
|
||||
auto new_reduce = reduction->clone_with_new_inputs({transpose->input_value(0), new_axes});
|
||||
new_ops.push_back(new_reduce);
|
||||
|
||||
auto updated_order = transpose_order;
|
||||
if (!keep_dims) {
|
||||
updated_order = get_reduced_order_constant(reduction_axes, transpose_order);
|
||||
new_ops.push_back(updated_order);
|
||||
}
|
||||
auto new_transpose = register_new_node<opset6::Transpose>(new_reduce, updated_order);
|
||||
new_ops.push_back(new_transpose);
|
||||
new_transpose->set_friendly_name(reduction->get_friendly_name());
|
||||
|
||||
ngraph::copy_runtime_info({reduction, transpose}, new_ops);
|
||||
ngraph::replace_node(reduction, new_transpose);
|
||||
|
||||
return true;
|
||||
};
|
||||
|
||||
auto m = std::make_shared<ngraph::pattern::Matcher>(reduce_or_squeeze_label, matcher_name);
|
||||
register_matcher(m, matcher_pass_callback);
|
||||
}
|
||||
|
||||
ov::pass::TransposeFQReduction::TransposeFQReduction() {
|
||||
MATCHER_SCOPE(TransposeFQReduction);
|
||||
|
||||
@ -176,3 +262,59 @@ ov::pass::TransposeFQReduction::TransposeFQReduction() {
|
||||
auto m = std::make_shared<ngraph::pattern::Matcher>(reduce_or_squeeze_label, matcher_name);
|
||||
register_matcher(m, matcher_pass_callback);
|
||||
}
|
||||
|
||||
ov::pass::TransposeFuse::TransposeFuse() {
|
||||
MATCHER_SCOPE(TransposeFuse);
|
||||
|
||||
auto transpose_1 =
|
||||
pattern::wrap_type<opset7::Transpose>({pattern::any_input(), pattern::wrap_type<opset7::Constant>()},
|
||||
pattern::consumers_count(1));
|
||||
auto transpose_2 = pattern::wrap_type<opset7::Transpose>({transpose_1, pattern::wrap_type<opset7::Constant>()});
|
||||
|
||||
ov::matcher_pass_callback matcher_pass_callback = [=](ngraph::pattern::Matcher& m) {
|
||||
const auto& pattern_to_output = m.get_pattern_value_map();
|
||||
|
||||
auto transpose1 = pattern_to_output.at(transpose_1).get_node_shared_ptr();
|
||||
auto transpose2 = pattern_to_output.at(transpose_2).get_node_shared_ptr();
|
||||
auto input = transpose1->input_value(0);
|
||||
|
||||
auto transpose1_order = std::dynamic_pointer_cast<opset7::Constant>(transpose1->get_input_node_shared_ptr(1));
|
||||
auto transpose2_order = std::dynamic_pointer_cast<opset7::Constant>(transpose2->get_input_node_shared_ptr(1));
|
||||
if (!transpose1_order || !transpose2_order)
|
||||
return false;
|
||||
|
||||
auto order1 = transpose1_order->cast_vector<int64_t>();
|
||||
auto order2 = transpose2_order->cast_vector<int64_t>();
|
||||
if (order1.size() != order2.size())
|
||||
return false;
|
||||
|
||||
bool is_ordered = true;
|
||||
for (size_t i = 0; i < order1.size(); i++) {
|
||||
order2[i] = order1[order2[i]];
|
||||
if (order2[i] != (int64_t)i)
|
||||
is_ordered = false;
|
||||
}
|
||||
|
||||
auto transpose_order_type = transpose1_order->get_element_type();
|
||||
if (transpose_order_type != transpose2_order->get_element_type())
|
||||
transpose_order_type = element::i64;
|
||||
|
||||
if (is_ordered) {
|
||||
return ngraph::replace_output_update_name(transpose2->output(0), input);
|
||||
} else {
|
||||
auto new_order = opset7::Constant::create(transpose_order_type, {order2.size()}, order2);
|
||||
auto new_transpose = register_new_node<opset7::Transpose>(input, new_order);
|
||||
|
||||
new_transpose->set_friendly_name(m.get_match_root()->get_friendly_name());
|
||||
ngraph::copy_runtime_info({transpose1, transpose2}, new_transpose);
|
||||
ngraph::replace_node(m.get_match_root(), new_transpose);
|
||||
|
||||
transpose_sinking::UpdateForwardSinkingAbility(new_transpose);
|
||||
}
|
||||
|
||||
return true;
|
||||
};
|
||||
|
||||
auto m = std::make_shared<ngraph::pattern::Matcher>(transpose_2, matcher_name);
|
||||
register_matcher(m, matcher_pass_callback);
|
||||
}
|
||||
|
@ -17,7 +17,7 @@
|
||||
using namespace ov;
|
||||
using namespace opset10;
|
||||
|
||||
ov::pass::TransposeFuse::TransposeFuse() {
|
||||
ov::pass::TransposeSinkingFuse::TransposeSinkingFuse() {
|
||||
MATCHER_SCOPE(TransposeFuse);
|
||||
auto transpose_label = pattern::wrap_type<Transpose>({pattern::any_input(), pattern::wrap_type<Constant>()});
|
||||
ov::matcher_pass_callback matcher_pass_callback = [=](pattern::Matcher& m) {
|
||||
|
@ -10,13 +10,14 @@
|
||||
#include <openvino/pass/pattern/op/wrap_type.hpp>
|
||||
|
||||
#include "itt.hpp"
|
||||
#include "transformations/common_optimizations/transpose_sinking.hpp"
|
||||
#include "transformations/common_optimizations/transpose_sinking_binary.hpp"
|
||||
#include "transformations/common_optimizations/transpose_sinking_concat.hpp"
|
||||
#include "transformations/common_optimizations/transpose_sinking_data_movement.hpp"
|
||||
#include "transformations/common_optimizations/transpose_sinking_interpolate.hpp"
|
||||
#include "transformations/common_optimizations/transpose_sinking_split.hpp"
|
||||
#include "transformations/common_optimizations/transpose_sinking_unary.hpp"
|
||||
#include "transformations/common_optimizations/transpose_sinking_fuse.hpp"
|
||||
#include "transformations/common_optimizations/transpose_sinking_reduction.hpp"
|
||||
#include "transformations/utils/utils.hpp"
|
||||
|
||||
ov::pass::TransposeSinkingGeneralForward::TransposeSinkingGeneralForward() {
|
||||
@ -28,7 +29,7 @@ ov::pass::TransposeSinkingGeneralForward::TransposeSinkingGeneralForward() {
|
||||
add_matcher<ov::pass::TransposeSinkingDataMovementForward>();
|
||||
add_matcher<ov::pass::TransposeSinkingReductionForward>();
|
||||
add_matcher<ov::pass::TransposeSinkingInterpolateForward>();
|
||||
add_matcher<ov::pass::TransposeFuse>();
|
||||
add_matcher<ov::pass::TransposeSinkingFuse>();
|
||||
}
|
||||
|
||||
ov::pass::TransposeSinkingGeneralBackward::TransposeSinkingGeneralBackward() {
|
||||
@ -40,7 +41,7 @@ ov::pass::TransposeSinkingGeneralBackward::TransposeSinkingGeneralBackward() {
|
||||
add_matcher<ov::pass::TransposeSinkingDataMovementBackward>();
|
||||
add_matcher<ov::pass::TransposeSinkingReductionBackward>();
|
||||
add_matcher<ov::pass::TransposeSinkingInterpolateBackward>();
|
||||
add_matcher<ov::pass::TransposeFuse>();
|
||||
add_matcher<ov::pass::TransposeSinkingFuse>();
|
||||
}
|
||||
|
||||
bool ov::pass::TransposeSinkingGeneral::run_on_model(const std::shared_ptr<ov::Model>& f) {
|
||||
|
@ -6,6 +6,8 @@
|
||||
#include "transformations/common_optimizations/transpose_sinking_binary.hpp"
|
||||
#include "transformations/common_optimizations/transpose_sinking_concat.hpp"
|
||||
#include "transformations/common_optimizations/transpose_sinking_split.hpp"
|
||||
#include "transformations/common_optimizations/transpose_sinking_data_movement.hpp"
|
||||
#include "transformations/common_optimizations/transpose_sinking_reduction.hpp"
|
||||
#include "transpose_sinking_test_utils.hpp"
|
||||
|
||||
#include <openvino/opsets/opset10.hpp>
|
||||
@ -98,6 +100,54 @@ public:
|
||||
FactoryPtr CreateSplitFactory(const std::string& type_name) {
|
||||
return std::make_shared<SplitFactory>(type_name);
|
||||
}
|
||||
|
||||
class PadFactory : public IFactory {
|
||||
public:
|
||||
explicit PadFactory(const std::string& type_name) : IFactory(type_name) {}
|
||||
NodePtr create(const OutputVector& parent_nodes) const override {
|
||||
return std::make_shared<Pad>(parent_nodes[0], parent_nodes[1], parent_nodes[2], ov::op::PadMode::CONSTANT);
|
||||
}
|
||||
};
|
||||
FactoryPtr CreatePadFactory(const std::string& type_name) {
|
||||
return std::make_shared<PadFactory>(type_name);
|
||||
}
|
||||
|
||||
class BatchToSpaceFactory : public IFactory {
|
||||
public:
|
||||
explicit BatchToSpaceFactory(const std::string& type_name) : IFactory(type_name) {}
|
||||
NodePtr create(const OutputVector& parent_nodes) const override {
|
||||
return std::make_shared<BatchToSpace>(parent_nodes[0], parent_nodes[1], parent_nodes[2], parent_nodes[3]);
|
||||
}
|
||||
};
|
||||
|
||||
FactoryPtr CreateBatchToSpaceFactory(const std::string& type_name) {
|
||||
return std::make_shared<BatchToSpaceFactory>(type_name);
|
||||
}
|
||||
|
||||
class SpaceToBatchFactory : public IFactory {
|
||||
public:
|
||||
explicit SpaceToBatchFactory(const std::string& type_name) : IFactory(type_name) {}
|
||||
NodePtr create(const OutputVector& parent_nodes) const override {
|
||||
return std::make_shared<SpaceToBatch>(parent_nodes[0], parent_nodes[1], parent_nodes[2], parent_nodes[3]);
|
||||
}
|
||||
};
|
||||
FactoryPtr CreateSpaceToBatchFactory(const std::string& type_name) {
|
||||
return std::make_shared<SpaceToBatchFactory>(type_name);
|
||||
}
|
||||
|
||||
template <typename ReductionT>
|
||||
class ReductionFactory : public IFactory {
|
||||
public:
|
||||
explicit ReductionFactory(const std::string& type_name) : IFactory(type_name) {}
|
||||
NodePtr create(const OutputVector& parent_nodes) const override {
|
||||
return std::make_shared<ReductionT>(parent_nodes[0], parent_nodes[1]);
|
||||
}
|
||||
};
|
||||
|
||||
template <typename ReductionT>
|
||||
FactoryPtr CreateReductionFactory(const std::string& type_name) {
|
||||
return std::make_shared<ReductionFactory<ReductionT>>(type_name);
|
||||
}
|
||||
// ----------------------------------------------------------------------------
|
||||
|
||||
#undef CREATE_UNARY_FACTORY
|
||||
@ -114,6 +164,18 @@ FactoryPtr CreateSplitFactory(const std::string& type_name) {
|
||||
|
||||
#undef CREATE_SPLIT_FACTORY
|
||||
#define CREATE_SPLIT_FACTORY(type_name) CreateSplitFactory(#type_name)
|
||||
|
||||
#undef CREATE_PAD_FACTORY
|
||||
#define CREATE_PAD_FACTORY(type_name) CreatePadFactory(#type_name)
|
||||
|
||||
#undef CREATE_BATCH_TO_SPACE_FACTORY
|
||||
#define CREATE_BATCH_TO_SPACE_FACTORY(type_name) CreateBatchToSpaceFactory(#type_name)
|
||||
|
||||
#undef CREATE_SPACE_TO_BATCH_FACTORY
|
||||
#define CREATE_SPACE_TO_BATCH_FACTORY(type_name) CreateSpaceToBatchFactory(#type_name)
|
||||
|
||||
#undef CREATE_REDUCTION_FACTORY
|
||||
#define CREATE_REDUCTION_FACTORY(type_name) CreateReductionFactory<type_name>(#type_name)
|
||||
// ----------------------------------------------------------------------------
|
||||
|
||||
struct Preprocessing {
|
||||
@ -128,81 +190,57 @@ struct Preprocessing {
|
||||
return new_inputs;
|
||||
}
|
||||
};
|
||||
using CreateGraphF = function<shared_ptr<ov::Model>(Preprocessing, FactoryPtr, Preprocessing, size_t, OutputVector)>;
|
||||
|
||||
using TestParams = tuple<FactoryPtr,
|
||||
FactoryPtr,
|
||||
PassFactoryPtr,
|
||||
size_t, /* num_unary_ops */
|
||||
Preprocessing,
|
||||
CreateGraphF, /* model_factory */
|
||||
Preprocessing,
|
||||
Preprocessing,
|
||||
CreateGraphF, /* reference_model_factory */
|
||||
Preprocessing,
|
||||
OutputVector>;
|
||||
struct TestCase;
|
||||
struct ModelDescription;
|
||||
using TestParams = tuple<size_t /* idx num_main_ops */, size_t /* idx main_op */, TestCase>;
|
||||
using CreateGraphF = function<shared_ptr<ov::Model>(size_t main_op_idx, const ModelDescription&, size_t, const OutputVector&)>;
|
||||
|
||||
struct TestCases {
|
||||
vector<FactoryPtr> main_node;
|
||||
vector<FactoryPtr> main_node_ref;
|
||||
PassFactoryPtr transformation;
|
||||
vector<size_t> num_main_ops;
|
||||
Preprocessing preprocess_before;
|
||||
CreateGraphF test_model;
|
||||
Preprocessing preprocess_after;
|
||||
Preprocessing preprocess_before_ref;
|
||||
CreateGraphF ref_model;
|
||||
Preprocessing preprocess_after_ref;
|
||||
OutputVector inputs_to_main;
|
||||
// Describes a model to test.
|
||||
// Expects to be used in such a scenario:
|
||||
// 1st Preprocessing inserts Transpose/Gather to the inputs
|
||||
// of the main node.
|
||||
// Factory contains the rules how to create the main testing node.
|
||||
// 2nd Preprocessing inserts Transpose/Gather to the outputs
|
||||
// of the main node.
|
||||
// model_template is a function which uses the arguments above.
|
||||
// Examples of the scenarios:
|
||||
// ModelDescription model: Param -> (Transpose inserted by 1st Preprocessing) -> Abs (main_node) -> Result
|
||||
// ModelDescription reference: Param -> Abs (main_node) -> (Transpose inserted by 2nd Preprocessing) -> Result
|
||||
struct ModelDescription {
|
||||
Preprocessing preprocess_inputs_to_main;
|
||||
// @parameterized with multiple values
|
||||
vector<FactoryPtr> main_op;
|
||||
Preprocessing preprocess_outputs_of_main;
|
||||
CreateGraphF model_template;
|
||||
};
|
||||
|
||||
struct TestCase {
|
||||
FactoryPtr main_node;
|
||||
FactoryPtr main_node_ref;
|
||||
PassFactoryPtr transformation;
|
||||
size_t num_main_ops = 0;
|
||||
Preprocessing preprocess_before;
|
||||
CreateGraphF test_model;
|
||||
Preprocessing preprocess_after;
|
||||
Preprocessing preprocess_before_ref;
|
||||
CreateGraphF ref_model;
|
||||
Preprocessing preprocess_after_ref;
|
||||
OutputVector inputs_to_main;
|
||||
// @parameterized with multiple values
|
||||
vector<size_t> num_main_ops;
|
||||
|
||||
explicit TestCase(const TestParams& params) {
|
||||
tie(main_node,
|
||||
main_node_ref,
|
||||
transformation,
|
||||
num_main_ops,
|
||||
preprocess_before,
|
||||
test_model,
|
||||
preprocess_after,
|
||||
preprocess_before_ref,
|
||||
ref_model,
|
||||
preprocess_after_ref,
|
||||
inputs_to_main) = params;
|
||||
}
|
||||
ModelDescription model;
|
||||
ModelDescription model_ref;
|
||||
PassFactoryPtr transformation;
|
||||
};
|
||||
|
||||
class TransposeSinkingTestFixture : public ::testing::WithParamInterface<TestParams>, public TransformationTestsF {
|
||||
public:
|
||||
/* static string get_test_name(const testing::TestParamInfo<TestParams>& obj) {
|
||||
auto test_case = TestCase(obj.param);
|
||||
static string get_test_name(testing::TestParamInfo<TestParams> obj) {
|
||||
size_t num_main_ops_idx;
|
||||
size_t main_op_idx;
|
||||
TestCase test_case;
|
||||
tie(num_main_ops_idx, main_op_idx, test_case) = obj.param;
|
||||
|
||||
ostringstream test_name;
|
||||
test_name << "unaryFactory=" << unary_factory->getTypeName() << "/";
|
||||
test_name << "numUnaryOps=" << num_unary_ops << "/";
|
||||
//test_name << "inputShape=" << to_string(input_shape) << "/";
|
||||
test_name << "unaryFactory=" << unary_factory->getTypeName() << "/";
|
||||
test_name << "passFactory=" << pass_factory->getTypeName() << "/";
|
||||
//test_name << "inputType=" << input_type;
|
||||
|
||||
test_name << "Factory=" << test_case.model.main_op[main_op_idx]->getTypeName() << "/";
|
||||
test_name << "NumOps=" << test_case.num_main_ops[num_main_ops_idx] << "/";
|
||||
test_name << "Transformation=" << test_case.transformation->getTypeName() << "/";
|
||||
return test_name.str();
|
||||
}*/
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
|
||||
vector<FactoryPtr> unary_factories = {
|
||||
CREATE_UNARY_FACTORY(Clamp), CREATE_UNARY_FACTORY(Elu), CREATE_UNARY_FACTORY(SoftPlus),
|
||||
CREATE_UNARY_FACTORY(LogicalNot), CREATE_UNARY_FACTORY(Convert), CREATE_UNARY_FACTORY(Abs),
|
||||
@ -226,152 +264,310 @@ std::vector<FactoryPtr> binary_factories = {CREATE_BINARY_FACTORY(Add),
|
||||
CREATE_BINARY_FACTORY(Subtract),
|
||||
CREATE_BINARY_FACTORY(PRelu)};
|
||||
|
||||
std::vector<FactoryPtr> reduction_factories = {CREATE_BINARY_FACTORY(ReduceMax),
|
||||
CREATE_BINARY_FACTORY(ReduceMin),
|
||||
CREATE_BINARY_FACTORY(ReduceMean),
|
||||
CREATE_BINARY_FACTORY(ReduceSum),
|
||||
CREATE_BINARY_FACTORY(ReduceProd),
|
||||
//CREATE_BINARY_FACTORY(ReduceLogicalOr),
|
||||
//CREATE_BINARY_FACTORY(ReduceLogicalAnd),
|
||||
CREATE_BINARY_FACTORY(ReduceL1),
|
||||
CREATE_BINARY_FACTORY(ReduceL2),
|
||||
//CREATE_BINARY_FACTORY(Squeeze),
|
||||
//CREATE_BINARY_FACTORY(Unsqueeze),
|
||||
};
|
||||
|
||||
|
||||
TEST_P(TransposeSinkingTestFixture, CompareFunctions) {
|
||||
auto test_case = TestCase(this->GetParam());
|
||||
model = test_case.test_model(test_case.preprocess_before,
|
||||
test_case.main_node,
|
||||
test_case.preprocess_after,
|
||||
test_case.num_main_ops,
|
||||
test_case.inputs_to_main);
|
||||
model_ref = test_case.ref_model(test_case.preprocess_before_ref,
|
||||
test_case.main_node_ref,
|
||||
test_case.preprocess_after_ref,
|
||||
test_case.num_main_ops,
|
||||
test_case.inputs_to_main);
|
||||
int num_main_ops_idx;
|
||||
int main_op_idx;
|
||||
TestCase test_case;
|
||||
tie(num_main_ops_idx, main_op_idx, test_case) = this->GetParam();
|
||||
model = test_case.model.model_template(main_op_idx,
|
||||
test_case.model,
|
||||
test_case.num_main_ops[num_main_ops_idx],
|
||||
test_case.inputs_to_main);
|
||||
|
||||
model_ref = test_case.model_ref.model_template(main_op_idx,
|
||||
test_case.model_ref,
|
||||
test_case.num_main_ops[num_main_ops_idx],
|
||||
test_case.inputs_to_main);
|
||||
test_case.transformation->registerPass(manager);
|
||||
// TODO: enable accuracy testing. The current issues: div by 0
|
||||
// comparator.enable(FunctionsComparator::CmpValues::ACCURACY);
|
||||
comparator.enable(FunctionsComparator::CmpValues::CONST_VALUES);
|
||||
comparator.enable(FunctionsComparator::CmpValues::ATTRIBUTES);
|
||||
}
|
||||
|
||||
namespace transpose_sinking {
|
||||
namespace common {
|
||||
|
||||
shared_ptr<ov::Model> create_model(const Preprocessing& preprocess_before,
|
||||
const FactoryPtr& main_op,
|
||||
const Preprocessing& preprocess_after,
|
||||
size_t num_ops,
|
||||
const OutputVector& inputs_to_main) {
|
||||
auto new_inputs = preprocess_before.apply(inputs_to_main);
|
||||
auto main_node = create_main_node(new_inputs, num_ops, main_op);
|
||||
auto outputs = preprocess_after.apply(main_node->outputs());
|
||||
shared_ptr<ov::Model> create_model(size_t main_node_idx,
|
||||
const ModelDescription& model_desc,
|
||||
size_t num_ops,
|
||||
const OutputVector& inputs_to_main) {
|
||||
auto new_inputs = model_desc.preprocess_inputs_to_main.apply(inputs_to_main);
|
||||
auto main_node = create_main_node(new_inputs, num_ops, model_desc.main_op[main_node_idx]);
|
||||
auto outputs = model_desc.preprocess_outputs_of_main.apply(main_node->outputs());
|
||||
return make_shared<ov::Model>(outputs, filter_parameters(inputs_to_main));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
auto wrapper = [](const TestCases& test_cases) {
|
||||
return ::testing::Combine(::testing::ValuesIn(test_cases.main_node),
|
||||
::testing::ValuesIn(test_cases.main_node_ref),
|
||||
::testing::Values(test_cases.transformation),
|
||||
::testing::ValuesIn(test_cases.num_main_ops),
|
||||
::testing::Values(test_cases.preprocess_before),
|
||||
::testing::Values(test_cases.test_model),
|
||||
::testing::Values(test_cases.preprocess_after),
|
||||
::testing::Values(test_cases.preprocess_before_ref),
|
||||
::testing::Values(test_cases.ref_model),
|
||||
::testing::Values(test_cases.preprocess_after_ref),
|
||||
::testing::Values(test_cases.inputs_to_main));
|
||||
auto wrapper = [](const TestCase& test_case) {
|
||||
OPENVINO_ASSERT(test_case.model.main_op.size() == test_case.model_ref.main_op.size(),
|
||||
"The number of main op (testing op) creator have to be the same for the testing model and for"
|
||||
"the reference model.");
|
||||
return ::testing::Combine(::testing::Range<size_t>(0, test_case.num_main_ops.size()),
|
||||
::testing::Range<size_t>(0, test_case.model.main_op.size()),
|
||||
::testing::Values(test_case));
|
||||
};
|
||||
|
||||
shared_ptr<Node> parameter(element::Type el_type, const PartialShape& ps) {
|
||||
return make_shared<Parameter>(el_type, ps);
|
||||
}
|
||||
|
||||
shared_ptr<Node> constant(element::Type el_type, const Shape& shape, const vector<int64_t>& value) {
|
||||
return make_shared<Constant>(el_type, shape, value);
|
||||
}
|
||||
|
||||
auto test_forward_unary = []() {
|
||||
TestCases test_cases;
|
||||
test_cases.main_node = unary_factories;
|
||||
test_cases.main_node_ref = unary_factories;
|
||||
test_cases.transformation = CREATE_PASS_FACTORY(TransposeSinkingUnaryForward);
|
||||
test_cases.num_main_ops = {1, 10};
|
||||
test_cases.preprocess_before = {{set_transpose_for}, {{0}}};
|
||||
test_cases.test_model = transpose_sinking::common::create_model;
|
||||
test_cases.ref_model = transpose_sinking::common::create_model;
|
||||
test_cases.preprocess_after_ref = {{set_transpose_for}, {{0}}};
|
||||
test_cases.inputs_to_main = {
|
||||
parameter(element::f32, {1, 96, 55, 55}),
|
||||
};
|
||||
return wrapper(test_cases);
|
||||
};
|
||||
TestCase test_case;
|
||||
|
||||
auto test_forward_binary = []() {
|
||||
TestCases test_cases;
|
||||
test_cases.main_node = binary_factories;
|
||||
test_cases.main_node_ref = binary_factories;
|
||||
test_cases.transformation = CREATE_PASS_FACTORY(TransposeSinkingBinaryForward);
|
||||
test_cases.num_main_ops = {1, 10};
|
||||
test_cases.preprocess_before = {{set_transpose_for}, {{0}}};
|
||||
test_cases.test_model = transpose_sinking::common::create_model;
|
||||
test_cases.preprocess_before_ref = {{set_transpose_for}, {{1}}};
|
||||
test_cases.ref_model = transpose_sinking::common::create_model;
|
||||
test_cases.preprocess_after_ref = {{set_transpose_for}, {{0}}};
|
||||
test_cases.inputs_to_main = {
|
||||
// Initialize common attributes
|
||||
test_case.transformation = CREATE_PASS_FACTORY(TransposeSinkingUnaryForward);
|
||||
test_case.num_main_ops = {1, 10};
|
||||
test_case.inputs_to_main = {
|
||||
parameter(element::f32, {1, 96, 55, 55}),
|
||||
parameter(element::f32, {55, 55, 96, 1}),
|
||||
};
|
||||
return wrapper(test_cases);
|
||||
};
|
||||
|
||||
auto test_forward_concat = []() {
|
||||
TestCases test_cases;
|
||||
test_cases.main_node = {CREATE_CONCAT_FACTORY(Concat)};
|
||||
test_cases.main_node_ref = {CREATE_CONCAT_REF_FACTORY(Concat)};
|
||||
test_cases.transformation = CREATE_PASS_FACTORY(TransposeSinkingConcatForward);
|
||||
test_cases.num_main_ops = {1, 10};
|
||||
test_cases.preprocess_before = {{set_transpose_for}, {{0}}};
|
||||
test_cases.test_model = transpose_sinking::common::create_model;
|
||||
test_cases.preprocess_before_ref = {{set_transpose_for}, {{1, 2}}};
|
||||
test_cases.ref_model = transpose_sinking::common::create_model;
|
||||
test_cases.preprocess_after_ref = {{set_transpose_for}, {{0}}};
|
||||
test_cases.inputs_to_main = {
|
||||
parameter(element::f32, {1, 96, 55, 55}),
|
||||
parameter(element::f32, {55, 55, 96, 1}),
|
||||
parameter(element::f32, {55, 55, 96, 1}),
|
||||
};
|
||||
return wrapper(test_cases);
|
||||
};
|
||||
// Test model description:
|
||||
test_case.model.preprocess_inputs_to_main = {{set_transpose_for}, {{0}}};
|
||||
test_case.model.main_op = unary_factories;
|
||||
test_case.model.model_template = transpose_sinking::common::create_model;
|
||||
|
||||
auto test_forward_split = []() {
|
||||
TestCases test_cases;
|
||||
test_cases.main_node = {CREATE_SPLIT_FACTORY(Concat)};
|
||||
test_cases.main_node_ref = {CREATE_SPLIT_FACTORY(Concat)};
|
||||
test_cases.transformation = CREATE_PASS_FACTORY(TransposeSinkingSplitForward);
|
||||
test_cases.num_main_ops = {1};
|
||||
test_cases.preprocess_before = {{set_transpose_for}, {{0}}};
|
||||
test_cases.test_model = transpose_sinking::common::create_model;
|
||||
// Reference model description:
|
||||
test_case.model_ref.main_op = unary_factories;
|
||||
test_case.model_ref.preprocess_outputs_of_main = {{set_transpose_for}, {{0}}};
|
||||
test_case.model_ref.model_template = transpose_sinking::common::create_model;
|
||||
|
||||
auto new_constant = [](const vector<size_t>& idxs, const OutputVector& out_vec) -> OutputVector {
|
||||
OutputVector new_out_vec(out_vec.size());
|
||||
new_out_vec[0] = out_vec[0];
|
||||
new_out_vec[1] = make_shared<Constant>(out_vec[1].get_element_type(), out_vec[1].get_shape(), std::vector<int64_t>{1});
|
||||
return new_out_vec;
|
||||
};
|
||||
test_cases.preprocess_before_ref = {{new_constant}, {{1}}};
|
||||
test_cases.ref_model = transpose_sinking::common::create_model;
|
||||
test_cases.preprocess_after_ref = {{set_transpose_for}, {{0, 1, 2}}};
|
||||
test_cases.inputs_to_main = {
|
||||
parameter(element::f32, {1, 3, 55, 55}),
|
||||
constant(element::i32, {}, {2}),
|
||||
};
|
||||
return wrapper(test_cases);
|
||||
return wrapper(test_case);
|
||||
};
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(TransposeSinkingCommonUnaryForward,
|
||||
TransposeSinkingTestFixture,
|
||||
test_forward_unary());
|
||||
|
||||
auto test_forward_binary = []() {
|
||||
TestCase test_case;
|
||||
|
||||
// Initialize common attributes
|
||||
test_case.transformation = CREATE_PASS_FACTORY(TransposeSinkingBinaryForward);
|
||||
test_case.num_main_ops = {1, 10};
|
||||
test_case.inputs_to_main = {
|
||||
parameter(element::f32, {1, 96, 55, 55}),
|
||||
parameter(element::f32, {55, 55, 96, 1}),
|
||||
};
|
||||
|
||||
// Test model description:
|
||||
test_case.model.preprocess_inputs_to_main = {{set_transpose_for}, {{0}}};
|
||||
test_case.model.main_op = binary_factories;
|
||||
test_case.model.model_template = transpose_sinking::common::create_model;
|
||||
|
||||
// Reference model description:
|
||||
test_case.model_ref.preprocess_inputs_to_main = {{set_transpose_for}, {{1}}};
|
||||
test_case.model_ref.main_op = binary_factories;
|
||||
test_case.model_ref.preprocess_outputs_of_main = {{set_transpose_for}, {{0}}};
|
||||
test_case.model_ref.model_template = transpose_sinking::common::create_model;
|
||||
|
||||
return wrapper(test_case);
|
||||
};
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(TransposeSinkingCommonBinaryForward,
|
||||
TransposeSinkingTestFixture,
|
||||
test_forward_binary());
|
||||
|
||||
|
||||
auto test_forward_concat = []() {
|
||||
TestCase test_case;
|
||||
|
||||
// Initialize common attributes
|
||||
test_case.transformation = CREATE_PASS_FACTORY(TransposeSinkingConcatForward);
|
||||
test_case.num_main_ops = {1, 3};
|
||||
test_case.inputs_to_main = {
|
||||
parameter(element::f32, {1, 96, 55, 55}),
|
||||
parameter(element::f32, {55, 55, 96, 1}),
|
||||
parameter(element::f32, {55, 55, 96, 1}),
|
||||
};
|
||||
|
||||
// Test model description:
|
||||
test_case.model.preprocess_inputs_to_main = {{set_transpose_for}, {{0}}};
|
||||
test_case.model.main_op = {CREATE_CONCAT_FACTORY(Concat)};
|
||||
test_case.model.model_template = transpose_sinking::common::create_model;
|
||||
|
||||
// Reference model description:
|
||||
test_case.model_ref.preprocess_inputs_to_main = {{set_transpose_for}, {{1, 2}}};
|
||||
test_case.model_ref.main_op = {CREATE_CONCAT_REF_FACTORY(Concat)};
|
||||
test_case.model_ref.preprocess_outputs_of_main = {{set_transpose_for}, {{0}}};
|
||||
test_case.model_ref.model_template = transpose_sinking::common::create_model;
|
||||
|
||||
return wrapper(test_case);
|
||||
};
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(TransposeSinkingCommonConcatForward,
|
||||
TransposeSinkingTestFixture,
|
||||
test_forward_concat());
|
||||
|
||||
auto test_forward_split = []() {
|
||||
TestCase test_case;
|
||||
|
||||
// Initialize common attributes
|
||||
test_case.transformation = CREATE_PASS_FACTORY(TransposeSinkingSplitForward);
|
||||
test_case.num_main_ops = {1, 2};
|
||||
test_case.inputs_to_main = {
|
||||
parameter(element::f32, {1, 9, 55, 55}),
|
||||
constant(element::i32, {}, {2}),
|
||||
};
|
||||
|
||||
// Test model description:
|
||||
test_case.model.preprocess_inputs_to_main = {{set_transpose_for}, {{0}}};
|
||||
test_case.model.main_op = {CREATE_SPLIT_FACTORY(Split)};
|
||||
test_case.model.model_template = transpose_sinking::common::create_model;
|
||||
|
||||
// Reference model description:
|
||||
auto new_constant = [](const vector<size_t>& idxs, const OutputVector& out_vec) -> OutputVector {
|
||||
OutputVector new_out_vec(out_vec.size());
|
||||
new_out_vec[0] = out_vec[0];
|
||||
new_out_vec[1] = make_shared<Constant>(out_vec[1].get_element_type(), out_vec[1].get_shape(), std::vector<int64_t>{1});
|
||||
return new_out_vec;
|
||||
};
|
||||
test_case.model_ref.preprocess_inputs_to_main = {{new_constant}, {{1}}};
|
||||
test_case.model_ref.main_op = {CREATE_SPLIT_FACTORY(Split)};
|
||||
test_case.model_ref.preprocess_outputs_of_main = {{set_transpose_for}, {{0, 1, 2}}};
|
||||
test_case.model_ref.model_template = transpose_sinking::common::create_model;
|
||||
|
||||
return wrapper(test_case);
|
||||
};
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(TransposeSinkingCommonSplitForward,
|
||||
TransposeSinkingTestFixture,
|
||||
test_forward_split());
|
||||
|
||||
auto test_forward_pad = []() {
|
||||
TestCase test_case;
|
||||
|
||||
// Initialize common attributes
|
||||
test_case.transformation = CREATE_PASS_FACTORY(TransposeSinkingDataMovementForward);
|
||||
test_case.num_main_ops = {1, 2};
|
||||
test_case.inputs_to_main = {
|
||||
parameter(element::f32, {1, 3, 55, 55}),
|
||||
constant(element::i32, {4}, {1, 2, 3, 4}),
|
||||
constant(element::i32, {4}, {1, 2, 3, 4}),
|
||||
};
|
||||
|
||||
// Test model description:
|
||||
test_case.model.preprocess_inputs_to_main = {{set_transpose_for}, {{0}}};
|
||||
test_case.model.main_op = {CREATE_PAD_FACTORY(Pad)};
|
||||
test_case.model.model_template = transpose_sinking::common::create_model;
|
||||
|
||||
// Reference model description:
|
||||
test_case.model_ref.preprocess_inputs_to_main = {{set_gather_for}, {{1, 2}}};
|
||||
test_case.model_ref.main_op = {CREATE_PAD_FACTORY(Pad)};
|
||||
test_case.model_ref.preprocess_outputs_of_main = {{set_transpose_for}, {{0}}};
|
||||
test_case.model_ref.model_template = transpose_sinking::common::create_model;
|
||||
|
||||
return wrapper(test_case);
|
||||
};
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(TransposeSinkingCommonPadForward,
|
||||
TransposeSinkingTestFixture,
|
||||
test_forward_pad());
|
||||
|
||||
auto test_forward_batch_to_space = []() {
|
||||
TestCase test_case;
|
||||
|
||||
// Initialize common attributes
|
||||
test_case.transformation = CREATE_PASS_FACTORY(TransposeSinkingDataMovementForward);
|
||||
test_case.num_main_ops = {1, 2};
|
||||
test_case.inputs_to_main = {
|
||||
parameter(element::f32, {128, 55, 3, 128}),
|
||||
constant(element::i32, {4}, {1, 2, 2, 2}),
|
||||
constant(element::i32, {4}, {1, 2, 2, 2}),
|
||||
constant(element::i32, {4}, {1, 2, 2, 2}),
|
||||
};
|
||||
|
||||
// Test model description:
|
||||
test_case.model.preprocess_inputs_to_main = {{set_transpose_for}, {{0}}};
|
||||
test_case.model.main_op = {CREATE_BATCH_TO_SPACE_FACTORY(BatchToSpace)};
|
||||
test_case.model.model_template = transpose_sinking::common::create_model;
|
||||
|
||||
// Reference model description:
|
||||
test_case.model_ref.preprocess_inputs_to_main = {{set_gather_for}, {{1, 2, 3}}};
|
||||
test_case.model_ref.main_op = {CREATE_BATCH_TO_SPACE_FACTORY(BatchToSpace)};
|
||||
test_case.model_ref.preprocess_outputs_of_main = {{set_transpose_for}, {{0}}};
|
||||
test_case.model_ref.model_template = transpose_sinking::common::create_model;
|
||||
|
||||
return wrapper(test_case);
|
||||
};
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(TransposeSinkingCommonBatchToSpaceForward,
|
||||
TransposeSinkingTestFixture,
|
||||
test_forward_batch_to_space());
|
||||
|
||||
auto test_forward_space_to_batch = []() {
|
||||
TestCase test_case;
|
||||
|
||||
// Initialize common attributes
|
||||
test_case.transformation = CREATE_PASS_FACTORY(TransposeSinkingDataMovementForward);
|
||||
test_case.num_main_ops = {1};
|
||||
test_case.inputs_to_main = {
|
||||
parameter(element::f32, {64, 9, 8, 1}),
|
||||
constant(element::i32, {4}, {1, 2, 3, 4}),
|
||||
constant(element::i32, {4}, {1, 2, 3, 4}),
|
||||
constant(element::i32, {4}, {1, 2, 3, 4}),
|
||||
};
|
||||
|
||||
// Test model description:
|
||||
test_case.model.preprocess_inputs_to_main = {{set_transpose_for}, {{0}}};
|
||||
test_case.model.main_op = {CREATE_SPACE_TO_BATCH_FACTORY(SpaceToBatch)};
|
||||
test_case.model.model_template = transpose_sinking::common::create_model;
|
||||
|
||||
// Reference model description:
|
||||
test_case.model_ref.preprocess_inputs_to_main = {{set_gather_for}, {{1, 2, 3}}};
|
||||
test_case.model_ref.main_op = {CREATE_SPACE_TO_BATCH_FACTORY(SpaceToBatch)};
|
||||
test_case.model_ref.preprocess_outputs_of_main = {{set_transpose_for}, {{0}}};
|
||||
test_case.model_ref.model_template = transpose_sinking::common::create_model;
|
||||
|
||||
return wrapper(test_case);
|
||||
};
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(TransposeSinkingCommonSpaceToBatchForward,
|
||||
TransposeSinkingTestFixture,
|
||||
test_forward_space_to_batch());
|
||||
|
||||
auto test_forward_reduction = []() {
|
||||
TestCase test_case;
|
||||
|
||||
// Initialize common attributes
|
||||
test_case.transformation = CREATE_PASS_FACTORY(TransposeSinkingReductionForward);
|
||||
test_case.num_main_ops = {1};
|
||||
test_case.inputs_to_main = {
|
||||
parameter(element::f32, {32, 4, 2, 1}),
|
||||
constant(element::i32, {2}, {1, 3}),
|
||||
};
|
||||
|
||||
// Test model description:
|
||||
test_case.model.preprocess_inputs_to_main = {{set_transpose_for}, {{0}}};
|
||||
test_case.model.main_op = reduction_factories;
|
||||
test_case.model.model_template = transpose_sinking::common::create_model;
|
||||
|
||||
// Reference model description:
|
||||
auto new_constant = [](const vector<size_t>& idxs, const OutputVector& out_vec) -> OutputVector {
|
||||
OutputVector new_out_vec(out_vec.size());
|
||||
new_out_vec[0] = out_vec[0];
|
||||
new_out_vec[1] = make_shared<Constant>(out_vec[1].get_element_type(), out_vec[1].get_shape(), std::vector<int64_t>{2, 0});
|
||||
return new_out_vec;
|
||||
};
|
||||
test_case.model_ref.preprocess_inputs_to_main = {{new_constant}, {{1}}};
|
||||
test_case.model_ref.main_op = reduction_factories;
|
||||
test_case.model_ref.preprocess_outputs_of_main = {{set_transpose_for}, {{0}}};
|
||||
test_case.model_ref.model_template = transpose_sinking::common::create_model;
|
||||
|
||||
return wrapper(test_case);
|
||||
};
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(TransposeSinkingCommonReductionForward,
|
||||
TransposeSinkingTestFixture,
|
||||
test_forward_reduction());
|
||||
|
@ -101,7 +101,7 @@ TEST_P(TransposeSinkingFQ, TransposeFQReduce) {
|
||||
manager.register_pass<ngraph::pass::InitUniqueNames>(unh);
|
||||
manager.register_pass<ov::pass::InitNodeInfo>();
|
||||
manager.register_pass<ov::pass::TransposeFQReduction>();
|
||||
manager.register_pass<ov::pass::TransposeSinkingReductionForward>();
|
||||
manager.register_pass<ov::pass::TransposeReduction>();
|
||||
manager.register_pass<ngraph::pass::CheckUniqueNames>(unh);
|
||||
manager.run_passes(f);
|
||||
ASSERT_NO_THROW(check_rt_info(f));
|
||||
@ -219,7 +219,7 @@ TEST_P(TransposeSinking, TransposeReduction) {
|
||||
ngraph::pass::Manager manager;
|
||||
manager.register_pass<ngraph::pass::InitUniqueNames>(unh);
|
||||
manager.register_pass<ov::pass::InitNodeInfo>();
|
||||
manager.register_pass<ov::pass::TransposeSinkingReductionForward>();
|
||||
manager.register_pass<ov::pass::TransposeReduction>();
|
||||
manager.register_pass<ngraph::pass::CheckUniqueNames>(unh);
|
||||
manager.run_passes(f);
|
||||
ASSERT_NO_THROW(check_rt_info(f));
|
||||
@ -340,7 +340,7 @@ TEST_F(TransformationTestsF, TransposeReduceNegative) {
|
||||
auto sub = std::make_shared<opset6::Subtract>(transpose, reduce_mean);
|
||||
|
||||
function = std::make_shared<ngraph::Function>(ngraph::NodeVector{sub}, ngraph::ParameterVector{input});
|
||||
manager.register_pass<ov::pass::TransposeSinkingReductionForward>();
|
||||
manager.register_pass<ov::pass::TransposeReduction>();
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -35,17 +35,30 @@ ParameterVector filter_parameters(const OutputVector &out_vec) {
|
||||
|
||||
OutputVector set_transpose_for(const vector<size_t>& idxs, const OutputVector& out_vec) {
|
||||
OutputVector result = out_vec;
|
||||
for (size_t i = 0; i < out_vec.size(); ++i) {
|
||||
if (find(idxs.begin(), idxs.end(), i) != idxs.end()) {
|
||||
const auto &out = out_vec[i];
|
||||
auto rank = out.get_partial_shape().rank().get_length();
|
||||
vector<int64_t> axes(rank);
|
||||
iota(axes.begin(), axes.end(), 0);
|
||||
reverse(axes.begin(), axes.end());
|
||||
auto order = make_shared<Constant>(element::i32, Shape{axes.size()}, axes);
|
||||
auto transpose = make_shared<Transpose>(out, order);
|
||||
result[i] = transpose;
|
||||
}
|
||||
for (const auto& idx : idxs) {
|
||||
const auto &out = out_vec[idx];
|
||||
auto rank = out.get_partial_shape().rank().get_length();
|
||||
vector<int64_t> axes(rank);
|
||||
iota(axes.begin(), axes.end(), 0);
|
||||
reverse(axes.begin(), axes.end());
|
||||
auto order = make_shared<Constant>(element::i32, Shape{axes.size()}, axes);
|
||||
auto transpose = make_shared<Transpose>(out, order);
|
||||
result[idx] = transpose;
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
OutputVector set_gather_for(const vector<size_t>& idxs, const OutputVector& out_vec) {
|
||||
OutputVector result = out_vec;
|
||||
for (const auto& idx : idxs) {
|
||||
const auto &out = out_vec[idx];
|
||||
vector<int64_t> axes(out.get_shape()[0]);
|
||||
iota(axes.begin(), axes.end(), 0);
|
||||
reverse(axes.begin(), axes.end());
|
||||
auto order = make_shared<Constant>(element::i32, Shape{axes.size()}, axes);
|
||||
auto axis = make_shared<Constant>(element::i32, Shape{}, 0);
|
||||
auto transpose = make_shared<Gather>(out, order, axis);
|
||||
result[idx] = transpose;
|
||||
}
|
||||
return result;
|
||||
}
|
||||
@ -61,3 +74,11 @@ std::string to_string(const Shape &shape) {
|
||||
result << "}";
|
||||
return result.str();
|
||||
}
|
||||
|
||||
std::shared_ptr<ov::Node> parameter(ov::element::Type el_type, const PartialShape &ps) {
|
||||
return std::make_shared<Parameter>(el_type, ps);
|
||||
}
|
||||
|
||||
shared_ptr<ov::Node> constant(ov::element::Type el_type, const Shape &shape, const vector<int64_t> &value) {
|
||||
return make_shared<Constant>(el_type, shape, value);
|
||||
}
|
||||
|
@ -58,7 +58,13 @@ public:
|
||||
#define CREATE_PASS_FACTORY(pass_name) std::make_shared<PassFactory<ov::pass::pass_name>>(#pass_name)
|
||||
|
||||
ov::OutputVector set_transpose_for(const std::vector<size_t>& idxs, const ov::OutputVector& out_vec);
|
||||
ov::OutputVector set_gather_for(const std::vector<size_t>& idxs, const ov::OutputVector& out_vec);
|
||||
|
||||
ov::ParameterVector filter_parameters(const ov::OutputVector& out_vec);
|
||||
|
||||
std::shared_ptr<ov::Node> create_main_node(const ov::OutputVector& inputs, size_t num_ops, const FactoryPtr& creator);
|
||||
|
||||
|
||||
std::shared_ptr<ov::Node> parameter(ov::element::Type el_type, const ov::PartialShape& ps);
|
||||
|
||||
std::shared_ptr<ov::Node> constant(ov::element::Type el_type, const ov::Shape& shape, const std::vector<int64_t>& value);
|
@ -166,7 +166,7 @@ void TransformationsPipeline::apply(const std::shared_ptr<ov::Model>& model) {
|
||||
pass_config->disable<ov::pass::AddFakeQuantizeFusion>();
|
||||
// TransposeReduction can be enabled when Transpose-Conv-Transpose patterns will be handled in ngraph
|
||||
// transformations
|
||||
pass_config->disable<ov::pass::TransposeSinkingReductionForward>();
|
||||
pass_config->disable<ov::pass::TransposeReduction>();
|
||||
// Operations Max and Min aren't supported
|
||||
pass_config->disable<ov::pass::ConcatReduceFusion>();
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user