Remove code duplicates for shape inference utils (#14721)

* Remove code duplicates for shape inference utils

* Fix typos and comments
This commit is contained in:
Pawel Raasz 2023-01-04 10:56:35 +01:00 committed by GitHub
parent a0be13be57
commit 000a634429
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -98,9 +98,9 @@ template <class TShape,
class TData,
class TRes = std::vector<TData>,
typename std::enable_if<std::is_same<TShape, ov::PartialShape>::value>::type* = nullptr>
std::unique_ptr<std::vector<TData>> get_input_const_data_as(const ov::Node* op,
size_t idx,
const std::map<size_t, HostTensorPtr>& constant_data = {}) {
std::unique_ptr<TRes> get_input_const_data_as(const ov::Node* op,
size_t idx,
const std::map<size_t, HostTensorPtr>& constant_data = {}) {
if (constant_data.count(idx)) {
return std::unique_ptr<TRes>(new TRes(ov::opset1::Constant(constant_data.at(idx)).cast_vector<TData>()));
} else if (const auto& constant = ov::get_constant_from_source(op->input_value(idx))) {
@ -113,86 +113,76 @@ std::unique_ptr<std::vector<TData>> get_input_const_data_as(const ov::Node* op,
} // namespace op
} // namespace ov
template <class T>
inline bool get_data_as_int64(
size_t idx,
const ov::Node* op,
std::vector<int64_t>& axes_value,
const std::map<size_t, std::shared_ptr<ngraph::runtime::HostTensor>>& constant_data = {}) {
if (constant_data.count(idx)) {
axes_value = ov::opset1::Constant(constant_data.at(idx)).cast_vector<int64_t>();
} else {
const auto& constant = ov::as_type_ptr<ov::opset1::Constant>(op->get_input_node_shared_ptr(idx));
NODE_VALIDATION_CHECK(op, constant != nullptr, "Static shape inference lacks constant data on port ", idx);
axes_value = constant->cast_vector<int64_t>();
}
return true;
}
template <>
inline bool get_data_as_int64<ov::PartialShape>(
size_t idx,
const ov::Node* op,
std::vector<int64_t>& axes_value,
const std::map<size_t, std::shared_ptr<ngraph::runtime::HostTensor>>& constant_data) {
if (constant_data.count(idx)) {
axes_value = ov::opset1::Constant(constant_data.at(idx)).cast_vector<int64_t>();
} else if (const auto& constant = ov::get_constant_from_source(op->input_value(idx))) {
axes_value = constant->cast_vector<int64_t>();
// Helper to reduce duplicates of code for get_data_as_... specific type functions.
template <class TShape, class TData>
inline bool get_data_as(const ov::Node* op,
size_t idx,
std::vector<TData>& data_out,
const std::map<size_t, ov::HostTensorPtr>& constant_data = {}) {
if (auto out = ov::op::get_input_const_data_as<TShape, TData>(op, idx, constant_data)) {
data_out = std::move(*out);
return true;
} else {
return false;
}
}
template <class TShape>
inline bool get_data_as_int64(size_t idx,
const ov::Node* op,
std::vector<int64_t>& axes_value,
const std::map<size_t, ov::HostTensorPtr>& constant_data = {}) {
return get_data_as<TShape>(op, idx, axes_value, constant_data);
}
template <class TShape>
inline bool get_data_as_float(size_t idx,
const ov::Node* op,
std::vector<float>& axes_value,
const std::map<size_t, ov::HostTensorPtr>& constant_data = {}) {
return get_data_as<TShape>(op, idx, axes_value, constant_data);
}
/**
* \brief Get the operator's constant data as shape of type T.
*
* \note The constant data are get as size_t (Dimension value type for static shape). If pointed input is signed the
* output shape dimension can be wrongly interpreted.
*
* \tparam TShape Shape type.
*
* \param idx Operator's input index.
* \param op Pointer to operator.
* \param shape Output shape made from constant data.
* \param constant_data Map with constant tensors. Optional default empty.
*
* \return true If constant data acquired as shape otherwise throws NodeValidation exception.
*/
template <class TShape>
inline bool get_data_as_shape(size_t idx,
const ov::Node* op,
TShape& shape,
const std::map<size_t, ov::HostTensorPtr>& constant_data = {}) {
// Note, assumes that get_input_const_data_as throws exception for TShape different then ov::PartialShape.
shape = *ov::op::get_input_const_data_as<TShape, size_t, TShape>(op, idx, constant_data);
return true;
}
template <class T>
inline bool get_data_as_float(
size_t idx,
const ov::Node* op,
std::vector<float>& axes_value,
const std::map<size_t, std::shared_ptr<ngraph::runtime::HostTensor>>& constant_data = {}) {
if (constant_data.count(idx)) {
axes_value = ov::opset1::Constant(constant_data.at(idx)).cast_vector<float>();
} else {
const auto& constant = ov::as_type_ptr<ov::opset1::Constant>(op->get_input_node_shared_ptr(idx));
NODE_VALIDATION_CHECK(op, constant != nullptr, "Static shape inference lacks constant data on port ", idx);
axes_value = constant->cast_vector<float>();
}
return true;
}
template <>
inline bool get_data_as_float<ov::PartialShape>(
size_t idx,
const ov::Node* op,
std::vector<float>& axes_value,
const std::map<size_t, std::shared_ptr<ngraph::runtime::HostTensor>>& constant_data) {
if (constant_data.count(idx)) {
axes_value = ov::opset1::Constant(constant_data.at(idx)).cast_vector<float>();
} else if (const auto& constant = ov::get_constant_from_source(op->input_value(idx))) {
axes_value = constant->cast_vector<float>();
} else {
return false;
}
return true;
}
template <class T>
inline bool get_data_as_shape(
size_t idx,
const ov::Node* op,
T& shape,
const std::map<size_t, std::shared_ptr<ngraph::runtime::HostTensor>>& constant_data = {}) {
if (constant_data.count(idx)) {
shape = T(ov::opset1::Constant(constant_data.at(idx)).cast_vector<size_t>());
} else {
const auto& constant = ov::as_type_ptr<ov::opset1::Constant>(op->get_input_node_shared_ptr(idx));
NODE_VALIDATION_CHECK(op, constant != nullptr, "Static shape inference lacks constant data on port ", idx);
shape = T(constant->cast_vector<size_t>());
}
return true;
}
/**
* \brief Get the operator's constant data as ov::PartialShape.
*
* If data not get as constant then try evaluate this input as partial shape from input's bounds and labels.
*
* \note The constant data are get as int64_t. If pointed input is unsigned then output shape
* dimension can be wrongly interpreted.
*
* \param idx Operator's input index.
* \param op Pointer to operator.
* \param shape Output shape made from constant data.
* \param constant_data Map with constant tensors. Optional default empty.
*
* \return true If constant data acquired as shape otherwise throws NodeValidation exception.
*/
template <>
inline bool get_data_as_shape<ov::PartialShape>(
size_t idx,