codestyle

This commit is contained in:
Tikhonov Ivan 2023-03-16 16:16:59 +00:00
parent 66d16ae45e
commit 88ddbb2437
7 changed files with 100 additions and 78 deletions

View File

@ -4,9 +4,9 @@
#pragma once
#include "../../transformations_visibility.hpp"
#include "openvino/pass/graph_rewrite.hpp"
#include "openvino/pass/pass.hpp"
#include "../../transformations_visibility.hpp"
namespace ov {
namespace pass {
@ -15,7 +15,7 @@ namespace transpose_sinking {
class TRANSFORMATIONS_API TSSliceForward;
class TRANSFORMATIONS_API TSSliceBackward;
}
} // namespace transpose_sinking
} // namespace pass
} // namespace ov

View File

@ -15,10 +15,10 @@
#include "transformations/transpose_sinking/ts_fuse.hpp"
#include "transformations/transpose_sinking/ts_interpolate.hpp"
#include "transformations/transpose_sinking/ts_reduction.hpp"
#include "transformations/transpose_sinking/ts_squeeze.hpp"
#include "transformations/transpose_sinking/ts_unsqueeze.hpp"
#include "transformations/transpose_sinking/ts_split.hpp"
#include "transformations/transpose_sinking/ts_squeeze.hpp"
#include "transformations/transpose_sinking/ts_unary.hpp"
#include "transformations/transpose_sinking/ts_unsqueeze.hpp"
#include "transformations/utils/utils.hpp"
using namespace ov::pass::transpose_sinking;

View File

@ -24,7 +24,7 @@ using namespace ov::pass::transpose_sinking::utils;
namespace {
bool get_keep_dims(const std::shared_ptr<Node> &reduction) {
bool get_keep_dims(const std::shared_ptr<Node>& reduction) {
auto arithmetic_reduce = as_type_ptr<ov::op::util::ArithmeticReductionKeepDims>(reduction);
auto logical_reduce = as_type_ptr<ov::op::util::LogicalReductionKeepDims>(reduction);
@ -36,14 +36,14 @@ bool get_keep_dims(const std::shared_ptr<Node> &reduction) {
return keep_dims;
}
}
} // namespace
TSReductionForward::TSReductionForward() {
MATCHER_SCOPE(TSReductionForward);
auto transpose_label = wrap_type<Transpose>({any_input(), wrap_type<Constant>()});
auto reduce_label = wrap_type<op::util::ArithmeticReductionKeepDims, op::util::LogicalReductionKeepDims>(
{transpose_label, wrap_type<Constant>()});
{transpose_label, wrap_type<Constant>()});
ov::matcher_pass_callback matcher_pass_callback = [=](pattern::Matcher& m) {
const auto& pattern_to_output = m.get_pattern_map();
@ -72,7 +72,8 @@ TSReductionForward::TSReductionForward() {
transpose_order_values = GetOrderAfterReduction(non_negative_axes, transpose_order_values);
}
auto new_transpose_order = Constant::create(transpose_order->get_element_type(), {transpose_order_values.size()},
auto new_transpose_order = Constant::create(transpose_order->get_element_type(),
{transpose_order_values.size()},
transpose_order_values);
auto new_const = Constant::create(reduction_axes->get_element_type(), reduction_axes->get_shape(), new_values);
@ -96,7 +97,8 @@ TSReductionBackward::TSReductionBackward() {
MATCHER_SCOPE(TSReductionBackward);
auto reduce_label = wrap_type<op::util::ArithmeticReductionKeepDims, op::util::LogicalReductionKeepDims>(
{any_input(), wrap_type<Constant>()}, HasSameOutputTransposeNodes);
{any_input(), wrap_type<Constant>()},
HasSameOutputTransposeNodes);
auto transpose_label = wrap_type<Transpose>({reduce_label, wrap_type<Constant>()});
ov::matcher_pass_callback matcher_pass_callback = [=](pattern::Matcher& m) {
@ -112,13 +114,14 @@ TSReductionBackward::TSReductionBackward() {
auto rank = reduction->get_input_partial_shape(0).rank();
auto non_negative_axes =
normalize_axes(reduction->get_friendly_name(), reduction_axes->cast_vector<int64_t>(), rank);
normalize_axes(reduction->get_friendly_name(), reduction_axes->cast_vector<int64_t>(), rank);
auto transpose_order_values = transpose_order->cast_vector<size_t>();
if (!keep_dims) {
transpose_order_values = GetOrderBeforeReduction(non_negative_axes, transpose_order_values);
}
auto reversed_order_values = ReverseTransposeOrder(transpose_order_values);
auto new_transpose_order = Constant::create(transpose_order->get_element_type(), {transpose_order_values.size()},
auto new_transpose_order = Constant::create(transpose_order->get_element_type(),
{transpose_order_values.size()},
transpose_order_values);
std::vector<size_t> new_values;

View File

@ -4,15 +4,14 @@
#include "transformations/transpose_sinking/ts_slice.hpp"
#include "openvino/pass/pattern/op/or.hpp"
#include "itt.hpp"
#include "openvino/op/util/op_types.hpp"
#include "openvino/opsets/opset10.hpp"
#include "openvino/pass/pattern/op/or.hpp"
#include "openvino/pass/pattern/op/wrap_type.hpp"
#include "openvino/util/common_util.hpp"
#include "transformations/transpose_sinking/ts_utils.hpp"
#include "transformations/rt_info/transpose_sinking_attr.hpp"
#include "transformations/transpose_sinking/ts_utils.hpp"
using namespace ov;
using namespace ov::opset10;

View File

@ -28,8 +28,8 @@ bool shape_to_squeeze_axes(const std::shared_ptr<Node>& reshape,
result_axes.clear();
auto reduction_axes_values = reshape_to_shape->cast_vector<int64_t>();
// supported the case if Reshape is equal to Squeeze
const auto &new_shape = reduction_axes_values;
const auto &input_pshape = reshape->get_input_partial_shape(0);
const auto& new_shape = reduction_axes_values;
const auto& input_pshape = reshape->get_input_partial_shape(0);
// todo: support dynamic case
if (input_pshape.is_dynamic()) {
return false;
@ -57,7 +57,7 @@ bool shape_to_squeeze_axes(const std::shared_ptr<Node>& reshape,
std::vector<size_t> squeeze_axes_to_shape(const std::shared_ptr<Node>& input_node, std::vector<size_t> squeeze_axes) {
std::vector<size_t> to_shape;
std::sort(squeeze_axes.begin(), squeeze_axes.end());
const auto& input_shape = input_node->input(0).get_shape(); // check is static
const auto& input_shape = input_node->input(0).get_shape(); // check is static
for (size_t i = 0, j = 0; i < input_shape.size(); ++i) {
if (j < squeeze_axes.size() && i == squeeze_axes[j]) {
++j;
@ -68,7 +68,7 @@ std::vector<size_t> squeeze_axes_to_shape(const std::shared_ptr<Node>& input_nod
return to_shape;
}
}
} // namespace
TSSqueezeForward::TSSqueezeForward() {
MATCHER_SCOPE(TSSqueezeForward);
@ -96,7 +96,8 @@ TSSqueezeForward::TSSqueezeForward() {
}
} else {
auto rank = squeeze->get_input_partial_shape(0).rank();
non_negative_axes = normalize_axes(squeeze->get_friendly_name(), squeeze_axes->cast_vector<int64_t>(), rank);
non_negative_axes =
normalize_axes(squeeze->get_friendly_name(), squeeze_axes->cast_vector<int64_t>(), rank);
}
// if 2nd input to squeeze is empty then all '1' dims will be deleted.
@ -172,7 +173,8 @@ TSSqueezeBackward::TSSqueezeBackward() {
}
} else {
auto rank = squeeze->get_input_partial_shape(0).rank();
non_negative_axes = normalize_axes(squeeze->get_friendly_name(), squeeze_axes->cast_vector<int64_t>(), rank);
non_negative_axes =
normalize_axes(squeeze->get_friendly_name(), squeeze_axes->cast_vector<int64_t>(), rank);
}
bool squeeze_all_dims = false;
@ -210,7 +212,8 @@ TSSqueezeBackward::TSSqueezeBackward() {
if (squeeze_all_dims) {
new_squeeze = squeeze->clone_with_new_inputs({new_transpose, squeeze->input_value(1)});
} else {
auto new_const = std::make_shared<Constant>(squeeze_axes->get_element_type(), squeeze_axes->get_shape(), new_values);
auto new_const =
std::make_shared<Constant>(squeeze_axes->get_element_type(), squeeze_axes->get_shape(), new_values);
new_squeeze = squeeze->clone_with_new_inputs({new_transpose, new_const});
}

View File

@ -23,13 +23,13 @@ using namespace ov::pass::transpose_sinking::utils;
namespace {
bool shape_to_unsqueeze_axes(const std::shared_ptr<Node>& reshape,
const std::shared_ptr<Constant>& reshape_to_shape,
std::vector<size_t>& result_axes) {
const std::shared_ptr<Constant>& reshape_to_shape,
std::vector<size_t>& result_axes) {
result_axes.clear();
auto reduction_axes_values = reshape_to_shape->cast_vector<int64_t>();
// supported the case if Reshape is equal to Unsqueeze
const auto &new_shape = reduction_axes_values;
const auto &input_pshape = reshape->get_input_partial_shape(0);
const auto& new_shape = reduction_axes_values;
const auto& input_pshape = reshape->get_input_partial_shape(0);
// todo: support dynamic case
if (input_pshape.is_dynamic()) {
return false;
@ -37,7 +37,7 @@ bool shape_to_unsqueeze_axes(const std::shared_ptr<Node>& reshape,
const auto input_shape = input_pshape.to_shape();
if (new_shape.size() > input_shape.size()) {
for (size_t i = 0, j = 0; i < input_shape.size();j++) {
for (size_t i = 0, j = 0; i < input_shape.size(); j++) {
if (input_shape[i] == new_shape[j]) {
i++;
} else if (input_shape[i] != new_shape[j] && new_shape[j] != 1) {
@ -54,8 +54,9 @@ bool shape_to_unsqueeze_axes(const std::shared_ptr<Node>& reshape,
return true;
}
std::vector<size_t> unsqueeze_axes_to_shape(const std::shared_ptr<Node>& input_node, std::vector<size_t> unsqueeze_axes) {
const auto& input_shape = input_node->input(0).get_shape(); // check is static
std::vector<size_t> unsqueeze_axes_to_shape(const std::shared_ptr<Node>& input_node,
std::vector<size_t> unsqueeze_axes) {
const auto& input_shape = input_node->input(0).get_shape(); // check is static
std::vector<size_t> to_shape(input_shape.size() + unsqueeze_axes.size());
std::sort(unsqueeze_axes.begin(), unsqueeze_axes.end());
std::stack<size_t, std::vector<size_t>> shape_to_add(input_shape);
@ -98,20 +99,20 @@ TSUnsqueezeForward::TSUnsqueezeForward() {
}
} else {
auto rank = unsqueeze->get_output_partial_shape(0).rank();
non_negative_axes = normalize_axes(unsqueeze->get_friendly_name(), unsqueeze_axes->cast_vector<int64_t>(), rank);
non_negative_axes =
normalize_axes(unsqueeze->get_friendly_name(), unsqueeze_axes->cast_vector<int64_t>(), rank);
}
auto ts_order_values = transpose_order->cast_vector<size_t>();
/* std::vector<size_t> new_values;
new_values.reserve(non_negative_axes.size());
for (const auto& axis : non_negative_axes) {
new_values.push_back(ts_order_values[axis]);
}*/
/* std::vector<size_t> new_values;
new_values.reserve(non_negative_axes.size());
for (const auto& axis : non_negative_axes) {
new_values.push_back(ts_order_values[axis]);
}*/
ts_order_values = GetOrderBeforeReduction(non_negative_axes, ts_order_values);
auto new_transpose_order = Constant::create(transpose_order->get_element_type(),
{ts_order_values.size()},
ts_order_values);
auto new_transpose_order =
Constant::create(transpose_order->get_element_type(), {ts_order_values.size()}, ts_order_values);
/*if (as_type_ptr<Reshape>(unsqueeze)) {
new_values = unsqueeze_axes_to_shape(unsqueeze, new_values);
@ -136,7 +137,8 @@ TSUnsqueezeForward::TSUnsqueezeForward() {
TSUnsqueezeBackward::TSUnsqueezeBackward() {
MATCHER_SCOPE(TSUnsqueezeBackward);
auto unsqueeze_label = wrap_type<Unsqueeze, Reshape>({any_input(), wrap_type<Constant>()}, HasSameOutputTransposeNodes);
auto unsqueeze_label =
wrap_type<Unsqueeze, Reshape>({any_input(), wrap_type<Constant>()}, HasSameOutputTransposeNodes);
auto transpose_label = wrap_type<Transpose>({unsqueeze_label, wrap_type<Constant>()});
ov::matcher_pass_callback matcher_pass_callback = [=](pattern::Matcher& m) {
@ -158,7 +160,8 @@ TSUnsqueezeBackward::TSUnsqueezeBackward() {
}
} else {
auto rank = unsqueeze->get_output_partial_shape(0).rank();
non_negative_axes = normalize_axes(unsqueeze->get_friendly_name(), unsqueeze_axes->cast_vector<int64_t>(), rank);
non_negative_axes =
normalize_axes(unsqueeze->get_friendly_name(), unsqueeze_axes->cast_vector<int64_t>(), rank);
}
auto transpose_order_values = transpose_order->cast_vector<size_t>();

View File

@ -191,7 +191,11 @@ class SliceFactory : public IFactory {
public:
explicit SliceFactory(const std::string& type_name) : IFactory(type_name) {}
NodePtr create(const OutputVector& parent_nodes) const override {
return std::make_shared<Slice>(parent_nodes[0], parent_nodes[1], parent_nodes[2], parent_nodes[3], parent_nodes[4]);
return std::make_shared<Slice>(parent_nodes[0],
parent_nodes[1],
parent_nodes[2],
parent_nodes[3],
parent_nodes[4]);
}
};
@ -687,8 +691,8 @@ auto test_forward_squeeze = []() {
test_case.transformation = CREATE_PASS_FACTORY(TSSqueezeForward);
test_case.num_main_ops = {1};
test_case.inputs_to_main = {
parameter(element::f32, {32, 1, 2, 1}),
constant<int64_t>(element::i32, {2}, {0, 2}),
parameter(element::f32, {32, 1, 2, 1}),
constant<int64_t>(element::i32, {2}, {0, 2}),
};
// Test model description:
@ -701,7 +705,7 @@ auto test_forward_squeeze = []() {
OutputVector new_out_vec(out_vec.size());
new_out_vec[0] = out_vec[0];
new_out_vec[1] =
make_shared<Constant>(out_vec[1].get_element_type(), out_vec[1].get_shape(), std::vector<int64_t>{3, 1});
make_shared<Constant>(out_vec[1].get_element_type(), out_vec[1].get_shape(), std::vector<int64_t>{3, 1});
return new_out_vec;
};
test_case.model_ref.preprocess_inputs_to_main = {{new_constant}, {{1}}};
@ -721,8 +725,8 @@ auto test_forward_unsqueeze = []() {
test_case.transformation = CREATE_PASS_FACTORY(TSUnsqueezeForward);
test_case.num_main_ops = {1};
test_case.inputs_to_main = {
parameter(element::f32, {32, 3, 2, 1}),
constant<int64_t>(element::i32, {2}, {0, 2}),
parameter(element::f32, {32, 3, 2, 1}),
constant<int64_t>(element::i32, {2}, {0, 2}),
};
// Test model description:
@ -735,7 +739,7 @@ auto test_forward_unsqueeze = []() {
OutputVector new_out_vec(out_vec.size());
new_out_vec[0] = out_vec[0];
new_out_vec[1] =
make_shared<Constant>(out_vec[1].get_element_type(), out_vec[1].get_shape(), std::vector<int64_t>{0, 2});
make_shared<Constant>(out_vec[1].get_element_type(), out_vec[1].get_shape(), std::vector<int64_t>{0, 2});
return new_out_vec;
};
test_case.model_ref.preprocess_inputs_to_main = {{new_constant}, {{1}}};
@ -762,11 +766,11 @@ auto test_forward_slice = []() {
test_case.transformation = CREATE_PASS_FACTORY(TSSliceForward);
test_case.num_main_ops = {1};
test_case.inputs_to_main = {
parameter(element::f32, {6, 4, 5, 3}),
constant<int64_t>(element::i32, {3}, {1, 2, 3}),
constant<int64_t>(element::i32, {3}, {0, 4, 11}),
constant<int64_t>(element::i32, {3}, {1, 2, -1}),
constant<int64_t>(element::i32, {3}, {0, 1, 2}),
parameter(element::f32, {6, 4, 5, 3}),
constant<int64_t>(element::i32, {3}, {1, 2, 3}),
constant<int64_t>(element::i32, {3}, {0, 4, 11}),
constant<int64_t>(element::i32, {3}, {1, 2, -1}),
constant<int64_t>(element::i32, {3}, {0, 1, 2}),
};
// Test model description:
@ -806,8 +810,8 @@ auto test_forward_reshape_squeeze = []() {
test_case.transformation = CREATE_PASS_FACTORY(TSSqueezeForward);
test_case.num_main_ops = {1};
test_case.inputs_to_main = {
parameter(element::f32, {6, 1, 5, 1, 4}),
constant<int64_t>(element::i32, {3}, {4, 5, 6}),
parameter(element::f32, {6, 1, 5, 1, 4}),
constant<int64_t>(element::i32, {3}, {4, 5, 6}),
};
// Test model description:
@ -820,7 +824,7 @@ auto test_forward_reshape_squeeze = []() {
OutputVector new_out_vec(out_vec.size());
new_out_vec[0] = out_vec[0];
new_out_vec[1] =
make_shared<Constant>(out_vec[1].get_element_type(), out_vec[1].get_shape(), std::vector<int64_t>{6, 5, 4});
make_shared<Constant>(out_vec[1].get_element_type(), out_vec[1].get_shape(), std::vector<int64_t>{6, 5, 4});
return new_out_vec;
};
test_case.model_ref.preprocess_inputs_to_main = {{new_constant}, {{1}}};
@ -831,7 +835,9 @@ auto test_forward_reshape_squeeze = []() {
return wrapper(test_case);
};
INSTANTIATE_TEST_SUITE_P(TransposeSinkingCommonReshapeSqueezeForward, TransposeSinkingTestFixture, test_forward_reshape_squeeze());
INSTANTIATE_TEST_SUITE_P(TransposeSinkingCommonReshapeSqueezeForward,
TransposeSinkingTestFixture,
test_forward_reshape_squeeze());
auto test_forward_reshape_unsqueeze = []() {
TestCase test_case;
@ -840,8 +846,8 @@ auto test_forward_reshape_unsqueeze = []() {
test_case.transformation = CREATE_PASS_FACTORY(TSUnsqueezeForward);
test_case.num_main_ops = {1};
test_case.inputs_to_main = {
parameter(element::f32, {6, 5, 4}),
constant<int64_t>(element::i32, {5}, {4, 1, 5, 1, 6}),
parameter(element::f32, {6, 5, 4}),
constant<int64_t>(element::i32, {5}, {4, 1, 5, 1, 6}),
};
// Test model description:
@ -863,7 +869,9 @@ auto test_forward_reshape_unsqueeze = []() {
return wrapper(test_case);
};
INSTANTIATE_TEST_SUITE_P(TransposeSinkingCommonReshapeUnsqueezeForward, TransposeSinkingTestFixture, test_forward_reshape_unsqueeze());
INSTANTIATE_TEST_SUITE_P(TransposeSinkingCommonReshapeUnsqueezeForward,
TransposeSinkingTestFixture,
test_forward_reshape_unsqueeze());
// ------------------ BACKWARD --------------------
auto test_backward_unary = []() {
@ -1141,7 +1149,6 @@ INSTANTIATE_TEST_SUITE_P(TransposeSinkingCommonInterpolateBackward,
TransposeSinkingTestFixture,
test_backward_interpolate());
auto test_backward_squeeze = []() {
TestCase test_case;
@ -1149,8 +1156,8 @@ auto test_backward_squeeze = []() {
test_case.transformation = CREATE_PASS_FACTORY(TSSqueezeBackward);
test_case.num_main_ops = {1};
test_case.inputs_to_main = {
parameter(element::f32, {32, 1, 2, 1}),
constant<int64_t>(element::i32, {2}, {1, 3}),
parameter(element::f32, {32, 1, 2, 1}),
constant<int64_t>(element::i32, {2}, {1, 3}),
};
// Test model description:
@ -1182,8 +1189,8 @@ auto test_backward_unsqueeze = []() {
test_case.transformation = CREATE_PASS_FACTORY(TSUnsqueezeBackward);
test_case.num_main_ops = {1};
test_case.inputs_to_main = {
parameter(element::f32, {32, 3, 2, 1}),
constant<int64_t>(element::i32, {2}, {0, 2}),
parameter(element::f32, {32, 3, 2, 1}),
constant<int64_t>(element::i32, {2}, {0, 2}),
};
// Test model description:
@ -1196,7 +1203,7 @@ auto test_backward_unsqueeze = []() {
OutputVector new_out_vec(out_vec.size());
new_out_vec[0] = out_vec[0];
new_out_vec[1] =
make_shared<Constant>(out_vec[1].get_element_type(), out_vec[1].get_shape(), std::vector<int64_t>{5, 3});
make_shared<Constant>(out_vec[1].get_element_type(), out_vec[1].get_shape(), std::vector<int64_t>{5, 3});
return new_out_vec;
};
test_case.model_ref.preprocess_inputs_to_main = {{set_transpose_for, new_constant}, {{0}, {1}}};
@ -1206,7 +1213,9 @@ auto test_backward_unsqueeze = []() {
return wrapper(test_case);
};
INSTANTIATE_TEST_SUITE_P(TransposeSinkingCommonUnsqueezeBackward, TransposeSinkingTestFixture, test_backward_unsqueeze());
INSTANTIATE_TEST_SUITE_P(TransposeSinkingCommonUnsqueezeBackward,
TransposeSinkingTestFixture,
test_backward_unsqueeze());
auto test_backward_slice = []() {
TestCase test_case;
@ -1215,11 +1224,11 @@ auto test_backward_slice = []() {
test_case.transformation = CREATE_PASS_FACTORY(TSSliceBackward);
test_case.num_main_ops = {1};
test_case.inputs_to_main = {
parameter(element::f32, {6, 4, 5, 3}),
constant<int64_t>(element::i32, {3}, {1, 2, 3}),
constant<int64_t>(element::i32, {3}, {0, 4, 11}),
constant<int64_t>(element::i32, {3}, {1, 2, -1}),
constant<int64_t>(element::i32, {3}, {0, 1, 2}),
parameter(element::f32, {6, 4, 5, 3}),
constant<int64_t>(element::i32, {3}, {1, 2, 3}),
constant<int64_t>(element::i32, {3}, {0, 4, 11}),
constant<int64_t>(element::i32, {3}, {1, 2, -1}),
constant<int64_t>(element::i32, {3}, {0, 1, 2}),
};
// Test model description:
@ -1257,8 +1266,8 @@ auto test_backward_reshape_squeeze = []() {
test_case.transformation = CREATE_PASS_FACTORY(TSSqueezeBackward);
test_case.num_main_ops = {1};
test_case.inputs_to_main = {
parameter(element::f32, {4, 1, 5, 1, 6}),
constant<int64_t>(element::i32, {3}, {4, 5, 6}),
parameter(element::f32, {4, 1, 5, 1, 6}),
constant<int64_t>(element::i32, {3}, {4, 5, 6}),
};
// Test model description:
@ -1281,7 +1290,9 @@ auto test_backward_reshape_squeeze = []() {
return wrapper(test_case);
};
INSTANTIATE_TEST_SUITE_P(TransposeSinkingCommonReshapeSqueezeBackward, TransposeSinkingTestFixture, test_backward_reshape_squeeze());
INSTANTIATE_TEST_SUITE_P(TransposeSinkingCommonReshapeSqueezeBackward,
TransposeSinkingTestFixture,
test_backward_reshape_squeeze());
auto test_backward_reshape_unsqueeze = []() {
TestCase test_case;
@ -1290,8 +1301,8 @@ auto test_backward_reshape_unsqueeze = []() {
test_case.transformation = CREATE_PASS_FACTORY(TSUnsqueezeBackward);
test_case.num_main_ops = {1};
test_case.inputs_to_main = {
parameter(element::f32, {4, 5, 6}),
constant<int64_t>(element::i32, {5}, {4, 1, 5, 1, 6}),
parameter(element::f32, {4, 5, 6}),
constant<int64_t>(element::i32, {5}, {4, 1, 5, 1, 6}),
};
// Test model description:
@ -1303,8 +1314,9 @@ auto test_backward_reshape_unsqueeze = []() {
auto new_constant = [](const vector<size_t>& idxs, const OutputVector& out_vec) -> OutputVector {
OutputVector new_out_vec(out_vec.size());
new_out_vec[0] = out_vec[0];
new_out_vec[1] =
make_shared<Constant>(out_vec[1].get_element_type(), out_vec[1].get_shape(), std::vector<int64_t>{6, 1, 5, 1, 4});
new_out_vec[1] = make_shared<Constant>(out_vec[1].get_element_type(),
out_vec[1].get_shape(),
std::vector<int64_t>{6, 1, 5, 1, 4});
return new_out_vec;
};
test_case.model_ref.preprocess_inputs_to_main = {{set_transpose_for, new_constant}, {{0}, {1}}};
@ -1314,7 +1326,9 @@ auto test_backward_reshape_unsqueeze = []() {
return wrapper(test_case);
};
INSTANTIATE_TEST_SUITE_P(TransposeSinkingCommonReshapeUnsqueezeBackward, TransposeSinkingTestFixture, test_backward_reshape_unsqueeze());
INSTANTIATE_TEST_SUITE_P(TransposeSinkingCommonReshapeUnsqueezeBackward,
TransposeSinkingTestFixture,
test_backward_reshape_unsqueeze());
} // namespace common
} // namespace testing
} // namespace transpose_sinking