Added set_element_type for tensor

This commit is contained in:
Ilya Churaev
2023-09-14 14:24:12 +04:00
parent 4ca3d51a40
commit 27608d2ea0
5 changed files with 46 additions and 0 deletions

View File

@@ -23,6 +23,13 @@ public:
*/
virtual void set_shape(ov::Shape shape) = 0;
/**
* @brief Set new element type for tensor
* @note Memory allocation may happen
* @param shape A new shape
*/
virtual void set_element_type(ov::element::Type element_type);
/**
* @return A tensor element type
*/

View File

@@ -153,6 +153,13 @@ public:
*/
void set_shape(const ov::Shape& shape);
/**
* @brief Set new element type for tensor, deallocate/allocate if new total size is bigger than previous one.
* @note Memory allocation may happen
* @param element_type A new element type
*/
void set_element_type(const ov::element::Type& element_type);
/**
* @return A tensor element type
*/

View File

@@ -19,6 +19,10 @@ size_t ITensor::get_size() const {
return shape_size(get_shape());
}
void ITensor::set_element_type(ov::element::Type element_type) {
OPENVINO_NOT_IMPLEMENTED;
}
size_t ITensor::get_byte_size() const {
return (get_size() * get_element_type().bitwidth() + 8 - 1) / 8;
}

View File

@@ -71,6 +71,10 @@ void Tensor::set_shape(const ov::Shape& shape) {
OV_TENSOR_STATEMENT(_impl->set_shape(shape));
}
void Tensor::set_element_type(const ov::element::Type& element_type) {
OV_TENSOR_STATEMENT(_impl->set_element_type(element_type));
}
const Shape& Tensor::get_shape() const {
OV_TENSOR_STATEMENT(return _impl->get_shape());
}

View File

@@ -62,6 +62,16 @@ public:
update_strides();
}
void set_element_type(ov::element::Type new_element_type) override {
OPENVINO_ASSERT(
shape_size(m_shape) * new_element_type.bitwidth() <= ov::shape_size(m_capacity) * m_element_type.bitwidth(),
"Could set new shape: ",
new_element_type);
m_element_type = std::move(new_element_type);
m_strides.clear();
update_strides();
}
const Strides& get_strides() const override {
OPENVINO_ASSERT(m_element_type.bitwidth() >= 8,
"Could not get strides for types with bitwidths less then 8 bit. Tensor type: ",
@@ -148,6 +158,12 @@ public:
}
m_shape = std::move(new_shape);
}
void set_element_type(ov::element::Type new_element_type) override {
OPENVINO_ASSERT(m_element_type.bitwidth() == new_element_type.bitwidth(),
"Element type is incompatible for strided tensor.");
ViewTensor::set_element_type(std::move(new_element_type));
}
};
/**
@@ -255,6 +271,10 @@ public:
OPENVINO_THROW("Shapes cannot be changed for ROI Tensor");
}
void set_element_type(ov::element::Type new_element_type) override {
OPENVINO_THROW("Element type cannot be changed for ROI Tensor");
}
void* data(const element::Type& element_type) const override {
auto owner_data = m_owner->data(element_type);
return static_cast<uint8_t*>(owner_data) + m_offset;
@@ -326,6 +346,10 @@ public:
update_strides();
}
void set_element_type(ov::element::Type new_element_type) override {
OPENVINO_THROW("Element type cannot be changed for Blob Tensor");
}
const Shape& get_shape() const override {
m_shape = blob->getTensorDesc().getBlockingDesc().getBlockDims();
return m_shape;