Reshape should support reshape to zero shapes (#5828)
* Reshape should support reshape to zero shapes * Fixed comments * Fixed backward compatible check * Fixed myriad tests * Removed header * Fixed myriad tests * Disabled Myriad tests * Fix tests * Fixed evaluate * Fixed comments * FIxed tests * Fixed tests * Fixed code style * Fixed Myriad tests * Added more tests
This commit is contained in:
parent
ac1803c3ad
commit
d56cf51c81
@ -82,7 +82,7 @@ INSTANTIATE_TEST_CASE_P(smoke_NGraph, DynamicToStaticTopKPropagationConcatBased,
|
||||
class DynamicToStaticTopKPropagationConcatReshape : public DynamicToStaticTopKPropagationConcatBased {
|
||||
protected:
|
||||
std::shared_ptr<ngraph::Node> buildSubgraph(std::shared_ptr<ngraph::Node> node) const override {
|
||||
return std::make_shared<ngraph::opset5::Reshape>(node, ngraph::opset5::Constant::create(ngraph::element::i64, ngraph::Shape{1}, {0}), false);
|
||||
return std::make_shared<ngraph::opset5::Reshape>(node, ngraph::opset5::Constant::create(ngraph::element::i64, ngraph::Shape{}, {1}), false);
|
||||
}
|
||||
};
|
||||
|
||||
|
@ -64,6 +64,12 @@ namespace ngraph
|
||||
bool m_special_zero;
|
||||
bool evaluate_reshape(const HostTensorVector& outputs,
|
||||
const HostTensorVector& inputs) const;
|
||||
|
||||
private:
|
||||
void calculate_output_shape(std::vector<Dimension>& reshape_pattern,
|
||||
const int64_t& minus_one_idx,
|
||||
const PartialShape& input_pshape,
|
||||
std::vector<Dimension>& output_shape) const;
|
||||
};
|
||||
} // namespace v1
|
||||
} // namespace op
|
||||
|
@ -35,158 +35,12 @@ namespace reshapeop
|
||||
{
|
||||
using T = typename element_type_traits<ET>::value_type;
|
||||
T* shape_pattern_ptr = shape_pattern->get_data_ptr<ET>();
|
||||
size_t output_rank = shape_pattern->get_shape()[0];
|
||||
size_t output_rank = shape_pattern->get_shape().empty() ? 0 : shape_pattern->get_shape()[0];
|
||||
for (size_t i = 0; i < output_rank; i++)
|
||||
{
|
||||
output_shape.push_back(shape_pattern_ptr[i]);
|
||||
}
|
||||
}
|
||||
|
||||
void calculate_output_shape(const ngraph::op::v1::Reshape* reshape_node,
|
||||
vector<Dimension>& reshape_pattern,
|
||||
const int64_t& minus_one_idx,
|
||||
const PartialShape& input_pshape,
|
||||
vector<Dimension>& output_shape)
|
||||
{
|
||||
if (reshape_pattern == std::vector<Dimension>{0} && !reshape_node->get_special_zero())
|
||||
{ // legacy check introduced by PR #1206
|
||||
reshape_pattern = std::vector<Dimension>{};
|
||||
output_shape = {};
|
||||
return;
|
||||
}
|
||||
Dimension output_product(1);
|
||||
for (int64_t i = 0; i < static_cast<int64_t>(reshape_pattern.size()); ++i)
|
||||
{
|
||||
if (i == minus_one_idx) // resolving everything except -1
|
||||
continue;
|
||||
|
||||
auto pattern_dim = reshape_pattern[i];
|
||||
if (pattern_dim.get_min_length() == 0 && pattern_dim.get_max_length() == 0 &&
|
||||
reshape_node->get_special_zero())
|
||||
{
|
||||
if (input_pshape.rank().is_dynamic())
|
||||
{
|
||||
output_shape[i] = Dimension::dynamic();
|
||||
output_product *= Dimension::dynamic();
|
||||
}
|
||||
else
|
||||
{
|
||||
NODE_VALIDATION_CHECK(reshape_node,
|
||||
i < input_pshape.rank().get_length(),
|
||||
"'0' dimension is out of range");
|
||||
output_shape[i] = input_pshape[i];
|
||||
// we do not include dimension to output product here and won't include in input
|
||||
// product later because we will divide output_product by input_product. This
|
||||
// dimension contributes to both products equally, but in case this dimension
|
||||
// is dynamic and others are not we could fully define output dimension that
|
||||
// is masked by -1
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
output_shape[i] = pattern_dim;
|
||||
output_product *= pattern_dim;
|
||||
}
|
||||
}
|
||||
Dimension input_product(1);
|
||||
if (input_pshape.rank().is_static())
|
||||
for (int64_t i = 0; i < input_pshape.rank().get_length(); ++i)
|
||||
{
|
||||
if (i < static_cast<int64_t>(reshape_pattern.size()) &&
|
||||
reshape_pattern[i].get_min_length() == 0 &&
|
||||
reshape_pattern[i].get_max_length() == 0)
|
||||
continue;
|
||||
input_product *= input_pshape[i];
|
||||
}
|
||||
else
|
||||
input_product = Dimension::dynamic();
|
||||
|
||||
if (minus_one_idx != -1) // resolving -1 masked dimension
|
||||
{
|
||||
if (output_product.get_min_length() == 0 && output_product.get_max_length() == 0)
|
||||
{
|
||||
// TODO: Decide if this is desired behavior here. (NumPy seems
|
||||
// to fail.)
|
||||
NODE_VALIDATION_CHECK(reshape_node,
|
||||
input_product.get_min_length() == 0 &&
|
||||
input_product.get_max_length() == 0,
|
||||
"Cannot infer '-1' dimension with zero-size output "
|
||||
"dimension unless at least one input dimension is "
|
||||
"also zero-size");
|
||||
output_shape[minus_one_idx] = Dimension(0);
|
||||
}
|
||||
else
|
||||
{
|
||||
if (input_product.is_static() && output_product.is_static())
|
||||
{
|
||||
NODE_VALIDATION_CHECK(
|
||||
reshape_node,
|
||||
input_product.get_length() % output_product.get_length() == 0,
|
||||
"Non-'-1' output dimensions do not evenly divide the input dimensions");
|
||||
}
|
||||
if (output_product.get_min_length() == 0 || output_product == Dimension() ||
|
||||
input_product == Dimension())
|
||||
{
|
||||
output_shape[minus_one_idx] = Dimension::dynamic();
|
||||
}
|
||||
else
|
||||
{
|
||||
Dimension::value_type lower;
|
||||
if (input_product.get_min_length() == 0)
|
||||
lower = 0;
|
||||
else if (input_product.get_min_length() == -1 ||
|
||||
output_product.get_max_length() == 0 ||
|
||||
output_product.get_max_length() == -1)
|
||||
lower = -1; // dynamic
|
||||
else
|
||||
lower = static_cast<Dimension::value_type>(
|
||||
ceil(static_cast<double>(input_product.get_min_length()) /
|
||||
output_product.get_max_length()));
|
||||
|
||||
Dimension::value_type upper;
|
||||
if (input_product.get_max_length() == 0)
|
||||
upper = 0;
|
||||
else if (input_product.get_max_length() == -1 ||
|
||||
output_product.get_min_length() == 0 ||
|
||||
output_product.get_min_length() == -1)
|
||||
upper = -1; // dynamic
|
||||
else
|
||||
upper = static_cast<Dimension::value_type>(
|
||||
floor(static_cast<double>(input_product.get_max_length()) /
|
||||
output_product.get_min_length()));
|
||||
|
||||
if (lower == -1)
|
||||
output_shape[minus_one_idx] = Dimension::dynamic();
|
||||
else if (upper == -1)
|
||||
output_shape[minus_one_idx] = Dimension(lower, upper);
|
||||
else if (lower > upper) // empty intersection
|
||||
output_shape[minus_one_idx] = Dimension::dynamic();
|
||||
else
|
||||
output_shape[minus_one_idx] = Dimension(lower, upper);
|
||||
}
|
||||
}
|
||||
}
|
||||
PartialShape output_pshape(output_shape);
|
||||
if (input_pshape.is_static() && output_pshape.is_static())
|
||||
{
|
||||
size_t zero_dims =
|
||||
std::count_if(reshape_pattern.begin(), reshape_pattern.end(), [](Dimension dim) {
|
||||
return dim.get_max_length() == 0 && dim.get_min_length() == 0;
|
||||
});
|
||||
|
||||
bool backward_compatible_check =
|
||||
(zero_dims && reshape_node->get_special_zero()) || minus_one_idx != -1;
|
||||
bool in_out_elements_equal = shape_size(reshape_node->get_input_shape(0)) ==
|
||||
shape_size(output_pshape.to_shape());
|
||||
|
||||
NODE_VALIDATION_CHECK(reshape_node,
|
||||
backward_compatible_check || in_out_elements_equal,
|
||||
"Requested output shape ",
|
||||
output_shape,
|
||||
" is incompatible with input shape ",
|
||||
reshape_node->get_input_shape(0));
|
||||
}
|
||||
}
|
||||
} // namespace reshapeop
|
||||
|
||||
NGRAPH_RTTI_DEFINITION(op::v1::Reshape, "Reshape", 1);
|
||||
@ -204,7 +58,6 @@ bool op::v1::Reshape::visit_attributes(AttributeVisitor& visitor)
|
||||
visitor.on_attribute("special_zero", m_special_zero);
|
||||
return true;
|
||||
}
|
||||
|
||||
void op::v1::Reshape::validate_and_infer_types()
|
||||
{
|
||||
NGRAPH_OP_SCOPE(v1_Reshape_validate_and_infer_types);
|
||||
@ -217,16 +70,21 @@ void op::v1::Reshape::validate_and_infer_types()
|
||||
const PartialShape& input_pshape = get_input_partial_shape(0);
|
||||
const PartialShape& shape_pattern_shape = get_input_partial_shape(1);
|
||||
NODE_VALIDATION_CHECK(this,
|
||||
shape_pattern_shape.rank().compatible(1),
|
||||
"Pattern shape must have rank 1, got ",
|
||||
shape_pattern_shape.rank().compatible(1) ||
|
||||
(shape_pattern_shape.rank().is_static() &&
|
||||
shape_pattern_shape.rank().get_length() == 0),
|
||||
"Pattern shape must have rank 1 or be empty, got ",
|
||||
shape_pattern_shape.rank(),
|
||||
".");
|
||||
Rank output_rank =
|
||||
shape_pattern_shape.rank().is_dynamic() ? Rank::dynamic() : shape_pattern_shape[0];
|
||||
shape_pattern_shape.rank().is_dynamic()
|
||||
? Rank::dynamic()
|
||||
: shape_pattern_shape.rank().get_length() == 0 ? 0 : shape_pattern_shape[0];
|
||||
set_output_type(0, get_input_element_type(0), PartialShape::dynamic(output_rank));
|
||||
set_input_is_relevant_to_shape(1);
|
||||
|
||||
std::vector<Dimension> reshape_pattern;
|
||||
bool shape_can_be_calculated = false;
|
||||
int64_t minus_one_idx = -1;
|
||||
|
||||
HostTensorPtr lb, ub;
|
||||
@ -235,6 +93,7 @@ void op::v1::Reshape::validate_and_infer_types()
|
||||
{
|
||||
const auto lower_bound = std::make_shared<op::Constant>(lb)->cast_vector<int64_t>();
|
||||
const auto upper_bound = std::make_shared<op::Constant>(ub)->cast_vector<int64_t>();
|
||||
shape_can_be_calculated = true;
|
||||
NGRAPH_CHECK(lower_bound.size() == upper_bound.size());
|
||||
for (size_t i = 0; i < lower_bound.size(); ++i)
|
||||
{
|
||||
@ -250,13 +109,22 @@ void op::v1::Reshape::validate_and_infer_types()
|
||||
}
|
||||
reshape_pattern.emplace_back(lower_bound[i], upper_bound[i]);
|
||||
}
|
||||
// For scalar case reshape_patter should be empty but scalar reshape pattern should be empty
|
||||
// or equal to 1
|
||||
if (output_rank.is_static() && output_rank.get_length() == 0 && !lower_bound.empty())
|
||||
{
|
||||
reshape_pattern.clear();
|
||||
NGRAPH_CHECK(lower_bound.size() == 1);
|
||||
NODE_VALIDATION_CHECK(this,
|
||||
lower_bound[0] == 1 && upper_bound[0] == 1,
|
||||
"The value of scalar shape pattern should be equal to 1!");
|
||||
}
|
||||
}
|
||||
|
||||
if (!reshape_pattern.empty())
|
||||
if (shape_can_be_calculated)
|
||||
{
|
||||
std::vector<Dimension> output_shape(output_rank.get_length());
|
||||
reshapeop::calculate_output_shape(
|
||||
this, reshape_pattern, minus_one_idx, input_pshape, output_shape);
|
||||
calculate_output_shape(reshape_pattern, minus_one_idx, input_pshape, output_shape);
|
||||
set_output_type(0, get_input_element_type(0), output_shape);
|
||||
}
|
||||
}
|
||||
@ -311,8 +179,8 @@ bool op::v1::Reshape::evaluate_reshape(const HostTensorVector& outputs,
|
||||
}
|
||||
|
||||
std::vector<Dimension> output_shape(out_shape_val.size());
|
||||
reshapeop::calculate_output_shape(
|
||||
this, reshape_pattern, minus_one_idx, inputs[0]->get_partial_shape(), output_shape);
|
||||
calculate_output_shape(
|
||||
reshape_pattern, minus_one_idx, inputs[0]->get_partial_shape(), output_shape);
|
||||
NGRAPH_CHECK(PartialShape(output_shape).is_static());
|
||||
outputs[0]->set_shape(PartialShape(output_shape).to_shape());
|
||||
|
||||
@ -390,3 +258,140 @@ bool op::v1::Reshape::constant_fold(OutputVector& output_values, const OutputVec
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
void op::v1::Reshape::calculate_output_shape(vector<Dimension>& reshape_pattern,
|
||||
const int64_t& minus_one_idx,
|
||||
const PartialShape& input_pshape,
|
||||
vector<Dimension>& output_shape) const
|
||||
{
|
||||
Dimension output_product(1);
|
||||
for (int64_t i = 0; i < static_cast<int64_t>(reshape_pattern.size()); ++i)
|
||||
{
|
||||
if (i == minus_one_idx) // resolving everything except -1
|
||||
continue;
|
||||
|
||||
auto pattern_dim = reshape_pattern[i];
|
||||
if (pattern_dim.get_min_length() == 0 && pattern_dim.get_max_length() == 0 &&
|
||||
get_special_zero())
|
||||
{
|
||||
if (input_pshape.rank().is_dynamic())
|
||||
{
|
||||
output_shape[i] = Dimension::dynamic();
|
||||
output_product *= Dimension::dynamic();
|
||||
}
|
||||
else
|
||||
{
|
||||
NODE_VALIDATION_CHECK(
|
||||
this, i < input_pshape.rank().get_length(), "'0' dimension is out of range");
|
||||
output_shape[i] = input_pshape[i];
|
||||
// we do not include dimension to output product here and won't include in input
|
||||
// product later because we will divide output_product by input_product. This
|
||||
// dimension contributes to both products equally, but in case this dimension
|
||||
// is dynamic and others are not we could fully define output dimension that
|
||||
// is masked by -1
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
output_shape[i] = pattern_dim;
|
||||
output_product *= pattern_dim;
|
||||
}
|
||||
}
|
||||
Dimension input_product(1);
|
||||
if (input_pshape.rank().is_static())
|
||||
for (int64_t i = 0; i < input_pshape.rank().get_length(); ++i)
|
||||
{
|
||||
if (i < static_cast<int64_t>(reshape_pattern.size()) &&
|
||||
reshape_pattern[i].get_min_length() == 0 &&
|
||||
reshape_pattern[i].get_max_length() == 0)
|
||||
continue;
|
||||
input_product *= input_pshape[i];
|
||||
}
|
||||
else
|
||||
input_product = Dimension::dynamic();
|
||||
|
||||
if (minus_one_idx != -1) // resolving -1 masked dimension
|
||||
{
|
||||
if (output_product.get_min_length() == 0 && output_product.get_max_length() == 0)
|
||||
{
|
||||
// TODO: Decide if this is desired behavior here. (NumPy seems
|
||||
// to fail.)
|
||||
NODE_VALIDATION_CHECK(this,
|
||||
input_product.get_min_length() == 0 &&
|
||||
input_product.get_max_length() == 0,
|
||||
"Cannot infer '-1' dimension with zero-size output "
|
||||
"dimension unless at least one input dimension is "
|
||||
"also zero-size");
|
||||
output_shape[minus_one_idx] = Dimension(0);
|
||||
}
|
||||
else
|
||||
{
|
||||
if (input_product.is_static() && output_product.is_static())
|
||||
{
|
||||
NODE_VALIDATION_CHECK(
|
||||
this,
|
||||
input_product.get_length() % output_product.get_length() == 0,
|
||||
"Non-'-1' output dimensions do not evenly divide the input dimensions");
|
||||
}
|
||||
if (output_product.get_min_length() == 0 || output_product == Dimension() ||
|
||||
input_product == Dimension())
|
||||
{
|
||||
output_shape[minus_one_idx] = Dimension::dynamic();
|
||||
}
|
||||
else
|
||||
{
|
||||
Dimension::value_type lower;
|
||||
if (input_product.get_min_length() == 0)
|
||||
lower = 0;
|
||||
else if (input_product.get_min_length() == -1 ||
|
||||
output_product.get_max_length() == 0 ||
|
||||
output_product.get_max_length() == -1)
|
||||
lower = -1; // dynamic
|
||||
else
|
||||
lower = static_cast<Dimension::value_type>(
|
||||
ceil(static_cast<double>(input_product.get_min_length()) /
|
||||
output_product.get_max_length()));
|
||||
|
||||
Dimension::value_type upper;
|
||||
if (input_product.get_max_length() == 0)
|
||||
upper = 0;
|
||||
else if (input_product.get_max_length() == -1 ||
|
||||
output_product.get_min_length() == 0 ||
|
||||
output_product.get_min_length() == -1)
|
||||
upper = -1; // dynamic
|
||||
else
|
||||
upper = static_cast<Dimension::value_type>(
|
||||
floor(static_cast<double>(input_product.get_max_length()) /
|
||||
output_product.get_min_length()));
|
||||
|
||||
if (lower == -1)
|
||||
output_shape[minus_one_idx] = Dimension::dynamic();
|
||||
else if (upper == -1)
|
||||
output_shape[minus_one_idx] = Dimension(lower, upper);
|
||||
else if (lower > upper) // empty intersection
|
||||
output_shape[minus_one_idx] = Dimension::dynamic();
|
||||
else
|
||||
output_shape[minus_one_idx] = Dimension(lower, upper);
|
||||
}
|
||||
}
|
||||
}
|
||||
PartialShape output_pshape(output_shape);
|
||||
if (input_pshape.is_static() && output_pshape.is_static())
|
||||
{
|
||||
size_t zero_dims =
|
||||
std::count_if(reshape_pattern.begin(), reshape_pattern.end(), [](Dimension dim) {
|
||||
return dim.get_max_length() == 0 && dim.get_min_length() == 0;
|
||||
});
|
||||
|
||||
bool backward_compatible_check = (zero_dims && get_special_zero()) || minus_one_idx != -1;
|
||||
bool in_out_elements_equal =
|
||||
shape_size(get_input_shape(0)) == shape_size(output_pshape.to_shape());
|
||||
|
||||
NODE_VALIDATION_CHECK(this,
|
||||
backward_compatible_check || in_out_elements_equal,
|
||||
"Requested output shape ",
|
||||
output_shape,
|
||||
" is incompatible with input shape ",
|
||||
get_input_shape(0));
|
||||
}
|
||||
}
|
||||
|
@ -331,6 +331,37 @@ NGRAPH_TEST(${BACKEND_NAME}, builder_reshape_3D_to_scalar)
|
||||
test_case.run();
|
||||
}
|
||||
|
||||
NGRAPH_TEST(${BACKEND_NAME}, builder_reshape_1d_to_same_shape)
|
||||
{
|
||||
const Shape input_shape{1};
|
||||
auto param = make_shared<op::Parameter>(element::f32, input_shape);
|
||||
auto r = make_shared<op::v1::Reshape>(
|
||||
param, op::Constant::create(element::i64, {}, std::vector<int64_t>{1}), false);
|
||||
auto function = make_shared<Function>(r, ParameterVector{param});
|
||||
|
||||
auto test_case = test::TestCase<TestEngine>(function);
|
||||
vector<float> input_values(shape_size(input_shape), 1.f);
|
||||
test_case.add_input<float>(input_shape, input_values);
|
||||
test_case.add_expected_output<float>(Shape{}, vector<float>{1.f});
|
||||
|
||||
test_case.run();
|
||||
}
|
||||
NGRAPH_TEST(${BACKEND_NAME}, builder_reshape_to_same_shape)
|
||||
{
|
||||
const Shape input_shape{};
|
||||
auto param = make_shared<op::Parameter>(element::f32, input_shape);
|
||||
auto r = make_shared<op::v1::Reshape>(
|
||||
param, op::Constant::create(element::i64, {}, std::vector<int64_t>{1}), false);
|
||||
auto function = make_shared<Function>(r, ParameterVector{param});
|
||||
|
||||
auto test_case = test::TestCase<TestEngine>(function);
|
||||
vector<float> input_values(shape_size(input_shape), 1.f);
|
||||
test_case.add_input<float>(input_shape, input_values);
|
||||
test_case.add_expected_output<float>(Shape{}, vector<float>{1.f});
|
||||
|
||||
test_case.run();
|
||||
}
|
||||
|
||||
#if NGRAPH_INTERPRETER_ENABLE
|
||||
|
||||
NGRAPH_TEST(${BACKEND_NAME}, reshape_shufflenet_5d)
|
||||
|
@ -579,3 +579,66 @@ TEST(type_prop, reshape_multiply_intervals_by_interval_zero_included)
|
||||
ASSERT_EQ(r->get_element_type(), element::f32);
|
||||
ASSERT_EQ(r->get_output_partial_shape(0), (PartialShape{Dimension(0, 24)}));
|
||||
}
|
||||
|
||||
TEST(type_prop, reshape_to_zero_shape)
|
||||
{
|
||||
auto param = make_shared<op::Parameter>(element::f32, Shape{0, 1});
|
||||
auto r = make_shared<op::v1::Reshape>(
|
||||
param, op::Constant::create(element::i64, {1}, std::vector<int64_t>{0}), false);
|
||||
ASSERT_EQ(r->get_element_type(), element::f32);
|
||||
ASSERT_EQ(r->get_output_shape(0), (Shape{0}));
|
||||
}
|
||||
|
||||
TEST(type_prop, reshape_to_zero_shape_dynamic)
|
||||
{
|
||||
auto param = make_shared<op::Parameter>(element::f32, PartialShape::dynamic());
|
||||
auto r = make_shared<op::v1::Reshape>(
|
||||
param, op::Constant::create(element::i64, {1}, std::vector<int64_t>{0}), false);
|
||||
ASSERT_EQ(r->get_element_type(), element::f32);
|
||||
ASSERT_EQ(r->get_output_shape(0), (Shape{0}));
|
||||
}
|
||||
|
||||
TEST(type_prop, reshape_to_zero_shape_incorrect)
|
||||
{
|
||||
auto param = make_shared<op::Parameter>(element::f32, Shape{2, 1});
|
||||
ASSERT_THROW(
|
||||
make_shared<op::v1::Reshape>(
|
||||
param, op::Constant::create(element::i64, {1}, std::vector<int64_t>{0}), false),
|
||||
std::exception);
|
||||
}
|
||||
|
||||
TEST(type_prop, reshape_to_zero)
|
||||
{
|
||||
auto param = make_shared<op::Parameter>(element::f32, Shape{2, 1});
|
||||
auto r = make_shared<op::v1::Reshape>(
|
||||
param, op::Constant::create(element::i64, {1}, std::vector<int64_t>{0}), true);
|
||||
ASSERT_EQ(r->get_element_type(), element::f32);
|
||||
ASSERT_EQ(r->get_output_shape(0), (Shape{2}));
|
||||
}
|
||||
|
||||
TEST(type_prop, reshape_to_scalar)
|
||||
{
|
||||
auto param = make_shared<op::Parameter>(element::f32, Shape{});
|
||||
auto r = make_shared<op::v1::Reshape>(
|
||||
param, op::Constant::create(element::i64, {}, std::vector<int64_t>{1}), false);
|
||||
ASSERT_EQ(r->get_element_type(), element::f32);
|
||||
ASSERT_EQ(r->get_output_shape(0), (Shape{}));
|
||||
}
|
||||
|
||||
TEST(type_prop, reshape_to_scalar_2)
|
||||
{
|
||||
auto param = make_shared<op::Parameter>(element::f32, Shape{});
|
||||
auto r = make_shared<op::v1::Reshape>(
|
||||
param, op::Constant::create(element::i64, {}, std::vector<int64_t>{1}), false);
|
||||
ASSERT_EQ(r->get_element_type(), element::f32);
|
||||
ASSERT_EQ(r->get_output_shape(0), (Shape{}));
|
||||
}
|
||||
|
||||
TEST(type_prop, reshape_to_scalar_3)
|
||||
{
|
||||
auto param = make_shared<op::Parameter>(element::f32, Shape{1, 2, 3});
|
||||
ASSERT_THROW(
|
||||
make_shared<op::v1::Reshape>(
|
||||
param, op::Constant::create(element::i64, {}, std::vector<int64_t>{100}), false),
|
||||
std::exception);
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user