Remove obsoleted v0::OneHot operator (#2855)
This commit is contained in:
parent
186e00fa2a
commit
e3ed796b2e
@ -22,61 +22,6 @@ namespace ngraph
|
|||||||
{
|
{
|
||||||
namespace op
|
namespace op
|
||||||
{
|
{
|
||||||
namespace v0
|
|
||||||
{
|
|
||||||
// clang-format off
|
|
||||||
/// \brief One-hot operator.
|
|
||||||
///
|
|
||||||
/// ## Parameters
|
|
||||||
///
|
|
||||||
/// | | Description |
|
|
||||||
/// | -------------- | ---------------------------------------------------------- |
|
|
||||||
/// | `shape` | The desired output shape, including the new one-hot axis. |
|
|
||||||
/// | `one_hot_axis` | The index within the output shape of the new one-hot axis. |
|
|
||||||
///
|
|
||||||
/// ## Inputs
|
|
||||||
///
|
|
||||||
/// | | Type | Description |
|
|
||||||
/// | ----- | ------------------------------------------------------- | -------------------------------------------------------------- |
|
|
||||||
/// | `arg` | \f$E[d_1,\dots,d_{m-1},d_{m+1},\dots,d_n]~(n \geq 0)\f$ | A tensor of any shape and any non-floating point element type. |
|
|
||||||
///
|
|
||||||
/// ## Output
|
|
||||||
///
|
|
||||||
/// | Type | Description |
|
|
||||||
/// | ---------------------- | -------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
|
|
||||||
/// | \f$E[d_1,\dots,d_n]\f$ | The tensor \f$T'\f$, where \f$T'[i_1,\dots,i_{m-1},i_m,i_{m+1},\dots,i_n] = 1\f$ if \f$T[i_1,\dots,i_{m-1},i_{m+1},\dots,i_n] = i_m\f$, else \f$0\f$. However, \f$T'\f$ is undefined if any non-integral value or any out-of-bounds value is detected in the input tensor. |
|
|
||||||
// clang-format on
|
|
||||||
class NGRAPH_DEPRECATED(
|
|
||||||
"This operation is deprecated and will be removed soon. "
|
|
||||||
"Use v1::OneHot instead of it.") NGRAPH_API OneHot : public Op
|
|
||||||
{
|
|
||||||
NGRAPH_SUPPRESS_DEPRECATED_START
|
|
||||||
public:
|
|
||||||
static constexpr NodeTypeInfo type_info{"OneHot", 0};
|
|
||||||
const NodeTypeInfo& get_type_info() const override { return type_info; }
|
|
||||||
/// \brief Constructs a one-hot operation.
|
|
||||||
OneHot() = default;
|
|
||||||
/// \brief Constructs a one-hot operation.
|
|
||||||
///
|
|
||||||
/// \param arg Node that produces the input tensor to be one-hot encoded.
|
|
||||||
/// \param shape The shape of the output tensor, including the new one-hot
|
|
||||||
/// axis.
|
|
||||||
/// \param one_hot_axis The index within the output shape of the new one-hot axis.
|
|
||||||
OneHot(const Output<Node>& arg, const PartialShape& shape, size_t one_hot_axis);
|
|
||||||
|
|
||||||
virtual std::shared_ptr<Node>
|
|
||||||
clone_with_new_inputs(const OutputVector& new_args) const override;
|
|
||||||
void validate_and_infer_types() override;
|
|
||||||
|
|
||||||
/// \return The index of the one-hot axis.
|
|
||||||
size_t get_one_hot_axis() const { return m_one_hot_axis; }
|
|
||||||
void set_one_hot_axis(size_t one_hot_axis) { m_one_hot_axis = one_hot_axis; }
|
|
||||||
protected:
|
|
||||||
PartialShape m_shape;
|
|
||||||
size_t m_one_hot_axis;
|
|
||||||
NGRAPH_SUPPRESS_DEPRECATED_END
|
|
||||||
};
|
|
||||||
}
|
|
||||||
namespace v1
|
namespace v1
|
||||||
{
|
{
|
||||||
class NGRAPH_API OneHot : public Op
|
class NGRAPH_API OneHot : public Op
|
||||||
@ -114,9 +59,5 @@ namespace ngraph
|
|||||||
int64_t m_axis;
|
int64_t m_axis;
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
// default opset version
|
|
||||||
NGRAPH_SUPPRESS_DEPRECATED_START
|
|
||||||
using v0::OneHot;
|
|
||||||
NGRAPH_SUPPRESS_DEPRECATED_END
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -122,7 +122,7 @@ NGRAPH_OP(NormalizeL2, ngraph::op::v0, 0)
|
|||||||
NGRAPH_OP(Not, ngraph::op::v0, 0)
|
NGRAPH_OP(Not, ngraph::op::v0, 0)
|
||||||
NGRAPH_OP(NotEqual, ngraph::op::v0, 0)
|
NGRAPH_OP(NotEqual, ngraph::op::v0, 0)
|
||||||
NGRAPH_OP(NotEqual, ngraph::op::v1, 1)
|
NGRAPH_OP(NotEqual, ngraph::op::v1, 1)
|
||||||
NGRAPH_OP(OneHot, ngraph::op::v0, 0)
|
NGRAPH_OP(OneHot, ngraph::op::v1, 1)
|
||||||
NGRAPH_OP(Or, ngraph::op::v0, 0)
|
NGRAPH_OP(Or, ngraph::op::v0, 0)
|
||||||
NGRAPH_OP(PRelu, ngraph::op::v0, 0)
|
NGRAPH_OP(PRelu, ngraph::op::v0, 0)
|
||||||
NGRAPH_OP(PSROIPooling, ngraph::op::v0, 0)
|
NGRAPH_OP(PSROIPooling, ngraph::op::v0, 0)
|
||||||
|
@ -19,90 +19,9 @@
|
|||||||
#include "ngraph/op/util/op_types.hpp"
|
#include "ngraph/op/util/op_types.hpp"
|
||||||
#include "ngraph/validation_util.hpp"
|
#include "ngraph/validation_util.hpp"
|
||||||
|
|
||||||
NGRAPH_SUPPRESS_DEPRECATED_START
|
|
||||||
|
|
||||||
using namespace std;
|
using namespace std;
|
||||||
using namespace ngraph;
|
using namespace ngraph;
|
||||||
|
|
||||||
constexpr NodeTypeInfo op::v0::OneHot::type_info;
|
|
||||||
|
|
||||||
op::v0::OneHot::OneHot(const Output<Node>& arg, const PartialShape& shape, size_t one_hot_axis)
|
|
||||||
: Op({arg})
|
|
||||||
, m_shape(shape)
|
|
||||||
, m_one_hot_axis(one_hot_axis)
|
|
||||||
{
|
|
||||||
constructor_validate_and_infer_types();
|
|
||||||
}
|
|
||||||
|
|
||||||
void op::v0::OneHot::validate_and_infer_types()
|
|
||||||
{
|
|
||||||
element::Type arg_et = get_input_element_type(0);
|
|
||||||
PartialShape arg_shape = get_input_partial_shape(0);
|
|
||||||
Rank arg_rank = arg_shape.rank();
|
|
||||||
|
|
||||||
NODE_VALIDATION_CHECK(this,
|
|
||||||
arg_et.is_dynamic() || arg_et.is_integral(),
|
|
||||||
"Argument does not have integral element type.");
|
|
||||||
|
|
||||||
NODE_VALIDATION_CHECK(
|
|
||||||
this, m_shape.rank().is_static(), "Requested result shape has dynamic rank.");
|
|
||||||
|
|
||||||
NODE_VALIDATION_CHECK(this,
|
|
||||||
m_one_hot_axis < m_shape.rank().get_length(),
|
|
||||||
"One-hot axis (",
|
|
||||||
m_one_hot_axis,
|
|
||||||
") is out of bounds (requested result shape: ",
|
|
||||||
m_shape,
|
|
||||||
").");
|
|
||||||
|
|
||||||
NODE_VALIDATION_CHECK(this,
|
|
||||||
m_shape[m_one_hot_axis].is_static(),
|
|
||||||
"Requested result shape (",
|
|
||||||
m_shape,
|
|
||||||
") has dynamic dimension at the one-hot axis ",
|
|
||||||
"(",
|
|
||||||
m_one_hot_axis,
|
|
||||||
").");
|
|
||||||
|
|
||||||
PartialShape result_shape{m_shape};
|
|
||||||
|
|
||||||
if (arg_rank.is_static())
|
|
||||||
{
|
|
||||||
std::vector<Dimension> expected_input_dims(m_shape.rank().get_length());
|
|
||||||
for (size_t i = 0; i < m_shape.rank().get_length(); i++)
|
|
||||||
{
|
|
||||||
expected_input_dims[i] = m_shape[i];
|
|
||||||
}
|
|
||||||
expected_input_dims.erase(expected_input_dims.begin() + m_one_hot_axis);
|
|
||||||
PartialShape expected_input_shape{expected_input_dims};
|
|
||||||
|
|
||||||
PartialShape merged_input_shape{expected_input_shape};
|
|
||||||
NODE_VALIDATION_CHECK(this,
|
|
||||||
PartialShape::merge_into(merged_input_shape, arg_shape),
|
|
||||||
"Argument shape ",
|
|
||||||
arg_shape,
|
|
||||||
" does not match the expected shape of ",
|
|
||||||
expected_input_shape,
|
|
||||||
".");
|
|
||||||
|
|
||||||
std::vector<Dimension> output_dims(merged_input_shape.rank().get_length());
|
|
||||||
for (size_t i = 0; i < merged_input_shape.rank().get_length(); i++)
|
|
||||||
{
|
|
||||||
output_dims[i] = merged_input_shape[i];
|
|
||||||
}
|
|
||||||
output_dims.insert(output_dims.begin() + m_one_hot_axis, m_shape[m_one_hot_axis]);
|
|
||||||
result_shape = PartialShape{output_dims};
|
|
||||||
}
|
|
||||||
|
|
||||||
set_output_type(0, arg_et, result_shape);
|
|
||||||
}
|
|
||||||
|
|
||||||
shared_ptr<Node> op::v0::OneHot::clone_with_new_inputs(const OutputVector& new_args) const
|
|
||||||
{
|
|
||||||
check_new_args_count(this, new_args);
|
|
||||||
return make_shared<v0::OneHot>(new_args.at(0), m_shape, m_one_hot_axis);
|
|
||||||
}
|
|
||||||
|
|
||||||
constexpr NodeTypeInfo op::v1::OneHot::type_info;
|
constexpr NodeTypeInfo op::v1::OneHot::type_info;
|
||||||
|
|
||||||
op::v1::OneHot::OneHot(const Output<Node>& indices,
|
op::v1::OneHot::OneHot(const Output<Node>& indices,
|
||||||
|
@ -31,8 +31,6 @@
|
|||||||
#include "util/test_control.hpp"
|
#include "util/test_control.hpp"
|
||||||
#include "util/test_tools.hpp"
|
#include "util/test_tools.hpp"
|
||||||
|
|
||||||
NGRAPH_SUPPRESS_DEPRECATED_START
|
|
||||||
|
|
||||||
using namespace std;
|
using namespace std;
|
||||||
using namespace ngraph;
|
using namespace ngraph;
|
||||||
|
|
||||||
@ -42,8 +40,12 @@ NGRAPH_TEST(${BACKEND_NAME}, one_hot_scalar_2_in_3)
|
|||||||
{
|
{
|
||||||
Shape shape_a{};
|
Shape shape_a{};
|
||||||
auto A = make_shared<op::Parameter>(element::i32, shape_a);
|
auto A = make_shared<op::Parameter>(element::i32, shape_a);
|
||||||
|
int axis = 0;
|
||||||
Shape shape_r{3};
|
Shape shape_r{3};
|
||||||
auto r = make_shared<op::OneHot>(A, Shape{3}, 0);
|
auto depth = op::Constant::create(element::i32, {}, {shape_r[axis]});
|
||||||
|
auto on_value = op::Constant::create(element::i32, {}, {1});
|
||||||
|
auto off_value = op::Constant::create(element::i32, {}, {0});
|
||||||
|
auto r = make_shared<op::v1::OneHot>(A, depth, on_value, off_value, axis);
|
||||||
auto f = make_shared<Function>(r, ParameterVector{A});
|
auto f = make_shared<Function>(r, ParameterVector{A});
|
||||||
|
|
||||||
auto backend = runtime::Backend::create("${BACKEND_NAME}");
|
auto backend = runtime::Backend::create("${BACKEND_NAME}");
|
||||||
@ -62,8 +64,12 @@ NGRAPH_TEST(${BACKEND_NAME}, one_hot_scalar_1_in_3)
|
|||||||
{
|
{
|
||||||
Shape shape_a{};
|
Shape shape_a{};
|
||||||
auto A = make_shared<op::Parameter>(element::i32, shape_a);
|
auto A = make_shared<op::Parameter>(element::i32, shape_a);
|
||||||
|
int axis = 0;
|
||||||
Shape shape_r{3};
|
Shape shape_r{3};
|
||||||
auto r = make_shared<op::OneHot>(A, Shape{3}, 0);
|
auto depth = op::Constant::create(element::i32, {}, {shape_r[axis]});
|
||||||
|
auto on_value = op::Constant::create(element::i32, {}, {1});
|
||||||
|
auto off_value = op::Constant::create(element::i32, {}, {0});
|
||||||
|
auto r = make_shared<op::v1::OneHot>(A, depth, on_value, off_value, axis);
|
||||||
auto f = make_shared<Function>(r, ParameterVector{A});
|
auto f = make_shared<Function>(r, ParameterVector{A});
|
||||||
|
|
||||||
auto backend = runtime::Backend::create("${BACKEND_NAME}");
|
auto backend = runtime::Backend::create("${BACKEND_NAME}");
|
||||||
@ -83,7 +89,11 @@ NGRAPH_TEST(${BACKEND_NAME}, one_hot_scalar_0_in_3)
|
|||||||
Shape shape_a{};
|
Shape shape_a{};
|
||||||
auto A = make_shared<op::Parameter>(element::i32, shape_a);
|
auto A = make_shared<op::Parameter>(element::i32, shape_a);
|
||||||
Shape shape_r{3};
|
Shape shape_r{3};
|
||||||
auto r = make_shared<op::OneHot>(A, Shape{3}, 0);
|
int axis = 0;
|
||||||
|
auto depth = op::Constant::create(element::i32, {}, {shape_r[axis]});
|
||||||
|
auto on_value = op::Constant::create(element::i32, {}, {1});
|
||||||
|
auto off_value = op::Constant::create(element::i32, {}, {0});
|
||||||
|
auto r = make_shared<op::v1::OneHot>(A, depth, on_value, off_value, axis);
|
||||||
auto f = make_shared<Function>(r, ParameterVector{A});
|
auto f = make_shared<Function>(r, ParameterVector{A});
|
||||||
|
|
||||||
auto backend = runtime::Backend::create("${BACKEND_NAME}");
|
auto backend = runtime::Backend::create("${BACKEND_NAME}");
|
||||||
@ -103,7 +113,11 @@ NGRAPH_TEST(${BACKEND_NAME}, one_hot_vector_0)
|
|||||||
Shape shape_a{8};
|
Shape shape_a{8};
|
||||||
auto A = make_shared<op::Parameter>(element::i32, shape_a);
|
auto A = make_shared<op::Parameter>(element::i32, shape_a);
|
||||||
Shape shape_r{3, 8};
|
Shape shape_r{3, 8};
|
||||||
auto r = make_shared<op::OneHot>(A, Shape{3, 8}, 0);
|
int axis = 0;
|
||||||
|
auto depth = op::Constant::create(element::i32, {}, {shape_r[axis]});
|
||||||
|
auto on_value = op::Constant::create(element::i32, {}, {1});
|
||||||
|
auto off_value = op::Constant::create(element::i32, {}, {0});
|
||||||
|
auto r = make_shared<op::v1::OneHot>(A, depth, on_value, off_value, axis);
|
||||||
auto f = make_shared<Function>(r, ParameterVector{A});
|
auto f = make_shared<Function>(r, ParameterVector{A});
|
||||||
|
|
||||||
auto backend = runtime::Backend::create("${BACKEND_NAME}");
|
auto backend = runtime::Backend::create("${BACKEND_NAME}");
|
||||||
@ -125,7 +139,11 @@ NGRAPH_TEST(${BACKEND_NAME}, one_hot_vector_1)
|
|||||||
Shape shape_a{8};
|
Shape shape_a{8};
|
||||||
auto A = make_shared<op::Parameter>(element::i32, shape_a);
|
auto A = make_shared<op::Parameter>(element::i32, shape_a);
|
||||||
Shape shape_r{8, 3};
|
Shape shape_r{8, 3};
|
||||||
auto r = make_shared<op::OneHot>(A, Shape{8, 3}, 1);
|
int axis = 1;
|
||||||
|
auto depth = op::Constant::create(element::i32, {}, {shape_r[axis]});
|
||||||
|
auto on_value = op::Constant::create(element::i32, {}, {1});
|
||||||
|
auto off_value = op::Constant::create(element::i32, {}, {0});
|
||||||
|
auto r = make_shared<op::v1::OneHot>(A, depth, on_value, off_value, axis);
|
||||||
auto f = make_shared<Function>(r, ParameterVector{A});
|
auto f = make_shared<Function>(r, ParameterVector{A});
|
||||||
|
|
||||||
auto backend = runtime::Backend::create("${BACKEND_NAME}");
|
auto backend = runtime::Backend::create("${BACKEND_NAME}");
|
||||||
@ -147,7 +165,11 @@ NGRAPH_TEST(${BACKEND_NAME}, one_hot_vector_1_barely_oob)
|
|||||||
Shape shape_a{8};
|
Shape shape_a{8};
|
||||||
auto A = make_shared<op::Parameter>(element::i32, shape_a);
|
auto A = make_shared<op::Parameter>(element::i32, shape_a);
|
||||||
Shape shape_r{8, 3};
|
Shape shape_r{8, 3};
|
||||||
auto r = make_shared<op::OneHot>(A, Shape{8, 3}, 1);
|
int axis = 1;
|
||||||
|
auto depth = op::Constant::create(element::i32, {}, {shape_r[axis]});
|
||||||
|
auto on_value = op::Constant::create(element::i32, {}, {1});
|
||||||
|
auto off_value = op::Constant::create(element::i32, {}, {0});
|
||||||
|
auto r = make_shared<op::v1::OneHot>(A, depth, on_value, off_value, axis);
|
||||||
auto f = make_shared<Function>(r, ParameterVector{A});
|
auto f = make_shared<Function>(r, ParameterVector{A});
|
||||||
|
|
||||||
auto backend = runtime::Backend::create("${BACKEND_NAME}");
|
auto backend = runtime::Backend::create("${BACKEND_NAME}");
|
||||||
@ -200,7 +222,11 @@ NGRAPH_TEST(${BACKEND_NAME}, one_hot_matrix_0)
|
|||||||
Shape shape_a{3, 3};
|
Shape shape_a{3, 3};
|
||||||
auto A = make_shared<op::Parameter>(element::i32, shape_a);
|
auto A = make_shared<op::Parameter>(element::i32, shape_a);
|
||||||
Shape shape_r{3, 3, 3};
|
Shape shape_r{3, 3, 3};
|
||||||
auto r = make_shared<op::OneHot>(A, Shape{3, 3, 3}, 0);
|
int axis = 0;
|
||||||
|
auto depth = op::Constant::create(element::i32, {}, {shape_r[axis]});
|
||||||
|
auto on_value = op::Constant::create(element::i32, {}, {1});
|
||||||
|
auto off_value = op::Constant::create(element::i32, {}, {0});
|
||||||
|
auto r = make_shared<op::v1::OneHot>(A, depth, on_value, off_value, axis);
|
||||||
auto f = make_shared<Function>(r, ParameterVector{A});
|
auto f = make_shared<Function>(r, ParameterVector{A});
|
||||||
|
|
||||||
auto backend = runtime::Backend::create("${BACKEND_NAME}");
|
auto backend = runtime::Backend::create("${BACKEND_NAME}");
|
||||||
@ -230,7 +256,11 @@ NGRAPH_TEST(${BACKEND_NAME}, one_hot_vector_many_categories)
|
|||||||
Shape shape_a{6};
|
Shape shape_a{6};
|
||||||
auto A = make_shared<op::Parameter>(element::i32, shape_a);
|
auto A = make_shared<op::Parameter>(element::i32, shape_a);
|
||||||
Shape shape_r{6, category_count};
|
Shape shape_r{6, category_count};
|
||||||
auto r = make_shared<op::OneHot>(A, Shape{6, category_count}, 1);
|
int axis = 1;
|
||||||
|
auto depth = op::Constant::create(element::i32, {}, {shape_r[axis]});
|
||||||
|
auto on_value = op::Constant::create(element::i32, {}, {1});
|
||||||
|
auto off_value = op::Constant::create(element::i32, {}, {0});
|
||||||
|
auto r = make_shared<op::v1::OneHot>(A, depth, on_value, off_value, axis);
|
||||||
auto f = make_shared<Function>(r, ParameterVector{A});
|
auto f = make_shared<Function>(r, ParameterVector{A});
|
||||||
|
|
||||||
auto backend = runtime::Backend::create("${BACKEND_NAME}");
|
auto backend = runtime::Backend::create("${BACKEND_NAME}");
|
||||||
@ -255,3 +285,32 @@ NGRAPH_TEST(${BACKEND_NAME}, one_hot_vector_many_categories)
|
|||||||
}
|
}
|
||||||
EXPECT_EQ(bit_positions, input_data);
|
EXPECT_EQ(bit_positions, input_data);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
NGRAPH_TEST(${BACKEND_NAME}, one_hot_on_off_float)
|
||||||
|
{
|
||||||
|
Shape shape_a{3, 3};
|
||||||
|
auto A = make_shared<op::Parameter>(element::i32, shape_a);
|
||||||
|
Shape shape_r{3, 3, 3};
|
||||||
|
int axis = 0;
|
||||||
|
auto depth = op::Constant::create(element::i32, {}, {shape_r[axis]});
|
||||||
|
auto on_value = op::Constant::create(element::f32, {}, {2.5});
|
||||||
|
auto off_value = op::Constant::create(element::f32, {}, {0.5});
|
||||||
|
auto r = make_shared<op::v1::OneHot>(A, depth, on_value, off_value, axis);
|
||||||
|
auto f = make_shared<Function>(r, ParameterVector{A});
|
||||||
|
|
||||||
|
auto backend = runtime::Backend::create("${BACKEND_NAME}");
|
||||||
|
|
||||||
|
// Create some tensors for input/output
|
||||||
|
auto a = backend->create_tensor(element::i32, shape_a);
|
||||||
|
copy_data(a,
|
||||||
|
vector<int32_t>{
|
||||||
|
0, 1, 1, 2, 1, 0, 0, 2, 1,
|
||||||
|
});
|
||||||
|
auto result = backend->create_tensor(element::f32, shape_r);
|
||||||
|
|
||||||
|
auto handle = backend->compile(f);
|
||||||
|
handle->call_with_validate({result}, {a});
|
||||||
|
EXPECT_EQ((vector<float>{2.5, 0.5, 0.5, 0.5, 0.5, 2.5, 2.5, 0.5, 0.5, 0.5, 2.5, 2.5, 0.5, 2.5,
|
||||||
|
0.5, 0.5, 0.5, 2.5, 0.5, 0.5, 0.5, 2.5, 0.5, 0.5, 0.5, 2.5, 0.5}),
|
||||||
|
read_vector<float>(result));
|
||||||
|
}
|
||||||
|
@ -562,7 +562,7 @@ namespace
|
|||||||
|
|
||||||
void op_is_OneHot()
|
void op_is_OneHot()
|
||||||
{
|
{
|
||||||
op::OneHot node;
|
op::v1::OneHot node;
|
||||||
EXPECT_FALSE(op::is_unary_elementwise_arithmetic(&node));
|
EXPECT_FALSE(op::is_unary_elementwise_arithmetic(&node));
|
||||||
EXPECT_FALSE(op::is_binary_elementwise_arithmetic(&node));
|
EXPECT_FALSE(op::is_binary_elementwise_arithmetic(&node));
|
||||||
EXPECT_FALSE(op::is_binary_elementwise_comparison(&node));
|
EXPECT_FALSE(op::is_binary_elementwise_comparison(&node));
|
||||||
|
@ -749,16 +749,6 @@ pad_negative_exterior_4d
|
|||||||
pad_2channel_2image_asym
|
pad_2channel_2image_asym
|
||||||
pad_symmetric
|
pad_symmetric
|
||||||
|
|
||||||
# Output 0 type 'i32' does not match Result type 'i64'" thrown in the test body.
|
|
||||||
one_hot_scalar_2_in_3
|
|
||||||
one_hot_scalar_1_in_3
|
|
||||||
one_hot_scalar_0_in_3
|
|
||||||
one_hot_vector_0
|
|
||||||
one_hot_vector_1
|
|
||||||
one_hot_vector_1_barely_oob
|
|
||||||
one_hot_matrix_0
|
|
||||||
one_hot_vector_many_categories
|
|
||||||
|
|
||||||
# LRN operation should be converted to LRN_IE
|
# LRN operation should be converted to LRN_IE
|
||||||
lrn_across_h
|
lrn_across_h
|
||||||
lrn_across_nw
|
lrn_across_nw
|
||||||
|
@ -869,14 +869,97 @@ protected:
|
|||||||
args[0]->get_data_ptr<const T>(), out[0]->get_data_ptr<T>(), element_count);
|
args[0]->get_data_ptr<const T>(), out[0]->get_data_ptr<T>(), element_count);
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
case OP_TYPEID::OneHot:
|
case OP_TYPEID::OneHot_v1:
|
||||||
{
|
{
|
||||||
const op::OneHot* oh = static_cast<const op::OneHot*>(&node);
|
const op::v1::OneHot* oh = static_cast<const op::v1::OneHot*>(&node);
|
||||||
reference::one_hot<T>(args[0]->get_data_ptr<const T>(),
|
T on_value = args[2]->get_data_ptr<T>()[0];
|
||||||
out[0]->get_data_ptr<T>(),
|
T off_value = args[3]->get_data_ptr<T>()[0];
|
||||||
node.get_input_shape(0),
|
|
||||||
node.get_output_shape(0),
|
switch (args[0]->get_element_type())
|
||||||
oh->get_one_hot_axis());
|
{
|
||||||
|
case element::Type_t::i8:
|
||||||
|
reference::one_hot(args[0]->get_data_ptr<const int8_t>(),
|
||||||
|
out[0]->get_data_ptr<T>(),
|
||||||
|
node.get_input_shape(0),
|
||||||
|
node.get_output_shape(0),
|
||||||
|
oh->get_axis(),
|
||||||
|
on_value,
|
||||||
|
off_value);
|
||||||
|
break;
|
||||||
|
case element::Type_t::i16:
|
||||||
|
reference::one_hot(args[0]->get_data_ptr<const int16_t>(),
|
||||||
|
out[0]->get_data_ptr<T>(),
|
||||||
|
node.get_input_shape(0),
|
||||||
|
node.get_output_shape(0),
|
||||||
|
oh->get_axis(),
|
||||||
|
on_value,
|
||||||
|
off_value);
|
||||||
|
break;
|
||||||
|
case element::Type_t::i32:
|
||||||
|
reference::one_hot(args[0]->get_data_ptr<const int32_t>(),
|
||||||
|
out[0]->get_data_ptr<T>(),
|
||||||
|
node.get_input_shape(0),
|
||||||
|
node.get_output_shape(0),
|
||||||
|
oh->get_axis(),
|
||||||
|
on_value,
|
||||||
|
off_value);
|
||||||
|
break;
|
||||||
|
case element::Type_t::i64:
|
||||||
|
reference::one_hot(args[0]->get_data_ptr<const int64_t>(),
|
||||||
|
out[0]->get_data_ptr<T>(),
|
||||||
|
node.get_input_shape(0),
|
||||||
|
node.get_output_shape(0),
|
||||||
|
oh->get_axis(),
|
||||||
|
on_value,
|
||||||
|
off_value);
|
||||||
|
break;
|
||||||
|
case element::Type_t::u8:
|
||||||
|
reference::one_hot(args[0]->get_data_ptr<const uint8_t>(),
|
||||||
|
out[0]->get_data_ptr<T>(),
|
||||||
|
node.get_input_shape(0),
|
||||||
|
node.get_output_shape(0),
|
||||||
|
oh->get_axis(),
|
||||||
|
on_value,
|
||||||
|
off_value);
|
||||||
|
break;
|
||||||
|
case element::Type_t::u16:
|
||||||
|
reference::one_hot(args[0]->get_data_ptr<const uint16_t>(),
|
||||||
|
out[0]->get_data_ptr<T>(),
|
||||||
|
node.get_input_shape(0),
|
||||||
|
node.get_output_shape(0),
|
||||||
|
oh->get_axis(),
|
||||||
|
on_value,
|
||||||
|
off_value);
|
||||||
|
break;
|
||||||
|
case element::Type_t::u32:
|
||||||
|
reference::one_hot(args[0]->get_data_ptr<const uint32_t>(),
|
||||||
|
out[0]->get_data_ptr<T>(),
|
||||||
|
node.get_input_shape(0),
|
||||||
|
node.get_output_shape(0),
|
||||||
|
oh->get_axis(),
|
||||||
|
on_value,
|
||||||
|
off_value);
|
||||||
|
break;
|
||||||
|
case element::Type_t::u64:
|
||||||
|
reference::one_hot(args[0]->get_data_ptr<const uint64_t>(),
|
||||||
|
out[0]->get_data_ptr<T>(),
|
||||||
|
node.get_input_shape(0),
|
||||||
|
node.get_output_shape(0),
|
||||||
|
oh->get_axis(),
|
||||||
|
on_value,
|
||||||
|
off_value);
|
||||||
|
break;
|
||||||
|
case element::Type_t::undefined:
|
||||||
|
case element::Type_t::dynamic:
|
||||||
|
case element::Type_t::u1:
|
||||||
|
case element::Type_t::boolean:
|
||||||
|
case element::Type_t::bf16:
|
||||||
|
case element::Type_t::f16:
|
||||||
|
case element::Type_t::f32:
|
||||||
|
case element::Type_t::f64:
|
||||||
|
default: NGRAPH_CHECK(false, "Indices input element type must be integer");
|
||||||
|
}
|
||||||
|
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
case OP_TYPEID::Parameter: break;
|
case OP_TYPEID::Parameter: break;
|
||||||
|
@ -33,6 +33,7 @@ NGRAPH_OP(LogicalOr, op::v1)
|
|||||||
NGRAPH_OP(LogicalXor, op::v1)
|
NGRAPH_OP(LogicalXor, op::v1)
|
||||||
NGRAPH_OP(LogicalNot, op::v1)
|
NGRAPH_OP(LogicalNot, op::v1)
|
||||||
NGRAPH_OP(GatherTree, op::v1)
|
NGRAPH_OP(GatherTree, op::v1)
|
||||||
|
NGRAPH_OP(OneHot, op::v1)
|
||||||
#undef ID_SUFFIX
|
#undef ID_SUFFIX
|
||||||
|
|
||||||
#define ID_SUFFIX(NAME) NAME##_v3
|
#define ID_SUFFIX(NAME) NAME##_v3
|
||||||
|
@ -102,7 +102,6 @@ NGRAPH_OP(MVN, ngraph::op)
|
|||||||
NGRAPH_OP(Negative, ngraph::op)
|
NGRAPH_OP(Negative, ngraph::op)
|
||||||
NGRAPH_OP(Not, ngraph::op)
|
NGRAPH_OP(Not, ngraph::op)
|
||||||
NGRAPH_OP(NotEqual, ngraph::op)
|
NGRAPH_OP(NotEqual, ngraph::op)
|
||||||
NGRAPH_OP(OneHot, ngraph::op)
|
|
||||||
NGRAPH_OP(Or, ngraph::op)
|
NGRAPH_OP(Or, ngraph::op)
|
||||||
NGRAPH_OP(Parameter, ngraph::op)
|
NGRAPH_OP(Parameter, ngraph::op)
|
||||||
NGRAPH_OP(Power, ngraph::op)
|
NGRAPH_OP(Power, ngraph::op)
|
||||||
|
@ -338,33 +338,6 @@ namespace opset0_downgrade
|
|||||||
return op_cast_binary_elementwise_node<op::v0::NotEqual, op::v1::NotEqual>(node);
|
return op_cast_binary_elementwise_node<op::v0::NotEqual, op::v1::NotEqual>(node);
|
||||||
}
|
}
|
||||||
|
|
||||||
shared_ptr<Node> op_cast(shared_ptr<op::v1::OneHot> node)
|
|
||||||
{
|
|
||||||
const auto indices = node->input_value(0);
|
|
||||||
const auto depth = node->input_value(1).get_node();
|
|
||||||
auto on_value = node->input_value(2);
|
|
||||||
auto off_value = node->input_value(3);
|
|
||||||
const auto axis = node->get_axis();
|
|
||||||
|
|
||||||
NGRAPH_CHECK(op::is_constant(depth), "depth input must be constant", *node);
|
|
||||||
const auto output_pshape = node->get_output_partial_shape(0);
|
|
||||||
NGRAPH_CHECK(output_pshape.is_static(), "output shape must be static", *node);
|
|
||||||
const auto output_shape = output_pshape.to_shape();
|
|
||||||
|
|
||||||
auto one_hot = std::make_shared<ngraph::op::Convert>(
|
|
||||||
std::make_shared<ngraph::op::OneHot>(indices, output_shape, axis),
|
|
||||||
on_value.get_element_type());
|
|
||||||
|
|
||||||
auto broadcasted_values = builder::numpy_broadcast_outputs({one_hot, on_value, off_value});
|
|
||||||
on_value = broadcasted_values[1];
|
|
||||||
off_value = broadcasted_values[2];
|
|
||||||
|
|
||||||
auto replacement_node = one_hot * (on_value - off_value) + off_value;
|
|
||||||
|
|
||||||
replace_node(node, replacement_node);
|
|
||||||
return replacement_node;
|
|
||||||
}
|
|
||||||
|
|
||||||
shared_ptr<Node> op_cast(shared_ptr<op::v1::Power> node)
|
shared_ptr<Node> op_cast(shared_ptr<op::v1::Power> node)
|
||||||
{
|
{
|
||||||
return op_cast_binary_elementwise_node<op::v0::Power, op::v1::Power>(node);
|
return op_cast_binary_elementwise_node<op::v0::Power, op::v1::Power>(node);
|
||||||
|
@ -302,27 +302,6 @@ namespace opset1_upgrade
|
|||||||
return op_cast_binary_elementwise_node<op::v0::NotEqual, op::v1::NotEqual>(node);
|
return op_cast_binary_elementwise_node<op::v0::NotEqual, op::v1::NotEqual>(node);
|
||||||
}
|
}
|
||||||
|
|
||||||
shared_ptr<Node> op_cast(shared_ptr<op::OneHot> node)
|
|
||||||
{
|
|
||||||
const auto indices = node->input_value(0).get_node_shared_ptr();
|
|
||||||
const auto one_hot_axis = node->get_one_hot_axis();
|
|
||||||
|
|
||||||
const auto output_pshape = node->get_output_partial_shape(0);
|
|
||||||
NGRAPH_CHECK(output_pshape[one_hot_axis].is_static(),
|
|
||||||
"OneHot:v0 one hot axis dimension must be static ",
|
|
||||||
*node);
|
|
||||||
const auto depth = output_pshape[one_hot_axis].get_length();
|
|
||||||
const auto depth_node = op::Constant::create(element::i64, Shape{}, {depth});
|
|
||||||
|
|
||||||
const auto on_value = op::Constant::create(element::i64, Shape{}, {1});
|
|
||||||
const auto off_value = op::Constant::create(element::i64, Shape{}, {0});
|
|
||||||
|
|
||||||
auto replacement_node =
|
|
||||||
make_shared<op::v1::OneHot>(indices, depth_node, on_value, off_value, one_hot_axis);
|
|
||||||
replace_node(node, replacement_node);
|
|
||||||
return replacement_node;
|
|
||||||
}
|
|
||||||
|
|
||||||
shared_ptr<Node> op_cast(shared_ptr<op::Or> node)
|
shared_ptr<Node> op_cast(shared_ptr<op::Or> node)
|
||||||
{
|
{
|
||||||
return op_cast_binary_elementwise_node<op::v0::Or, op::v1::LogicalOr>(node);
|
return op_cast_binary_elementwise_node<op::v0::Or, op::v1::LogicalOr>(node);
|
||||||
|
@ -18,363 +18,9 @@
|
|||||||
#include "ngraph/ngraph.hpp"
|
#include "ngraph/ngraph.hpp"
|
||||||
#include "util/type_prop.hpp"
|
#include "util/type_prop.hpp"
|
||||||
|
|
||||||
NGRAPH_SUPPRESS_DEPRECATED_START
|
|
||||||
|
|
||||||
using namespace std;
|
using namespace std;
|
||||||
using namespace ngraph;
|
using namespace ngraph;
|
||||||
|
|
||||||
TEST(type_prop, one_hot_deduce_scalar)
|
|
||||||
{
|
|
||||||
auto param = make_shared<op::Parameter>(element::i32, Shape{});
|
|
||||||
auto oh = make_shared<op::OneHot>(param, Shape{9}, 0);
|
|
||||||
ASSERT_EQ(oh->get_element_type(), element::i32);
|
|
||||||
ASSERT_EQ(oh->get_shape(), (Shape{9}));
|
|
||||||
}
|
|
||||||
|
|
||||||
TEST(type_prop, one_hot_deduce_vector_0)
|
|
||||||
{
|
|
||||||
auto param = make_shared<op::Parameter>(element::i32, Shape{8});
|
|
||||||
auto oh = make_shared<op::OneHot>(param, Shape{9, 8}, 0);
|
|
||||||
ASSERT_EQ(oh->get_element_type(), element::i32);
|
|
||||||
ASSERT_EQ(oh->get_shape(), (Shape{9, 8}));
|
|
||||||
}
|
|
||||||
|
|
||||||
TEST(type_prop, one_hot_deduce_vector_1)
|
|
||||||
{
|
|
||||||
auto param = make_shared<op::Parameter>(element::i32, Shape{8});
|
|
||||||
auto oh = make_shared<op::OneHot>(param, Shape{8, 9}, 1);
|
|
||||||
ASSERT_EQ(oh->get_element_type(), element::i32);
|
|
||||||
ASSERT_EQ(oh->get_shape(), (Shape{8, 9}));
|
|
||||||
}
|
|
||||||
|
|
||||||
TEST(type_prop, one_hot_deduce_matrix_0)
|
|
||||||
{
|
|
||||||
auto param = make_shared<op::Parameter>(element::i32, Shape{12, 24});
|
|
||||||
auto oh = make_shared<op::OneHot>(param, Shape{2, 12, 24}, 0);
|
|
||||||
ASSERT_EQ(oh->get_element_type(), element::i32);
|
|
||||||
ASSERT_EQ(oh->get_shape(), (Shape{2, 12, 24}));
|
|
||||||
}
|
|
||||||
|
|
||||||
TEST(type_prop, one_hot_deduce_matrix_1)
|
|
||||||
{
|
|
||||||
auto param = make_shared<op::Parameter>(element::i32, Shape{12, 24});
|
|
||||||
auto oh = make_shared<op::OneHot>(param, Shape{12, 2, 24}, 1);
|
|
||||||
ASSERT_EQ(oh->get_element_type(), element::i32);
|
|
||||||
ASSERT_EQ(oh->get_shape(), (Shape{12, 2, 24}));
|
|
||||||
}
|
|
||||||
|
|
||||||
TEST(type_prop, one_hot_deduce_matrix_2)
|
|
||||||
{
|
|
||||||
auto param = make_shared<op::Parameter>(element::i32, Shape{12, 24});
|
|
||||||
auto oh = make_shared<op::OneHot>(param, Shape{12, 24, 2}, 2);
|
|
||||||
ASSERT_EQ(oh->get_element_type(), element::i32);
|
|
||||||
ASSERT_EQ(oh->get_shape(), (Shape{12, 24, 2}));
|
|
||||||
}
|
|
||||||
|
|
||||||
TEST(type_prop, one_hot_deduce_et_dynamic)
|
|
||||||
{
|
|
||||||
auto param = make_shared<op::Parameter>(element::dynamic, Shape{12, 24});
|
|
||||||
auto oh = make_shared<op::OneHot>(param, Shape{12, 24, 2}, 2);
|
|
||||||
ASSERT_EQ(oh->get_element_type(), element::dynamic);
|
|
||||||
ASSERT_EQ(oh->get_shape(), (Shape{12, 24, 2}));
|
|
||||||
}
|
|
||||||
|
|
||||||
TEST(type_prop, one_hot_deduce_floating_point)
|
|
||||||
{
|
|
||||||
auto param = make_shared<op::Parameter>(element::f32, Shape{12, 24});
|
|
||||||
try
|
|
||||||
{
|
|
||||||
auto oh = make_shared<op::OneHot>(param, Shape{12, 24, 8}, 3);
|
|
||||||
// Should have thrown, so fail if it didn't
|
|
||||||
FAIL() << "Invalid floating-point element type not detected.";
|
|
||||||
}
|
|
||||||
catch (const NodeValidationFailure& error)
|
|
||||||
{
|
|
||||||
EXPECT_HAS_SUBSTRING(error.what(),
|
|
||||||
std::string("Argument does not have integral element type."));
|
|
||||||
}
|
|
||||||
catch (...)
|
|
||||||
{
|
|
||||||
FAIL() << "Deduced type check failed for unexpected reason";
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
TEST(type_prop, one_hot_deduce_axis_oob)
|
|
||||||
{
|
|
||||||
auto param = make_shared<op::Parameter>(element::i32, Shape{12, 24});
|
|
||||||
try
|
|
||||||
{
|
|
||||||
auto oh = make_shared<op::OneHot>(param, Shape{12, 24, 8}, 3);
|
|
||||||
// Should have thrown, so fail if it didn't
|
|
||||||
FAIL() << "One-hot axis out of bounds not detected.";
|
|
||||||
}
|
|
||||||
catch (const NodeValidationFailure& error)
|
|
||||||
{
|
|
||||||
EXPECT_HAS_SUBSTRING(error.what(), std::string("One-hot axis (3) is out of bounds"));
|
|
||||||
}
|
|
||||||
catch (...)
|
|
||||||
{
|
|
||||||
FAIL() << "Deduced type check failed for unexpected reason";
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
TEST(type_prop, one_hot_deduce_shape_incompatible)
|
|
||||||
{
|
|
||||||
auto param = make_shared<op::Parameter>(element::i32, Shape{12, 24});
|
|
||||||
try
|
|
||||||
{
|
|
||||||
auto oh = make_shared<op::OneHot>(param, Shape{12, 22, 8}, 2);
|
|
||||||
// Should have thrown, so fail if it didn't
|
|
||||||
FAIL() << "Incompatible one-hot output shape not detected.";
|
|
||||||
}
|
|
||||||
catch (const ngraph_error& error)
|
|
||||||
{
|
|
||||||
EXPECT_HAS_SUBSTRING(
|
|
||||||
error.what(), std::string("Argument shape {12,24} does not match the expected shape"));
|
|
||||||
}
|
|
||||||
catch (...)
|
|
||||||
{
|
|
||||||
FAIL() << "Deduced type check failed for unexpected reason";
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
TEST(type_prop, one_hot_partial_rank_dynamic_rank_dynamic)
|
|
||||||
{
|
|
||||||
PartialShape input_shape{PartialShape::dynamic()};
|
|
||||||
PartialShape requested_shape{PartialShape::dynamic()};
|
|
||||||
size_t one_hot_axis{3000};
|
|
||||||
|
|
||||||
auto param = make_shared<op::Parameter>(element::i32, input_shape);
|
|
||||||
try
|
|
||||||
{
|
|
||||||
auto oh = make_shared<op::OneHot>(param, requested_shape, one_hot_axis);
|
|
||||||
// Should have thrown, so fail if it didn't
|
|
||||||
FAIL() << "Dynamic rank for requested result shape not detected";
|
|
||||||
}
|
|
||||||
catch (const ngraph_error& error)
|
|
||||||
{
|
|
||||||
EXPECT_HAS_SUBSTRING(error.what(), std::string("Requested result shape has dynamic rank"));
|
|
||||||
}
|
|
||||||
catch (...)
|
|
||||||
{
|
|
||||||
FAIL() << "Deduced type check failed for unexpected reason";
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
TEST(type_prop, one_hot_partial_rank_dynamic_rank_static_dynamic_ok)
|
|
||||||
{
|
|
||||||
PartialShape input_shape{PartialShape::dynamic()};
|
|
||||||
PartialShape requested_shape{Dimension::dynamic(), 2, 3, Dimension::dynamic()};
|
|
||||||
size_t one_hot_axis{2};
|
|
||||||
|
|
||||||
auto param = make_shared<op::Parameter>(element::i32, input_shape);
|
|
||||||
auto oh = make_shared<op::OneHot>(param, requested_shape, one_hot_axis);
|
|
||||||
|
|
||||||
ASSERT_EQ(oh->get_output_element_type(0), element::i32);
|
|
||||||
ASSERT_TRUE(oh->get_output_partial_shape(0).same_scheme(
|
|
||||||
PartialShape{Dimension::dynamic(), 2, 3, Dimension::dynamic()}));
|
|
||||||
}
|
|
||||||
|
|
||||||
TEST(type_prop, one_hot_partial_rank_dynamic_rank_static_dynamic_one_hot_dim_dynamic)
|
|
||||||
{
|
|
||||||
PartialShape input_shape{PartialShape::dynamic()};
|
|
||||||
PartialShape requested_shape{Dimension::dynamic(), 2, 3, Dimension::dynamic()};
|
|
||||||
size_t one_hot_axis{3};
|
|
||||||
|
|
||||||
auto param = make_shared<op::Parameter>(element::i32, input_shape);
|
|
||||||
try
|
|
||||||
{
|
|
||||||
auto oh = make_shared<op::OneHot>(param, requested_shape, one_hot_axis);
|
|
||||||
// Should have thrown, so fail if it didn't
|
|
||||||
FAIL() << "Dynamic one-hot dimension not detected";
|
|
||||||
}
|
|
||||||
catch (const ngraph_error& error)
|
|
||||||
{
|
|
||||||
EXPECT_HAS_SUBSTRING(error.what(),
|
|
||||||
std::string("Requested result shape ({?,2,3,?}) has dynamic dimension "
|
|
||||||
"at the one-hot axis (3)"));
|
|
||||||
}
|
|
||||||
catch (...)
|
|
||||||
{
|
|
||||||
FAIL() << "Deduced type check failed for unexpected reason";
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
TEST(type_prop, one_hot_partial_rank_dynamic_rank_static_dynamic_one_hot_axis_oob)
|
|
||||||
{
|
|
||||||
PartialShape input_shape{PartialShape::dynamic()};
|
|
||||||
PartialShape requested_shape{Dimension::dynamic(), 2, 3, Dimension::dynamic()};
|
|
||||||
size_t one_hot_axis{4};
|
|
||||||
|
|
||||||
auto param = make_shared<op::Parameter>(element::i32, input_shape);
|
|
||||||
try
|
|
||||||
{
|
|
||||||
auto oh = make_shared<op::OneHot>(param, requested_shape, one_hot_axis);
|
|
||||||
// Should have thrown, so fail if it didn't
|
|
||||||
FAIL() << "One-hot axis out of bounds not detected (rank-dynamic argument, rank-static "
|
|
||||||
"dynamic result shape)";
|
|
||||||
}
|
|
||||||
catch (const ngraph_error& error)
|
|
||||||
{
|
|
||||||
EXPECT_HAS_SUBSTRING(
|
|
||||||
error.what(),
|
|
||||||
std::string("One-hot axis (4) is out of bounds (requested result shape: {?,2,3,?})"));
|
|
||||||
}
|
|
||||||
catch (...)
|
|
||||||
{
|
|
||||||
FAIL() << "Deduced type check failed for unexpected reason";
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
TEST(type_prop, one_hot_partial_rank_static_dynamic_rank_static_dynamic_ok)
|
|
||||||
{
|
|
||||||
PartialShape input_shape{3, Dimension::dynamic(), Dimension::dynamic(), 4};
|
|
||||||
PartialShape requested_shape{Dimension::dynamic(), 2, 3, Dimension::dynamic(), 4};
|
|
||||||
size_t one_hot_axis{2};
|
|
||||||
|
|
||||||
auto param = make_shared<op::Parameter>(element::i32, input_shape);
|
|
||||||
auto oh = make_shared<op::OneHot>(param, requested_shape, one_hot_axis);
|
|
||||||
|
|
||||||
ASSERT_EQ(oh->get_output_element_type(0), element::i32);
|
|
||||||
ASSERT_TRUE(oh->get_output_partial_shape(0).same_scheme(
|
|
||||||
PartialShape{3, 2, 3, Dimension::dynamic(), 4}));
|
|
||||||
}
|
|
||||||
|
|
||||||
TEST(type_prop,
|
|
||||||
one_hot_partial_rank_static_dynamic_rank_static_dynamic_incompatible_rank_input_short)
|
|
||||||
{
|
|
||||||
PartialShape input_shape{3, Dimension::dynamic(), Dimension::dynamic()};
|
|
||||||
PartialShape requested_shape{Dimension::dynamic(), 2, 3, Dimension::dynamic(), 4};
|
|
||||||
size_t one_hot_axis{2};
|
|
||||||
|
|
||||||
auto param = make_shared<op::Parameter>(element::i32, input_shape);
|
|
||||||
try
|
|
||||||
{
|
|
||||||
auto oh = make_shared<op::OneHot>(param, requested_shape, one_hot_axis);
|
|
||||||
// Should have thrown, so fail if it didn't
|
|
||||||
FAIL() << "Incompatible input/output ranks not detected (rank-static dynamic argument, "
|
|
||||||
"rank-static dynamic result shape)";
|
|
||||||
}
|
|
||||||
catch (const ngraph_error& error)
|
|
||||||
{
|
|
||||||
EXPECT_HAS_SUBSTRING(
|
|
||||||
error.what(),
|
|
||||||
std::string("Argument shape {3,?,?} does not match the expected shape of {?,2,?,4}"));
|
|
||||||
}
|
|
||||||
catch (...)
|
|
||||||
{
|
|
||||||
FAIL() << "Deduced type check failed for unexpected reason";
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
TEST(type_prop,
|
|
||||||
one_hot_partial_rank_static_dynamic_rank_static_dynamic_incompatible_rank_input_long)
|
|
||||||
{
|
|
||||||
PartialShape input_shape{3, Dimension::dynamic(), Dimension::dynamic(), 4, 5};
|
|
||||||
PartialShape requested_shape{Dimension::dynamic(), 2, 3, Dimension::dynamic(), 4};
|
|
||||||
size_t one_hot_axis{2};
|
|
||||||
|
|
||||||
auto param = make_shared<op::Parameter>(element::i32, input_shape);
|
|
||||||
try
|
|
||||||
{
|
|
||||||
auto oh = make_shared<op::OneHot>(param, requested_shape, one_hot_axis);
|
|
||||||
// Should have thrown, so fail if it didn't
|
|
||||||
FAIL() << "Incompatible input/output ranks not detected (rank-static dynamic argument, "
|
|
||||||
"rank-static dynamic result shape)";
|
|
||||||
}
|
|
||||||
catch (const ngraph_error& error)
|
|
||||||
{
|
|
||||||
EXPECT_HAS_SUBSTRING(
|
|
||||||
error.what(),
|
|
||||||
std::string(
|
|
||||||
"Argument shape {3,?,?,4,5} does not match the expected shape of {?,2,?,4}"));
|
|
||||||
}
|
|
||||||
catch (...)
|
|
||||||
{
|
|
||||||
FAIL() << "Deduced type check failed for unexpected reason";
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
TEST(type_prop, one_hot_partial_rank_static_dynamic_rank_static_dynamic_incompatible_dim)
|
|
||||||
{
|
|
||||||
PartialShape input_shape{3, Dimension::dynamic(), Dimension::dynamic(), 5};
|
|
||||||
PartialShape requested_shape{Dimension::dynamic(), 2, 3, Dimension::dynamic(), 4};
|
|
||||||
size_t one_hot_axis{2};
|
|
||||||
|
|
||||||
auto param = make_shared<op::Parameter>(element::i32, input_shape);
|
|
||||||
try
|
|
||||||
{
|
|
||||||
auto oh = make_shared<op::OneHot>(param, requested_shape, one_hot_axis);
|
|
||||||
// Should have thrown, so fail if it didn't
|
|
||||||
FAIL() << "Incompatible input/output dimensions not detected (rank-static dynamic "
|
|
||||||
"argument, rank-static dynamic result shape)";
|
|
||||||
}
|
|
||||||
catch (const ngraph_error& error)
|
|
||||||
{
|
|
||||||
EXPECT_HAS_SUBSTRING(
|
|
||||||
error.what(),
|
|
||||||
std::string("Argument shape {3,?,?,5} does not match the expected shape of {?,2,?,4}"));
|
|
||||||
}
|
|
||||||
catch (...)
|
|
||||||
{
|
|
||||||
FAIL() << "Deduced type check failed for unexpected reason";
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
TEST(type_prop, one_hot_partial_rank_static_dynamic_rank_static_dynamic_one_hot_dim_dynamic)
|
|
||||||
{
|
|
||||||
PartialShape input_shape{3, Dimension::dynamic(), Dimension::dynamic(), 4};
|
|
||||||
PartialShape requested_shape{
|
|
||||||
Dimension::dynamic(), 2, Dimension::dynamic(), Dimension::dynamic(), 4};
|
|
||||||
size_t one_hot_axis{2};
|
|
||||||
|
|
||||||
auto param = make_shared<op::Parameter>(element::i32, input_shape);
|
|
||||||
try
|
|
||||||
{
|
|
||||||
auto oh = make_shared<op::OneHot>(param, requested_shape, one_hot_axis);
|
|
||||||
// Should have thrown, so fail if it didn't
|
|
||||||
FAIL() << "Dynamic one-hot dimension not detected (rank-static dynamic argument, "
|
|
||||||
"rank-static dynamic result shape)";
|
|
||||||
}
|
|
||||||
catch (const ngraph_error& error)
|
|
||||||
{
|
|
||||||
EXPECT_HAS_SUBSTRING(error.what(),
|
|
||||||
std::string("Requested result shape ({?,2,?,?,4}) has dynamic "
|
|
||||||
"dimension at the one-hot axis (2)"));
|
|
||||||
}
|
|
||||||
catch (...)
|
|
||||||
{
|
|
||||||
FAIL() << "Deduced type check failed for unexpected reason";
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
TEST(type_prop, one_hot_partial_rank_static_dynamic_rank_static_dynamic_one_hot_axis_oob)
|
|
||||||
{
|
|
||||||
PartialShape input_shape{3, Dimension::dynamic(), Dimension::dynamic(), 4};
|
|
||||||
PartialShape requested_shape{
|
|
||||||
Dimension::dynamic(), 2, Dimension::dynamic(), Dimension::dynamic(), 4};
|
|
||||||
size_t one_hot_axis{2};
|
|
||||||
|
|
||||||
auto param = make_shared<op::Parameter>(element::i32, input_shape);
|
|
||||||
try
|
|
||||||
{
|
|
||||||
auto oh = make_shared<op::OneHot>(param, requested_shape, one_hot_axis);
|
|
||||||
// Should have thrown, so fail if it didn't
|
|
||||||
FAIL() << "One-hot axis out of bounds not detected (rank-static dynamic argument, "
|
|
||||||
"rank-static dynamic result shape)";
|
|
||||||
}
|
|
||||||
catch (const ngraph_error& error)
|
|
||||||
{
|
|
||||||
EXPECT_HAS_SUBSTRING(error.what(),
|
|
||||||
std::string("Requested result shape ({?,2,?,?,4}) has dynamic "
|
|
||||||
"dimension at the one-hot axis (2)"));
|
|
||||||
}
|
|
||||||
catch (...)
|
|
||||||
{
|
|
||||||
FAIL() << "Deduced type check failed for unexpected reason";
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
TEST(type_prop, one_hot_v1_output_shape)
|
TEST(type_prop, one_hot_v1_output_shape)
|
||||||
{
|
{
|
||||||
auto indices = make_shared<op::Parameter>(element::i64, Shape{3});
|
auto indices = make_shared<op::Parameter>(element::i64, Shape{3});
|
||||||
|
Loading…
Reference in New Issue
Block a user