Fixed access to the data of FP16 IRs with nGraph Python API (#1707)
This commit is contained in:
@@ -46,11 +46,24 @@ py::buffer_info _get_buffer_info(const ngraph::op::Constant& c)
|
||||
return py::buffer_info(
|
||||
const_cast<void*>(c.get_data_ptr()), /* Pointer to buffer */
|
||||
static_cast<ssize_t>(c.get_element_type().size()), /* Size of one scalar */
|
||||
py::format_descriptor<T>::format(), /* Python struct-style format
|
||||
descriptor */
|
||||
static_cast<ssize_t>(shape.size()), /* Number of dimensions */
|
||||
std::vector<ssize_t>{shape.begin(), shape.end()}, /* Buffer dimensions */
|
||||
_get_byte_strides<T>(shape) /* Strides (in bytes) for each index */
|
||||
py::format_descriptor<T>::format(), /* Python struct-style format descriptor */
|
||||
static_cast<ssize_t>(shape.size()), /* Number of dimensions */
|
||||
std::vector<ssize_t>{shape.begin(), shape.end()}, /* Buffer dimensions */
|
||||
_get_byte_strides<T>(shape) /* Strides (in bytes) for each index */
|
||||
);
|
||||
}
|
||||
|
||||
template <>
|
||||
py::buffer_info _get_buffer_info<ngraph::float16>(const ngraph::op::Constant& c)
|
||||
{
|
||||
ngraph::Shape shape = c.get_shape();
|
||||
return py::buffer_info(
|
||||
const_cast<void*>(c.get_data_ptr()), /* Pointer to buffer */
|
||||
static_cast<ssize_t>(c.get_element_type().size()), /* Size of one scalar */
|
||||
std::string(1, 'H'), /* Python struct-style format descriptor */
|
||||
static_cast<ssize_t>(shape.size()), /* Number of dimensions */
|
||||
std::vector<ssize_t>{shape.begin(), shape.end()}, /* Buffer dimensions */
|
||||
_get_byte_strides<ngraph::float16>(shape) /* Strides (in bytes) for each index */
|
||||
);
|
||||
}
|
||||
|
||||
@@ -61,6 +74,9 @@ void regclass_pyngraph_op_Constant(py::module m)
|
||||
constant.doc() = "ngraph.impl.op.Constant wraps ngraph::op::Constant";
|
||||
constant.def(
|
||||
py::init<const ngraph::element::Type&, const ngraph::Shape&, const std::vector<char>&>());
|
||||
constant.def(py::init<const ngraph::element::Type&,
|
||||
const ngraph::Shape&,
|
||||
const std::vector<ngraph::float16>&>());
|
||||
constant.def(
|
||||
py::init<const ngraph::element::Type&, const ngraph::Shape&, const std::vector<float>&>());
|
||||
constant.def(
|
||||
@@ -97,6 +113,10 @@ void regclass_pyngraph_op_Constant(py::module m)
|
||||
{
|
||||
return _get_buffer_info<char>(self);
|
||||
}
|
||||
else if (element_type == ngraph::element::f16)
|
||||
{
|
||||
return _get_buffer_info<ngraph::float16>(self);
|
||||
}
|
||||
else if (element_type == ngraph::element::f32)
|
||||
{
|
||||
return _get_buffer_info<float>(self);
|
||||
@@ -139,7 +159,7 @@ void regclass_pyngraph_op_Constant(py::module m)
|
||||
}
|
||||
else
|
||||
{
|
||||
throw std::runtime_error("Unsupproted data type!");
|
||||
throw std::runtime_error("Unsupported data type!");
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user