Opset 1 transpose shape inference review (#12937)
* Test to interval shape propagated by transpose * Test to propagate labels by transpose * Add template transpose shape inference * Fixes to transpose shape inference * Update names for shapes: input -> input_shape order -> order_shape * Not fill output shape for dynamic range * Add constexpr to SeqGen and Between comparator * Correct StaticShape creation in test * Tests check partial value propagate in arg input * Add evaluate upper, lower, label to transpose - add test * Add common methods for inference and evaluate * Move helpers to shape_inference * Move transpose attribute to transpose op * Fix include in transpose operator * Correct label generation and type * Fix null conversion * Use uint64_t for labels tensor * Fix compare labels * Use order length as output rank * Update transpose transformation test * Move helpers to validation_util * Correct test assertion for expected shape * Transpose evaluate use common function for output calculation * Remove redundant helpers from transpose test
This commit is contained in:
parent
457f606812
commit
b4ad7033c9
@ -133,4 +133,24 @@ OPENVINO_API std::shared_ptr<op::v0::Constant> get_constant_from_source(const Ou
|
||||
/// \param output_labels Vector of TensorLabel objects representing resulting value labels
|
||||
/// \return boolean status if label evaluation was successful.
|
||||
OPENVINO_API bool default_label_evaluator(const Node* node, TensorLabelVector& output_labels);
|
||||
|
||||
/// \brief Generates transpose default axes order at end of input vector.
|
||||
///
|
||||
/// Default axes order is decreasing sequence numbers which start from `length - 1`.
|
||||
///
|
||||
/// \param axes_order Vector where default order will be generated.
|
||||
/// \param length Sequence length of axes order.
|
||||
///
|
||||
OPENVINO_API void generate_transpose_default_order(std::vector<int64_t>& axes_order, const size_t length);
|
||||
|
||||
/// \brief Check if vector of axes order has got valid values.
|
||||
///
|
||||
/// Axes order has to be unique numbers in range of [0, size).
|
||||
///
|
||||
/// \param axes_order Vector with axes order to check.
|
||||
/// \param size Input for transpose rank size.
|
||||
///
|
||||
/// \return true if axes order is valid otherwise false.
|
||||
///
|
||||
OPENVINO_API bool is_valid_axes_order(const std::vector<int64_t>& axes_order, const size_t size);
|
||||
} // namespace ov
|
||||
|
@ -25,7 +25,9 @@
|
||||
the same name and not all of them are overrided in Derived class, the only overrided methods \
|
||||
will be available from Derived class. We need to explicitly cast Derived to Base class to \
|
||||
have an access to remaining methods or use this using. */ \
|
||||
using ov::op::Op::evaluate;
|
||||
using ov::op::Op::evaluate; \
|
||||
using ov::op::Op::evaluate_lower; \
|
||||
using ov::op::Op::evaluate_upper;
|
||||
|
||||
namespace ov {
|
||||
namespace op {
|
||||
|
@ -35,8 +35,17 @@ public:
|
||||
|
||||
OPENVINO_SUPPRESS_DEPRECATED_START
|
||||
bool evaluate(const HostTensorVector& outputs, const HostTensorVector& inputs) const override;
|
||||
bool evaluate_upper(const HostTensorVector& output_values) const override;
|
||||
bool evaluate_lower(const HostTensorVector& output_values) const override;
|
||||
OPENVINO_SUPPRESS_DEPRECATED_END
|
||||
|
||||
bool has_evaluate() const override;
|
||||
bool evaluate_label(TensorLabelVector& output_labels) const override;
|
||||
|
||||
/// \brief Inputs indexes and count.
|
||||
enum Ins : size_t { ARG, ORDER, IN_COUNT };
|
||||
/// \brief Outputs indexes and count.
|
||||
enum Outs : size_t { ARG_T, OUT_COUNT };
|
||||
};
|
||||
} // namespace v1
|
||||
} // namespace op
|
||||
|
52
src/core/shape_inference/include/compare.hpp
Normal file
52
src/core/shape_inference/include/compare.hpp
Normal file
@ -0,0 +1,52 @@
|
||||
// Copyright (C) 2018-2022 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#pragma once
|
||||
|
||||
namespace ov {
|
||||
namespace cmp {
|
||||
/** \brief Enumerate bounds to compare */
|
||||
enum Bound : uint8_t { NONE, LOWER, UPPER, BOTH };
|
||||
|
||||
/**
|
||||
* \brief Compare if value is between lower and upper bounds.
|
||||
*
|
||||
* The Between comparator has four modes to check value:
|
||||
* - Bound::None (lower, upper)
|
||||
* - Bound::LOWER [lower, upper)
|
||||
* - Bound::UPPER (lower, upper]
|
||||
* - Bound::BOTH [lower, upper]
|
||||
*
|
||||
* \tparam T Value type to compare.
|
||||
* \tparam BMode Compare bounds mode.
|
||||
*/
|
||||
template <class T, Bound BMode = Bound::NONE>
|
||||
class Between {
|
||||
T _lower_bound, _upper_bound;
|
||||
|
||||
public:
|
||||
constexpr Between(const T& lower, const T& upper) : _lower_bound{lower}, _upper_bound{upper} {}
|
||||
|
||||
template <Bound B = BMode, typename std::enable_if<B == Bound::NONE>::type* = nullptr>
|
||||
constexpr bool operator()(const T& value) const {
|
||||
return (_lower_bound < value) && (value < _upper_bound);
|
||||
}
|
||||
|
||||
template <Bound B = BMode, typename std::enable_if<B == Bound::LOWER>::type* = nullptr>
|
||||
constexpr bool operator()(const T& value) const {
|
||||
return (_lower_bound <= value) && (value < _upper_bound);
|
||||
}
|
||||
|
||||
template <Bound B = BMode, typename std::enable_if<B == Bound::UPPER>::type* = nullptr>
|
||||
constexpr bool operator()(const T& value) const {
|
||||
return (_lower_bound < value) && (value <= _upper_bound);
|
||||
}
|
||||
|
||||
template <Bound B = BMode, typename std::enable_if<B == Bound::BOTH>::type* = nullptr>
|
||||
constexpr bool operator()(const T& value) const {
|
||||
return (_lower_bound <= value) && (value <= _upper_bound);
|
||||
}
|
||||
};
|
||||
} // namespace cmp
|
||||
} // namespace ov
|
35
src/core/shape_inference/include/sequnce_generator.hpp
Normal file
35
src/core/shape_inference/include/sequnce_generator.hpp
Normal file
@ -0,0 +1,35 @@
|
||||
// Copyright (C) 2018-2022 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#pragma once
|
||||
|
||||
namespace ov {
|
||||
/** \brief Enumerate directions */
|
||||
enum Direction : uint8_t { FORWARD, BACKWARD };
|
||||
|
||||
/**
|
||||
* \brief Infinite generator of sequence increasing values.
|
||||
*
|
||||
* Start value can be specified.
|
||||
*
|
||||
* \tparam T Type of sequence values (must support `++` or '--' operators).
|
||||
*/
|
||||
template <class T, Direction D = Direction::FORWARD>
|
||||
class SeqGen {
|
||||
T _counter;
|
||||
|
||||
public:
|
||||
constexpr SeqGen(const T& start) : _counter{start} {}
|
||||
|
||||
template <Direction Di = D, typename std::enable_if<Di == Direction::FORWARD>::type* = nullptr>
|
||||
T operator()() {
|
||||
return _counter++;
|
||||
}
|
||||
|
||||
template <Direction Di = D, typename std::enable_if<Di == Direction::BACKWARD>::type* = nullptr>
|
||||
T operator()() {
|
||||
return _counter--;
|
||||
}
|
||||
};
|
||||
} // namespace ov
|
@ -0,0 +1,78 @@
|
||||
// Copyright (C) 2018-2022 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
#pragma once
|
||||
|
||||
#include "openvino/op/transpose.hpp"
|
||||
#include "utils.hpp"
|
||||
|
||||
namespace ov {
|
||||
namespace op {
|
||||
namespace v1 {
|
||||
|
||||
/**
|
||||
* \brief Calculate transpose output shape.
|
||||
*
|
||||
* \tparam T Type of shape
|
||||
*
|
||||
* \param op Transpose operator pointer.
|
||||
* \param input_shape Transpose input shape.
|
||||
* \param axes_order Transpose axes order (modified if empty).
|
||||
*
|
||||
* \return Output shape
|
||||
*/
|
||||
template <class T>
|
||||
T calc_output_shape(const Transpose* const op, const T& input_shape, std::vector<int64_t>& axes_order) {
|
||||
const auto output_rank = input_shape.size();
|
||||
|
||||
if (axes_order.empty()) {
|
||||
generate_transpose_default_order(axes_order, output_rank);
|
||||
} else {
|
||||
NODE_VALIDATION_CHECK(op,
|
||||
is_valid_axes_order(axes_order, output_rank),
|
||||
"Permutation ",
|
||||
AxisVector(axes_order.begin(), axes_order.end()),
|
||||
" is not valid for input shape ",
|
||||
input_shape);
|
||||
}
|
||||
|
||||
T output_shape;
|
||||
for (auto&& axis : axes_order) {
|
||||
output_shape.push_back(input_shape[axis]);
|
||||
}
|
||||
|
||||
return output_shape;
|
||||
}
|
||||
|
||||
/**
|
||||
* \brief Do transpose inference on input and output shapes.
|
||||
*
|
||||
* \tparam T Type of inference shapes.
|
||||
*
|
||||
* \param op Transpose operator pointer.
|
||||
* \param input_shapes Input shapes of transpose.
|
||||
* \param output_shapes Output shapes of transpose which be modified by inference.
|
||||
* \param constant_data Map of constant data.
|
||||
*/
|
||||
template <class T>
|
||||
void shape_infer(const Transpose* op,
|
||||
const std::vector<T>& input_shapes,
|
||||
std::vector<T>& output_shapes,
|
||||
const std::map<size_t, std::shared_ptr<ngraph::runtime::HostTensor>>& constant_data = {}) {
|
||||
const auto& input_shape = input_shapes[Transpose::ARG];
|
||||
auto& output_shape = output_shapes[Transpose::ARG_T];
|
||||
|
||||
std::vector<int64_t> axes;
|
||||
const auto has_order = get_data_as_int64<T>(Transpose::ORDER, op, axes, constant_data);
|
||||
|
||||
if (has_order && input_shape.rank().is_static()) {
|
||||
output_shape = calc_output_shape(op, input_shape, axes);
|
||||
} else if (has_order) {
|
||||
output_shape = ov::PartialShape::dynamic(axes.size());
|
||||
} else {
|
||||
output_shape = ov::PartialShape::dynamic(input_shape.rank());
|
||||
}
|
||||
}
|
||||
} // namespace v1
|
||||
} // namespace op
|
||||
} // namespace ov
|
@ -4,10 +4,10 @@
|
||||
|
||||
#include "ngraph/op/transpose.hpp"
|
||||
|
||||
#include <ngraph/validation_util.hpp>
|
||||
|
||||
#include "itt.hpp"
|
||||
#include "ngraph/runtime/reference/transpose.hpp"
|
||||
#include "ngraph/validation_util.hpp"
|
||||
#include "transpose_shape_inference.hpp"
|
||||
|
||||
using namespace std;
|
||||
using namespace ngraph;
|
||||
@ -18,102 +18,81 @@ op::v1::Transpose::Transpose(const Output<Node>& arg, const Output<Node>& input_
|
||||
constructor_validate_and_infer_types();
|
||||
}
|
||||
|
||||
bool ngraph::op::v1::Transpose::visit_attributes(AttributeVisitor& visitor) {
|
||||
bool op::v1::Transpose::visit_attributes(AttributeVisitor& visitor) {
|
||||
OV_OP_SCOPE(v1_Transpose_visit_attributes);
|
||||
return true;
|
||||
}
|
||||
|
||||
void op::v1::Transpose::validate_and_infer_types() {
|
||||
OV_OP_SCOPE(v1_Transpose_validate_and_infer_types);
|
||||
const auto& input_order_et = get_input_element_type(1);
|
||||
const auto& input_order_et = get_input_element_type(ORDER);
|
||||
NODE_VALIDATION_CHECK(this,
|
||||
input_order_et.is_dynamic() || input_order_et.is_integral_number(),
|
||||
"Input order must have an integral number element type.");
|
||||
|
||||
const auto& input_order_shape = get_input_partial_shape(1);
|
||||
const auto& input_order_shape = get_input_partial_shape(ORDER);
|
||||
NODE_VALIDATION_CHECK(this, input_order_shape.rank().compatible(1), "Input order must be a vector.");
|
||||
|
||||
const auto& arg_shape = get_input_partial_shape(0);
|
||||
const auto& arg_shape = get_input_partial_shape(ARG);
|
||||
NODE_VALIDATION_CHECK(
|
||||
this,
|
||||
input_order_shape.compatible(ov::PartialShape{arg_shape.rank()}) ||
|
||||
(input_order_shape.is_static() && input_order_shape.rank() == 1 && input_order_shape[0] == 0),
|
||||
"Input order must have shape [n], where n is the rank of arg.");
|
||||
|
||||
set_input_is_relevant_to_shape(1);
|
||||
set_input_is_relevant_to_shape(ORDER);
|
||||
|
||||
NGRAPH_SUPPRESS_DEPRECATED_START
|
||||
if (const auto& input_const = get_constant_from_source(input_value(1))) {
|
||||
auto permutation = input_const->get_axis_vector_val();
|
||||
if (permutation.empty()) {
|
||||
for (int64_t i = 1; i <= arg_shape.rank().get_length(); ++i)
|
||||
permutation.emplace_back(arg_shape.rank().get_length() - i);
|
||||
}
|
||||
NODE_VALIDATION_CHECK(this,
|
||||
is_valid_permutation(permutation, arg_shape.rank()),
|
||||
"Permutation ",
|
||||
permutation,
|
||||
" is not valid for input shape ",
|
||||
arg_shape);
|
||||
set_output_type(0, get_input_element_type(0), ngraph::apply_permutation(arg_shape, permutation));
|
||||
} else {
|
||||
Rank output_rank = arg_shape.rank();
|
||||
if (output_rank.is_dynamic() && input_order_shape.is_static() && input_order_shape[0].get_length())
|
||||
output_rank = input_order_shape[0];
|
||||
set_output_type(0, get_input_element_type(0), ov::PartialShape::dynamic(output_rank));
|
||||
}
|
||||
NGRAPH_SUPPRESS_DEPRECATED_END
|
||||
std::vector<ov::PartialShape> input_shapes{arg_shape, input_order_shape};
|
||||
std::vector<ov::PartialShape> output_shapes(OUT_COUNT, ov::PartialShape{});
|
||||
|
||||
shape_infer(this, input_shapes, output_shapes);
|
||||
|
||||
set_output_size(output_shapes.size());
|
||||
set_output_type(ARG, get_input_element_type(ARG), output_shapes[ARG_T]);
|
||||
}
|
||||
|
||||
shared_ptr<Node> op::v1::Transpose::clone_with_new_inputs(const OutputVector& new_args) const {
|
||||
OV_OP_SCOPE(v1_Transpose_clone_with_new_inputs);
|
||||
check_new_args_count(this, new_args);
|
||||
return make_shared<v1::Transpose>(new_args[0], new_args[1]);
|
||||
return make_shared<v1::Transpose>(new_args[ARG], new_args[ORDER]);
|
||||
}
|
||||
|
||||
namespace transpose {
|
||||
namespace {
|
||||
bool evaluate_transpose(const HostTensorPtr& arg1, const HostTensorPtr& arg2, const HostTensorPtr& out) {
|
||||
NGRAPH_CHECK(arg2->get_element_type().is_integral_number(),
|
||||
"Transpose axis element type has to be integral data type.");
|
||||
|
||||
std::vector<int64_t> axes_order = host_tensor_2_vector<int64_t>(arg2);
|
||||
ov::Shape in_shape = arg1->get_shape();
|
||||
if (shape_size(arg2->get_shape()) == 0) {
|
||||
axes_order.resize(in_shape.size());
|
||||
std::iota(axes_order.begin(), axes_order.end(), 0);
|
||||
std::reverse(axes_order.begin(), axes_order.end());
|
||||
} else {
|
||||
std::unordered_set<int64_t> axes_set(axes_order.begin(), axes_order.end());
|
||||
bool is_unique_order = axes_set.size() == axes_order.size();
|
||||
NGRAPH_CHECK(is_unique_order, "Transpose axes order values must be unique.");
|
||||
}
|
||||
|
||||
ov::Shape out_shape(in_shape.size());
|
||||
std::transform(axes_order.begin(), axes_order.end(), out_shape.begin(), [&](const int64_t& v) {
|
||||
NGRAPH_CHECK(v >= 0, "Negative values for transpose axes order are not supported.");
|
||||
NGRAPH_CHECK(v < int64_t(in_shape.size()), "Transpose axis ", v, " is out of shape range.");
|
||||
return in_shape[v];
|
||||
});
|
||||
|
||||
out->set_shape(out_shape);
|
||||
out->set_element_type(arg1->get_element_type());
|
||||
runtime::reference::transpose(arg1->get_data_ptr<char>(),
|
||||
out->get_data_ptr<char>(),
|
||||
arg1->get_shape(),
|
||||
arg1->get_element_type().size(),
|
||||
axes_order.data(),
|
||||
out_shape);
|
||||
return true;
|
||||
}
|
||||
} // namespace
|
||||
} // namespace transpose
|
||||
bool op::v1::Transpose::evaluate(const HostTensorVector& output_values, const HostTensorVector& input_values) const {
|
||||
OV_OP_SCOPE(v1_Transpose_evaluate);
|
||||
return transpose::evaluate_transpose(input_values[0], input_values[1], output_values[0]);
|
||||
|
||||
const auto& order = input_values[ORDER];
|
||||
OPENVINO_ASSERT(order->get_element_type().is_integral_number(),
|
||||
"Transpose axis element type has to be integral data type.");
|
||||
|
||||
const auto& arg = input_values[ARG];
|
||||
std::vector<int64_t> axes_order = host_tensor_2_vector<int64_t>(order);
|
||||
auto out_shape = calc_output_shape(this, arg->get_shape(), axes_order);
|
||||
|
||||
auto& out = output_values[ARG_T];
|
||||
out->set_shape(out_shape);
|
||||
out->set_element_type(arg->get_element_type());
|
||||
ngraph::runtime::reference::transpose(arg->get_data_ptr<char>(),
|
||||
out->get_data_ptr<char>(),
|
||||
arg->get_shape(),
|
||||
arg->get_element_type().size(),
|
||||
axes_order.data(),
|
||||
out_shape);
|
||||
return true;
|
||||
}
|
||||
|
||||
bool op::v1::Transpose::has_evaluate() const {
|
||||
OV_OP_SCOPE(v1_Transpose_has_evaluate);
|
||||
return get_input_element_type(1).is_integral_number();
|
||||
}
|
||||
|
||||
bool op::v1::Transpose::evaluate_lower(const HostTensorVector& output_values) const {
|
||||
return get_input_tensor(ORDER).has_and_set_bound() && default_lower_bound_evaluator(this, output_values);
|
||||
}
|
||||
|
||||
bool op::v1::Transpose::evaluate_upper(const HostTensorVector& output_values) const {
|
||||
return get_input_tensor(ORDER).has_and_set_bound() && default_upper_bound_evaluator(this, output_values);
|
||||
}
|
||||
|
||||
bool op::v1::Transpose::evaluate_label(TensorLabelVector& output_labels) const {
|
||||
return get_input_tensor(ORDER).has_and_set_bound() && default_label_evaluator(this, output_labels);
|
||||
}
|
||||
|
@ -10,6 +10,7 @@
|
||||
#include <ngraph/rt_info.hpp>
|
||||
#include <numeric>
|
||||
|
||||
#include "compare.hpp"
|
||||
#include "ngraph/evaluator.hpp"
|
||||
#include "ngraph/op/concat.hpp"
|
||||
#include "ngraph/op/convert.hpp"
|
||||
@ -23,6 +24,7 @@
|
||||
#include "ngraph/shape.hpp"
|
||||
#include "ngraph/type/element_type_traits.hpp"
|
||||
#include "ngraph/util.hpp"
|
||||
#include "sequnce_generator.hpp"
|
||||
|
||||
NGRAPH_SUPPRESS_DEPRECATED_START
|
||||
using namespace std;
|
||||
@ -1642,3 +1644,13 @@ bool ngraph::validate_host_tensor_vector(const HostTensorVector& tensor_vector,
|
||||
return t != nullptr;
|
||||
});
|
||||
}
|
||||
|
||||
void ov::generate_transpose_default_order(std::vector<int64_t>& axes_order, const size_t length) {
|
||||
axes_order.reserve(length);
|
||||
std::generate_n(std::back_inserter(axes_order), length, ov::SeqGen<size_t, ov::Direction::BACKWARD>(length - 1));
|
||||
}
|
||||
|
||||
bool ov::is_valid_axes_order(const std::vector<int64_t>& axes_order, const size_t size) {
|
||||
return (std::unordered_set<size_t>(axes_order.cbegin(), axes_order.cend()).size() == size) &&
|
||||
std::all_of(axes_order.cbegin(), axes_order.cend(), ov::cmp::Between<int64_t, ov::cmp::LOWER>(0, size));
|
||||
}
|
||||
|
@ -8,11 +8,13 @@
|
||||
#include <vector>
|
||||
|
||||
#include "engines_util/execute_tools.hpp"
|
||||
#include "gtest/gtest.h"
|
||||
#include "gmock/gmock.h"
|
||||
#include "ngraph/runtime/host_tensor.hpp"
|
||||
#include "ngraph/runtime/reference/transpose.hpp"
|
||||
#include "ngraph/util.hpp"
|
||||
#include "ngraph/validation_util.hpp"
|
||||
#include "openvino/opsets/opset9.hpp"
|
||||
#include "sequnce_generator.hpp"
|
||||
#include "util/all_close_f.hpp"
|
||||
#include "util/test_tools.hpp"
|
||||
#include "util/type_prop.hpp"
|
||||
@ -154,7 +156,7 @@ TEST(op_eval, eval_duplicated_axes_transpose) {
|
||||
|
||||
FAIL() << "Duplicated axes values not detected";
|
||||
} catch (const ngraph_error& error) {
|
||||
EXPECT_HAS_SUBSTRING(error.what(), std::string("must be unique"));
|
||||
EXPECT_HAS_SUBSTRING(error.what(), std::string("Permutation AxisVector{2, 1, 2} is not valid for input shape"));
|
||||
} catch (...) {
|
||||
FAIL() << "Failed for unexpected reason";
|
||||
}
|
||||
@ -179,7 +181,7 @@ TEST(op_eval, eval_out_of_shape_axes_transpose) {
|
||||
|
||||
FAIL() << "Out of shape axes not detected";
|
||||
} catch (const ngraph_error& error) {
|
||||
EXPECT_HAS_SUBSTRING(error.what(), std::string("out of shape"));
|
||||
EXPECT_HAS_SUBSTRING(error.what(), std::string("Permutation AxisVector{0, 1, 3} is not valid for input shape"));
|
||||
} catch (...) {
|
||||
FAIL() << "Failed for unexpected reason";
|
||||
}
|
||||
@ -208,8 +210,173 @@ TEST(op_eval, eval_negative_axes_transpose) {
|
||||
ASSERT_EQ(actual_results, expected_result);
|
||||
FAIL() << "Negative axes for Transpose were not supported before.";
|
||||
} catch (const ngraph_error& error) {
|
||||
EXPECT_HAS_SUBSTRING(error.what(), std::string("not supported"));
|
||||
std::stringstream exp_msg;
|
||||
exp_msg << "Permutation " << AxisVector(perm.begin(), perm.end()) << " is not valid for input shape";
|
||||
EXPECT_HAS_SUBSTRING(error.what(), exp_msg.str());
|
||||
} catch (...) {
|
||||
FAIL() << "Failed for unexpected reason";
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
std::vector<T> tensor_to_vector(const ov::Tensor& tensor) {
|
||||
std::vector<T> rc(tensor.data<T>(), tensor.data<T>() + tensor.get_size());
|
||||
return rc;
|
||||
}
|
||||
|
||||
using namespace ov::opset9;
|
||||
using namespace testing;
|
||||
|
||||
using test_param = std::tuple<std::vector<int32_t>, PartialShape>;
|
||||
|
||||
class TransposeEvalTest : public TestWithParam<test_param> {
|
||||
protected:
|
||||
void SetUp() override {
|
||||
std::tie(axes_order, p_shape) = GetParam();
|
||||
|
||||
std::generate_n(std::back_inserter(lower_values),
|
||||
ov::shape_size(p_shape.get_min_shape()),
|
||||
ov::SeqGen<int32_t>(-10));
|
||||
std::generate_n(std::back_inserter(upper_values),
|
||||
ov::shape_size(p_shape.get_min_shape()),
|
||||
ov::SeqGen<int32_t>(20));
|
||||
|
||||
lower_v_tensor = std::make_shared<ov::HostTensor>(dtype, p_shape.get_min_shape(), lower_values.data());
|
||||
upper_v_tensor = std::make_shared<ov::HostTensor>(dtype, p_shape.get_min_shape(), upper_values.data());
|
||||
axes_v_tensor = std::make_shared<ov::HostTensor>(dtype, Shape{axes_order.size()}, axes_order.data());
|
||||
|
||||
arg = make_shared<Parameter>(dtype, p_shape);
|
||||
order = make_shared<Parameter>(dtype, Shape{axes_order.size()});
|
||||
transpose = make_shared<Transpose>(arg, order);
|
||||
|
||||
// prepare result tensors for evaluation
|
||||
result = exp_result = ov::TensorVector{ov::Tensor(dtype, {0})};
|
||||
}
|
||||
|
||||
void node_set_lower_and_upper(ov::Node* node, const HostTensorPtr& lower, const HostTensorPtr& upper) {
|
||||
if (lower != nullptr) {
|
||||
node->get_output_tensor(0).set_lower_value(lower);
|
||||
}
|
||||
|
||||
if (upper != nullptr) {
|
||||
node->get_output_tensor(0).set_upper_value(upper);
|
||||
}
|
||||
}
|
||||
|
||||
PartialShape p_shape;
|
||||
ov::element::Type dtype{ov::element::from<int32_t>()};
|
||||
ov::element::Type label_dtype{ov::element::u64};
|
||||
|
||||
std::vector<int32_t> axes_order, lower_values, upper_values;
|
||||
HostTensorPtr lower_v_tensor, upper_v_tensor, axes_v_tensor;
|
||||
ov::TensorVector result, exp_result;
|
||||
std::shared_ptr<Transpose> transpose;
|
||||
std::shared_ptr<Parameter> arg, order;
|
||||
|
||||
TensorLabel labels;
|
||||
TensorLabelVector out_labels = TensorLabelVector(Transpose::OUT_COUNT);
|
||||
};
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(op_eval,
|
||||
TransposeEvalTest,
|
||||
Values(make_tuple(std::vector<int32_t>{0}, PartialShape{4}),
|
||||
make_tuple(std::vector<int32_t>{0, 1}, PartialShape{2, 5}),
|
||||
make_tuple(std::vector<int32_t>{1, 0}, PartialShape{2, 5}),
|
||||
make_tuple(std::vector<int32_t>{0, 1, 2}, PartialShape{2, 3, 1}),
|
||||
make_tuple(std::vector<int32_t>{1, 2, 0}, PartialShape{2, 3, 1}),
|
||||
make_tuple(std::vector<int32_t>{1, 3, 2, 0}, PartialShape{2, 3, 1, 5})),
|
||||
PrintToStringParamName());
|
||||
|
||||
TEST_P(TransposeEvalTest, evaluate_lower) {
|
||||
node_set_lower_and_upper(arg.get(), lower_v_tensor, upper_v_tensor);
|
||||
node_set_lower_and_upper(order.get(), axes_v_tensor, axes_v_tensor);
|
||||
|
||||
const auto inputs = ov::TensorVector{ov::Tensor(dtype, p_shape.get_min_shape(), lower_values.data()),
|
||||
ov::Tensor(dtype, Shape{axes_order.size()}, axes_order.data())};
|
||||
// evaluate expected values
|
||||
const auto exp_evaluate = transpose->evaluate(exp_result, inputs);
|
||||
|
||||
ASSERT_EQ(transpose->evaluate_lower(result), exp_evaluate);
|
||||
ASSERT_EQ(tensor_to_vector<int32_t>(result[Transpose::ARG_T]),
|
||||
tensor_to_vector<int32_t>(exp_result[Transpose::ARG_T]));
|
||||
}
|
||||
|
||||
TEST_P(TransposeEvalTest, evaluate_lower_but_arg_lower_values_not_set) {
|
||||
node_set_lower_and_upper(arg.get(), nullptr, upper_v_tensor);
|
||||
node_set_lower_and_upper(order.get(), axes_v_tensor, axes_v_tensor);
|
||||
|
||||
ASSERT_FALSE(transpose->evaluate_lower(result));
|
||||
}
|
||||
|
||||
TEST_P(TransposeEvalTest, evaluate_lower_but_order_has_no_bounds_set) {
|
||||
node_set_lower_and_upper(arg.get(), lower_v_tensor, upper_v_tensor);
|
||||
|
||||
ASSERT_FALSE(transpose->evaluate_lower(result));
|
||||
}
|
||||
|
||||
TEST_P(TransposeEvalTest, evaluate_upper) {
|
||||
node_set_lower_and_upper(arg.get(), lower_v_tensor, upper_v_tensor);
|
||||
node_set_lower_and_upper(order.get(), axes_v_tensor, axes_v_tensor);
|
||||
|
||||
auto inputs = ov::TensorVector{ov::Tensor(dtype, p_shape.get_min_shape(), upper_values.data()),
|
||||
ov::Tensor(dtype, Shape{axes_order.size()}, axes_order.data())};
|
||||
// evaluate expected values
|
||||
transpose->evaluate(exp_result, inputs);
|
||||
|
||||
ASSERT_TRUE(transpose->evaluate_upper(result));
|
||||
ASSERT_EQ(tensor_to_vector<int32_t>(result[Transpose::ARG_T]),
|
||||
tensor_to_vector<int32_t>(exp_result[Transpose::ARG_T]));
|
||||
}
|
||||
|
||||
TEST_P(TransposeEvalTest, evaluate_upper_but_arg_upper_values_not_set) {
|
||||
node_set_lower_and_upper(arg.get(), upper_v_tensor, nullptr);
|
||||
node_set_lower_and_upper(order.get(), axes_v_tensor, axes_v_tensor);
|
||||
|
||||
ASSERT_FALSE(transpose->evaluate_upper(result));
|
||||
}
|
||||
|
||||
TEST_P(TransposeEvalTest, evaluate_upper_but_order_has_no_bounds_set) {
|
||||
node_set_lower_and_upper(arg.get(), lower_v_tensor, upper_v_tensor);
|
||||
|
||||
ASSERT_FALSE(transpose->evaluate_upper(result));
|
||||
}
|
||||
|
||||
TEST_P(TransposeEvalTest, evaluate_label_but_empty_label_set) {
|
||||
exp_result = ov::TensorVector{ov::Tensor(label_dtype, {0})};
|
||||
|
||||
labels.resize(ov::shape_size(p_shape.get_shape()), 0);
|
||||
arg->get_default_output().get_tensor().set_value_label(labels);
|
||||
|
||||
node_set_lower_and_upper(order.get(), axes_v_tensor, axes_v_tensor);
|
||||
|
||||
ASSERT_FALSE(transpose->evaluate_label(out_labels));
|
||||
}
|
||||
|
||||
TEST_P(TransposeEvalTest, evaluate_label_but_order_has_no_bound_set) {
|
||||
exp_result = ov::TensorVector{ov::Tensor(label_dtype, {0})};
|
||||
|
||||
std::generate_n(std::back_inserter(labels), ov::shape_size(p_shape.get_shape()), ov::SeqGen<size_t>(30));
|
||||
arg->get_default_output().get_tensor().set_value_label(labels);
|
||||
|
||||
ASSERT_FALSE(transpose->evaluate_label(out_labels));
|
||||
}
|
||||
|
||||
TEST_P(TransposeEvalTest, evaluate_label) {
|
||||
exp_result = ov::TensorVector{ov::Tensor(label_dtype, {0})};
|
||||
|
||||
std::generate_n(std::back_inserter(labels), ov::shape_size(p_shape.get_shape()), ov::SeqGen<size_t>(5));
|
||||
arg->get_default_output().get_tensor().set_value_label(labels);
|
||||
|
||||
node_set_lower_and_upper(order.get(), axes_v_tensor, axes_v_tensor);
|
||||
|
||||
auto labels_u64 = std::vector<uint64_t>(labels.cbegin(), labels.cend());
|
||||
auto inputs = ov::TensorVector{ov::Tensor(label_dtype, p_shape.get_shape(), labels_u64.data()),
|
||||
ov::Tensor(dtype, Shape{axes_order.size()}, axes_order.data())};
|
||||
|
||||
auto exp_eval_result = transpose->evaluate(exp_result, inputs);
|
||||
|
||||
ASSERT_EQ(transpose->evaluate_label(out_labels), exp_eval_result);
|
||||
ASSERT_THAT(
|
||||
out_labels[Transpose::ARG_T],
|
||||
ElementsAreArray(exp_result[Transpose::ARG_T].data<uint64_t>(), exp_result[Transpose::ARG_T].get_size()));
|
||||
}
|
||||
|
@ -2,12 +2,15 @@
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#include "gtest/gtest.h"
|
||||
#include "dimension_tracker.hpp"
|
||||
#include "ngraph/ngraph.hpp"
|
||||
#include "sequnce_generator.hpp"
|
||||
#include "util/type_prop.hpp"
|
||||
|
||||
using namespace std;
|
||||
using namespace ngraph;
|
||||
using namespace testing;
|
||||
using namespace ov::op;
|
||||
|
||||
TEST(type_prop, transpose_arg_static_input_order_static_ok) {
|
||||
auto arg = make_shared<op::Parameter>(element::f32, Shape{2, 4, 6, 8});
|
||||
@ -16,7 +19,7 @@ TEST(type_prop, transpose_arg_static_input_order_static_ok) {
|
||||
auto r = make_shared<op::Transpose>(arg, input_order);
|
||||
|
||||
EXPECT_EQ(r->get_output_element_type(0), element::f32);
|
||||
EXPECT_TRUE(r->get_output_partial_shape(0).same_scheme(PartialShape::dynamic(4)));
|
||||
EXPECT_EQ(r->get_output_partial_shape(0), PartialShape::dynamic(4));
|
||||
}
|
||||
|
||||
TEST(type_prop, transpose_arg_static_input_order_constant_ok) {
|
||||
@ -26,7 +29,7 @@ TEST(type_prop, transpose_arg_static_input_order_constant_ok) {
|
||||
auto r = make_shared<op::Transpose>(arg, input_order);
|
||||
|
||||
EXPECT_EQ(r->get_output_element_type(0), element::f32);
|
||||
EXPECT_TRUE(r->get_output_partial_shape(0).same_scheme(PartialShape{6, 4, 2, 8}));
|
||||
EXPECT_EQ(r->get_output_partial_shape(0), (PartialShape{6, 4, 2, 8}));
|
||||
}
|
||||
|
||||
TEST(type_prop, transpose_arg_static_input_order_constant_invalid_perm) {
|
||||
@ -44,6 +47,21 @@ TEST(type_prop, transpose_arg_static_input_order_constant_invalid_perm) {
|
||||
}
|
||||
}
|
||||
|
||||
TEST(type_prop, transpose_with_not_unique_order) {
|
||||
const auto order = std::vector<size_t>{1, 0, 1};
|
||||
auto arg = make_shared<op::Parameter>(element::f32, Shape{1, 4, 300});
|
||||
auto input_order = make_shared<op::Constant>(element::i64, Shape{order.size()}, order);
|
||||
|
||||
try {
|
||||
auto r = make_shared<op::Transpose>(arg, input_order);
|
||||
FAIL() << "Did not detect invalid permutation";
|
||||
} catch (const NodeValidationFailure& error) {
|
||||
EXPECT_HAS_SUBSTRING(error.what(), std::string("Permutation AxisVector{1, 0, 1} is not valid for input shape"));
|
||||
} catch (...) {
|
||||
FAIL() << "Deduced type check failed for unexpected reason";
|
||||
}
|
||||
}
|
||||
|
||||
TEST(type_prop, transpose_arg_rank_static_dynamic_input_order_static_ok) {
|
||||
auto arg = make_shared<op::Parameter>(element::f32, PartialShape{2, Dimension::dynamic(), Dimension::dynamic(), 8});
|
||||
auto input_order = make_shared<op::Parameter>(element::i64, Shape{4});
|
||||
@ -51,7 +69,7 @@ TEST(type_prop, transpose_arg_rank_static_dynamic_input_order_static_ok) {
|
||||
auto r = make_shared<op::Transpose>(arg, input_order);
|
||||
|
||||
EXPECT_EQ(r->get_output_element_type(0), element::f32);
|
||||
EXPECT_TRUE(r->get_output_partial_shape(0).same_scheme(PartialShape::dynamic(4)));
|
||||
EXPECT_EQ(r->get_output_partial_shape(0), PartialShape::dynamic(4));
|
||||
}
|
||||
|
||||
TEST(type_prop, transpose_arg_static_input_order_rank_static_dynamic_ok) {
|
||||
@ -61,7 +79,7 @@ TEST(type_prop, transpose_arg_static_input_order_rank_static_dynamic_ok) {
|
||||
auto r = make_shared<op::Transpose>(arg, input_order);
|
||||
|
||||
EXPECT_EQ(r->get_output_element_type(0), element::f32);
|
||||
EXPECT_TRUE(r->get_output_partial_shape(0).same_scheme(PartialShape::dynamic(4)));
|
||||
EXPECT_EQ(r->get_output_partial_shape(0), PartialShape::dynamic(4));
|
||||
}
|
||||
|
||||
TEST(type_prop, transpose_arg_rank_static_dynamic_input_order_rank_static_dynamic_ok) {
|
||||
@ -71,7 +89,7 @@ TEST(type_prop, transpose_arg_rank_static_dynamic_input_order_rank_static_dynami
|
||||
auto r = make_shared<op::Transpose>(arg, input_order);
|
||||
|
||||
EXPECT_EQ(r->get_output_element_type(0), element::f32);
|
||||
EXPECT_TRUE(r->get_output_partial_shape(0).same_scheme(PartialShape::dynamic(4)));
|
||||
EXPECT_EQ(r->get_output_partial_shape(0), PartialShape::dynamic(4));
|
||||
}
|
||||
|
||||
TEST(type_prop, transpose_arg_rank_dynamic_input_order_rank_static_dynamic_ok) {
|
||||
@ -81,7 +99,7 @@ TEST(type_prop, transpose_arg_rank_dynamic_input_order_rank_static_dynamic_ok) {
|
||||
auto r = make_shared<op::Transpose>(arg, input_order);
|
||||
|
||||
EXPECT_EQ(r->get_output_element_type(0), element::f32);
|
||||
EXPECT_TRUE(r->get_output_partial_shape(0).same_scheme(PartialShape::dynamic()));
|
||||
EXPECT_EQ(r->get_output_partial_shape(0), PartialShape::dynamic());
|
||||
}
|
||||
|
||||
TEST(type_prop, transpose_arg_rank_dynamic_input_order_rank_dynamic_ok) {
|
||||
@ -91,7 +109,7 @@ TEST(type_prop, transpose_arg_rank_dynamic_input_order_rank_dynamic_ok) {
|
||||
auto r = make_shared<op::Transpose>(arg, input_order);
|
||||
|
||||
EXPECT_EQ(r->get_output_element_type(0), element::f32);
|
||||
EXPECT_TRUE(r->get_output_partial_shape(0).same_scheme(PartialShape::dynamic()));
|
||||
EXPECT_EQ(r->get_output_partial_shape(0), PartialShape::dynamic());
|
||||
}
|
||||
|
||||
TEST(type_prop, transpose_arg_rank_static_dynamic_input_order_rank_dynamic_ok) {
|
||||
@ -101,7 +119,18 @@ TEST(type_prop, transpose_arg_rank_static_dynamic_input_order_rank_dynamic_ok) {
|
||||
auto r = make_shared<op::Transpose>(arg, input_order);
|
||||
|
||||
EXPECT_EQ(r->get_output_element_type(0), element::f32);
|
||||
EXPECT_TRUE(r->get_output_partial_shape(0).same_scheme(PartialShape::dynamic(4)));
|
||||
EXPECT_EQ(r->get_output_partial_shape(0), PartialShape::dynamic(4));
|
||||
}
|
||||
|
||||
TEST(type_prop, transpose_arg_rank_dynamic_input_order_const_ok) {
|
||||
const auto axes_order = std::vector<int64_t>{1, 3, 0, 2};
|
||||
auto arg = make_shared<op::Parameter>(element::f32, PartialShape::dynamic());
|
||||
auto input_order = op::Constant::create(element::i64, Shape{axes_order.size()}, axes_order);
|
||||
|
||||
auto r = make_shared<op::Transpose>(arg, input_order);
|
||||
|
||||
EXPECT_EQ(r->get_output_element_type(0), element::f32);
|
||||
EXPECT_EQ(r->get_output_partial_shape(0), PartialShape::dynamic(axes_order.size()));
|
||||
}
|
||||
|
||||
TEST(type_prop, transpose_dynamic_interval_input_data) {
|
||||
@ -111,7 +140,7 @@ TEST(type_prop, transpose_dynamic_interval_input_data) {
|
||||
auto r = make_shared<op::Transpose>(arg, input_order);
|
||||
|
||||
EXPECT_EQ(r->get_output_element_type(0), element::f32);
|
||||
EXPECT_TRUE(r->get_output_partial_shape(0).same_scheme(PartialShape::dynamic(3)));
|
||||
EXPECT_EQ(r->get_output_partial_shape(0), PartialShape::dynamic(3));
|
||||
}
|
||||
|
||||
TEST(type_prop, transpose_arg_static_input_order_static_input_order_not_vector) {
|
||||
@ -205,7 +234,7 @@ TEST(type_prop, transpose_input_order_et_dynamic_ok) {
|
||||
auto r = make_shared<op::Transpose>(arg, input_order);
|
||||
|
||||
EXPECT_EQ(r->get_output_element_type(0), element::f32);
|
||||
EXPECT_TRUE(r->get_output_partial_shape(0).same_scheme(PartialShape::dynamic(4)));
|
||||
EXPECT_EQ(r->get_output_partial_shape(0), PartialShape::dynamic(4));
|
||||
}
|
||||
|
||||
TEST(type_prop, transpose_input_order_et_wrong) {
|
||||
@ -230,4 +259,162 @@ TEST(type_prop, transpose_with_empty_order) {
|
||||
|
||||
EXPECT_EQ(r->get_output_element_type(0), element::f32);
|
||||
EXPECT_TRUE(r->get_output_partial_shape(0).same_scheme(PartialShape({300, 1})));
|
||||
EXPECT_EQ(r->get_output_partial_shape(0), (PartialShape{300, 1}));
|
||||
}
|
||||
|
||||
/** \brief Transpose with order as parameter shape dimensions. */
|
||||
TEST(type_prop, transpose_order_as_parameter_shape) {
|
||||
const auto arg = make_shared<v0::Parameter>(element::f32, PartialShape{Dimension(2, 8), Dimension(4, 16), 6});
|
||||
|
||||
const auto param = make_shared<v0::Parameter>(element::i64, PartialShape{2, 0, 1});
|
||||
const auto shape_of = make_shared<v3::ShapeOf>(param);
|
||||
// order after gather [1, 2, 0]
|
||||
const auto gather = make_shared<v1::Gather>(shape_of,
|
||||
op::Constant::create(element::i64, {3}, {2, 0, 1}),
|
||||
op::Constant::create(element::i64, {}, {0}));
|
||||
|
||||
const auto r = make_shared<v1::Transpose>(arg, gather);
|
||||
|
||||
ASSERT_EQ(r->get_output_element_type(v1::Transpose::ARG_T), element::f32);
|
||||
ASSERT_EQ(r->get_output_partial_shape(v1::Transpose::ARG_T), PartialShape({Dimension(4, 16), 6, Dimension(2, 8)}));
|
||||
}
|
||||
|
||||
/** \brief Transpose with order as paramater shape dimensions after multiple transformations. */
|
||||
TEST(type_prop, transpose_order_as_parameter_shape_after_transformation) {
|
||||
const auto arg = make_shared<v0::Parameter>(element::f32, PartialShape{Dimension(2, 8), Dimension(4, 16), 6});
|
||||
|
||||
const auto param = make_shared<v0::Parameter>(element::i64, PartialShape{8, 20, 1});
|
||||
const auto shape_of = make_shared<v3::ShapeOf>(param);
|
||||
const auto cast_fp = make_shared<op::Convert>(shape_of, element::f32);
|
||||
const auto mul = make_shared<v1::Multiply>(cast_fp, op::Constant::create(element::f32, {3}, {-2, 1, -2}));
|
||||
const auto div = make_shared<v1::Divide>(mul, op::Constant::create(element::f32, {3}, {-10, 41, -1}));
|
||||
// order after convert [1, 0, 2]
|
||||
const auto cast_int = make_shared<op::Convert>(div, element::i32);
|
||||
// order after gather [2, 1, 0]
|
||||
const auto gather = make_shared<v1::Gather>(cast_int,
|
||||
op::Constant::create(element::i32, {3}, {2, 0, 1}),
|
||||
op::Constant::create(element::i32, {}, {0}));
|
||||
|
||||
const auto r = make_shared<v1::Transpose>(arg, gather);
|
||||
|
||||
ASSERT_EQ(r->get_output_element_type(v1::Transpose::ARG_T), element::f32);
|
||||
ASSERT_EQ(r->get_output_partial_shape(v1::Transpose::ARG_T), PartialShape({6, Dimension(4, 16), Dimension(2, 8)}));
|
||||
}
|
||||
|
||||
/**
|
||||
* \brief Transpose when order is dimensions from parameter shape.
|
||||
*
|
||||
* One dimension is dynamic, transposed output shape cannot be deduced and will be dynamic.
|
||||
*/
|
||||
TEST(type_prop, transpose_when_order_is_shape_of_dynamic_partial_shape) {
|
||||
const auto arg = make_shared<op::Parameter>(element::f32, PartialShape{Dimension(2, 8), Dimension(4, 16), 6});
|
||||
|
||||
const auto param = make_shared<op::Parameter>(element::i64, PartialShape{0, 2, Dimension(1, 2)});
|
||||
const auto shape_of = make_shared<v3::ShapeOf>(param);
|
||||
|
||||
const auto r = make_shared<v1::Transpose>(arg, shape_of);
|
||||
|
||||
ASSERT_EQ(r->get_output_element_type(v1::Transpose::ARG_T), element::f32);
|
||||
ASSERT_EQ(r->get_output_partial_shape(v1::Transpose::ARG_T), PartialShape::dynamic(3));
|
||||
}
|
||||
|
||||
using transpose_prop_params = tuple<vector<int64_t>, // transpose order
|
||||
PartialShape, // Input partial shape
|
||||
PartialShape // Expected partial shape
|
||||
>;
|
||||
|
||||
// Test pre-defined constants.
|
||||
static constexpr auto exp_type = element::f32;
|
||||
static const auto interval_dim_1 = Dimension(3, 5);
|
||||
static const auto interval_dim_2 = Dimension(1, 8);
|
||||
|
||||
/** \brief Parametrize fixture to test transpose property. */
|
||||
class TransposeTest : public TestWithParam<transpose_prop_params> {
|
||||
protected:
|
||||
PartialShape input_p_shape, exp_p_shape;
|
||||
vector<int64_t> transpose_order;
|
||||
|
||||
void SetUp() override {
|
||||
std::tie(transpose_order, input_p_shape, exp_p_shape) = GetParam();
|
||||
}
|
||||
|
||||
vector<size_t> make_seq_labels(const size_t first, const size_t count) {
|
||||
vector<size_t> labels;
|
||||
|
||||
generate_n(std::back_inserter(labels), count, ov::SeqGen<size_t>(first));
|
||||
return labels;
|
||||
}
|
||||
|
||||
vector<size_t> make_seq_labels_by_order(const size_t first, const vector<int64_t> order) {
|
||||
vector<size_t> labels;
|
||||
transform(order.cbegin(), order.cend(), back_inserter(labels), [&first](const int64_t& dim) {
|
||||
return dim + first;
|
||||
});
|
||||
return labels;
|
||||
}
|
||||
};
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(
|
||||
type_prop,
|
||||
TransposeTest,
|
||||
Values(make_tuple(vector<int64_t>{2, 0, 1}, PartialShape{2, interval_dim_2, 4}, PartialShape{4, 2, interval_dim_2}),
|
||||
make_tuple(vector<int64_t>{0, 2, 1},
|
||||
PartialShape{interval_dim_1, interval_dim_2, 4},
|
||||
PartialShape{interval_dim_1, 4, interval_dim_2}),
|
||||
make_tuple(vector<int64_t>{1, 2, 3, 0},
|
||||
PartialShape{interval_dim_1, 2, 3, 4},
|
||||
PartialShape{2, 3, 4, interval_dim_1}),
|
||||
make_tuple(vector<int64_t>{3, 0, 2, 1},
|
||||
PartialShape{interval_dim_1, 2, interval_dim_2, 4},
|
||||
PartialShape{4, interval_dim_1, interval_dim_2, 2}),
|
||||
make_tuple(vector<int64_t>{1, 0, 3, 2},
|
||||
PartialShape{interval_dim_1, interval_dim_2, interval_dim_2, interval_dim_1},
|
||||
PartialShape{interval_dim_2, interval_dim_1, interval_dim_1, interval_dim_2})),
|
||||
PrintToStringParamName());
|
||||
|
||||
TEST_P(TransposeTest, use_default_ctor) {
|
||||
const auto input = make_shared<op::Parameter>(exp_type, input_p_shape);
|
||||
const auto order = op::Constant::create(element::i64, Shape{transpose_order.size()}, transpose_order);
|
||||
|
||||
const auto output = make_shared<op::Transpose>();
|
||||
output->set_arguments(NodeVector{input, order});
|
||||
output->validate_and_infer_types();
|
||||
|
||||
EXPECT_EQ(output->get_output_element_type(op::Transpose::ARG_T), exp_type);
|
||||
EXPECT_EQ(output->get_output_partial_shape(op::Transpose::ARG_T), exp_p_shape);
|
||||
}
|
||||
|
||||
/**
|
||||
* \brief Test interval dimension propagate in transpose.
|
||||
*
|
||||
* The interval dimensions should be moved accordingly to transpose order.
|
||||
*/
|
||||
TEST_P(TransposeTest, propagate_interval_shape) {
|
||||
const auto input = make_shared<op::Parameter>(exp_type, input_p_shape);
|
||||
const auto order = op::Constant::create(element::i64, Shape{transpose_order.size()}, transpose_order);
|
||||
|
||||
const auto output = make_shared<op::Transpose>(input, order);
|
||||
|
||||
EXPECT_EQ(output->get_output_element_type(op::Transpose::ARG_T), exp_type);
|
||||
EXPECT_EQ(output->get_output_partial_shape(op::Transpose::ARG_T), exp_p_shape);
|
||||
}
|
||||
|
||||
/**
|
||||
* \brief Check labels propagation for all dimensions.
|
||||
*
|
||||
* The labels should be moved accordingly to transpose order.
|
||||
*/
|
||||
TEST_P(TransposeTest, propagate_labels) {
|
||||
constexpr size_t first_label = 33;
|
||||
|
||||
const auto labels = make_seq_labels(first_label, transpose_order.size());
|
||||
const auto exp_labels = make_seq_labels_by_order(first_label, transpose_order);
|
||||
|
||||
set_shape_labels(input_p_shape, labels);
|
||||
|
||||
const auto input = make_shared<op::Parameter>(exp_type, input_p_shape);
|
||||
const auto order = op::Constant::create(element::i64, Shape{transpose_order.size()}, transpose_order);
|
||||
const auto output = make_shared<op::Transpose>(input, order);
|
||||
|
||||
ASSERT_EQ(get_shape_labels(output->get_output_partial_shape(op::Transpose::ARG_T)), exp_labels);
|
||||
}
|
||||
|
@ -14,11 +14,14 @@
|
||||
#include <openvino/opsets/opset8.hpp>
|
||||
|
||||
#include "assign_shape_inference.hpp"
|
||||
#include "batch_to_space_shape_inference.hpp"
|
||||
#include "broadcast_shape_inference.hpp"
|
||||
#include "bucketize_shape_inference.hpp"
|
||||
#include "convolution_shape_inference.hpp"
|
||||
#include "ctc_greedy_decoder_seq_len_shape_inference.hpp"
|
||||
#include "ctc_greedy_decoder_shape_inference.hpp"
|
||||
#include "ctc_loss_shape_inference.hpp"
|
||||
#include "depth_to_space_shape_inference.hpp"
|
||||
#include "detection_output_shape_inference.hpp"
|
||||
#include "einsum_shape_inference.hpp"
|
||||
#include "embedding_segments_sum_shape_inference.hpp"
|
||||
@ -29,6 +32,7 @@
|
||||
#include "experimental_detectron_roi_feature_shape_inference.hpp"
|
||||
#include "experimental_detectron_topkrois_shape_inference.hpp"
|
||||
#include "extract_image_patches_shape_inference.hpp"
|
||||
#include "eye_shape_inference.hpp"
|
||||
#include "fake_quantize.hpp"
|
||||
#include "fft_base_shape_inference.hpp"
|
||||
#include "gather_elements_shape_inference.hpp"
|
||||
@ -36,25 +40,8 @@
|
||||
#include "gather_tree_shape_inference.hpp"
|
||||
#include "interpolate_shape_inference.hpp"
|
||||
#include "lstm_cell_shape_inference.hpp"
|
||||
#include "matmul_shape_inference.hpp"
|
||||
#include "one_hot_shape_inference.hpp"
|
||||
#include "read_value_shape_inference.hpp"
|
||||
#include "reduce_shape_inference.hpp"
|
||||
#include "reverse_sequence_shape_inference.hpp"
|
||||
#include "scatter_elements_update_shape_inference.hpp"
|
||||
#include "scatter_nd_base_shape_inference.hpp"
|
||||
#include "ctc_loss_shape_inference.hpp"
|
||||
#include "fft_base_shape_inference.hpp"
|
||||
#include "shape_inference.hpp"
|
||||
#include "shape_nodes.hpp"
|
||||
#include "fake_quantize.hpp"
|
||||
#include "batch_to_space_shape_inference.hpp"
|
||||
#include "depth_to_space_shape_inference.hpp"
|
||||
#include "space_to_batch_shape_inference.hpp"
|
||||
#include "space_to_depth_shape_inference.hpp"
|
||||
#include "experimental_detectron_detection_output_shape_inference.hpp"
|
||||
#include "bucketize_shape_inference.hpp"
|
||||
#include "embedding_segments_sum_shape_inference.hpp"
|
||||
#include "embeddingbag_offsets_shape_inference.hpp"
|
||||
#include "pad_shape_inference.hpp"
|
||||
#include "proposal_shape_inference.hpp"
|
||||
#include "range_shape_inference.hpp"
|
||||
@ -68,19 +55,18 @@
|
||||
#include "scatter_elements_update_shape_inference.hpp"
|
||||
#include "scatter_nd_base_shape_inference.hpp"
|
||||
#include "select_shape_inference.hpp"
|
||||
#include "shape_inference.hpp"
|
||||
#include "shape_nodes.hpp"
|
||||
#include "shuffle_channels_shape_inference.hpp"
|
||||
#include "space_to_batch_shape_inference.hpp"
|
||||
#include "space_to_depth_shape_inference.hpp"
|
||||
#include "split_shape_inference.hpp"
|
||||
#include "broadcast_shape_inference.hpp"
|
||||
#include "static_shape.hpp"
|
||||
#include "strided_slice_shape_inference.hpp"
|
||||
#include "tile_shape_inference.hpp"
|
||||
#include "topk_shape_inference.hpp"
|
||||
#include "transpose_shape_inference.hpp"
|
||||
#include "utils.hpp"
|
||||
#include "variadic_split_shape_inference.hpp"
|
||||
#include "matmul_shape_inference.hpp"
|
||||
#include "eye_shape_inference.hpp"
|
||||
|
||||
namespace ov {
|
||||
namespace intel_cpu {
|
||||
@ -563,6 +549,8 @@ std::shared_ptr<IShapeInfer> make_shape_inference(const std::shared_ptr<ngraph::
|
||||
return std::make_shared<entryFallbackWithPadding<ov::op::v1::DeformableConvolution>>(node);
|
||||
} else if (auto node = ov::as_type_ptr<ov::op::v8::DeformableConvolution>(op)) {
|
||||
return std::make_shared<entryFallbackWithPadding<ov::op::v8::DeformableConvolution>>(node);
|
||||
} else if (auto node = ov::as_type_ptr<ov::opset8::Transpose>(op)) {
|
||||
return make_shared_entryIOC(node);
|
||||
} else {
|
||||
return std::make_shared<entryFallback>(op);
|
||||
}
|
||||
|
@ -0,0 +1,106 @@
|
||||
// Copyright (C) 2018-2022 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
#include "gtest/gtest.h"
|
||||
#include "openvino/op/constant.hpp"
|
||||
#include "openvino/op/parameter.hpp"
|
||||
#include "openvino/op/transpose.hpp"
|
||||
#include "transpose_shape_inference.hpp"
|
||||
#include "utils/shape_inference/static_shape.hpp"
|
||||
|
||||
using namespace ov;
|
||||
using namespace ov::intel_cpu;
|
||||
using namespace testing;
|
||||
|
||||
template <class TInput, class TOrder>
|
||||
std::shared_ptr<op::v1::Transpose> make_transpose(const TInput& input_shape, const TOrder& transpose_order) {
|
||||
const auto input = std::make_shared<op::v0::Parameter>(element::f32, PartialShape::dynamic(input_shape.size()));
|
||||
const auto order =
|
||||
std::make_shared<op::v0::Constant>(element::i64, ov::Shape{transpose_order.size()}, transpose_order);
|
||||
return std::make_shared<op::v1::Transpose>(input, order);
|
||||
}
|
||||
|
||||
using transpose_params = std::tuple<std::vector<size_t>, // transpose order
|
||||
StaticShape, // Input shape
|
||||
StaticShape // Expected shape
|
||||
>;
|
||||
|
||||
class StaticShapeInferenceTest : public TestWithParam<transpose_params> {
|
||||
protected:
|
||||
void SetUp() override {
|
||||
std::tie(transpose_order, input_shape, exp_shape) = GetParam();
|
||||
|
||||
transpose = make_transpose(input_shape, transpose_order);
|
||||
}
|
||||
|
||||
StaticShape input_shape, exp_shape;
|
||||
std::vector<size_t> transpose_order;
|
||||
|
||||
std::shared_ptr<op::v1::Transpose> transpose;
|
||||
};
|
||||
|
||||
/** \brief Use transpose order -> output shape dimensions shall be as transpose order. */
|
||||
INSTANTIATE_TEST_SUITE_P(
|
||||
transpose_by_order,
|
||||
StaticShapeInferenceTest,
|
||||
Values(make_tuple(std::vector<size_t>{0}, StaticShape({3}), StaticShape({3})),
|
||||
make_tuple(std::vector<size_t>{0, 1}, StaticShape({5, 2}), StaticShape({5, 2})),
|
||||
make_tuple(std::vector<size_t>{1, 0}, StaticShape({8, 3}), StaticShape({3, 8})),
|
||||
make_tuple(std::vector<size_t>{2, 0, 1}, StaticShape({1, 0, 2}), StaticShape({2, 1, 0})),
|
||||
make_tuple(std::vector<size_t>{2, 0, 3, 1}, StaticShape({10, 8, 9, 2}), StaticShape({9, 10, 2, 8})),
|
||||
make_tuple(std::vector<size_t>{1, 3, 2, 0}, StaticShape({1, 2, 3, 4}), StaticShape({2, 4, 3, 1}))),
|
||||
PrintToStringParamName());
|
||||
|
||||
/** \brief Empty transpose order -> output shape dimensions shall be in reverse order. */
|
||||
INSTANTIATE_TEST_SUITE_P(
|
||||
transpose_reverse,
|
||||
StaticShapeInferenceTest,
|
||||
Values(make_tuple(std::vector<size_t>{}, StaticShape({1}), StaticShape({1})),
|
||||
make_tuple(std::vector<size_t>{}, StaticShape({23}), StaticShape({23})),
|
||||
make_tuple(std::vector<size_t>{}, StaticShape({3, 8}), StaticShape({8, 3})),
|
||||
make_tuple(std::vector<size_t>{}, StaticShape({1, 0, 2}), StaticShape({2, 0, 1})),
|
||||
make_tuple(std::vector<size_t>{}, StaticShape({21, 1, 5, 9}), StaticShape({9, 5, 1, 21})),
|
||||
make_tuple(std::vector<size_t>{}, StaticShape({0, 0, 0}), StaticShape({0, 0, 0})),
|
||||
make_tuple(std::vector<size_t>{}, StaticShape({0, 2, 0}), StaticShape({0, 2, 0})),
|
||||
make_tuple(std::vector<size_t>{}, StaticShape({0, 2, 0, 0}), StaticShape({0, 0, 2, 0}))),
|
||||
PrintToStringParamName());
|
||||
|
||||
/** \brief Check shape_infer for transpose on static shapes. */
|
||||
TEST_P(StaticShapeInferenceTest, transpose_static) {
|
||||
auto output_shapes = std::vector<StaticShape>{StaticShape{}};
|
||||
|
||||
shape_infer(transpose.get(), {input_shape, transpose_order}, output_shapes);
|
||||
|
||||
ASSERT_EQ(output_shapes[op::v1::Transpose::ARG_T], exp_shape);
|
||||
}
|
||||
|
||||
/** \brief Shape infer when transpose input got dynamic dimensions. */
|
||||
TEST(StaticShapeInferenceTest, transpose_input_shape_dim_dynamic) {
|
||||
const auto input_shape = PartialShape{-1, -1, -1};
|
||||
const auto order = std::vector<size_t>{1, 2, 0};
|
||||
const auto transpose = make_transpose(input_shape, order);
|
||||
|
||||
auto output_shapes = std::vector<StaticShape>{StaticShape{}};
|
||||
|
||||
shape_infer(transpose.get(), {StaticShape{2, 6, 3}, order}, output_shapes);
|
||||
ASSERT_EQ(output_shapes[op::v1::Transpose::ARG_T], StaticShape({6, 3, 2}));
|
||||
}
|
||||
|
||||
/** \brief Shape inference when transpose order stored in constant map. */
|
||||
TEST(StaticShapeInferenceTest, transpose_order_in_constant_map) {
|
||||
const auto input_shape = PartialShape{2, 4, 6, 8};
|
||||
const auto input = std::make_shared<op::v0::Parameter>(element::f32, input_shape);
|
||||
const auto order = std::make_shared<op::v0::Parameter>(element::i64, Shape{4});
|
||||
|
||||
const auto transpose = std::make_shared<op::v1::Transpose>(input, order);
|
||||
|
||||
const auto axes_order = std::vector<size_t>{1, 2, 0, 3};
|
||||
const auto axes = std::make_shared<op::v0::Constant>(element::i64, ov::Shape{axes_order.size()}, axes_order);
|
||||
const auto const_tensor = std::make_shared<ngraph::runtime::HostTensor>(axes);
|
||||
const std::map<size_t, std::shared_ptr<ngraph::runtime::HostTensor>> const_map = {{1, const_tensor}};
|
||||
|
||||
auto output_shapes = std::vector<StaticShape>{StaticShape{}};
|
||||
shape_infer(transpose.get(), {StaticShape({2, 4, 6, 8}), StaticShape()}, output_shapes, const_map);
|
||||
|
||||
ASSERT_EQ(output_shapes[op::v1::Transpose::ARG_T], StaticShape({4, 6, 2, 8}));
|
||||
}
|
@ -306,14 +306,14 @@ const std::vector<TransposeTransformationTestValues> testValues = {
|
||||
}
|
||||
},
|
||||
{
|
||||
ngraph::element::u8,
|
||||
{{}, {}, {}},
|
||||
ngraph::element::u8,
|
||||
{
|
||||
{ngraph::element::f32},
|
||||
{ {128}, ngraph::element::f32, {}, true, 1, ngraph::element::u8, true },
|
||||
{0.1f}
|
||||
},
|
||||
ngraph::element::f32,
|
||||
{{}, {}, {}}
|
||||
}
|
||||
}
|
||||
},
|
||||
};
|
||||
|
Loading…
Reference in New Issue
Block a user