Added set_element_type for tensor
This commit is contained in:
@@ -23,6 +23,13 @@ public:
|
||||
*/
|
||||
virtual void set_shape(ov::Shape shape) = 0;
|
||||
|
||||
/**
|
||||
* @brief Set new element type for tensor
|
||||
* @note Memory allocation may happen
|
||||
* @param shape A new shape
|
||||
*/
|
||||
virtual void set_element_type(ov::element::Type element_type);
|
||||
|
||||
/**
|
||||
* @return A tensor element type
|
||||
*/
|
||||
|
||||
@@ -153,6 +153,13 @@ public:
|
||||
*/
|
||||
void set_shape(const ov::Shape& shape);
|
||||
|
||||
/**
|
||||
* @brief Set new element type for tensor, deallocate/allocate if new total size is bigger than previous one.
|
||||
* @note Memory allocation may happen
|
||||
* @param element_type A new element type
|
||||
*/
|
||||
void set_element_type(const ov::element::Type& element_type);
|
||||
|
||||
/**
|
||||
* @return A tensor element type
|
||||
*/
|
||||
|
||||
@@ -19,6 +19,10 @@ size_t ITensor::get_size() const {
|
||||
return shape_size(get_shape());
|
||||
}
|
||||
|
||||
void ITensor::set_element_type(ov::element::Type element_type) {
|
||||
OPENVINO_NOT_IMPLEMENTED;
|
||||
}
|
||||
|
||||
size_t ITensor::get_byte_size() const {
|
||||
return (get_size() * get_element_type().bitwidth() + 8 - 1) / 8;
|
||||
}
|
||||
|
||||
@@ -71,6 +71,10 @@ void Tensor::set_shape(const ov::Shape& shape) {
|
||||
OV_TENSOR_STATEMENT(_impl->set_shape(shape));
|
||||
}
|
||||
|
||||
void Tensor::set_element_type(const ov::element::Type& element_type) {
|
||||
OV_TENSOR_STATEMENT(_impl->set_element_type(element_type));
|
||||
}
|
||||
|
||||
const Shape& Tensor::get_shape() const {
|
||||
OV_TENSOR_STATEMENT(return _impl->get_shape());
|
||||
}
|
||||
|
||||
@@ -62,6 +62,16 @@ public:
|
||||
update_strides();
|
||||
}
|
||||
|
||||
void set_element_type(ov::element::Type new_element_type) override {
|
||||
OPENVINO_ASSERT(
|
||||
shape_size(m_shape) * new_element_type.bitwidth() <= ov::shape_size(m_capacity) * m_element_type.bitwidth(),
|
||||
"Could set new shape: ",
|
||||
new_element_type);
|
||||
m_element_type = std::move(new_element_type);
|
||||
m_strides.clear();
|
||||
update_strides();
|
||||
}
|
||||
|
||||
const Strides& get_strides() const override {
|
||||
OPENVINO_ASSERT(m_element_type.bitwidth() >= 8,
|
||||
"Could not get strides for types with bitwidths less then 8 bit. Tensor type: ",
|
||||
@@ -148,6 +158,12 @@ public:
|
||||
}
|
||||
m_shape = std::move(new_shape);
|
||||
}
|
||||
|
||||
void set_element_type(ov::element::Type new_element_type) override {
|
||||
OPENVINO_ASSERT(m_element_type.bitwidth() == new_element_type.bitwidth(),
|
||||
"Element type is incompatible for strided tensor.");
|
||||
ViewTensor::set_element_type(std::move(new_element_type));
|
||||
}
|
||||
};
|
||||
|
||||
/**
|
||||
@@ -255,6 +271,10 @@ public:
|
||||
OPENVINO_THROW("Shapes cannot be changed for ROI Tensor");
|
||||
}
|
||||
|
||||
void set_element_type(ov::element::Type new_element_type) override {
|
||||
OPENVINO_THROW("Element type cannot be changed for ROI Tensor");
|
||||
}
|
||||
|
||||
void* data(const element::Type& element_type) const override {
|
||||
auto owner_data = m_owner->data(element_type);
|
||||
return static_cast<uint8_t*>(owner_data) + m_offset;
|
||||
@@ -326,6 +346,10 @@ public:
|
||||
update_strides();
|
||||
}
|
||||
|
||||
void set_element_type(ov::element::Type new_element_type) override {
|
||||
OPENVINO_THROW("Element type cannot be changed for Blob Tensor");
|
||||
}
|
||||
|
||||
const Shape& get_shape() const override {
|
||||
m_shape = blob->getTensorDesc().getBlockingDesc().getBlockDims();
|
||||
return m_shape;
|
||||
|
||||
Reference in New Issue
Block a user