Throw an exception if somebody tries to reallocate tensor

This commit is contained in:
Ilya Churaev 2023-10-06 07:45:39 +04:00
parent 7df8041367
commit 8e06d6d576
3 changed files with 14 additions and 14 deletions

View File

@ -289,7 +289,7 @@ def test_evaluate_invalid_input_shape():
[Tensor("float32", Shape([2, 1]))],
[Tensor("float32", Shape([3, 1])), Tensor("float32", Shape([3, 1]))],
)
assert "must be compatible with the partial shape: [2,1]" in str(e.value)
assert "Could set new shape: [1,3]" in str(e.value)
def test_get_batch():

View File

@ -15,6 +15,7 @@
#include "gmock/gmock.h"
#include "gtest/gtest.h"
#include "ngraph/validation_util.hpp"
#include "openvino/core/except.hpp"
#include "openvino/core/model.hpp"
#include "openvino/core/shape.hpp"
#include "openvino/core/type/element_type.hpp"
@ -2664,3 +2665,15 @@ TEST(eval, evaluate_cum_sum_v0_exclusive_reversed) {
EXPECT_EQ(outputs[0].get_shape(), data->get_shape());
EXPECT_EQ(memcmp(outputs[0].data(), out_expected, sizeof(out_expected)), 0);
}
TEST(eval, invalid_shape) {
auto p1 = make_shared<ov::op::v0::Parameter>(element::f32, PartialShape{1, 2});
auto p2 = make_shared<ov::op::v0::Parameter>(element::f32, PartialShape{1, 2});
auto add = make_shared<op::v1::Add>(p1, p2);
auto model = make_shared<Model>(OutputVector{add}, ParameterVector{p1, p2});
auto result_tensor = ov::Tensor(element::f32, {1, 2});
auto out_vector = ov::TensorVector{result_tensor};
auto in_vector = ov::TensorVector{make_tensor<element::Type_t::f32>({1, 3}, {1.0f, 1.0f, 1.0f}),
make_tensor<element::Type_t::f32>({1, 3}, {7.0f, 6.0f, 1.0f})};
ASSERT_THROW(model->evaluate(out_vector, in_vector), ov::Exception);
}

View File

@ -186,19 +186,6 @@ public:
m_allocator.deallocate(m_ptr, get_byte_size());
}
void set_shape(ov::Shape new_shape) override {
if (m_shape == new_shape)
return;
auto old_byte_size = get_byte_size();
m_shape = std::move(new_shape);
if (get_byte_size() > old_byte_size) {
m_allocator.deallocate(m_ptr, old_byte_size);
m_ptr = m_allocator.allocate(get_byte_size());
}
m_strides.clear();
update_strides();
}
private:
Allocator m_allocator;
};