Calculate TensorView strides on demand not in ctor (#17729)
* Optimize strides calculation using one loop * Calculate strides on get_strides or set_shape instead in ctor in TensorView * Call once update strides on get
This commit is contained in:
parent
334114844d
commit
57e23ffc0a
@ -25,10 +25,11 @@ public:
|
||||
: m_element_type{element_type},
|
||||
m_shape{shape},
|
||||
m_capacity{shape},
|
||||
m_strides{},
|
||||
m_strides_once{},
|
||||
m_ptr{ptr} {
|
||||
OPENVINO_ASSERT(m_ptr != nullptr);
|
||||
OPENVINO_ASSERT(m_element_type != element::undefined && m_element_type.is_static());
|
||||
update_strides();
|
||||
}
|
||||
|
||||
void* data(const element::Type& element_type) const override {
|
||||
@ -53,6 +54,7 @@ public:
|
||||
void set_shape(ov::Shape new_shape) override {
|
||||
OPENVINO_ASSERT(shape_size(new_shape) <= ov::shape_size(m_capacity), "Could set new shape: ", new_shape);
|
||||
m_shape = std::move(new_shape);
|
||||
m_strides.clear();
|
||||
update_strides();
|
||||
}
|
||||
|
||||
@ -60,27 +62,32 @@ public:
|
||||
OPENVINO_ASSERT(m_element_type.bitwidth() >= 8,
|
||||
"Could not get strides for types with bitwidths less then 8 bit. Tensor type: ",
|
||||
m_element_type);
|
||||
std::call_once(m_strides_once, &ViewTensor::update_strides, this);
|
||||
return m_strides;
|
||||
}
|
||||
|
||||
protected:
|
||||
void update_strides() {
|
||||
void update_strides() const {
|
||||
if (m_element_type.bitwidth() < 8)
|
||||
return;
|
||||
|
||||
auto& shape = get_shape();
|
||||
m_strides.clear();
|
||||
if (!shape.empty()) {
|
||||
if (m_strides.empty() && !shape.empty()) {
|
||||
m_strides.resize(shape.size());
|
||||
m_strides.back() = m_element_type.size();
|
||||
std::copy(shape.rbegin(), shape.rend() - 1, m_strides.rbegin() + 1);
|
||||
std::partial_sum(m_strides.rbegin(), m_strides.rend(), m_strides.rbegin(), std::multiplies<size_t>());
|
||||
std::transform(shape.crbegin(),
|
||||
shape.crend() - 1,
|
||||
m_strides.rbegin(),
|
||||
m_strides.rbegin() + 1,
|
||||
std::multiplies<size_t>());
|
||||
}
|
||||
}
|
||||
|
||||
element::Type m_element_type;
|
||||
Shape m_shape;
|
||||
Shape m_capacity;
|
||||
Strides m_strides;
|
||||
mutable Strides m_strides;
|
||||
mutable std::once_flag m_strides_once;
|
||||
void* m_ptr;
|
||||
};
|
||||
|
||||
@ -96,7 +103,7 @@ public:
|
||||
"Could not create strided access tensor for types with bitwidths less then 8 bit. Tensor type: ",
|
||||
get_element_type());
|
||||
// Save default strides
|
||||
auto shape_strides = m_strides;
|
||||
auto shape_strides = get_strides();
|
||||
// Change strides
|
||||
m_strides = strides;
|
||||
OPENVINO_ASSERT(m_shape.size() == m_strides.size());
|
||||
@ -183,6 +190,7 @@ public:
|
||||
m_allocator.deallocate(m_ptr, old_byte_size);
|
||||
m_ptr = m_allocator.allocate(get_byte_size());
|
||||
}
|
||||
m_strides.clear();
|
||||
update_strides();
|
||||
}
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user