Calculate TensorView strides on demand not in ctor (#17729)
* Optimize strides calculation using one loop * Calculate strides on get_strides or set_shape instead in ctor in TensorView * Call once update strides on get
This commit is contained in:
parent
334114844d
commit
57e23ffc0a
@ -25,10 +25,11 @@ public:
|
|||||||
: m_element_type{element_type},
|
: m_element_type{element_type},
|
||||||
m_shape{shape},
|
m_shape{shape},
|
||||||
m_capacity{shape},
|
m_capacity{shape},
|
||||||
|
m_strides{},
|
||||||
|
m_strides_once{},
|
||||||
m_ptr{ptr} {
|
m_ptr{ptr} {
|
||||||
OPENVINO_ASSERT(m_ptr != nullptr);
|
OPENVINO_ASSERT(m_ptr != nullptr);
|
||||||
OPENVINO_ASSERT(m_element_type != element::undefined && m_element_type.is_static());
|
OPENVINO_ASSERT(m_element_type != element::undefined && m_element_type.is_static());
|
||||||
update_strides();
|
|
||||||
}
|
}
|
||||||
|
|
||||||
void* data(const element::Type& element_type) const override {
|
void* data(const element::Type& element_type) const override {
|
||||||
@ -53,6 +54,7 @@ public:
|
|||||||
void set_shape(ov::Shape new_shape) override {
|
void set_shape(ov::Shape new_shape) override {
|
||||||
OPENVINO_ASSERT(shape_size(new_shape) <= ov::shape_size(m_capacity), "Could set new shape: ", new_shape);
|
OPENVINO_ASSERT(shape_size(new_shape) <= ov::shape_size(m_capacity), "Could set new shape: ", new_shape);
|
||||||
m_shape = std::move(new_shape);
|
m_shape = std::move(new_shape);
|
||||||
|
m_strides.clear();
|
||||||
update_strides();
|
update_strides();
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -60,27 +62,32 @@ public:
|
|||||||
OPENVINO_ASSERT(m_element_type.bitwidth() >= 8,
|
OPENVINO_ASSERT(m_element_type.bitwidth() >= 8,
|
||||||
"Could not get strides for types with bitwidths less then 8 bit. Tensor type: ",
|
"Could not get strides for types with bitwidths less then 8 bit. Tensor type: ",
|
||||||
m_element_type);
|
m_element_type);
|
||||||
|
std::call_once(m_strides_once, &ViewTensor::update_strides, this);
|
||||||
return m_strides;
|
return m_strides;
|
||||||
}
|
}
|
||||||
|
|
||||||
protected:
|
protected:
|
||||||
void update_strides() {
|
void update_strides() const {
|
||||||
if (m_element_type.bitwidth() < 8)
|
if (m_element_type.bitwidth() < 8)
|
||||||
return;
|
return;
|
||||||
|
|
||||||
auto& shape = get_shape();
|
auto& shape = get_shape();
|
||||||
m_strides.clear();
|
if (m_strides.empty() && !shape.empty()) {
|
||||||
if (!shape.empty()) {
|
|
||||||
m_strides.resize(shape.size());
|
m_strides.resize(shape.size());
|
||||||
m_strides.back() = m_element_type.size();
|
m_strides.back() = m_element_type.size();
|
||||||
std::copy(shape.rbegin(), shape.rend() - 1, m_strides.rbegin() + 1);
|
std::transform(shape.crbegin(),
|
||||||
std::partial_sum(m_strides.rbegin(), m_strides.rend(), m_strides.rbegin(), std::multiplies<size_t>());
|
shape.crend() - 1,
|
||||||
|
m_strides.rbegin(),
|
||||||
|
m_strides.rbegin() + 1,
|
||||||
|
std::multiplies<size_t>());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
element::Type m_element_type;
|
element::Type m_element_type;
|
||||||
Shape m_shape;
|
Shape m_shape;
|
||||||
Shape m_capacity;
|
Shape m_capacity;
|
||||||
Strides m_strides;
|
mutable Strides m_strides;
|
||||||
|
mutable std::once_flag m_strides_once;
|
||||||
void* m_ptr;
|
void* m_ptr;
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -96,7 +103,7 @@ public:
|
|||||||
"Could not create strided access tensor for types with bitwidths less then 8 bit. Tensor type: ",
|
"Could not create strided access tensor for types with bitwidths less then 8 bit. Tensor type: ",
|
||||||
get_element_type());
|
get_element_type());
|
||||||
// Save default strides
|
// Save default strides
|
||||||
auto shape_strides = m_strides;
|
auto shape_strides = get_strides();
|
||||||
// Change strides
|
// Change strides
|
||||||
m_strides = strides;
|
m_strides = strides;
|
||||||
OPENVINO_ASSERT(m_shape.size() == m_strides.size());
|
OPENVINO_ASSERT(m_shape.size() == m_strides.size());
|
||||||
@ -183,6 +190,7 @@ public:
|
|||||||
m_allocator.deallocate(m_ptr, old_byte_size);
|
m_allocator.deallocate(m_ptr, old_byte_size);
|
||||||
m_ptr = m_allocator.allocate(get_byte_size());
|
m_ptr = m_allocator.allocate(get_byte_size());
|
||||||
}
|
}
|
||||||
|
m_strides.clear();
|
||||||
update_strides();
|
update_strides();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user