Remove deprecated legacy transpose ref function (#5765)

This commit is contained in:
Katarzyna Mitrus 2021-05-24 19:05:29 +02:00 committed by GitHub
parent e54a3882ee
commit eba2410411
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 2 additions and 54 deletions

View File

@ -31,50 +31,6 @@ namespace ngraph
runtime::opt_kernel::reshape( runtime::opt_kernel::reshape(
data, out, data_shape, axis_vector, out_shape, element_size); data, out, data_shape, axis_vector, out_shape, element_size);
} }
// Legacy function template to ensure backward compatibility
// Can be removed after ARM plugin start using evaluate or no template function
template <typename T, typename U>
NGRAPH_DEPRECATED(
"Traspose function with template types is deprecated, use function with char* "
"args.")
void transpose(const T* arg, T* out, Shape arg_shape, const U* axes_order = nullptr)
{
std::vector<std::int64_t> converted_axes_order(arg_shape.size());
if (axes_order == nullptr)
{
std::iota(converted_axes_order.begin(), converted_axes_order.end(), 0);
std::reverse(converted_axes_order.begin(), converted_axes_order.end());
}
else
{
for (size_t i = 0; i < converted_axes_order.size(); ++i)
{
converted_axes_order[i] = static_cast<std::int64_t>(axes_order[i]);
}
}
Shape output_shape(arg_shape.size());
std::transform(
converted_axes_order.begin(),
converted_axes_order.end(),
output_shape.begin(),
[&](const int64_t& v) {
NGRAPH_CHECK(v >= 0,
"Negative values for transpose axes order are not supported.");
NGRAPH_CHECK(v < int64_t(arg_shape.size()),
"Transpose axis ",
v,
" is out of shape range.");
return arg_shape[v];
});
transpose(reinterpret_cast<const char*>(arg),
reinterpret_cast<char*>(out),
arg_shape,
sizeof(T),
converted_axes_order.data(),
output_shape);
}
} // namespace reference } // namespace reference
} // namespace runtime } // namespace runtime
} // namespace ngraph } // namespace ngraph

View File

@ -34,7 +34,7 @@ void test_tranpose_eval(shared_ptr<Function> fun)
std::vector<std::vector<T>> expected_results{ std::vector<std::vector<T>> expected_results{
{1, 2, 3, 4, 5, 6}, {1, 4, 2, 5, 3, 6}, {1, 4, 2, 5, 3, 6}, {1, 4, 2, 5, 3, 6}, {1, 7, 4, 10, 2, 8, 5, 11, 3, 9, 6, 12}}; {1, 2, 3, 4, 5, 6}, {1, 4, 2, 5, 3, 6}, {1, 4, 2, 5, 3, 6}, {1, 4, 2, 5, 3, 6}, {1, 7, 4, 10, 2, 8, 5, 11, 3, 9, 6, 12}};
std::vector<Shape> expected_result_shapes{{2, 3}, {3, 2}, {3, 2}, {3, 1, 2}, {3, 2, 2}}; std::vector<Shape> expected_result_shapes{{2, 3}, {3, 2}, {3, 2}, {3, 1, 2}, {3, 2, 2}};
for (size_t i = 0; i < data_shapes.size(); i++) for (size_t i = 0; i < data_shapes.size(); i++)
{ {
auto result_tensor = make_shared<HostTensor>(element::dynamic, PartialShape::dynamic()); auto result_tensor = make_shared<HostTensor>(element::dynamic, PartialShape::dynamic());
@ -44,14 +44,6 @@ void test_tranpose_eval(shared_ptr<Function> fun)
auto actual_results = read_vector<T>(result_tensor); auto actual_results = read_vector<T>(result_tensor);
ASSERT_EQ(actual_results, expected_results[i]); ASSERT_EQ(actual_results, expected_results[i]);
{ // Temporary test for legacy reference function template
NGRAPH_SUPPRESS_DEPRECATED_START
std::vector<T> ref_results(input_data[i].size());
runtime::reference::transpose<T, T_AXIS>(input_data[i].data(), ref_results.data(), data_shapes[i], axes_order[i].data());
ASSERT_EQ(ref_results, expected_results[i]);
NGRAPH_SUPPRESS_DEPRECATED_END
}
} }
} }
@ -77,7 +69,7 @@ TEST(op_eval, eval_transpose)
const auto input_floating = make_shared<op::Parameter>(element::f32, PartialShape::dynamic()); const auto input_floating = make_shared<op::Parameter>(element::f32, PartialShape::dynamic());
const auto transpose_floating = make_shared<op::v1::Transpose>(input_floating, axis); const auto transpose_floating = make_shared<op::v1::Transpose>(input_floating, axis);
const auto function_floating = make_shared<Function>(OutputVector{transpose_floating}, ParameterVector{input_floating, axis}); const auto function_floating = make_shared<Function>(OutputVector{transpose_floating}, ParameterVector{input_floating, axis});
switch (axis->get_element_type()) switch (axis->get_element_type())
{ {
case element::Type_t::i8: case element::Type_t::i8: