diff --git a/src/core/shape_inference/include/utils.hpp b/src/core/shape_inference/include/utils.hpp index 831f7f14062..71dcf069c08 100644 --- a/src/core/shape_inference/include/utils.hpp +++ b/src/core/shape_inference/include/utils.hpp @@ -98,9 +98,9 @@ template , typename std::enable_if::value>::type* = nullptr> -std::unique_ptr> get_input_const_data_as(const ov::Node* op, - size_t idx, - const std::map& constant_data = {}) { +std::unique_ptr get_input_const_data_as(const ov::Node* op, + size_t idx, + const std::map& constant_data = {}) { if (constant_data.count(idx)) { return std::unique_ptr(new TRes(ov::opset1::Constant(constant_data.at(idx)).cast_vector())); } else if (const auto& constant = ov::get_constant_from_source(op->input_value(idx))) { @@ -113,86 +113,76 @@ std::unique_ptr> get_input_const_data_as(const ov::Node* op, } // namespace op } // namespace ov -template -inline bool get_data_as_int64( - size_t idx, - const ov::Node* op, - std::vector& axes_value, - const std::map>& constant_data = {}) { - if (constant_data.count(idx)) { - axes_value = ov::opset1::Constant(constant_data.at(idx)).cast_vector(); - } else { - const auto& constant = ov::as_type_ptr(op->get_input_node_shared_ptr(idx)); - NODE_VALIDATION_CHECK(op, constant != nullptr, "Static shape inference lacks constant data on port ", idx); - axes_value = constant->cast_vector(); - } - return true; -} - -template <> -inline bool get_data_as_int64( - size_t idx, - const ov::Node* op, - std::vector& axes_value, - const std::map>& constant_data) { - if (constant_data.count(idx)) { - axes_value = ov::opset1::Constant(constant_data.at(idx)).cast_vector(); - } else if (const auto& constant = ov::get_constant_from_source(op->input_value(idx))) { - axes_value = constant->cast_vector(); +// Helper to reduce duplicates of code for get_data_as_... specific type functions. +template +inline bool get_data_as(const ov::Node* op, + size_t idx, + std::vector& data_out, + const std::map& constant_data = {}) { + if (auto out = ov::op::get_input_const_data_as(op, idx, constant_data)) { + data_out = std::move(*out); + return true; } else { return false; } +} + +template +inline bool get_data_as_int64(size_t idx, + const ov::Node* op, + std::vector& axes_value, + const std::map& constant_data = {}) { + return get_data_as(op, idx, axes_value, constant_data); +} + +template +inline bool get_data_as_float(size_t idx, + const ov::Node* op, + std::vector& axes_value, + const std::map& constant_data = {}) { + return get_data_as(op, idx, axes_value, constant_data); +} + +/** + * \brief Get the operator's constant data as shape of type T. + * + * \note The constant data are get as size_t (Dimension value type for static shape). If pointed input is signed the + * output shape dimension can be wrongly interpreted. + * + * \tparam TShape Shape type. + * + * \param idx Operator's input index. + * \param op Pointer to operator. + * \param shape Output shape made from constant data. + * \param constant_data Map with constant tensors. Optional default empty. + * + * \return true If constant data acquired as shape otherwise throws NodeValidation exception. + */ +template +inline bool get_data_as_shape(size_t idx, + const ov::Node* op, + TShape& shape, + const std::map& constant_data = {}) { + // Note, assumes that get_input_const_data_as throws exception for TShape different then ov::PartialShape. + shape = *ov::op::get_input_const_data_as(op, idx, constant_data); return true; } -template -inline bool get_data_as_float( - size_t idx, - const ov::Node* op, - std::vector& axes_value, - const std::map>& constant_data = {}) { - if (constant_data.count(idx)) { - axes_value = ov::opset1::Constant(constant_data.at(idx)).cast_vector(); - } else { - const auto& constant = ov::as_type_ptr(op->get_input_node_shared_ptr(idx)); - NODE_VALIDATION_CHECK(op, constant != nullptr, "Static shape inference lacks constant data on port ", idx); - axes_value = constant->cast_vector(); - } - return true; -} - -template <> -inline bool get_data_as_float( - size_t idx, - const ov::Node* op, - std::vector& axes_value, - const std::map>& constant_data) { - if (constant_data.count(idx)) { - axes_value = ov::opset1::Constant(constant_data.at(idx)).cast_vector(); - } else if (const auto& constant = ov::get_constant_from_source(op->input_value(idx))) { - axes_value = constant->cast_vector(); - } else { - return false; - } - return true; -} - -template -inline bool get_data_as_shape( - size_t idx, - const ov::Node* op, - T& shape, - const std::map>& constant_data = {}) { - if (constant_data.count(idx)) { - shape = T(ov::opset1::Constant(constant_data.at(idx)).cast_vector()); - } else { - const auto& constant = ov::as_type_ptr(op->get_input_node_shared_ptr(idx)); - NODE_VALIDATION_CHECK(op, constant != nullptr, "Static shape inference lacks constant data on port ", idx); - shape = T(constant->cast_vector()); - } - return true; -} - +/** + * \brief Get the operator's constant data as ov::PartialShape. + * + * If data not get as constant then try evaluate this input as partial shape from input's bounds and labels. + * + * \note The constant data are get as int64_t. If pointed input is unsigned then output shape + * dimension can be wrongly interpreted. + * + * \param idx Operator's input index. + * \param op Pointer to operator. + * \param shape Output shape made from constant data. + * \param constant_data Map with constant tensors. Optional default empty. + * + * \return true If constant data acquired as shape otherwise throws NodeValidation exception. + */ template <> inline bool get_data_as_shape( size_t idx,