Calculate TensorView strides on demand not in ctor (#17729)

* Optimize strides calculation using one loop

* Calculate strides on get_strides or set_shape
instead in ctor in TensorView

* Call once update strides on get
This commit is contained in:
Pawel Raasz 2023-05-29 06:20:27 +02:00 committed by GitHub
parent 334114844d
commit 57e23ffc0a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -25,10 +25,11 @@ public:
: m_element_type{element_type},
m_shape{shape},
m_capacity{shape},
m_strides{},
m_strides_once{},
m_ptr{ptr} {
OPENVINO_ASSERT(m_ptr != nullptr);
OPENVINO_ASSERT(m_element_type != element::undefined && m_element_type.is_static());
update_strides();
}
void* data(const element::Type& element_type) const override {
@ -53,6 +54,7 @@ public:
void set_shape(ov::Shape new_shape) override {
OPENVINO_ASSERT(shape_size(new_shape) <= ov::shape_size(m_capacity), "Could set new shape: ", new_shape);
m_shape = std::move(new_shape);
m_strides.clear();
update_strides();
}
@ -60,27 +62,32 @@ public:
OPENVINO_ASSERT(m_element_type.bitwidth() >= 8,
"Could not get strides for types with bitwidths less then 8 bit. Tensor type: ",
m_element_type);
std::call_once(m_strides_once, &ViewTensor::update_strides, this);
return m_strides;
}
protected:
void update_strides() {
void update_strides() const {
if (m_element_type.bitwidth() < 8)
return;
auto& shape = get_shape();
m_strides.clear();
if (!shape.empty()) {
if (m_strides.empty() && !shape.empty()) {
m_strides.resize(shape.size());
m_strides.back() = m_element_type.size();
std::copy(shape.rbegin(), shape.rend() - 1, m_strides.rbegin() + 1);
std::partial_sum(m_strides.rbegin(), m_strides.rend(), m_strides.rbegin(), std::multiplies<size_t>());
std::transform(shape.crbegin(),
shape.crend() - 1,
m_strides.rbegin(),
m_strides.rbegin() + 1,
std::multiplies<size_t>());
}
}
element::Type m_element_type;
Shape m_shape;
Shape m_capacity;
Strides m_strides;
mutable Strides m_strides;
mutable std::once_flag m_strides_once;
void* m_ptr;
};
@ -96,7 +103,7 @@ public:
"Could not create strided access tensor for types with bitwidths less then 8 bit. Tensor type: ",
get_element_type());
// Save default strides
auto shape_strides = m_strides;
auto shape_strides = get_strides();
// Change strides
m_strides = strides;
OPENVINO_ASSERT(m_shape.size() == m_strides.size());
@ -183,6 +190,7 @@ public:
m_allocator.deallocate(m_ptr, old_byte_size);
m_ptr = m_allocator.allocate(get_byte_size());
}
m_strides.clear();
update_strides();
}