codestyle
This commit is contained in:
parent
66d16ae45e
commit
88ddbb2437
@ -4,9 +4,9 @@
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "../../transformations_visibility.hpp"
|
||||
#include "openvino/pass/graph_rewrite.hpp"
|
||||
#include "openvino/pass/pass.hpp"
|
||||
#include "../../transformations_visibility.hpp"
|
||||
|
||||
namespace ov {
|
||||
namespace pass {
|
||||
@ -15,7 +15,7 @@ namespace transpose_sinking {
|
||||
class TRANSFORMATIONS_API TSSliceForward;
|
||||
class TRANSFORMATIONS_API TSSliceBackward;
|
||||
|
||||
}
|
||||
} // namespace transpose_sinking
|
||||
} // namespace pass
|
||||
} // namespace ov
|
||||
|
||||
|
@ -15,10 +15,10 @@
|
||||
#include "transformations/transpose_sinking/ts_fuse.hpp"
|
||||
#include "transformations/transpose_sinking/ts_interpolate.hpp"
|
||||
#include "transformations/transpose_sinking/ts_reduction.hpp"
|
||||
#include "transformations/transpose_sinking/ts_squeeze.hpp"
|
||||
#include "transformations/transpose_sinking/ts_unsqueeze.hpp"
|
||||
#include "transformations/transpose_sinking/ts_split.hpp"
|
||||
#include "transformations/transpose_sinking/ts_squeeze.hpp"
|
||||
#include "transformations/transpose_sinking/ts_unary.hpp"
|
||||
#include "transformations/transpose_sinking/ts_unsqueeze.hpp"
|
||||
#include "transformations/utils/utils.hpp"
|
||||
|
||||
using namespace ov::pass::transpose_sinking;
|
||||
|
@ -24,7 +24,7 @@ using namespace ov::pass::transpose_sinking::utils;
|
||||
|
||||
namespace {
|
||||
|
||||
bool get_keep_dims(const std::shared_ptr<Node> &reduction) {
|
||||
bool get_keep_dims(const std::shared_ptr<Node>& reduction) {
|
||||
auto arithmetic_reduce = as_type_ptr<ov::op::util::ArithmeticReductionKeepDims>(reduction);
|
||||
auto logical_reduce = as_type_ptr<ov::op::util::LogicalReductionKeepDims>(reduction);
|
||||
|
||||
@ -36,14 +36,14 @@ bool get_keep_dims(const std::shared_ptr<Node> &reduction) {
|
||||
return keep_dims;
|
||||
}
|
||||
|
||||
}
|
||||
} // namespace
|
||||
|
||||
TSReductionForward::TSReductionForward() {
|
||||
MATCHER_SCOPE(TSReductionForward);
|
||||
|
||||
auto transpose_label = wrap_type<Transpose>({any_input(), wrap_type<Constant>()});
|
||||
auto reduce_label = wrap_type<op::util::ArithmeticReductionKeepDims, op::util::LogicalReductionKeepDims>(
|
||||
{transpose_label, wrap_type<Constant>()});
|
||||
{transpose_label, wrap_type<Constant>()});
|
||||
|
||||
ov::matcher_pass_callback matcher_pass_callback = [=](pattern::Matcher& m) {
|
||||
const auto& pattern_to_output = m.get_pattern_map();
|
||||
@ -72,7 +72,8 @@ TSReductionForward::TSReductionForward() {
|
||||
transpose_order_values = GetOrderAfterReduction(non_negative_axes, transpose_order_values);
|
||||
}
|
||||
|
||||
auto new_transpose_order = Constant::create(transpose_order->get_element_type(), {transpose_order_values.size()},
|
||||
auto new_transpose_order = Constant::create(transpose_order->get_element_type(),
|
||||
{transpose_order_values.size()},
|
||||
transpose_order_values);
|
||||
|
||||
auto new_const = Constant::create(reduction_axes->get_element_type(), reduction_axes->get_shape(), new_values);
|
||||
@ -96,7 +97,8 @@ TSReductionBackward::TSReductionBackward() {
|
||||
MATCHER_SCOPE(TSReductionBackward);
|
||||
|
||||
auto reduce_label = wrap_type<op::util::ArithmeticReductionKeepDims, op::util::LogicalReductionKeepDims>(
|
||||
{any_input(), wrap_type<Constant>()}, HasSameOutputTransposeNodes);
|
||||
{any_input(), wrap_type<Constant>()},
|
||||
HasSameOutputTransposeNodes);
|
||||
auto transpose_label = wrap_type<Transpose>({reduce_label, wrap_type<Constant>()});
|
||||
|
||||
ov::matcher_pass_callback matcher_pass_callback = [=](pattern::Matcher& m) {
|
||||
@ -112,13 +114,14 @@ TSReductionBackward::TSReductionBackward() {
|
||||
|
||||
auto rank = reduction->get_input_partial_shape(0).rank();
|
||||
auto non_negative_axes =
|
||||
normalize_axes(reduction->get_friendly_name(), reduction_axes->cast_vector<int64_t>(), rank);
|
||||
normalize_axes(reduction->get_friendly_name(), reduction_axes->cast_vector<int64_t>(), rank);
|
||||
auto transpose_order_values = transpose_order->cast_vector<size_t>();
|
||||
if (!keep_dims) {
|
||||
transpose_order_values = GetOrderBeforeReduction(non_negative_axes, transpose_order_values);
|
||||
}
|
||||
auto reversed_order_values = ReverseTransposeOrder(transpose_order_values);
|
||||
auto new_transpose_order = Constant::create(transpose_order->get_element_type(), {transpose_order_values.size()},
|
||||
auto new_transpose_order = Constant::create(transpose_order->get_element_type(),
|
||||
{transpose_order_values.size()},
|
||||
transpose_order_values);
|
||||
|
||||
std::vector<size_t> new_values;
|
||||
|
@ -4,15 +4,14 @@
|
||||
|
||||
#include "transformations/transpose_sinking/ts_slice.hpp"
|
||||
|
||||
#include "openvino/pass/pattern/op/or.hpp"
|
||||
|
||||
#include "itt.hpp"
|
||||
#include "openvino/op/util/op_types.hpp"
|
||||
#include "openvino/opsets/opset10.hpp"
|
||||
#include "openvino/pass/pattern/op/or.hpp"
|
||||
#include "openvino/pass/pattern/op/wrap_type.hpp"
|
||||
#include "openvino/util/common_util.hpp"
|
||||
#include "transformations/transpose_sinking/ts_utils.hpp"
|
||||
#include "transformations/rt_info/transpose_sinking_attr.hpp"
|
||||
#include "transformations/transpose_sinking/ts_utils.hpp"
|
||||
|
||||
using namespace ov;
|
||||
using namespace ov::opset10;
|
||||
|
@ -28,8 +28,8 @@ bool shape_to_squeeze_axes(const std::shared_ptr<Node>& reshape,
|
||||
result_axes.clear();
|
||||
auto reduction_axes_values = reshape_to_shape->cast_vector<int64_t>();
|
||||
// supported the case if Reshape is equal to Squeeze
|
||||
const auto &new_shape = reduction_axes_values;
|
||||
const auto &input_pshape = reshape->get_input_partial_shape(0);
|
||||
const auto& new_shape = reduction_axes_values;
|
||||
const auto& input_pshape = reshape->get_input_partial_shape(0);
|
||||
// todo: support dynamic case
|
||||
if (input_pshape.is_dynamic()) {
|
||||
return false;
|
||||
@ -57,7 +57,7 @@ bool shape_to_squeeze_axes(const std::shared_ptr<Node>& reshape,
|
||||
std::vector<size_t> squeeze_axes_to_shape(const std::shared_ptr<Node>& input_node, std::vector<size_t> squeeze_axes) {
|
||||
std::vector<size_t> to_shape;
|
||||
std::sort(squeeze_axes.begin(), squeeze_axes.end());
|
||||
const auto& input_shape = input_node->input(0).get_shape(); // check is static
|
||||
const auto& input_shape = input_node->input(0).get_shape(); // check is static
|
||||
for (size_t i = 0, j = 0; i < input_shape.size(); ++i) {
|
||||
if (j < squeeze_axes.size() && i == squeeze_axes[j]) {
|
||||
++j;
|
||||
@ -68,7 +68,7 @@ std::vector<size_t> squeeze_axes_to_shape(const std::shared_ptr<Node>& input_nod
|
||||
return to_shape;
|
||||
}
|
||||
|
||||
}
|
||||
} // namespace
|
||||
|
||||
TSSqueezeForward::TSSqueezeForward() {
|
||||
MATCHER_SCOPE(TSSqueezeForward);
|
||||
@ -96,7 +96,8 @@ TSSqueezeForward::TSSqueezeForward() {
|
||||
}
|
||||
} else {
|
||||
auto rank = squeeze->get_input_partial_shape(0).rank();
|
||||
non_negative_axes = normalize_axes(squeeze->get_friendly_name(), squeeze_axes->cast_vector<int64_t>(), rank);
|
||||
non_negative_axes =
|
||||
normalize_axes(squeeze->get_friendly_name(), squeeze_axes->cast_vector<int64_t>(), rank);
|
||||
}
|
||||
|
||||
// if 2nd input to squeeze is empty then all '1' dims will be deleted.
|
||||
@ -172,7 +173,8 @@ TSSqueezeBackward::TSSqueezeBackward() {
|
||||
}
|
||||
} else {
|
||||
auto rank = squeeze->get_input_partial_shape(0).rank();
|
||||
non_negative_axes = normalize_axes(squeeze->get_friendly_name(), squeeze_axes->cast_vector<int64_t>(), rank);
|
||||
non_negative_axes =
|
||||
normalize_axes(squeeze->get_friendly_name(), squeeze_axes->cast_vector<int64_t>(), rank);
|
||||
}
|
||||
|
||||
bool squeeze_all_dims = false;
|
||||
@ -210,7 +212,8 @@ TSSqueezeBackward::TSSqueezeBackward() {
|
||||
if (squeeze_all_dims) {
|
||||
new_squeeze = squeeze->clone_with_new_inputs({new_transpose, squeeze->input_value(1)});
|
||||
} else {
|
||||
auto new_const = std::make_shared<Constant>(squeeze_axes->get_element_type(), squeeze_axes->get_shape(), new_values);
|
||||
auto new_const =
|
||||
std::make_shared<Constant>(squeeze_axes->get_element_type(), squeeze_axes->get_shape(), new_values);
|
||||
new_squeeze = squeeze->clone_with_new_inputs({new_transpose, new_const});
|
||||
}
|
||||
|
||||
|
@ -23,13 +23,13 @@ using namespace ov::pass::transpose_sinking::utils;
|
||||
namespace {
|
||||
|
||||
bool shape_to_unsqueeze_axes(const std::shared_ptr<Node>& reshape,
|
||||
const std::shared_ptr<Constant>& reshape_to_shape,
|
||||
std::vector<size_t>& result_axes) {
|
||||
const std::shared_ptr<Constant>& reshape_to_shape,
|
||||
std::vector<size_t>& result_axes) {
|
||||
result_axes.clear();
|
||||
auto reduction_axes_values = reshape_to_shape->cast_vector<int64_t>();
|
||||
// supported the case if Reshape is equal to Unsqueeze
|
||||
const auto &new_shape = reduction_axes_values;
|
||||
const auto &input_pshape = reshape->get_input_partial_shape(0);
|
||||
const auto& new_shape = reduction_axes_values;
|
||||
const auto& input_pshape = reshape->get_input_partial_shape(0);
|
||||
// todo: support dynamic case
|
||||
if (input_pshape.is_dynamic()) {
|
||||
return false;
|
||||
@ -37,7 +37,7 @@ bool shape_to_unsqueeze_axes(const std::shared_ptr<Node>& reshape,
|
||||
|
||||
const auto input_shape = input_pshape.to_shape();
|
||||
if (new_shape.size() > input_shape.size()) {
|
||||
for (size_t i = 0, j = 0; i < input_shape.size();j++) {
|
||||
for (size_t i = 0, j = 0; i < input_shape.size(); j++) {
|
||||
if (input_shape[i] == new_shape[j]) {
|
||||
i++;
|
||||
} else if (input_shape[i] != new_shape[j] && new_shape[j] != 1) {
|
||||
@ -54,8 +54,9 @@ bool shape_to_unsqueeze_axes(const std::shared_ptr<Node>& reshape,
|
||||
return true;
|
||||
}
|
||||
|
||||
std::vector<size_t> unsqueeze_axes_to_shape(const std::shared_ptr<Node>& input_node, std::vector<size_t> unsqueeze_axes) {
|
||||
const auto& input_shape = input_node->input(0).get_shape(); // check is static
|
||||
std::vector<size_t> unsqueeze_axes_to_shape(const std::shared_ptr<Node>& input_node,
|
||||
std::vector<size_t> unsqueeze_axes) {
|
||||
const auto& input_shape = input_node->input(0).get_shape(); // check is static
|
||||
std::vector<size_t> to_shape(input_shape.size() + unsqueeze_axes.size());
|
||||
std::sort(unsqueeze_axes.begin(), unsqueeze_axes.end());
|
||||
std::stack<size_t, std::vector<size_t>> shape_to_add(input_shape);
|
||||
@ -98,20 +99,20 @@ TSUnsqueezeForward::TSUnsqueezeForward() {
|
||||
}
|
||||
} else {
|
||||
auto rank = unsqueeze->get_output_partial_shape(0).rank();
|
||||
non_negative_axes = normalize_axes(unsqueeze->get_friendly_name(), unsqueeze_axes->cast_vector<int64_t>(), rank);
|
||||
non_negative_axes =
|
||||
normalize_axes(unsqueeze->get_friendly_name(), unsqueeze_axes->cast_vector<int64_t>(), rank);
|
||||
}
|
||||
auto ts_order_values = transpose_order->cast_vector<size_t>();
|
||||
|
||||
/* std::vector<size_t> new_values;
|
||||
new_values.reserve(non_negative_axes.size());
|
||||
for (const auto& axis : non_negative_axes) {
|
||||
new_values.push_back(ts_order_values[axis]);
|
||||
}*/
|
||||
/* std::vector<size_t> new_values;
|
||||
new_values.reserve(non_negative_axes.size());
|
||||
for (const auto& axis : non_negative_axes) {
|
||||
new_values.push_back(ts_order_values[axis]);
|
||||
}*/
|
||||
|
||||
ts_order_values = GetOrderBeforeReduction(non_negative_axes, ts_order_values);
|
||||
auto new_transpose_order = Constant::create(transpose_order->get_element_type(),
|
||||
{ts_order_values.size()},
|
||||
ts_order_values);
|
||||
auto new_transpose_order =
|
||||
Constant::create(transpose_order->get_element_type(), {ts_order_values.size()}, ts_order_values);
|
||||
|
||||
/*if (as_type_ptr<Reshape>(unsqueeze)) {
|
||||
new_values = unsqueeze_axes_to_shape(unsqueeze, new_values);
|
||||
@ -136,7 +137,8 @@ TSUnsqueezeForward::TSUnsqueezeForward() {
|
||||
TSUnsqueezeBackward::TSUnsqueezeBackward() {
|
||||
MATCHER_SCOPE(TSUnsqueezeBackward);
|
||||
|
||||
auto unsqueeze_label = wrap_type<Unsqueeze, Reshape>({any_input(), wrap_type<Constant>()}, HasSameOutputTransposeNodes);
|
||||
auto unsqueeze_label =
|
||||
wrap_type<Unsqueeze, Reshape>({any_input(), wrap_type<Constant>()}, HasSameOutputTransposeNodes);
|
||||
auto transpose_label = wrap_type<Transpose>({unsqueeze_label, wrap_type<Constant>()});
|
||||
|
||||
ov::matcher_pass_callback matcher_pass_callback = [=](pattern::Matcher& m) {
|
||||
@ -158,7 +160,8 @@ TSUnsqueezeBackward::TSUnsqueezeBackward() {
|
||||
}
|
||||
} else {
|
||||
auto rank = unsqueeze->get_output_partial_shape(0).rank();
|
||||
non_negative_axes = normalize_axes(unsqueeze->get_friendly_name(), unsqueeze_axes->cast_vector<int64_t>(), rank);
|
||||
non_negative_axes =
|
||||
normalize_axes(unsqueeze->get_friendly_name(), unsqueeze_axes->cast_vector<int64_t>(), rank);
|
||||
}
|
||||
|
||||
auto transpose_order_values = transpose_order->cast_vector<size_t>();
|
||||
|
@ -191,7 +191,11 @@ class SliceFactory : public IFactory {
|
||||
public:
|
||||
explicit SliceFactory(const std::string& type_name) : IFactory(type_name) {}
|
||||
NodePtr create(const OutputVector& parent_nodes) const override {
|
||||
return std::make_shared<Slice>(parent_nodes[0], parent_nodes[1], parent_nodes[2], parent_nodes[3], parent_nodes[4]);
|
||||
return std::make_shared<Slice>(parent_nodes[0],
|
||||
parent_nodes[1],
|
||||
parent_nodes[2],
|
||||
parent_nodes[3],
|
||||
parent_nodes[4]);
|
||||
}
|
||||
};
|
||||
|
||||
@ -687,8 +691,8 @@ auto test_forward_squeeze = []() {
|
||||
test_case.transformation = CREATE_PASS_FACTORY(TSSqueezeForward);
|
||||
test_case.num_main_ops = {1};
|
||||
test_case.inputs_to_main = {
|
||||
parameter(element::f32, {32, 1, 2, 1}),
|
||||
constant<int64_t>(element::i32, {2}, {0, 2}),
|
||||
parameter(element::f32, {32, 1, 2, 1}),
|
||||
constant<int64_t>(element::i32, {2}, {0, 2}),
|
||||
};
|
||||
|
||||
// Test model description:
|
||||
@ -701,7 +705,7 @@ auto test_forward_squeeze = []() {
|
||||
OutputVector new_out_vec(out_vec.size());
|
||||
new_out_vec[0] = out_vec[0];
|
||||
new_out_vec[1] =
|
||||
make_shared<Constant>(out_vec[1].get_element_type(), out_vec[1].get_shape(), std::vector<int64_t>{3, 1});
|
||||
make_shared<Constant>(out_vec[1].get_element_type(), out_vec[1].get_shape(), std::vector<int64_t>{3, 1});
|
||||
return new_out_vec;
|
||||
};
|
||||
test_case.model_ref.preprocess_inputs_to_main = {{new_constant}, {{1}}};
|
||||
@ -721,8 +725,8 @@ auto test_forward_unsqueeze = []() {
|
||||
test_case.transformation = CREATE_PASS_FACTORY(TSUnsqueezeForward);
|
||||
test_case.num_main_ops = {1};
|
||||
test_case.inputs_to_main = {
|
||||
parameter(element::f32, {32, 3, 2, 1}),
|
||||
constant<int64_t>(element::i32, {2}, {0, 2}),
|
||||
parameter(element::f32, {32, 3, 2, 1}),
|
||||
constant<int64_t>(element::i32, {2}, {0, 2}),
|
||||
};
|
||||
|
||||
// Test model description:
|
||||
@ -735,7 +739,7 @@ auto test_forward_unsqueeze = []() {
|
||||
OutputVector new_out_vec(out_vec.size());
|
||||
new_out_vec[0] = out_vec[0];
|
||||
new_out_vec[1] =
|
||||
make_shared<Constant>(out_vec[1].get_element_type(), out_vec[1].get_shape(), std::vector<int64_t>{0, 2});
|
||||
make_shared<Constant>(out_vec[1].get_element_type(), out_vec[1].get_shape(), std::vector<int64_t>{0, 2});
|
||||
return new_out_vec;
|
||||
};
|
||||
test_case.model_ref.preprocess_inputs_to_main = {{new_constant}, {{1}}};
|
||||
@ -762,11 +766,11 @@ auto test_forward_slice = []() {
|
||||
test_case.transformation = CREATE_PASS_FACTORY(TSSliceForward);
|
||||
test_case.num_main_ops = {1};
|
||||
test_case.inputs_to_main = {
|
||||
parameter(element::f32, {6, 4, 5, 3}),
|
||||
constant<int64_t>(element::i32, {3}, {1, 2, 3}),
|
||||
constant<int64_t>(element::i32, {3}, {0, 4, 11}),
|
||||
constant<int64_t>(element::i32, {3}, {1, 2, -1}),
|
||||
constant<int64_t>(element::i32, {3}, {0, 1, 2}),
|
||||
parameter(element::f32, {6, 4, 5, 3}),
|
||||
constant<int64_t>(element::i32, {3}, {1, 2, 3}),
|
||||
constant<int64_t>(element::i32, {3}, {0, 4, 11}),
|
||||
constant<int64_t>(element::i32, {3}, {1, 2, -1}),
|
||||
constant<int64_t>(element::i32, {3}, {0, 1, 2}),
|
||||
};
|
||||
|
||||
// Test model description:
|
||||
@ -806,8 +810,8 @@ auto test_forward_reshape_squeeze = []() {
|
||||
test_case.transformation = CREATE_PASS_FACTORY(TSSqueezeForward);
|
||||
test_case.num_main_ops = {1};
|
||||
test_case.inputs_to_main = {
|
||||
parameter(element::f32, {6, 1, 5, 1, 4}),
|
||||
constant<int64_t>(element::i32, {3}, {4, 5, 6}),
|
||||
parameter(element::f32, {6, 1, 5, 1, 4}),
|
||||
constant<int64_t>(element::i32, {3}, {4, 5, 6}),
|
||||
};
|
||||
|
||||
// Test model description:
|
||||
@ -820,7 +824,7 @@ auto test_forward_reshape_squeeze = []() {
|
||||
OutputVector new_out_vec(out_vec.size());
|
||||
new_out_vec[0] = out_vec[0];
|
||||
new_out_vec[1] =
|
||||
make_shared<Constant>(out_vec[1].get_element_type(), out_vec[1].get_shape(), std::vector<int64_t>{6, 5, 4});
|
||||
make_shared<Constant>(out_vec[1].get_element_type(), out_vec[1].get_shape(), std::vector<int64_t>{6, 5, 4});
|
||||
return new_out_vec;
|
||||
};
|
||||
test_case.model_ref.preprocess_inputs_to_main = {{new_constant}, {{1}}};
|
||||
@ -831,7 +835,9 @@ auto test_forward_reshape_squeeze = []() {
|
||||
return wrapper(test_case);
|
||||
};
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(TransposeSinkingCommonReshapeSqueezeForward, TransposeSinkingTestFixture, test_forward_reshape_squeeze());
|
||||
INSTANTIATE_TEST_SUITE_P(TransposeSinkingCommonReshapeSqueezeForward,
|
||||
TransposeSinkingTestFixture,
|
||||
test_forward_reshape_squeeze());
|
||||
|
||||
auto test_forward_reshape_unsqueeze = []() {
|
||||
TestCase test_case;
|
||||
@ -840,8 +846,8 @@ auto test_forward_reshape_unsqueeze = []() {
|
||||
test_case.transformation = CREATE_PASS_FACTORY(TSUnsqueezeForward);
|
||||
test_case.num_main_ops = {1};
|
||||
test_case.inputs_to_main = {
|
||||
parameter(element::f32, {6, 5, 4}),
|
||||
constant<int64_t>(element::i32, {5}, {4, 1, 5, 1, 6}),
|
||||
parameter(element::f32, {6, 5, 4}),
|
||||
constant<int64_t>(element::i32, {5}, {4, 1, 5, 1, 6}),
|
||||
};
|
||||
|
||||
// Test model description:
|
||||
@ -863,7 +869,9 @@ auto test_forward_reshape_unsqueeze = []() {
|
||||
return wrapper(test_case);
|
||||
};
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(TransposeSinkingCommonReshapeUnsqueezeForward, TransposeSinkingTestFixture, test_forward_reshape_unsqueeze());
|
||||
INSTANTIATE_TEST_SUITE_P(TransposeSinkingCommonReshapeUnsqueezeForward,
|
||||
TransposeSinkingTestFixture,
|
||||
test_forward_reshape_unsqueeze());
|
||||
// ------------------ BACKWARD --------------------
|
||||
|
||||
auto test_backward_unary = []() {
|
||||
@ -1141,7 +1149,6 @@ INSTANTIATE_TEST_SUITE_P(TransposeSinkingCommonInterpolateBackward,
|
||||
TransposeSinkingTestFixture,
|
||||
test_backward_interpolate());
|
||||
|
||||
|
||||
auto test_backward_squeeze = []() {
|
||||
TestCase test_case;
|
||||
|
||||
@ -1149,8 +1156,8 @@ auto test_backward_squeeze = []() {
|
||||
test_case.transformation = CREATE_PASS_FACTORY(TSSqueezeBackward);
|
||||
test_case.num_main_ops = {1};
|
||||
test_case.inputs_to_main = {
|
||||
parameter(element::f32, {32, 1, 2, 1}),
|
||||
constant<int64_t>(element::i32, {2}, {1, 3}),
|
||||
parameter(element::f32, {32, 1, 2, 1}),
|
||||
constant<int64_t>(element::i32, {2}, {1, 3}),
|
||||
};
|
||||
|
||||
// Test model description:
|
||||
@ -1182,8 +1189,8 @@ auto test_backward_unsqueeze = []() {
|
||||
test_case.transformation = CREATE_PASS_FACTORY(TSUnsqueezeBackward);
|
||||
test_case.num_main_ops = {1};
|
||||
test_case.inputs_to_main = {
|
||||
parameter(element::f32, {32, 3, 2, 1}),
|
||||
constant<int64_t>(element::i32, {2}, {0, 2}),
|
||||
parameter(element::f32, {32, 3, 2, 1}),
|
||||
constant<int64_t>(element::i32, {2}, {0, 2}),
|
||||
};
|
||||
|
||||
// Test model description:
|
||||
@ -1196,7 +1203,7 @@ auto test_backward_unsqueeze = []() {
|
||||
OutputVector new_out_vec(out_vec.size());
|
||||
new_out_vec[0] = out_vec[0];
|
||||
new_out_vec[1] =
|
||||
make_shared<Constant>(out_vec[1].get_element_type(), out_vec[1].get_shape(), std::vector<int64_t>{5, 3});
|
||||
make_shared<Constant>(out_vec[1].get_element_type(), out_vec[1].get_shape(), std::vector<int64_t>{5, 3});
|
||||
return new_out_vec;
|
||||
};
|
||||
test_case.model_ref.preprocess_inputs_to_main = {{set_transpose_for, new_constant}, {{0}, {1}}};
|
||||
@ -1206,7 +1213,9 @@ auto test_backward_unsqueeze = []() {
|
||||
return wrapper(test_case);
|
||||
};
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(TransposeSinkingCommonUnsqueezeBackward, TransposeSinkingTestFixture, test_backward_unsqueeze());
|
||||
INSTANTIATE_TEST_SUITE_P(TransposeSinkingCommonUnsqueezeBackward,
|
||||
TransposeSinkingTestFixture,
|
||||
test_backward_unsqueeze());
|
||||
|
||||
auto test_backward_slice = []() {
|
||||
TestCase test_case;
|
||||
@ -1215,11 +1224,11 @@ auto test_backward_slice = []() {
|
||||
test_case.transformation = CREATE_PASS_FACTORY(TSSliceBackward);
|
||||
test_case.num_main_ops = {1};
|
||||
test_case.inputs_to_main = {
|
||||
parameter(element::f32, {6, 4, 5, 3}),
|
||||
constant<int64_t>(element::i32, {3}, {1, 2, 3}),
|
||||
constant<int64_t>(element::i32, {3}, {0, 4, 11}),
|
||||
constant<int64_t>(element::i32, {3}, {1, 2, -1}),
|
||||
constant<int64_t>(element::i32, {3}, {0, 1, 2}),
|
||||
parameter(element::f32, {6, 4, 5, 3}),
|
||||
constant<int64_t>(element::i32, {3}, {1, 2, 3}),
|
||||
constant<int64_t>(element::i32, {3}, {0, 4, 11}),
|
||||
constant<int64_t>(element::i32, {3}, {1, 2, -1}),
|
||||
constant<int64_t>(element::i32, {3}, {0, 1, 2}),
|
||||
};
|
||||
|
||||
// Test model description:
|
||||
@ -1257,8 +1266,8 @@ auto test_backward_reshape_squeeze = []() {
|
||||
test_case.transformation = CREATE_PASS_FACTORY(TSSqueezeBackward);
|
||||
test_case.num_main_ops = {1};
|
||||
test_case.inputs_to_main = {
|
||||
parameter(element::f32, {4, 1, 5, 1, 6}),
|
||||
constant<int64_t>(element::i32, {3}, {4, 5, 6}),
|
||||
parameter(element::f32, {4, 1, 5, 1, 6}),
|
||||
constant<int64_t>(element::i32, {3}, {4, 5, 6}),
|
||||
};
|
||||
|
||||
// Test model description:
|
||||
@ -1281,7 +1290,9 @@ auto test_backward_reshape_squeeze = []() {
|
||||
return wrapper(test_case);
|
||||
};
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(TransposeSinkingCommonReshapeSqueezeBackward, TransposeSinkingTestFixture, test_backward_reshape_squeeze());
|
||||
INSTANTIATE_TEST_SUITE_P(TransposeSinkingCommonReshapeSqueezeBackward,
|
||||
TransposeSinkingTestFixture,
|
||||
test_backward_reshape_squeeze());
|
||||
|
||||
auto test_backward_reshape_unsqueeze = []() {
|
||||
TestCase test_case;
|
||||
@ -1290,8 +1301,8 @@ auto test_backward_reshape_unsqueeze = []() {
|
||||
test_case.transformation = CREATE_PASS_FACTORY(TSUnsqueezeBackward);
|
||||
test_case.num_main_ops = {1};
|
||||
test_case.inputs_to_main = {
|
||||
parameter(element::f32, {4, 5, 6}),
|
||||
constant<int64_t>(element::i32, {5}, {4, 1, 5, 1, 6}),
|
||||
parameter(element::f32, {4, 5, 6}),
|
||||
constant<int64_t>(element::i32, {5}, {4, 1, 5, 1, 6}),
|
||||
};
|
||||
|
||||
// Test model description:
|
||||
@ -1303,8 +1314,9 @@ auto test_backward_reshape_unsqueeze = []() {
|
||||
auto new_constant = [](const vector<size_t>& idxs, const OutputVector& out_vec) -> OutputVector {
|
||||
OutputVector new_out_vec(out_vec.size());
|
||||
new_out_vec[0] = out_vec[0];
|
||||
new_out_vec[1] =
|
||||
make_shared<Constant>(out_vec[1].get_element_type(), out_vec[1].get_shape(), std::vector<int64_t>{6, 1, 5, 1, 4});
|
||||
new_out_vec[1] = make_shared<Constant>(out_vec[1].get_element_type(),
|
||||
out_vec[1].get_shape(),
|
||||
std::vector<int64_t>{6, 1, 5, 1, 4});
|
||||
return new_out_vec;
|
||||
};
|
||||
test_case.model_ref.preprocess_inputs_to_main = {{set_transpose_for, new_constant}, {{0}, {1}}};
|
||||
@ -1314,7 +1326,9 @@ auto test_backward_reshape_unsqueeze = []() {
|
||||
return wrapper(test_case);
|
||||
};
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(TransposeSinkingCommonReshapeUnsqueezeBackward, TransposeSinkingTestFixture, test_backward_reshape_unsqueeze());
|
||||
INSTANTIATE_TEST_SUITE_P(TransposeSinkingCommonReshapeUnsqueezeBackward,
|
||||
TransposeSinkingTestFixture,
|
||||
test_backward_reshape_unsqueeze());
|
||||
} // namespace common
|
||||
} // namespace testing
|
||||
} // namespace transpose_sinking
|
Loading…
Reference in New Issue
Block a user