Fix stride issue for ZeroDims (#20686)

* Fix stride issue for ZeroDims

* Add test case

* Fix ITensor::is_continuous() issue

* Fix the same issue in gpu plugin and template plugin
This commit is contained in:
River Li 2023-10-27 13:27:53 +08:00 committed by GitHub
parent 14d51de93c
commit be25d9038e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 15 additions and 4 deletions

View File

@ -25,9 +25,10 @@ size_t ITensor::get_byte_size() const {
}
bool ITensor::is_continuous() const {
if (get_element_type().bitwidth() < 8)
if ((get_element_type().bitwidth() < 8) || get_size() == 0) {
// OpenVINO doesn't support strides for lp types
return true;
}
const auto& shape = get_shape();
const auto& type = get_element_type();
std::vector<size_t> strides(shape.size());

View File

@ -52,3 +52,13 @@ TEST(tensor, wrap_tensor_with_unspecified_type_from_host_tensor) {
// !tensor means that the tensor is not initialized
EXPECT_EQ(!tensor, true);
}
TEST(tensor, create_tensor_with_zero_dims_check_stride) {
ov::Shape shape = {0, 0, 0, 0};
auto tensor = ov::Tensor(element::f32, shape);
EXPECT_EQ(!!tensor, true);
auto stride = tensor.get_strides();
EXPECT_EQ(stride.size(), shape.size());
EXPECT_EQ(stride.back(), 0);
EXPECT_EQ(tensor.is_continuous(), true);
}

View File

@ -77,7 +77,7 @@ protected:
auto& shape = get_shape();
if (m_strides.empty() && !shape.empty()) {
m_strides.resize(shape.size());
m_strides.back() = m_element_type.size();
m_strides.back() = shape.back() == 0 ? 0 : m_element_type.size();
std::transform(shape.crbegin(),
shape.crend() - 1,
m_strides.rbegin(),

View File

@ -63,7 +63,7 @@ void RemoteTensorImpl::update_strides() {
m_strides.clear();
if (!shape.empty()) {
m_strides.resize(shape.size());
m_strides.back() = m_element_type.size();
m_strides.back() = shape.back() == 0 ? 0 : m_element_type.size();
std::copy(shape.rbegin(), shape.rend() - 1, m_strides.rbegin() + 1);
std::partial_sum(m_strides.rbegin(), m_strides.rend(), m_strides.rbegin(), std::multiplies<size_t>());
}

View File

@ -26,7 +26,7 @@ class VectorTensorImpl : public ov::IRemoteTensor {
m_strides.clear();
if (!shape.empty()) {
m_strides.resize(shape.size());
m_strides.back() = m_element_type.size();
m_strides.back() = shape.back() == 0 ? 0 : m_element_type.size();
std::copy(shape.rbegin(), shape.rend() - 1, m_strides.rbegin() + 1);
std::partial_sum(m_strides.rbegin(), m_strides.rend(), m_strides.rbegin(), std::multiplies<size_t>());
}