Remove code duplicates for shape inference utils (#14721)
* Remove code duplicates for shape inference utils * Fix typos and comments
This commit is contained in:
parent
a0be13be57
commit
000a634429
@ -98,9 +98,9 @@ template <class TShape,
|
|||||||
class TData,
|
class TData,
|
||||||
class TRes = std::vector<TData>,
|
class TRes = std::vector<TData>,
|
||||||
typename std::enable_if<std::is_same<TShape, ov::PartialShape>::value>::type* = nullptr>
|
typename std::enable_if<std::is_same<TShape, ov::PartialShape>::value>::type* = nullptr>
|
||||||
std::unique_ptr<std::vector<TData>> get_input_const_data_as(const ov::Node* op,
|
std::unique_ptr<TRes> get_input_const_data_as(const ov::Node* op,
|
||||||
size_t idx,
|
size_t idx,
|
||||||
const std::map<size_t, HostTensorPtr>& constant_data = {}) {
|
const std::map<size_t, HostTensorPtr>& constant_data = {}) {
|
||||||
if (constant_data.count(idx)) {
|
if (constant_data.count(idx)) {
|
||||||
return std::unique_ptr<TRes>(new TRes(ov::opset1::Constant(constant_data.at(idx)).cast_vector<TData>()));
|
return std::unique_ptr<TRes>(new TRes(ov::opset1::Constant(constant_data.at(idx)).cast_vector<TData>()));
|
||||||
} else if (const auto& constant = ov::get_constant_from_source(op->input_value(idx))) {
|
} else if (const auto& constant = ov::get_constant_from_source(op->input_value(idx))) {
|
||||||
@ -113,86 +113,76 @@ std::unique_ptr<std::vector<TData>> get_input_const_data_as(const ov::Node* op,
|
|||||||
} // namespace op
|
} // namespace op
|
||||||
} // namespace ov
|
} // namespace ov
|
||||||
|
|
||||||
template <class T>
|
// Helper to reduce duplicates of code for get_data_as_... specific type functions.
|
||||||
inline bool get_data_as_int64(
|
template <class TShape, class TData>
|
||||||
size_t idx,
|
inline bool get_data_as(const ov::Node* op,
|
||||||
const ov::Node* op,
|
size_t idx,
|
||||||
std::vector<int64_t>& axes_value,
|
std::vector<TData>& data_out,
|
||||||
const std::map<size_t, std::shared_ptr<ngraph::runtime::HostTensor>>& constant_data = {}) {
|
const std::map<size_t, ov::HostTensorPtr>& constant_data = {}) {
|
||||||
if (constant_data.count(idx)) {
|
if (auto out = ov::op::get_input_const_data_as<TShape, TData>(op, idx, constant_data)) {
|
||||||
axes_value = ov::opset1::Constant(constant_data.at(idx)).cast_vector<int64_t>();
|
data_out = std::move(*out);
|
||||||
} else {
|
return true;
|
||||||
const auto& constant = ov::as_type_ptr<ov::opset1::Constant>(op->get_input_node_shared_ptr(idx));
|
|
||||||
NODE_VALIDATION_CHECK(op, constant != nullptr, "Static shape inference lacks constant data on port ", idx);
|
|
||||||
axes_value = constant->cast_vector<int64_t>();
|
|
||||||
}
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
|
|
||||||
template <>
|
|
||||||
inline bool get_data_as_int64<ov::PartialShape>(
|
|
||||||
size_t idx,
|
|
||||||
const ov::Node* op,
|
|
||||||
std::vector<int64_t>& axes_value,
|
|
||||||
const std::map<size_t, std::shared_ptr<ngraph::runtime::HostTensor>>& constant_data) {
|
|
||||||
if (constant_data.count(idx)) {
|
|
||||||
axes_value = ov::opset1::Constant(constant_data.at(idx)).cast_vector<int64_t>();
|
|
||||||
} else if (const auto& constant = ov::get_constant_from_source(op->input_value(idx))) {
|
|
||||||
axes_value = constant->cast_vector<int64_t>();
|
|
||||||
} else {
|
} else {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <class TShape>
|
||||||
|
inline bool get_data_as_int64(size_t idx,
|
||||||
|
const ov::Node* op,
|
||||||
|
std::vector<int64_t>& axes_value,
|
||||||
|
const std::map<size_t, ov::HostTensorPtr>& constant_data = {}) {
|
||||||
|
return get_data_as<TShape>(op, idx, axes_value, constant_data);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <class TShape>
|
||||||
|
inline bool get_data_as_float(size_t idx,
|
||||||
|
const ov::Node* op,
|
||||||
|
std::vector<float>& axes_value,
|
||||||
|
const std::map<size_t, ov::HostTensorPtr>& constant_data = {}) {
|
||||||
|
return get_data_as<TShape>(op, idx, axes_value, constant_data);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* \brief Get the operator's constant data as shape of type T.
|
||||||
|
*
|
||||||
|
* \note The constant data are get as size_t (Dimension value type for static shape). If pointed input is signed the
|
||||||
|
* output shape dimension can be wrongly interpreted.
|
||||||
|
*
|
||||||
|
* \tparam TShape Shape type.
|
||||||
|
*
|
||||||
|
* \param idx Operator's input index.
|
||||||
|
* \param op Pointer to operator.
|
||||||
|
* \param shape Output shape made from constant data.
|
||||||
|
* \param constant_data Map with constant tensors. Optional default empty.
|
||||||
|
*
|
||||||
|
* \return true If constant data acquired as shape otherwise throws NodeValidation exception.
|
||||||
|
*/
|
||||||
|
template <class TShape>
|
||||||
|
inline bool get_data_as_shape(size_t idx,
|
||||||
|
const ov::Node* op,
|
||||||
|
TShape& shape,
|
||||||
|
const std::map<size_t, ov::HostTensorPtr>& constant_data = {}) {
|
||||||
|
// Note, assumes that get_input_const_data_as throws exception for TShape different then ov::PartialShape.
|
||||||
|
shape = *ov::op::get_input_const_data_as<TShape, size_t, TShape>(op, idx, constant_data);
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
template <class T>
|
/**
|
||||||
inline bool get_data_as_float(
|
* \brief Get the operator's constant data as ov::PartialShape.
|
||||||
size_t idx,
|
*
|
||||||
const ov::Node* op,
|
* If data not get as constant then try evaluate this input as partial shape from input's bounds and labels.
|
||||||
std::vector<float>& axes_value,
|
*
|
||||||
const std::map<size_t, std::shared_ptr<ngraph::runtime::HostTensor>>& constant_data = {}) {
|
* \note The constant data are get as int64_t. If pointed input is unsigned then output shape
|
||||||
if (constant_data.count(idx)) {
|
* dimension can be wrongly interpreted.
|
||||||
axes_value = ov::opset1::Constant(constant_data.at(idx)).cast_vector<float>();
|
*
|
||||||
} else {
|
* \param idx Operator's input index.
|
||||||
const auto& constant = ov::as_type_ptr<ov::opset1::Constant>(op->get_input_node_shared_ptr(idx));
|
* \param op Pointer to operator.
|
||||||
NODE_VALIDATION_CHECK(op, constant != nullptr, "Static shape inference lacks constant data on port ", idx);
|
* \param shape Output shape made from constant data.
|
||||||
axes_value = constant->cast_vector<float>();
|
* \param constant_data Map with constant tensors. Optional default empty.
|
||||||
}
|
*
|
||||||
return true;
|
* \return true If constant data acquired as shape otherwise throws NodeValidation exception.
|
||||||
}
|
*/
|
||||||
|
|
||||||
template <>
|
|
||||||
inline bool get_data_as_float<ov::PartialShape>(
|
|
||||||
size_t idx,
|
|
||||||
const ov::Node* op,
|
|
||||||
std::vector<float>& axes_value,
|
|
||||||
const std::map<size_t, std::shared_ptr<ngraph::runtime::HostTensor>>& constant_data) {
|
|
||||||
if (constant_data.count(idx)) {
|
|
||||||
axes_value = ov::opset1::Constant(constant_data.at(idx)).cast_vector<float>();
|
|
||||||
} else if (const auto& constant = ov::get_constant_from_source(op->input_value(idx))) {
|
|
||||||
axes_value = constant->cast_vector<float>();
|
|
||||||
} else {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
|
|
||||||
template <class T>
|
|
||||||
inline bool get_data_as_shape(
|
|
||||||
size_t idx,
|
|
||||||
const ov::Node* op,
|
|
||||||
T& shape,
|
|
||||||
const std::map<size_t, std::shared_ptr<ngraph::runtime::HostTensor>>& constant_data = {}) {
|
|
||||||
if (constant_data.count(idx)) {
|
|
||||||
shape = T(ov::opset1::Constant(constant_data.at(idx)).cast_vector<size_t>());
|
|
||||||
} else {
|
|
||||||
const auto& constant = ov::as_type_ptr<ov::opset1::Constant>(op->get_input_node_shared_ptr(idx));
|
|
||||||
NODE_VALIDATION_CHECK(op, constant != nullptr, "Static shape inference lacks constant data on port ", idx);
|
|
||||||
shape = T(constant->cast_vector<size_t>());
|
|
||||||
}
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
|
|
||||||
template <>
|
template <>
|
||||||
inline bool get_data_as_shape<ov::PartialShape>(
|
inline bool get_data_as_shape<ov::PartialShape>(
|
||||||
size_t idx,
|
size_t idx,
|
||||||
|
Loading…
Reference in New Issue
Block a user