Fix reshape evaluate to use tensor input shape (#14307)
This commit is contained in:
parent
541045b0ba
commit
78f95ddea4
@ -8,6 +8,7 @@
|
||||
#include <dimension_tracker.hpp>
|
||||
#include <ngraph/validation_util.hpp>
|
||||
|
||||
#include "compare.hpp"
|
||||
#include "itt.hpp"
|
||||
#include "ngraph/op/constant.hpp"
|
||||
#include "ngraph/runtime/opt_kernel/reshape.hpp"
|
||||
@ -417,18 +418,16 @@ void op::v1::Reshape::calculate_output_shape(vector<Dimension>& reshape_pattern,
|
||||
|
||||
ov::PartialShape output_pshape(output_shape);
|
||||
if (input_pshape.is_static() && output_pshape.is_static()) {
|
||||
size_t zero_dims = std::count_if(reshape_pattern.begin(), reshape_pattern.end(), [](Dimension dim) {
|
||||
return dim.get_max_length() == 0 && dim.get_min_length() == 0;
|
||||
});
|
||||
size_t zero_dims = std::count_if(reshape_pattern.begin(), reshape_pattern.end(), cmp::Equal<Dimension>(0));
|
||||
|
||||
bool backward_compatible_check = (zero_dims && get_special_zero()) || minus_one_idx != -1;
|
||||
bool in_out_elements_equal = shape_size(get_input_shape(0)) == shape_size(output_pshape.to_shape());
|
||||
bool in_out_elements_equal = shape_size(input_pshape.get_shape()) == shape_size(output_pshape.to_shape());
|
||||
|
||||
NODE_VALIDATION_CHECK(this,
|
||||
backward_compatible_check || in_out_elements_equal,
|
||||
"Requested output shape ",
|
||||
output_shape,
|
||||
" is incompatible with input shape ",
|
||||
get_input_shape(0));
|
||||
input_pshape);
|
||||
}
|
||||
}
|
||||
|
@ -7,6 +7,7 @@
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "common_test_utils/test_assertions.hpp"
|
||||
#include "engines_util/execute_tools.hpp"
|
||||
#include "engines_util/test_case.hpp"
|
||||
#include "gmock/gmock.h"
|
||||
@ -70,6 +71,7 @@ NGRAPH_SUPPRESS_DEPRECATED_START
|
||||
|
||||
using namespace std;
|
||||
using namespace ngraph;
|
||||
using namespace testing;
|
||||
|
||||
#define ASSERT_FLOAT_VECTORS_EQ(expected, result) \
|
||||
ASSERT_EQ(expected.size(), result.size()) << "Array sizes differ."; \
|
||||
@ -633,6 +635,39 @@ TEST(eval, evaluate_reshape_v1_pattern_int16) {
|
||||
ASSERT_EQ(computed_val, expected_val);
|
||||
}
|
||||
|
||||
TEST(eval, evaluate_reshape_v1_data_dynamic_shape) {
|
||||
constexpr auto exp_dtype = element::i32;
|
||||
|
||||
auto data = make_shared<op::Parameter>(exp_dtype, PartialShape::dynamic());
|
||||
auto pattern = make_shared<op::Parameter>(element::i64, Shape{6});
|
||||
auto dyn_reshape = make_shared<op::v1::Reshape>(data, pattern, true);
|
||||
auto f = make_shared<Function>(OutputVector{dyn_reshape}, ParameterVector{data, pattern});
|
||||
auto result_tensor = make_shared<HostTensor>();
|
||||
|
||||
ASSERT_TRUE(f->evaluate({result_tensor},
|
||||
{make_host_tensor<element::Type_t::i32>(Shape{2, 2, 2}, {0, 1, 2, 3, 4, 5, 6, 7}),
|
||||
make_host_tensor<element::Type_t::i64>(pattern->get_shape(), {2, 0, 1, -1, 1, 1})}));
|
||||
|
||||
EXPECT_EQ(result_tensor->get_element_type(), exp_dtype);
|
||||
EXPECT_EQ(result_tensor->get_partial_shape(), PartialShape({2, 2, 1, 2, 1, 1}));
|
||||
EXPECT_THAT(read_vector<int32_t>(result_tensor), ElementsAre(0, 1, 2, 3, 4, 5, 6, 7));
|
||||
}
|
||||
|
||||
TEST(eval, evaluate_reshape_v1_not_backward_compatible_and_in_out_size_not_eq) {
|
||||
constexpr auto exp_dtype = element::i32;
|
||||
auto data = make_shared<op::Parameter>(exp_dtype, PartialShape::dynamic());
|
||||
auto pattern = make_shared<op::Parameter>(element::i16, Shape{5});
|
||||
auto dyn_reshape = make_shared<op::v1::Reshape>(data, pattern, true);
|
||||
auto f = make_shared<Function>(OutputVector{dyn_reshape}, ParameterVector{data, pattern});
|
||||
auto result_tensor = make_shared<HostTensor>();
|
||||
|
||||
OV_EXPECT_THROW(f->evaluate({result_tensor},
|
||||
{make_host_tensor<element::Type_t::i32>(Shape{2, 2, 2}, {0, 1, 2, 3, 4, 5, 6, 7}),
|
||||
make_host_tensor<element::Type_t::i16>(pattern->get_shape(), {2, 1, 1, 1, 1})}),
|
||||
NodeValidationFailure,
|
||||
HasSubstr("Requested output shape [2,1,1,1,1] is incompatible with input shape [2,2,2]"));
|
||||
}
|
||||
|
||||
TEST(eval, evaluate_convert) {
|
||||
auto p = make_shared<op::Parameter>(element::f32, PartialShape{-1, -1});
|
||||
auto convert = make_shared<op::v0::Convert>(p, element::i64);
|
||||
|
Loading…
Reference in New Issue
Block a user