Remove code duplicates for shape inference utils (#14721)
* Remove code duplicates for shape inference utils * Fix typos and comments
This commit is contained in:
parent
a0be13be57
commit
000a634429
@ -98,7 +98,7 @@ template <class TShape,
|
||||
class TData,
|
||||
class TRes = std::vector<TData>,
|
||||
typename std::enable_if<std::is_same<TShape, ov::PartialShape>::value>::type* = nullptr>
|
||||
std::unique_ptr<std::vector<TData>> get_input_const_data_as(const ov::Node* op,
|
||||
std::unique_ptr<TRes> get_input_const_data_as(const ov::Node* op,
|
||||
size_t idx,
|
||||
const std::map<size_t, HostTensorPtr>& constant_data = {}) {
|
||||
if (constant_data.count(idx)) {
|
||||
@ -113,86 +113,76 @@ std::unique_ptr<std::vector<TData>> get_input_const_data_as(const ov::Node* op,
|
||||
} // namespace op
|
||||
} // namespace ov
|
||||
|
||||
template <class T>
|
||||
inline bool get_data_as_int64(
|
||||
// Helper to reduce duplicates of code for get_data_as_... specific type functions.
|
||||
template <class TShape, class TData>
|
||||
inline bool get_data_as(const ov::Node* op,
|
||||
size_t idx,
|
||||
const ov::Node* op,
|
||||
std::vector<int64_t>& axes_value,
|
||||
const std::map<size_t, std::shared_ptr<ngraph::runtime::HostTensor>>& constant_data = {}) {
|
||||
if (constant_data.count(idx)) {
|
||||
axes_value = ov::opset1::Constant(constant_data.at(idx)).cast_vector<int64_t>();
|
||||
} else {
|
||||
const auto& constant = ov::as_type_ptr<ov::opset1::Constant>(op->get_input_node_shared_ptr(idx));
|
||||
NODE_VALIDATION_CHECK(op, constant != nullptr, "Static shape inference lacks constant data on port ", idx);
|
||||
axes_value = constant->cast_vector<int64_t>();
|
||||
}
|
||||
std::vector<TData>& data_out,
|
||||
const std::map<size_t, ov::HostTensorPtr>& constant_data = {}) {
|
||||
if (auto out = ov::op::get_input_const_data_as<TShape, TData>(op, idx, constant_data)) {
|
||||
data_out = std::move(*out);
|
||||
return true;
|
||||
}
|
||||
|
||||
template <>
|
||||
inline bool get_data_as_int64<ov::PartialShape>(
|
||||
size_t idx,
|
||||
const ov::Node* op,
|
||||
std::vector<int64_t>& axes_value,
|
||||
const std::map<size_t, std::shared_ptr<ngraph::runtime::HostTensor>>& constant_data) {
|
||||
if (constant_data.count(idx)) {
|
||||
axes_value = ov::opset1::Constant(constant_data.at(idx)).cast_vector<int64_t>();
|
||||
} else if (const auto& constant = ov::get_constant_from_source(op->input_value(idx))) {
|
||||
axes_value = constant->cast_vector<int64_t>();
|
||||
} else {
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
template <class T>
|
||||
inline bool get_data_as_float(
|
||||
size_t idx,
|
||||
template <class TShape>
|
||||
inline bool get_data_as_int64(size_t idx,
|
||||
const ov::Node* op,
|
||||
std::vector<int64_t>& axes_value,
|
||||
const std::map<size_t, ov::HostTensorPtr>& constant_data = {}) {
|
||||
return get_data_as<TShape>(op, idx, axes_value, constant_data);
|
||||
}
|
||||
|
||||
template <class TShape>
|
||||
inline bool get_data_as_float(size_t idx,
|
||||
const ov::Node* op,
|
||||
std::vector<float>& axes_value,
|
||||
const std::map<size_t, std::shared_ptr<ngraph::runtime::HostTensor>>& constant_data = {}) {
|
||||
if (constant_data.count(idx)) {
|
||||
axes_value = ov::opset1::Constant(constant_data.at(idx)).cast_vector<float>();
|
||||
} else {
|
||||
const auto& constant = ov::as_type_ptr<ov::opset1::Constant>(op->get_input_node_shared_ptr(idx));
|
||||
NODE_VALIDATION_CHECK(op, constant != nullptr, "Static shape inference lacks constant data on port ", idx);
|
||||
axes_value = constant->cast_vector<float>();
|
||||
}
|
||||
return true;
|
||||
const std::map<size_t, ov::HostTensorPtr>& constant_data = {}) {
|
||||
return get_data_as<TShape>(op, idx, axes_value, constant_data);
|
||||
}
|
||||
|
||||
template <>
|
||||
inline bool get_data_as_float<ov::PartialShape>(
|
||||
size_t idx,
|
||||
/**
|
||||
* \brief Get the operator's constant data as shape of type T.
|
||||
*
|
||||
* \note The constant data are get as size_t (Dimension value type for static shape). If pointed input is signed the
|
||||
* output shape dimension can be wrongly interpreted.
|
||||
*
|
||||
* \tparam TShape Shape type.
|
||||
*
|
||||
* \param idx Operator's input index.
|
||||
* \param op Pointer to operator.
|
||||
* \param shape Output shape made from constant data.
|
||||
* \param constant_data Map with constant tensors. Optional default empty.
|
||||
*
|
||||
* \return true If constant data acquired as shape otherwise throws NodeValidation exception.
|
||||
*/
|
||||
template <class TShape>
|
||||
inline bool get_data_as_shape(size_t idx,
|
||||
const ov::Node* op,
|
||||
std::vector<float>& axes_value,
|
||||
const std::map<size_t, std::shared_ptr<ngraph::runtime::HostTensor>>& constant_data) {
|
||||
if (constant_data.count(idx)) {
|
||||
axes_value = ov::opset1::Constant(constant_data.at(idx)).cast_vector<float>();
|
||||
} else if (const auto& constant = ov::get_constant_from_source(op->input_value(idx))) {
|
||||
axes_value = constant->cast_vector<float>();
|
||||
} else {
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
template <class T>
|
||||
inline bool get_data_as_shape(
|
||||
size_t idx,
|
||||
const ov::Node* op,
|
||||
T& shape,
|
||||
const std::map<size_t, std::shared_ptr<ngraph::runtime::HostTensor>>& constant_data = {}) {
|
||||
if (constant_data.count(idx)) {
|
||||
shape = T(ov::opset1::Constant(constant_data.at(idx)).cast_vector<size_t>());
|
||||
} else {
|
||||
const auto& constant = ov::as_type_ptr<ov::opset1::Constant>(op->get_input_node_shared_ptr(idx));
|
||||
NODE_VALIDATION_CHECK(op, constant != nullptr, "Static shape inference lacks constant data on port ", idx);
|
||||
shape = T(constant->cast_vector<size_t>());
|
||||
}
|
||||
TShape& shape,
|
||||
const std::map<size_t, ov::HostTensorPtr>& constant_data = {}) {
|
||||
// Note, assumes that get_input_const_data_as throws exception for TShape different then ov::PartialShape.
|
||||
shape = *ov::op::get_input_const_data_as<TShape, size_t, TShape>(op, idx, constant_data);
|
||||
return true;
|
||||
}
|
||||
|
||||
/**
|
||||
* \brief Get the operator's constant data as ov::PartialShape.
|
||||
*
|
||||
* If data not get as constant then try evaluate this input as partial shape from input's bounds and labels.
|
||||
*
|
||||
* \note The constant data are get as int64_t. If pointed input is unsigned then output shape
|
||||
* dimension can be wrongly interpreted.
|
||||
*
|
||||
* \param idx Operator's input index.
|
||||
* \param op Pointer to operator.
|
||||
* \param shape Output shape made from constant data.
|
||||
* \param constant_data Map with constant tensors. Optional default empty.
|
||||
*
|
||||
* \return true If constant data acquired as shape otherwise throws NodeValidation exception.
|
||||
*/
|
||||
template <>
|
||||
inline bool get_data_as_shape<ov::PartialShape>(
|
||||
size_t idx,
|
||||
|
Loading…
Reference in New Issue
Block a user