diff --git a/src/core/dev_api/validation_util.hpp b/src/core/dev_api/validation_util.hpp index e93fefd1411..2495fd10299 100644 --- a/src/core/dev_api/validation_util.hpp +++ b/src/core/dev_api/validation_util.hpp @@ -34,7 +34,7 @@ OPENVINO_API bool are_unique(const std::vector& data); /// /// \param value Value to be clipped. /// \param min Minimum value bound. -/// \param max Maximum value boiund +/// \param max Maximum value bound. /// /// \return Value if between min, max otherwise min or max. OPENVINO_API int64_t clip(const int64_t& value, const int64_t& min, const int64_t& max); @@ -43,18 +43,21 @@ OPENVINO_API int64_t clip(const int64_t& value, const int64_t& min, const int64_ /// /// \param subgraph sink /// -/// \return Constant node or nullptr if unable to constantfold the subgraph +/// \return Constant node or nullptr if unable to constant fold the subgraph OPENVINO_API std::shared_ptr constantfold_subgraph(const Output& subgraph_sink); -/** - * @brief Runs an estimation of source tensor. If it succeeded to calculate both bounds and - * they are the same returns Constant operation from the resulting bound, otherwise nullptr. - * - * @param source Node output used to get its tensor data as constant. - * @return Shared pointer to constant data or nullptr. - */ +/// \brief Runs an estimation of source tensor. If it succeeded to calculate both bounds and +/// they are the same returns Constant operation from the resulting bound, otherwise nullptr. +/// +/// \param source Node output used to get its tensor data as constant. +/// \return Shared pointer to constant data or nullptr. OPENVINO_API std::shared_ptr get_constant_from_source(const Output& source); +/// \brief Make scalar tensor which stores maximum value of ov::element::Type. +/// \param et Element type to get its maximum. +/// \return Tensor with maximum value. +Tensor make_tensor_of_max_value(const element::Type_t et); + /// \brief Apply auto padding to padding_above and padding_below inputs /// if all needed informations are known. /// diff --git a/src/core/include/openvino/op/mod.hpp b/src/core/include/openvino/op/mod.hpp index 5e58a2ec03d..defb1c65163 100644 --- a/src/core/include/openvino/op/mod.hpp +++ b/src/core/include/openvino/op/mod.hpp @@ -29,6 +29,8 @@ public: std::shared_ptr clone_with_new_inputs(const OutputVector& new_args) const override; bool evaluate(ov::TensorVector& outputs, const ov::TensorVector& inputs) const override; + bool evaluate_lower(TensorVector& outputs) const override; + bool evaluate_upper(TensorVector& outputs) const override; bool has_evaluate() const override; }; } // namespace v1 diff --git a/src/core/reference/include/openvino/reference/mod.hpp b/src/core/reference/include/openvino/reference/mod.hpp index 81ae69e32eb..671ee012393 100644 --- a/src/core/reference/include/openvino/reference/mod.hpp +++ b/src/core/reference/include/openvino/reference/mod.hpp @@ -6,6 +6,7 @@ #include #include +#include #include "openvino/reference/autobroadcast_binop.hpp" #include "openvino/reference/utils/type_util.hpp" @@ -22,6 +23,72 @@ template ()>::type* = T mod(const T x, const T y) { return x - (std::trunc(x / y) * y); } + +/** + * @brief Estimates division remainder `[v1, v2] % m = [r0, r1]` as interval. + * + * Assumes that ` 0 <= v1 <= v2 and m != 0`, in other cases result is undefined behaviour. + * The result interval estimate minimum and maximum but is not true that value can be any value between min and max. + * e.g. + * - [4,6] % 5 = [0, 4], but in fact accurate result is set of [0,1,4] + + * @param v1 Minimum of value interval. + * @param v2 Maximum of value interval. + * @param m Modulo divisor. + * @return Remainder of division as interval range. + */ +template ::value>::type* = nullptr> +std::pair mod_interval_value(const T v1, const T v2, const T m) { + const auto v_diff = v2 - v1; + auto r = std::make_pair(func::mod(v1, m), func::mod(v2, m)); + + if ((r.second < r.first) || ((v_diff != T{0}) && (v_diff >= m))) { + r.first = T{0}; + r.second = m - T{1}; + } + return r; +} + +/** + * @brief Estimates division reminder of `[v1, v2] & [m1, m2] = [r0, r1]` as interval. + * + * * Assumes that ` 0 <= v1 <= v2 and 0 < m1 <= m2`, in other cases result is undefined behaviour. + * + * @param v1 Minimum of value interval. + * @param v2 Maximum of value interval. + * @param m1 Minimum of modulo divisor. + * @param m2 Maximum of modulo divisor. + * @return Remainder of division as interval range. + */ +template ::value>::type* = nullptr> +std::pair mod_interval(const T v1, const T v2, const T m1, const T m2) { + auto r = mod_interval_value(v1, v2, m1); + if (v2 != 0) { + if (m1 != m2) { + const auto v_diff = v2 - v1; + const auto m_diff = m2 - m1; + + auto r2 = mod_interval_value(v1, v2, m2); + r.first = std::min(r.first, r2.first); + r.second = std::max(r.second, r2.second); + + if (v_diff == T{0} && m_diff != T{1}) { + const T v2_half = v2 / T{2}; + if ((m1 < v2_half) || ((m1 < v2) && (v2 < m2))) { + r.first = T{0}; + + if ((v2_half < m2) && (m2 < v2)) { + const T v2_half_next = v2_half + T{1}; + r.second = func::mod(v2, v2_half_next); + } else { + r.second = m2 - T{1}; + } + } + } + } + } + return r; +} } // namespace func /** @@ -42,7 +109,7 @@ void mod(InputIt arg0, const Shape& arg_shape1, const op::AutoBroadcastSpec& broadcast_spec) { using T = typename std::iterator_traits::value_type; - autobroadcast_binop(arg0, arg1, out, arg_shape0, arg_shape1, broadcast_spec, &func::mod); + autobroadcast_binop(arg0, arg1, out, arg_shape0, arg_shape1, broadcast_spec, func::mod); } } // namespace reference } // namespace ov diff --git a/src/core/src/op/mod.cpp b/src/core/src/op/mod.cpp index e8aa1a8a009..816d605a292 100644 --- a/src/core/src/op/mod.cpp +++ b/src/core/src/op/mod.cpp @@ -4,13 +4,30 @@ #include "openvino/op/mod.hpp" +#include "bound_evaluate.hpp" #include "element_visitor.hpp" #include "itt.hpp" -#include "openvino/core/shape_util.hpp" +#include "openvino/core/validation_util.hpp" +#include "openvino/op/constant.hpp" +#include "openvino/op/equal.hpp" +#include "openvino/op/logical_or.hpp" +#include "openvino/op/select.hpp" #include "openvino/reference/mod.hpp" #include "utils.hpp" +#include "validation_util.hpp" namespace ov { +namespace util { +namespace { +Tensor make_tensor_of_value(const element::Type_t et, const int64_t value) { + auto c = op::v0::Constant(et, Shape{}, value); + auto t = Tensor(et, Shape{}); + std::memcpy(t.data(), c.get_data_ptr(), t.get_byte_size()); + return t; +} +} // namespace +} // namespace util + namespace op { namespace mod { struct Evaluate : ov::element::NoAction { @@ -31,6 +48,185 @@ struct Evaluate : ov::element::NoAction { return true; } }; + +struct EvaluateBound : element::NoAction { + using element::NoAction::visit; + + template > + static result_type visit(const Tensor& v_lb, + const Tensor& v_ub, + const Tensor& m_lb, + const Tensor& m_ub, + Tensor& out, + const bool is_lower) { + auto v_lb_first = v_lb.data(); + auto v_lb_last = std::next(v_lb_first, v_lb.get_size()); + auto v_ub_first = v_ub.data(); + auto m_lb_first = m_lb.data(); + auto m_ub_first = m_ub.data(); + auto out_first = out.data(); + + if (is_lower) { + while (v_lb_first != v_lb_last) { + *out_first++ = + reference::func::mod_interval(*v_lb_first++, *v_ub_first++, *m_lb_first++, *m_ub_first++).first; + } + } else { + while (v_lb_first != v_lb_last) { + *out_first++ = + reference::func::mod_interval(*v_lb_first++, *v_ub_first++, *m_lb_first++, *m_ub_first++).second; + } + } + return true; + } +}; + +namespace { + +/** + * @brief Get node inputs bounds as TensorVector. + * + * The inputs bounds are stored as [lower0, upper0, lower1, upper1]. + * + * @param op Pointer to the node. + * @return Vector with inputs bounds tensors. + */ +TensorVector get_bounds(const Node* const op) { + auto&& v_bounds = ov::evaluate_both_bounds(op->input_value(0)); + auto&& m_bounds = ov::evaluate_both_bounds(op->input_value(1)); + return {std::move(v_bounds.first), + std::move(v_bounds.second), + std::move(m_bounds.first), + std::move(m_bounds.second)}; +} + +/** + * @brief Check if all bounds in vector are valid (allocated). + * + * @param bounds TensorVector of bounds for check. + * @return True if bounds area valid otherwise false. + */ +bool are_bounds_valid(const TensorVector& bounds) { + return std::all_of(bounds.begin(), bounds.end(), [](const Tensor& t) { + return static_cast(t); + }); +} + +/** + * @brief Evaluate binary mask of values which cannot be calculated by modulo. + * + * @param bounds Modulo inputs bounds. + * @return Tensor with binary mask or empty tensor if evaluate failed. + */ +Tensor evaluate_undefined_result_mask(const TensorVector& bounds) { + const auto eq_op = v1::Equal(); + const auto or_op = v1::LogicalOr(); + + const auto& in_et = bounds.front().get_element_type(); + + auto zero_t = ov::util::make_tensor_of_value(in_et, 0); + auto max_t = ov::util::make_tensor_of_max_value(in_et); + + const auto& v_ub = bounds[1]; + const auto& m_lb = bounds[2]; + const auto& m_ub = bounds[3]; + + auto m_mask = TensorVector{{element::boolean, m_ub.get_shape()}}; + if (!eq_op.evaluate(m_mask, {m_lb, zero_t})) { + return {}; + } + + auto out_masks = TensorVector{{element::boolean, m_lb.get_shape()}}; + if (!eq_op.evaluate(out_masks, {m_ub, zero_t})) { + return {}; + } + + auto m_or_inputs = TensorVector{out_masks[0], m_mask[0]}; + or_op.evaluate(m_mask, m_or_inputs); + if (!eq_op.evaluate(out_masks, {m_lb, max_t})) { + return {}; + } + + or_op.evaluate(m_mask, m_or_inputs); + auto v_mask = TensorVector{{element::boolean, v_ub.get_shape()}}; + if (!eq_op.evaluate(v_mask, {v_ub, max_t})) { + return {}; + } + + out_masks[0].set_shape(ov::op::infer_broadcast_shape(&or_op, v_mask[0].get_shape(), m_mask[0].get_shape())); + return or_op.evaluate(out_masks, {v_mask[0], m_mask[0]}) ? out_masks[0] : Tensor{}; +} + +/** + * @brief Get the inputs bound with valid values only. + * + * The values which result modulo to give undefined result are replaced by one. + * The auto broadcast is applied to have inputs same shape. + * + * @param bounds Modulo operator inputs bounds. + * @param mask Mask with undefined result values. + * @return Vector of bounds tensors. + */ +TensorVector get_bounds_with_valid_values(const TensorVector& bounds, const Tensor& mask) { + const auto select_op = v1::Select(); + const auto one_t = ov::util::make_tensor_of_value(bounds.front().get_element_type(), 1); + + auto m_bounds = TensorVector(); + m_bounds.reserve(bounds.size()); + std::transform(bounds.cbegin(), bounds.cend(), std::back_inserter(m_bounds), [&](const Tensor& b) { + auto tmp = TensorVector{{b.get_element_type(), mask.get_shape()}}; + return select_op.evaluate(tmp, {mask, one_t, b}) ? tmp.front() : Tensor{}; + }); + return m_bounds; +} + +/** + * @brief Evaluate modulo upper or lower bound. + * + * @param op Pointer to modulo node. + * @param outputs Tensor vector with one tensor to store bounds result. + * @param is_lower True to evaluate lower otherwise evaluate upper. + * @return True if outputs has valid data otherwise false. + */ +bool evaluate_bound(const Node* const op, TensorVector& outputs, bool is_lower) { + const auto bounds = mod::get_bounds(op); + + if (mod::are_bounds_valid(bounds)) { + const auto& in_et = bounds[0].get_element_type(); + + const auto undefined_result_mask = mod::evaluate_undefined_result_mask(bounds); + if (!undefined_result_mask) { + return false; + } + + // Set inputs values to 1 for undefined results mask (0, inf, etc.) + const auto m_bounds = mod::get_bounds_with_valid_values(bounds, undefined_result_mask); + if (!mod::are_bounds_valid(m_bounds)) { + return false; + } + + // Evaluate bound. + outputs[0].set_shape(undefined_result_mask.get_shape()); + using namespace ov::element; + if (!IfTypeOf::apply(in_et, + m_bounds[0], + m_bounds[1], + m_bounds[2], + m_bounds[3], + outputs[0], + is_lower)) { + return false; + } + // Set undefined bound value for results which cannot be calculated. + const auto select_op = v1::Select(); + const auto undefined_bound = + is_lower ? ov::util::make_tensor_of_value(in_et, 0) : ov::util::make_tensor_of_max_value(in_et); + return select_op.evaluate(outputs, {undefined_result_mask, undefined_bound, outputs.front()}); + } else { + return false; + } +} +} // namespace } // namespace mod namespace v1 { @@ -59,6 +255,16 @@ bool Mod::evaluate(ov::TensorVector& outputs, const ov::TensorVector& inputs) co get_autob()); } +bool Mod::evaluate_lower(TensorVector& outputs) const { + OV_OP_SCOPE(v1_Mod_evaluate_lower); + return mod::evaluate_bound(this, outputs, true); +} + +bool Mod::evaluate_upper(TensorVector& outputs) const { + OV_OP_SCOPE(v1_Mod_evaluate_upper); + return mod::evaluate_bound(this, outputs, false); +} + bool Mod::has_evaluate() const { OV_OP_SCOPE(v1_Mod_has_evaluate); diff --git a/src/core/src/validation_util.cpp b/src/core/src/validation_util.cpp index 803364b2890..4a7bd1958f1 100644 --- a/src/core/src/validation_util.cpp +++ b/src/core/src/validation_util.cpp @@ -910,32 +910,8 @@ void evaluate_nodes(std::map& value_map, } std::shared_ptr get_constant_max_of_type(element::Type_t t) { -#define OPENVINO_TYPE_TO_MAX_CONST(t) \ - case t: \ - return ov::op::v0::Constant::create( \ - t, \ - {}, \ - {std::numeric_limits::value_type>::max()}); \ - break - - switch (t) { - OPENVINO_TYPE_TO_MAX_CONST(element::boolean); - OPENVINO_TYPE_TO_MAX_CONST(element::bf16); - OPENVINO_TYPE_TO_MAX_CONST(element::f16); - OPENVINO_TYPE_TO_MAX_CONST(element::f32); - OPENVINO_TYPE_TO_MAX_CONST(element::f64); - OPENVINO_TYPE_TO_MAX_CONST(element::i8); - OPENVINO_TYPE_TO_MAX_CONST(element::i16); - OPENVINO_TYPE_TO_MAX_CONST(element::i32); - OPENVINO_TYPE_TO_MAX_CONST(element::i64); - OPENVINO_TYPE_TO_MAX_CONST(element::u1); - OPENVINO_TYPE_TO_MAX_CONST(element::u8); - OPENVINO_TYPE_TO_MAX_CONST(element::u16); - OPENVINO_TYPE_TO_MAX_CONST(element::u32); - OPENVINO_TYPE_TO_MAX_CONST(element::u64); - default: - return nullptr; - } + auto tensor = ov::util::make_tensor_of_max_value(t); + return tensor ? std::make_shared(tensor) : nullptr; } std::shared_ptr get_constant_min_of_type(element::Type_t t) { @@ -1385,6 +1361,48 @@ std::shared_ptr get_constant_from_source(const Output& source) { } } +template +Tensor make_tensor_of_max_value(const element::Type_t et) { + Tensor t{et, Shape{}}; + *t.data() = std::numeric_limits::max(); + return t; +} + +Tensor make_tensor_of_max_value(const element::Type_t et) { + switch (et) { + case element::boolean: + return make_tensor_of_max_value>(et); + case element::bf16: + return make_tensor_of_max_value>(et); + case element::f16: + return make_tensor_of_max_value>(et); + case element::f32: + return make_tensor_of_max_value>(et); + case element::f64: + return make_tensor_of_max_value>(et); + case element::i8: + return make_tensor_of_max_value>(et); + case element::i16: + return make_tensor_of_max_value>(et); + case element::i32: + return make_tensor_of_max_value>(et); + case element::i64: + return make_tensor_of_max_value>(et); + case element::u1: + return make_tensor_of_max_value>(et); + case element::u8: + return make_tensor_of_max_value>(et); + case element::u16: + return make_tensor_of_max_value>(et); + case element::u32: + return make_tensor_of_max_value>(et); + case element::u64: + return make_tensor_of_max_value>(et); + default: + return {}; + } +} + std::vector get_tensors_partial_shapes(const TensorVector& tensors) { std::vector shapes; shapes.reserve(tensors.size()); diff --git a/src/core/tests/type_prop/mod.cpp b/src/core/tests/type_prop/mod.cpp index b1dbab11eea..0e5af52401b 100644 --- a/src/core/tests/type_prop/mod.cpp +++ b/src/core/tests/type_prop/mod.cpp @@ -5,7 +5,171 @@ #include "openvino/op/mod.hpp" #include "arithmetic_ops.hpp" +#include "openvino/core/validation_util.hpp" +#include "openvino/op/broadcast.hpp" +#include "openvino/op/constant.hpp" +#include "openvino/op/shape_of.hpp" +#include "openvino/op/squeeze.hpp" using Type = ::testing::Types; INSTANTIATE_TYPED_TEST_SUITE_P(type_prop_mod, ArithmeticOperator, Type); + +using ov::op::v0::Constant; +using ov::op::v0::Parameter; +using ov::op::v0::Squeeze; +using ov::op::v3::Broadcast; +using ov::op::v3::ShapeOf; + +class TypePropModV1Test : public TypePropOpTest {}; + +TEST_F(TypePropModV1Test, preserve_constant_data_on_inputs) { + const auto a = Constant::create(ov::element::i32, ov::Shape{4}, {4, 10, 22, 5}); + const auto b = Constant::create(ov::element::i32, ov::Shape{4}, {3, 4, 8, 3}); + const auto op = make_op(a, b); + + const auto param = std::make_shared(ov::element::i32, ov::Shape{1}); + auto bc = std::make_shared(param, op, ov::op::BroadcastType::BIDIRECTIONAL); + const auto& output_shape = bc->get_output_partial_shape(0); + EXPECT_EQ(output_shape, ov::PartialShape({1, 2, 6, 2})); +} + +TEST_F(TypePropModV1Test, preserve_partial_values_on_inputs) { + const auto a = std::make_shared(ov::element::i64, ov::PartialShape{{5, 6}, 22, {3, 7}, -1, {7, 9}}); + const auto b = std::make_shared(ov::element::i64, ov::PartialShape{3, {12, 18}, {4, 6}, -1, {0, 4}}); + const auto op = make_op(std::make_shared(a), std::make_shared(b)); + + const auto param = std::make_shared(ov::element::i64, ov::Shape{1}); + auto bc = std::make_shared(param, op, ov::op::BroadcastType::BIDIRECTIONAL); + + const auto& output_shape = bc->get_output_partial_shape(0); + EXPECT_EQ(output_shape, ov::PartialShape({{0, 2}, {4, 10}, {0, 5}, -1, -1})); +} + +TEST_F(TypePropModV1Test, preserve_partial_values_when_m_is_interval_scalar) { + const auto a = std::make_shared(ov::element::i64, ov::PartialShape{{5, 6}, 22, {3, 7}, -1, {7, 9}}); + const auto b = std::make_shared(ov::element::i64, ov::PartialShape{{12, 18}}); + const auto b_scalar = std::make_shared(std::make_shared(b)); + const auto op = make_op(std::make_shared(a), b_scalar); + + const auto param = std::make_shared(ov::element::i64, ov::Shape{1}); + auto bc = std::make_shared(param, op, ov::op::BroadcastType::BIDIRECTIONAL); + + const auto& output_shape = bc->get_output_partial_shape(0); + EXPECT_EQ(output_shape, ov::PartialShape({{5, 6}, {4, 10}, {3, 7}, -1, {7, 9}})); +} + +TEST_F(TypePropModV1Test, preserve_partial_values_when_value_is_interval_scalar) { + const auto a = std::make_shared(ov::element::i64, ov::PartialShape{{3, 7}}); + const auto b = std::make_shared(ov::element::i64, ov::PartialShape{3, {12, 18}, {4, 6}, -1, {0, 4}}); + const auto a_scalar = std::make_shared(std::make_shared(a)); + const auto op = make_op(a_scalar, std::make_shared(b)); + + const auto param = std::make_shared(ov::element::i64, ov::Shape{1}); + auto bc = std::make_shared(param, op, ov::op::BroadcastType::BIDIRECTIONAL); + + const auto& output_shape = bc->get_output_partial_shape(0); + EXPECT_EQ(output_shape, ov::PartialShape({{0, 2}, {3, 7}, {0, 5}, -1, -1})); +} + +// test params as {a, b, exp_result} +using IntervalModuloParams = std::tuple; + +class SingleDimModV1Test : public TypePropModV1Test, public testing::WithParamInterface { +protected: + void SetUp() override { + std::tie(a_dim, b_dim, exp_dim) = GetParam(); + } + + ov::Dimension a_dim, b_dim, exp_dim; +}; + +const auto v_and_m_static = testing::Values(IntervalModuloParams{{0, 0}, {1, 1}, {0, 0}}, + IntervalModuloParams{{0, 0}, {9, 9}, {0, 0}}, + IntervalModuloParams{{0, 0}, {1000, 1000}, {0, 0}}, + IntervalModuloParams{{10, 10}, {3, 3}, {1, 1}}, + IntervalModuloParams{{10, 10}, {6, 6}, {4, 4}}, + IntervalModuloParams{{10, 10}, {5, 5}, {0, 0}}, + IntervalModuloParams{{10, 10}, {15, 15}, {10, 10}}); + +const auto v_interval_m_static = testing::Values(IntervalModuloParams{{6, 7}, {4, 4}, {2, 3}}, + IntervalModuloParams{{6, 8}, {4, 4}, {0, 3}}, // Result [0,2,3] + IntervalModuloParams{{6, 8}, {10, 10}, {6, 8}}, + IntervalModuloParams{{6, 8}, {7, 7}, {0, 6}}, + IntervalModuloParams{{4, 8}, {7, 7}, {0, 6}}, + IntervalModuloParams{{15, 16}, {7, 7}, {1, 2}}, + IntervalModuloParams{{5, 20}, {5, 5}, {0, 4}}, + + IntervalModuloParams{{5, 10}, {7, 7}, {0, 6}}); + +const auto v_static_m_interval = testing::Values(IntervalModuloParams{{0, 0}, {3, 13}, {0, 0}}, + IntervalModuloParams{{10, 10}, {2, 4}, {0, 3}}, + IntervalModuloParams{{10, 10}, {2, 6}, {0, 4}}, + IntervalModuloParams{{10, 10}, {6, 9}, {1, 4}}, + IntervalModuloParams{{10, 10}, {9, 11}, {0, 10}}, + IntervalModuloParams{{10, 10}, {3, 11}, {0, 10}}, + IntervalModuloParams{{10, 10}, {3, 10}, {0, 9}}, + IntervalModuloParams{{10, 10}, {7, 8}, {2, 3}}, + IntervalModuloParams{{100, 100}, {2, 20}, {0, 19}}, + // can be estimated accurate as only two results are possible + IntervalModuloParams{{100, 100}, {15, 16}, {4, 10}}, + // can not be estimated accurate as there are three results [10,4,15] + // Requires to calculate all possibilities and pick min, max + IntervalModuloParams{{100, 100}, {15, 17}, {0, 16}}); + +const auto v_and_m_intervals = testing::Values(IntervalModuloParams{{1, 10}, {2, 9}, {0, 8}}, + IntervalModuloParams{{1, 10}, {6, 9}, {0, 8}}, + IntervalModuloParams{{1, 10}, {2, 12}, {0, 10}}, + IntervalModuloParams{{1, 10}, {6, 12}, {0, 10}}, + IntervalModuloParams{{1, 10}, {11, 12}, {1, 10}}, + IntervalModuloParams{{1, 10}, {11, 15}, {1, 10}}, + IntervalModuloParams{{4, 10}, {10, 13}, {0, 10}}, + IntervalModuloParams{{10, 20}, {3, 5}, {0, 4}}, + IntervalModuloParams{{10, 10}, {3, 10}, {0, 9}}, + IntervalModuloParams{{5, 20}, {5, 10}, {0, 9}}, + IntervalModuloParams{{10, 100}, {3, 20}, {0, 19}}, + IntervalModuloParams{{10, 100}, {2, 20}, {0, 19}}, + IntervalModuloParams{{10, 100}, {51, 60}, {0, 59}}); + +// If input is infinite or m has 0 then output is undefined. +const auto v_and_m_special_values = testing::Values(IntervalModuloParams{{0, -1}, {5, 5}, {0, -1}}, + IntervalModuloParams{{10, -1}, {4, 4}, {0, -1}}, + // Evaluate low/up return [0, max] + // but evaluate both bounds return [0] as `m` has same bounds + IntervalModuloParams{{11, 11}, {0, 0}, {0, 0}}, + IntervalModuloParams{{11, 11}, {0, 5}, {0, -1}}, + IntervalModuloParams{{11, 20}, {0, 5}, {0, -1}}, + IntervalModuloParams{{11, 20}, {0, -1}, {0, -1}}, + IntervalModuloParams{{0, -1}, {0, -1}, {0, -1}}); + +INSTANTIATE_TEST_SUITE_P(v_and_m_static, SingleDimModV1Test, v_and_m_static); +INSTANTIATE_TEST_SUITE_P(value_interval_m_static, SingleDimModV1Test, v_interval_m_static); +INSTANTIATE_TEST_SUITE_P(value_static_m_interval, SingleDimModV1Test, v_static_m_interval); +INSTANTIATE_TEST_SUITE_P(value_and_m_as_intervals, SingleDimModV1Test, v_and_m_intervals); +INSTANTIATE_TEST_SUITE_P(value_and_m_special_values, SingleDimModV1Test, v_and_m_special_values); + +TEST_P(SingleDimModV1Test, preserve_value_on_inputs_i64) { + constexpr auto et = ov::element::i64; + const auto a = std::make_shared(et, ov::PartialShape{a_dim}); + const auto b = std::make_shared(et, ov::PartialShape{b_dim}); + const auto op = make_op(std::make_shared(a), std::make_shared(b)); + + const auto param = std::make_shared(et, ov::Shape{1}); + const auto bc = std::make_shared(param, op, ov::op::BroadcastType::BIDIRECTIONAL); + const auto& output_shape = bc->get_output_partial_shape(0); + + EXPECT_EQ(output_shape, ov::PartialShape({exp_dim})); +} + +TEST_P(SingleDimModV1Test, preserve_value_on_inputs_i32) { + constexpr auto et = ov::element::i32; + const auto a = std::make_shared(et, ov::PartialShape{a_dim}); + const auto b = std::make_shared(et, ov::PartialShape{b_dim}); + const auto op = make_op(std::make_shared(a, et), std::make_shared(b, et)); + + const auto param = std::make_shared(et, ov::Shape{1}); + const auto bc = std::make_shared(param, op, ov::op::BroadcastType::BIDIRECTIONAL); + const auto& output_shape = bc->get_output_partial_shape(0); + + EXPECT_EQ(output_shape, ov::PartialShape({exp_dim})); +}