Merge branch 'itikhono/ts/slice' of https://github.com/itikhono/openvino into itikhono/ts/slice
This commit is contained in:
commit
83ab2cc5f6
@ -4,9 +4,9 @@
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "../../transformations_visibility.hpp"
|
||||
#include "openvino/pass/graph_rewrite.hpp"
|
||||
#include "openvino/pass/pass.hpp"
|
||||
#include "../../transformations_visibility.hpp"
|
||||
|
||||
namespace ov {
|
||||
namespace pass {
|
||||
@ -15,7 +15,7 @@ namespace transpose_sinking {
|
||||
class TRANSFORMATIONS_API TSSliceForward;
|
||||
class TRANSFORMATIONS_API TSSliceBackward;
|
||||
|
||||
}
|
||||
} // namespace transpose_sinking
|
||||
} // namespace pass
|
||||
} // namespace ov
|
||||
|
||||
|
@ -15,10 +15,10 @@
|
||||
#include "transformations/transpose_sinking/ts_fuse.hpp"
|
||||
#include "transformations/transpose_sinking/ts_interpolate.hpp"
|
||||
#include "transformations/transpose_sinking/ts_reduction.hpp"
|
||||
#include "transformations/transpose_sinking/ts_squeeze.hpp"
|
||||
#include "transformations/transpose_sinking/ts_unsqueeze.hpp"
|
||||
#include "transformations/transpose_sinking/ts_split.hpp"
|
||||
#include "transformations/transpose_sinking/ts_squeeze.hpp"
|
||||
#include "transformations/transpose_sinking/ts_unary.hpp"
|
||||
#include "transformations/transpose_sinking/ts_unsqueeze.hpp"
|
||||
#include "transformations/utils/utils.hpp"
|
||||
|
||||
using namespace ov::pass::transpose_sinking;
|
||||
|
@ -24,7 +24,7 @@ using namespace ov::pass::transpose_sinking::utils;
|
||||
|
||||
namespace {
|
||||
|
||||
bool get_keep_dims(const std::shared_ptr<Node> &reduction) {
|
||||
bool get_keep_dims(const std::shared_ptr<Node>& reduction) {
|
||||
auto arithmetic_reduce = as_type_ptr<ov::op::util::ArithmeticReductionKeepDims>(reduction);
|
||||
auto logical_reduce = as_type_ptr<ov::op::util::LogicalReductionKeepDims>(reduction);
|
||||
|
||||
@ -36,14 +36,14 @@ bool get_keep_dims(const std::shared_ptr<Node> &reduction) {
|
||||
return keep_dims;
|
||||
}
|
||||
|
||||
}
|
||||
} // namespace
|
||||
|
||||
TSReductionForward::TSReductionForward() {
|
||||
MATCHER_SCOPE(TSReductionForward);
|
||||
|
||||
auto transpose_label = wrap_type<Transpose>({any_input(), wrap_type<Constant>()});
|
||||
auto reduce_label = wrap_type<op::util::ArithmeticReductionKeepDims, op::util::LogicalReductionKeepDims>(
|
||||
{transpose_label, wrap_type<Constant>()});
|
||||
{transpose_label, wrap_type<Constant>()});
|
||||
|
||||
ov::matcher_pass_callback matcher_pass_callback = [=](pattern::Matcher& m) {
|
||||
const auto& pattern_to_output = m.get_pattern_map();
|
||||
@ -72,7 +72,8 @@ TSReductionForward::TSReductionForward() {
|
||||
transpose_order_values = GetOrderAfterReduction(non_negative_axes, transpose_order_values);
|
||||
}
|
||||
|
||||
auto new_transpose_order = Constant::create(transpose_order->get_element_type(), {transpose_order_values.size()},
|
||||
auto new_transpose_order = Constant::create(transpose_order->get_element_type(),
|
||||
{transpose_order_values.size()},
|
||||
transpose_order_values);
|
||||
|
||||
auto new_const = Constant::create(reduction_axes->get_element_type(), reduction_axes->get_shape(), new_values);
|
||||
@ -96,7 +97,8 @@ TSReductionBackward::TSReductionBackward() {
|
||||
MATCHER_SCOPE(TSReductionBackward);
|
||||
|
||||
auto reduce_label = wrap_type<op::util::ArithmeticReductionKeepDims, op::util::LogicalReductionKeepDims>(
|
||||
{any_input(), wrap_type<Constant>()}, HasSameOutputTransposeNodes);
|
||||
{any_input(), wrap_type<Constant>()},
|
||||
HasSameOutputTransposeNodes);
|
||||
auto transpose_label = wrap_type<Transpose>({reduce_label, wrap_type<Constant>()});
|
||||
|
||||
ov::matcher_pass_callback matcher_pass_callback = [=](pattern::Matcher& m) {
|
||||
@ -112,13 +114,14 @@ TSReductionBackward::TSReductionBackward() {
|
||||
|
||||
auto rank = reduction->get_input_partial_shape(0).rank();
|
||||
auto non_negative_axes =
|
||||
normalize_axes(reduction->get_friendly_name(), reduction_axes->cast_vector<int64_t>(), rank);
|
||||
normalize_axes(reduction->get_friendly_name(), reduction_axes->cast_vector<int64_t>(), rank);
|
||||
auto transpose_order_values = transpose_order->cast_vector<size_t>();
|
||||
if (!keep_dims) {
|
||||
transpose_order_values = GetOrderBeforeReduction(non_negative_axes, transpose_order_values);
|
||||
}
|
||||
auto reversed_order_values = ReverseTransposeOrder(transpose_order_values);
|
||||
auto new_transpose_order = Constant::create(transpose_order->get_element_type(), {transpose_order_values.size()},
|
||||
auto new_transpose_order = Constant::create(transpose_order->get_element_type(),
|
||||
{transpose_order_values.size()},
|
||||
transpose_order_values);
|
||||
|
||||
std::vector<size_t> new_values;
|
||||
|
@ -4,15 +4,14 @@
|
||||
|
||||
#include "transformations/transpose_sinking/ts_slice.hpp"
|
||||
|
||||
#include "openvino/pass/pattern/op/or.hpp"
|
||||
|
||||
#include "itt.hpp"
|
||||
#include "openvino/op/util/op_types.hpp"
|
||||
#include "openvino/opsets/opset10.hpp"
|
||||
#include "openvino/pass/pattern/op/or.hpp"
|
||||
#include "openvino/pass/pattern/op/wrap_type.hpp"
|
||||
#include "openvino/util/common_util.hpp"
|
||||
#include "transformations/transpose_sinking/ts_utils.hpp"
|
||||
#include "transformations/rt_info/transpose_sinking_attr.hpp"
|
||||
#include "transformations/transpose_sinking/ts_utils.hpp"
|
||||
|
||||
using namespace ov;
|
||||
using namespace ov::opset10;
|
||||
|
@ -28,8 +28,8 @@ bool shape_to_squeeze_axes(const std::shared_ptr<Node>& reshape,
|
||||
result_axes.clear();
|
||||
auto reduction_axes_values = reshape_to_shape->cast_vector<int64_t>();
|
||||
// supported the case if Reshape is equal to Squeeze
|
||||
const auto &new_shape = reduction_axes_values;
|
||||
const auto &input_pshape = reshape->get_input_partial_shape(0);
|
||||
const auto& new_shape = reduction_axes_values;
|
||||
const auto& input_pshape = reshape->get_input_partial_shape(0);
|
||||
// todo: support dynamic case
|
||||
if (input_pshape.is_dynamic()) {
|
||||
return false;
|
||||
@ -58,7 +58,7 @@ bool shape_to_squeeze_axes(const std::shared_ptr<Node>& reshape,
|
||||
std::vector<size_t> squeeze_axes_to_shape(const std::shared_ptr<Node>& input_node, std::vector<size_t> squeeze_axes) {
|
||||
std::vector<size_t> to_shape;
|
||||
std::sort(squeeze_axes.begin(), squeeze_axes.end());
|
||||
const auto& input_shape = input_node->input(0).get_shape(); // check is static
|
||||
const auto& input_shape = input_node->input(0).get_shape(); // check is static
|
||||
for (size_t i = 0, j = 0; i < input_shape.size(); ++i) {
|
||||
if (j < squeeze_axes.size() && i == squeeze_axes[j]) {
|
||||
++j;
|
||||
@ -69,7 +69,7 @@ std::vector<size_t> squeeze_axes_to_shape(const std::shared_ptr<Node>& input_nod
|
||||
return to_shape;
|
||||
}
|
||||
|
||||
}
|
||||
} // namespace
|
||||
|
||||
TSSqueezeForward::TSSqueezeForward() {
|
||||
MATCHER_SCOPE(TSSqueezeForward);
|
||||
@ -97,7 +97,8 @@ TSSqueezeForward::TSSqueezeForward() {
|
||||
}
|
||||
} else {
|
||||
auto rank = squeeze->get_input_partial_shape(0).rank();
|
||||
non_negative_axes = normalize_axes(squeeze->get_friendly_name(), squeeze_axes->cast_vector<int64_t>(), rank);
|
||||
non_negative_axes =
|
||||
normalize_axes(squeeze->get_friendly_name(), squeeze_axes->cast_vector<int64_t>(), rank);
|
||||
}
|
||||
|
||||
// if 2nd input to squeeze is empty then all '1' dims will be deleted.
|
||||
@ -173,7 +174,8 @@ TSSqueezeBackward::TSSqueezeBackward() {
|
||||
}
|
||||
} else {
|
||||
auto rank = squeeze->get_input_partial_shape(0).rank();
|
||||
non_negative_axes = normalize_axes(squeeze->get_friendly_name(), squeeze_axes->cast_vector<int64_t>(), rank);
|
||||
non_negative_axes =
|
||||
normalize_axes(squeeze->get_friendly_name(), squeeze_axes->cast_vector<int64_t>(), rank);
|
||||
}
|
||||
|
||||
bool squeeze_all_dims = false;
|
||||
@ -211,7 +213,8 @@ TSSqueezeBackward::TSSqueezeBackward() {
|
||||
if (squeeze_all_dims) {
|
||||
new_squeeze = squeeze->clone_with_new_inputs({new_transpose, squeeze->input_value(1)});
|
||||
} else {
|
||||
auto new_const = std::make_shared<Constant>(squeeze_axes->get_element_type(), squeeze_axes->get_shape(), new_values);
|
||||
auto new_const =
|
||||
std::make_shared<Constant>(squeeze_axes->get_element_type(), squeeze_axes->get_shape(), new_values);
|
||||
new_squeeze = squeeze->clone_with_new_inputs({new_transpose, new_const});
|
||||
}
|
||||
|
||||
|
@ -191,7 +191,11 @@ class SliceFactory : public IFactory {
|
||||
public:
|
||||
explicit SliceFactory(const std::string& type_name) : IFactory(type_name) {}
|
||||
NodePtr create(const OutputVector& parent_nodes) const override {
|
||||
return std::make_shared<Slice>(parent_nodes[0], parent_nodes[1], parent_nodes[2], parent_nodes[3], parent_nodes[4]);
|
||||
return std::make_shared<Slice>(parent_nodes[0],
|
||||
parent_nodes[1],
|
||||
parent_nodes[2],
|
||||
parent_nodes[3],
|
||||
parent_nodes[4]);
|
||||
}
|
||||
};
|
||||
|
||||
@ -687,8 +691,8 @@ auto test_forward_squeeze = []() {
|
||||
test_case.transformation = CREATE_PASS_FACTORY(TSSqueezeForward);
|
||||
test_case.num_main_ops = {1};
|
||||
test_case.inputs_to_main = {
|
||||
parameter(element::f32, {32, 1, 2, 1}),
|
||||
constant<int64_t>(element::i32, {2}, {0, 2}),
|
||||
parameter(element::f32, {32, 1, 2, 1}),
|
||||
constant<int64_t>(element::i32, {2}, {0, 2}),
|
||||
};
|
||||
|
||||
// Test model description:
|
||||
@ -701,7 +705,7 @@ auto test_forward_squeeze = []() {
|
||||
OutputVector new_out_vec(out_vec.size());
|
||||
new_out_vec[0] = out_vec[0];
|
||||
new_out_vec[1] =
|
||||
make_shared<Constant>(out_vec[1].get_element_type(), out_vec[1].get_shape(), std::vector<int64_t>{3, 1});
|
||||
make_shared<Constant>(out_vec[1].get_element_type(), out_vec[1].get_shape(), std::vector<int64_t>{3, 1});
|
||||
return new_out_vec;
|
||||
};
|
||||
test_case.model_ref.preprocess_inputs_to_main = {{new_constant}, {{1}}};
|
||||
@ -721,8 +725,8 @@ auto test_forward_unsqueeze = []() {
|
||||
test_case.transformation = CREATE_PASS_FACTORY(TSUnsqueezeForward);
|
||||
test_case.num_main_ops = {1};
|
||||
test_case.inputs_to_main = {
|
||||
parameter(element::f32, {32, 3, 2, 1}),
|
||||
constant<int64_t>(element::i32, {2}, {0, 2}),
|
||||
parameter(element::f32, {32, 3, 2, 1}),
|
||||
constant<int64_t>(element::i32, {2}, {0, 2}),
|
||||
};
|
||||
|
||||
// Test model description:
|
||||
@ -735,7 +739,7 @@ auto test_forward_unsqueeze = []() {
|
||||
OutputVector new_out_vec(out_vec.size());
|
||||
new_out_vec[0] = out_vec[0];
|
||||
new_out_vec[1] =
|
||||
make_shared<Constant>(out_vec[1].get_element_type(), out_vec[1].get_shape(), std::vector<int64_t>{0, 2});
|
||||
make_shared<Constant>(out_vec[1].get_element_type(), out_vec[1].get_shape(), std::vector<int64_t>{0, 2});
|
||||
return new_out_vec;
|
||||
};
|
||||
test_case.model_ref.preprocess_inputs_to_main = {{new_constant}, {{1}}};
|
||||
@ -762,11 +766,11 @@ auto test_forward_slice = []() {
|
||||
test_case.transformation = CREATE_PASS_FACTORY(TSSliceForward);
|
||||
test_case.num_main_ops = {1};
|
||||
test_case.inputs_to_main = {
|
||||
parameter(element::f32, {6, 4, 5, 3}),
|
||||
constant<int64_t>(element::i32, {3}, {1, 2, 3}),
|
||||
constant<int64_t>(element::i32, {3}, {0, 4, 11}),
|
||||
constant<int64_t>(element::i32, {3}, {1, 2, -1}),
|
||||
constant<int64_t>(element::i32, {3}, {0, 1, 2}),
|
||||
parameter(element::f32, {6, 4, 5, 3}),
|
||||
constant<int64_t>(element::i32, {3}, {1, 2, 3}),
|
||||
constant<int64_t>(element::i32, {3}, {0, 4, 11}),
|
||||
constant<int64_t>(element::i32, {3}, {1, 2, -1}),
|
||||
constant<int64_t>(element::i32, {3}, {0, 1, 2}),
|
||||
};
|
||||
|
||||
// Test model description:
|
||||
@ -806,8 +810,8 @@ auto test_forward_reshape_squeeze = []() {
|
||||
test_case.transformation = CREATE_PASS_FACTORY(TSSqueezeForward);
|
||||
test_case.num_main_ops = {1};
|
||||
test_case.inputs_to_main = {
|
||||
parameter(element::f32, {6, 1, 5, 1, 4}),
|
||||
constant<int64_t>(element::i32, {3}, {4, 5, 6}),
|
||||
parameter(element::f32, {6, 1, 5, 1, 4}),
|
||||
constant<int64_t>(element::i32, {3}, {4, 5, 6}),
|
||||
};
|
||||
|
||||
// Test model description:
|
||||
@ -820,7 +824,7 @@ auto test_forward_reshape_squeeze = []() {
|
||||
OutputVector new_out_vec(out_vec.size());
|
||||
new_out_vec[0] = out_vec[0];
|
||||
new_out_vec[1] =
|
||||
make_shared<Constant>(out_vec[1].get_element_type(), out_vec[1].get_shape(), std::vector<int64_t>{6, 5, 4});
|
||||
make_shared<Constant>(out_vec[1].get_element_type(), out_vec[1].get_shape(), std::vector<int64_t>{6, 5, 4});
|
||||
return new_out_vec;
|
||||
};
|
||||
test_case.model_ref.preprocess_inputs_to_main = {{new_constant}, {{1}}};
|
||||
@ -831,7 +835,9 @@ auto test_forward_reshape_squeeze = []() {
|
||||
return wrapper(test_case);
|
||||
};
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(TransposeSinkingCommonReshapeSqueezeForward, TransposeSinkingTestFixture, test_forward_reshape_squeeze());
|
||||
INSTANTIATE_TEST_SUITE_P(TransposeSinkingCommonReshapeSqueezeForward,
|
||||
TransposeSinkingTestFixture,
|
||||
test_forward_reshape_squeeze());
|
||||
|
||||
auto test_forward_reshape_unsqueeze = []() {
|
||||
TestCase test_case;
|
||||
@ -840,8 +846,8 @@ auto test_forward_reshape_unsqueeze = []() {
|
||||
test_case.transformation = CREATE_PASS_FACTORY(TSUnsqueezeForward);
|
||||
test_case.num_main_ops = {1};
|
||||
test_case.inputs_to_main = {
|
||||
parameter(element::f32, {6, 5, 4}),
|
||||
constant<int64_t>(element::i32, {5}, {4, 1, 5, 1, 6}),
|
||||
parameter(element::f32, {6, 5, 4}),
|
||||
constant<int64_t>(element::i32, {5}, {4, 1, 5, 1, 6}),
|
||||
};
|
||||
|
||||
// Test model description:
|
||||
@ -863,7 +869,9 @@ auto test_forward_reshape_unsqueeze = []() {
|
||||
return wrapper(test_case);
|
||||
};
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(TransposeSinkingCommonReshapeUnsqueezeForward, TransposeSinkingTestFixture, test_forward_reshape_unsqueeze());
|
||||
INSTANTIATE_TEST_SUITE_P(TransposeSinkingCommonReshapeUnsqueezeForward,
|
||||
TransposeSinkingTestFixture,
|
||||
test_forward_reshape_unsqueeze());
|
||||
// ------------------ BACKWARD --------------------
|
||||
|
||||
auto test_backward_unary = []() {
|
||||
@ -1141,7 +1149,6 @@ INSTANTIATE_TEST_SUITE_P(TransposeSinkingCommonInterpolateBackward,
|
||||
TransposeSinkingTestFixture,
|
||||
test_backward_interpolate());
|
||||
|
||||
|
||||
auto test_backward_squeeze = []() {
|
||||
TestCase test_case;
|
||||
|
||||
@ -1149,8 +1156,8 @@ auto test_backward_squeeze = []() {
|
||||
test_case.transformation = CREATE_PASS_FACTORY(TSSqueezeBackward);
|
||||
test_case.num_main_ops = {1};
|
||||
test_case.inputs_to_main = {
|
||||
parameter(element::f32, {32, 1, 2, 1}),
|
||||
constant<int64_t>(element::i32, {2}, {1, 3}),
|
||||
parameter(element::f32, {32, 1, 2, 1}),
|
||||
constant<int64_t>(element::i32, {2}, {1, 3}),
|
||||
};
|
||||
|
||||
// Test model description:
|
||||
@ -1182,8 +1189,8 @@ auto test_backward_unsqueeze = []() {
|
||||
test_case.transformation = CREATE_PASS_FACTORY(TSUnsqueezeBackward);
|
||||
test_case.num_main_ops = {1};
|
||||
test_case.inputs_to_main = {
|
||||
parameter(element::f32, {32, 3, 2, 1}),
|
||||
constant<int64_t>(element::i32, {2}, {0, 2}),
|
||||
parameter(element::f32, {32, 3, 2, 1}),
|
||||
constant<int64_t>(element::i32, {2}, {0, 2}),
|
||||
};
|
||||
|
||||
// Test model description:
|
||||
@ -1196,7 +1203,7 @@ auto test_backward_unsqueeze = []() {
|
||||
OutputVector new_out_vec(out_vec.size());
|
||||
new_out_vec[0] = out_vec[0];
|
||||
new_out_vec[1] =
|
||||
make_shared<Constant>(out_vec[1].get_element_type(), out_vec[1].get_shape(), std::vector<int64_t>{5, 3});
|
||||
make_shared<Constant>(out_vec[1].get_element_type(), out_vec[1].get_shape(), std::vector<int64_t>{5, 3});
|
||||
return new_out_vec;
|
||||
};
|
||||
test_case.model_ref.preprocess_inputs_to_main = {{set_transpose_for, new_constant}, {{0}, {1}}};
|
||||
@ -1206,7 +1213,9 @@ auto test_backward_unsqueeze = []() {
|
||||
return wrapper(test_case);
|
||||
};
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(TransposeSinkingCommonUnsqueezeBackward, TransposeSinkingTestFixture, test_backward_unsqueeze());
|
||||
INSTANTIATE_TEST_SUITE_P(TransposeSinkingCommonUnsqueezeBackward,
|
||||
TransposeSinkingTestFixture,
|
||||
test_backward_unsqueeze());
|
||||
|
||||
auto test_backward_slice = []() {
|
||||
TestCase test_case;
|
||||
@ -1215,11 +1224,11 @@ auto test_backward_slice = []() {
|
||||
test_case.transformation = CREATE_PASS_FACTORY(TSSliceBackward);
|
||||
test_case.num_main_ops = {1};
|
||||
test_case.inputs_to_main = {
|
||||
parameter(element::f32, {6, 4, 5, 3}),
|
||||
constant<int64_t>(element::i32, {3}, {1, 2, 3}),
|
||||
constant<int64_t>(element::i32, {3}, {0, 4, 11}),
|
||||
constant<int64_t>(element::i32, {3}, {1, 2, -1}),
|
||||
constant<int64_t>(element::i32, {3}, {0, 1, 2}),
|
||||
parameter(element::f32, {6, 4, 5, 3}),
|
||||
constant<int64_t>(element::i32, {3}, {1, 2, 3}),
|
||||
constant<int64_t>(element::i32, {3}, {0, 4, 11}),
|
||||
constant<int64_t>(element::i32, {3}, {1, 2, -1}),
|
||||
constant<int64_t>(element::i32, {3}, {0, 1, 2}),
|
||||
};
|
||||
|
||||
// Test model description:
|
||||
@ -1257,8 +1266,8 @@ auto test_backward_reshape_squeeze = []() {
|
||||
test_case.transformation = CREATE_PASS_FACTORY(TSSqueezeBackward);
|
||||
test_case.num_main_ops = {1};
|
||||
test_case.inputs_to_main = {
|
||||
parameter(element::f32, {4, 1, 5, 1, 6}),
|
||||
constant<int64_t>(element::i32, {3}, {4, 5, 6}),
|
||||
parameter(element::f32, {4, 1, 5, 1, 6}),
|
||||
constant<int64_t>(element::i32, {3}, {4, 5, 6}),
|
||||
};
|
||||
|
||||
// Test model description:
|
||||
@ -1281,7 +1290,9 @@ auto test_backward_reshape_squeeze = []() {
|
||||
return wrapper(test_case);
|
||||
};
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(TransposeSinkingCommonReshapeSqueezeBackward, TransposeSinkingTestFixture, test_backward_reshape_squeeze());
|
||||
INSTANTIATE_TEST_SUITE_P(TransposeSinkingCommonReshapeSqueezeBackward,
|
||||
TransposeSinkingTestFixture,
|
||||
test_backward_reshape_squeeze());
|
||||
|
||||
auto test_backward_reshape_unsqueeze = []() {
|
||||
TestCase test_case;
|
||||
@ -1290,8 +1301,8 @@ auto test_backward_reshape_unsqueeze = []() {
|
||||
test_case.transformation = CREATE_PASS_FACTORY(TSUnsqueezeBackward);
|
||||
test_case.num_main_ops = {1};
|
||||
test_case.inputs_to_main = {
|
||||
parameter(element::f32, {4, 5, 6}),
|
||||
constant<int64_t>(element::i32, {5}, {4, 1, 5, 1, 6}),
|
||||
parameter(element::f32, {4, 5, 6}),
|
||||
constant<int64_t>(element::i32, {5}, {4, 1, 5, 1, 6}),
|
||||
};
|
||||
|
||||
// Test model description:
|
||||
@ -1303,8 +1314,9 @@ auto test_backward_reshape_unsqueeze = []() {
|
||||
auto new_constant = [](const vector<size_t>& idxs, const OutputVector& out_vec) -> OutputVector {
|
||||
OutputVector new_out_vec(out_vec.size());
|
||||
new_out_vec[0] = out_vec[0];
|
||||
new_out_vec[1] =
|
||||
make_shared<Constant>(out_vec[1].get_element_type(), out_vec[1].get_shape(), std::vector<int64_t>{6, 1, 5, 1, 4});
|
||||
new_out_vec[1] = make_shared<Constant>(out_vec[1].get_element_type(),
|
||||
out_vec[1].get_shape(),
|
||||
std::vector<int64_t>{6, 1, 5, 1, 4});
|
||||
return new_out_vec;
|
||||
};
|
||||
test_case.model_ref.preprocess_inputs_to_main = {{set_transpose_for, new_constant}, {{0}, {1}}};
|
||||
@ -1314,7 +1326,9 @@ auto test_backward_reshape_unsqueeze = []() {
|
||||
return wrapper(test_case);
|
||||
};
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(TransposeSinkingCommonReshapeUnsqueezeBackward, TransposeSinkingTestFixture, test_backward_reshape_unsqueeze());
|
||||
INSTANTIATE_TEST_SUITE_P(TransposeSinkingCommonReshapeUnsqueezeBackward,
|
||||
TransposeSinkingTestFixture,
|
||||
test_backward_reshape_unsqueeze());
|
||||
} // namespace common
|
||||
} // namespace testing
|
||||
} // namespace transpose_sinking
|
Loading…
Reference in New Issue
Block a user