Fixed access to the data of FP16 IRs with nGraph Python API (#1707)

This commit is contained in:
Jan Iwaszkiewicz
2020-08-11 06:16:11 +02:00
committed by GitHub
parent 51b564e9d8
commit 2b474c8a47

View File

@@ -46,11 +46,24 @@ py::buffer_info _get_buffer_info(const ngraph::op::Constant& c)
return py::buffer_info(
const_cast<void*>(c.get_data_ptr()), /* Pointer to buffer */
static_cast<ssize_t>(c.get_element_type().size()), /* Size of one scalar */
py::format_descriptor<T>::format(), /* Python struct-style format
descriptor */
static_cast<ssize_t>(shape.size()), /* Number of dimensions */
std::vector<ssize_t>{shape.begin(), shape.end()}, /* Buffer dimensions */
_get_byte_strides<T>(shape) /* Strides (in bytes) for each index */
py::format_descriptor<T>::format(), /* Python struct-style format descriptor */
static_cast<ssize_t>(shape.size()), /* Number of dimensions */
std::vector<ssize_t>{shape.begin(), shape.end()}, /* Buffer dimensions */
_get_byte_strides<T>(shape) /* Strides (in bytes) for each index */
);
}
template <>
py::buffer_info _get_buffer_info<ngraph::float16>(const ngraph::op::Constant& c)
{
ngraph::Shape shape = c.get_shape();
return py::buffer_info(
const_cast<void*>(c.get_data_ptr()), /* Pointer to buffer */
static_cast<ssize_t>(c.get_element_type().size()), /* Size of one scalar */
std::string(1, 'H'), /* Python struct-style format descriptor */
static_cast<ssize_t>(shape.size()), /* Number of dimensions */
std::vector<ssize_t>{shape.begin(), shape.end()}, /* Buffer dimensions */
_get_byte_strides<ngraph::float16>(shape) /* Strides (in bytes) for each index */
);
}
@@ -61,6 +74,9 @@ void regclass_pyngraph_op_Constant(py::module m)
constant.doc() = "ngraph.impl.op.Constant wraps ngraph::op::Constant";
constant.def(
py::init<const ngraph::element::Type&, const ngraph::Shape&, const std::vector<char>&>());
constant.def(py::init<const ngraph::element::Type&,
const ngraph::Shape&,
const std::vector<ngraph::float16>&>());
constant.def(
py::init<const ngraph::element::Type&, const ngraph::Shape&, const std::vector<float>&>());
constant.def(
@@ -97,6 +113,10 @@ void regclass_pyngraph_op_Constant(py::module m)
{
return _get_buffer_info<char>(self);
}
else if (element_type == ngraph::element::f16)
{
return _get_buffer_info<ngraph::float16>(self);
}
else if (element_type == ngraph::element::f32)
{
return _get_buffer_info<float>(self);
@@ -139,7 +159,7 @@ void regclass_pyngraph_op_Constant(py::module m)
}
else
{
throw std::runtime_error("Unsupproted data type!");
throw std::runtime_error("Unsupported data type!");
}
});
}