[shape_infer]Add preserve partial values on inputs for Mod operator (#20169)

* Preserve partial values on mod inputs
- static values full range of integers
- intervals only if not negatives

* Fix bounds evaluate when inputs are scalars
This commit is contained in:
Pawel Raasz 2023-10-24 12:53:54 +02:00 committed by GitHub
parent 1daa4b9e5e
commit 750f62fd04
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 497 additions and 37 deletions

View File

@ -34,7 +34,7 @@ OPENVINO_API bool are_unique(const std::vector<int64_t>& data);
///
/// \param value Value to be clipped.
/// \param min Minimum value bound.
/// \param max Maximum value boiund
/// \param max Maximum value bound.
///
/// \return Value if between min, max otherwise min or max.
OPENVINO_API int64_t clip(const int64_t& value, const int64_t& min, const int64_t& max);
@ -43,18 +43,21 @@ OPENVINO_API int64_t clip(const int64_t& value, const int64_t& min, const int64_
///
/// \param subgraph sink
///
/// \return Constant node or nullptr if unable to constantfold the subgraph
/// \return Constant node or nullptr if unable to constant fold the subgraph
OPENVINO_API std::shared_ptr<op::v0::Constant> constantfold_subgraph(const Output<Node>& subgraph_sink);
/**
* @brief Runs an estimation of source tensor. If it succeeded to calculate both bounds and
* they are the same returns Constant operation from the resulting bound, otherwise nullptr.
*
* @param source Node output used to get its tensor data as constant.
* @return Shared pointer to constant data or nullptr.
*/
/// \brief Runs an estimation of source tensor. If it succeeded to calculate both bounds and
/// they are the same returns Constant operation from the resulting bound, otherwise nullptr.
///
/// \param source Node output used to get its tensor data as constant.
/// \return Shared pointer to constant data or nullptr.
OPENVINO_API std::shared_ptr<op::v0::Constant> get_constant_from_source(const Output<Node>& source);
/// \brief Make scalar tensor which stores maximum value of ov::element::Type.
/// \param et Element type to get its maximum.
/// \return Tensor with maximum value.
Tensor make_tensor_of_max_value(const element::Type_t et);
/// \brief Apply auto padding to padding_above and padding_below inputs
/// if all needed informations are known.
///

View File

@ -29,6 +29,8 @@ public:
std::shared_ptr<Node> clone_with_new_inputs(const OutputVector& new_args) const override;
bool evaluate(ov::TensorVector& outputs, const ov::TensorVector& inputs) const override;
bool evaluate_lower(TensorVector& outputs) const override;
bool evaluate_upper(TensorVector& outputs) const override;
bool has_evaluate() const override;
};
} // namespace v1

View File

@ -6,6 +6,7 @@
#include <cmath>
#include <cstddef>
#include <utility>
#include "openvino/reference/autobroadcast_binop.hpp"
#include "openvino/reference/utils/type_util.hpp"
@ -22,6 +23,72 @@ template <class T, typename std::enable_if<ov::is_floating_point<T>()>::type* =
T mod(const T x, const T y) {
return x - (std::trunc(x / y) * y);
}
/**
* @brief Estimates division remainder `[v1, v2] % m = [r0, r1]` as interval.
*
* Assumes that ` 0 <= v1 <= v2 and m != 0`, in other cases result is undefined behaviour.
* The result interval estimate minimum and maximum but is not true that value can be any value between min and max.
* e.g.
* - [4,6] % 5 = [0, 4], but in fact accurate result is set of [0,1,4]
* @param v1 Minimum of value interval.
* @param v2 Maximum of value interval.
* @param m Modulo divisor.
* @return Remainder of division as interval range.
*/
template <class T, typename std::enable_if<std::is_integral<T>::value>::type* = nullptr>
std::pair<T, T> mod_interval_value(const T v1, const T v2, const T m) {
const auto v_diff = v2 - v1;
auto r = std::make_pair(func::mod(v1, m), func::mod(v2, m));
if ((r.second < r.first) || ((v_diff != T{0}) && (v_diff >= m))) {
r.first = T{0};
r.second = m - T{1};
}
return r;
}
/**
* @brief Estimates division reminder of `[v1, v2] & [m1, m2] = [r0, r1]` as interval.
*
* * Assumes that ` 0 <= v1 <= v2 and 0 < m1 <= m2`, in other cases result is undefined behaviour.
*
* @param v1 Minimum of value interval.
* @param v2 Maximum of value interval.
* @param m1 Minimum of modulo divisor.
* @param m2 Maximum of modulo divisor.
* @return Remainder of division as interval range.
*/
template <class T, typename std::enable_if<std::is_integral<T>::value>::type* = nullptr>
std::pair<T, T> mod_interval(const T v1, const T v2, const T m1, const T m2) {
auto r = mod_interval_value(v1, v2, m1);
if (v2 != 0) {
if (m1 != m2) {
const auto v_diff = v2 - v1;
const auto m_diff = m2 - m1;
auto r2 = mod_interval_value(v1, v2, m2);
r.first = std::min(r.first, r2.first);
r.second = std::max(r.second, r2.second);
if (v_diff == T{0} && m_diff != T{1}) {
const T v2_half = v2 / T{2};
if ((m1 < v2_half) || ((m1 < v2) && (v2 < m2))) {
r.first = T{0};
if ((v2_half < m2) && (m2 < v2)) {
const T v2_half_next = v2_half + T{1};
r.second = func::mod(v2, v2_half_next);
} else {
r.second = m2 - T{1};
}
}
}
}
}
return r;
}
} // namespace func
/**
@ -42,7 +109,7 @@ void mod(InputIt arg0,
const Shape& arg_shape1,
const op::AutoBroadcastSpec& broadcast_spec) {
using T = typename std::iterator_traits<OutputIt>::value_type;
autobroadcast_binop(arg0, arg1, out, arg_shape0, arg_shape1, broadcast_spec, &func::mod<T>);
autobroadcast_binop(arg0, arg1, out, arg_shape0, arg_shape1, broadcast_spec, func::mod<T>);
}
} // namespace reference
} // namespace ov

View File

@ -4,13 +4,30 @@
#include "openvino/op/mod.hpp"
#include "bound_evaluate.hpp"
#include "element_visitor.hpp"
#include "itt.hpp"
#include "openvino/core/shape_util.hpp"
#include "openvino/core/validation_util.hpp"
#include "openvino/op/constant.hpp"
#include "openvino/op/equal.hpp"
#include "openvino/op/logical_or.hpp"
#include "openvino/op/select.hpp"
#include "openvino/reference/mod.hpp"
#include "utils.hpp"
#include "validation_util.hpp"
namespace ov {
namespace util {
namespace {
Tensor make_tensor_of_value(const element::Type_t et, const int64_t value) {
auto c = op::v0::Constant(et, Shape{}, value);
auto t = Tensor(et, Shape{});
std::memcpy(t.data(), c.get_data_ptr(), t.get_byte_size());
return t;
}
} // namespace
} // namespace util
namespace op {
namespace mod {
struct Evaluate : ov::element::NoAction<bool> {
@ -31,6 +48,185 @@ struct Evaluate : ov::element::NoAction<bool> {
return true;
}
};
struct EvaluateBound : element::NoAction<bool> {
using element::NoAction<bool>::visit;
template <element::Type_t ET, class T = fundamental_type_for<ET>>
static result_type visit(const Tensor& v_lb,
const Tensor& v_ub,
const Tensor& m_lb,
const Tensor& m_ub,
Tensor& out,
const bool is_lower) {
auto v_lb_first = v_lb.data<const T>();
auto v_lb_last = std::next(v_lb_first, v_lb.get_size());
auto v_ub_first = v_ub.data<const T>();
auto m_lb_first = m_lb.data<const T>();
auto m_ub_first = m_ub.data<const T>();
auto out_first = out.data<T>();
if (is_lower) {
while (v_lb_first != v_lb_last) {
*out_first++ =
reference::func::mod_interval(*v_lb_first++, *v_ub_first++, *m_lb_first++, *m_ub_first++).first;
}
} else {
while (v_lb_first != v_lb_last) {
*out_first++ =
reference::func::mod_interval(*v_lb_first++, *v_ub_first++, *m_lb_first++, *m_ub_first++).second;
}
}
return true;
}
};
namespace {
/**
* @brief Get node inputs bounds as TensorVector.
*
* The inputs bounds are stored as [lower0, upper0, lower1, upper1].
*
* @param op Pointer to the node.
* @return Vector with inputs bounds tensors.
*/
TensorVector get_bounds(const Node* const op) {
auto&& v_bounds = ov::evaluate_both_bounds(op->input_value(0));
auto&& m_bounds = ov::evaluate_both_bounds(op->input_value(1));
return {std::move(v_bounds.first),
std::move(v_bounds.second),
std::move(m_bounds.first),
std::move(m_bounds.second)};
}
/**
* @brief Check if all bounds in vector are valid (allocated).
*
* @param bounds TensorVector of bounds for check.
* @return True if bounds area valid otherwise false.
*/
bool are_bounds_valid(const TensorVector& bounds) {
return std::all_of(bounds.begin(), bounds.end(), [](const Tensor& t) {
return static_cast<bool>(t);
});
}
/**
* @brief Evaluate binary mask of values which cannot be calculated by modulo.
*
* @param bounds Modulo inputs bounds.
* @return Tensor with binary mask or empty tensor if evaluate failed.
*/
Tensor evaluate_undefined_result_mask(const TensorVector& bounds) {
const auto eq_op = v1::Equal();
const auto or_op = v1::LogicalOr();
const auto& in_et = bounds.front().get_element_type();
auto zero_t = ov::util::make_tensor_of_value(in_et, 0);
auto max_t = ov::util::make_tensor_of_max_value(in_et);
const auto& v_ub = bounds[1];
const auto& m_lb = bounds[2];
const auto& m_ub = bounds[3];
auto m_mask = TensorVector{{element::boolean, m_ub.get_shape()}};
if (!eq_op.evaluate(m_mask, {m_lb, zero_t})) {
return {};
}
auto out_masks = TensorVector{{element::boolean, m_lb.get_shape()}};
if (!eq_op.evaluate(out_masks, {m_ub, zero_t})) {
return {};
}
auto m_or_inputs = TensorVector{out_masks[0], m_mask[0]};
or_op.evaluate(m_mask, m_or_inputs);
if (!eq_op.evaluate(out_masks, {m_lb, max_t})) {
return {};
}
or_op.evaluate(m_mask, m_or_inputs);
auto v_mask = TensorVector{{element::boolean, v_ub.get_shape()}};
if (!eq_op.evaluate(v_mask, {v_ub, max_t})) {
return {};
}
out_masks[0].set_shape(ov::op::infer_broadcast_shape(&or_op, v_mask[0].get_shape(), m_mask[0].get_shape()));
return or_op.evaluate(out_masks, {v_mask[0], m_mask[0]}) ? out_masks[0] : Tensor{};
}
/**
* @brief Get the inputs bound with valid values only.
*
* The values which result modulo to give undefined result are replaced by one.
* The auto broadcast is applied to have inputs same shape.
*
* @param bounds Modulo operator inputs bounds.
* @param mask Mask with undefined result values.
* @return Vector of bounds tensors.
*/
TensorVector get_bounds_with_valid_values(const TensorVector& bounds, const Tensor& mask) {
const auto select_op = v1::Select();
const auto one_t = ov::util::make_tensor_of_value(bounds.front().get_element_type(), 1);
auto m_bounds = TensorVector();
m_bounds.reserve(bounds.size());
std::transform(bounds.cbegin(), bounds.cend(), std::back_inserter(m_bounds), [&](const Tensor& b) {
auto tmp = TensorVector{{b.get_element_type(), mask.get_shape()}};
return select_op.evaluate(tmp, {mask, one_t, b}) ? tmp.front() : Tensor{};
});
return m_bounds;
}
/**
* @brief Evaluate modulo upper or lower bound.
*
* @param op Pointer to modulo node.
* @param outputs Tensor vector with one tensor to store bounds result.
* @param is_lower True to evaluate lower otherwise evaluate upper.
* @return True if outputs has valid data otherwise false.
*/
bool evaluate_bound(const Node* const op, TensorVector& outputs, bool is_lower) {
const auto bounds = mod::get_bounds(op);
if (mod::are_bounds_valid(bounds)) {
const auto& in_et = bounds[0].get_element_type();
const auto undefined_result_mask = mod::evaluate_undefined_result_mask(bounds);
if (!undefined_result_mask) {
return false;
}
// Set inputs values to 1 for undefined results mask (0, inf, etc.)
const auto m_bounds = mod::get_bounds_with_valid_values(bounds, undefined_result_mask);
if (!mod::are_bounds_valid(m_bounds)) {
return false;
}
// Evaluate bound.
outputs[0].set_shape(undefined_result_mask.get_shape());
using namespace ov::element;
if (!IfTypeOf<i8, i16, i32, i64, u8, u16, u32, u64>::apply<mod::EvaluateBound>(in_et,
m_bounds[0],
m_bounds[1],
m_bounds[2],
m_bounds[3],
outputs[0],
is_lower)) {
return false;
}
// Set undefined bound value for results which cannot be calculated.
const auto select_op = v1::Select();
const auto undefined_bound =
is_lower ? ov::util::make_tensor_of_value(in_et, 0) : ov::util::make_tensor_of_max_value(in_et);
return select_op.evaluate(outputs, {undefined_result_mask, undefined_bound, outputs.front()});
} else {
return false;
}
}
} // namespace
} // namespace mod
namespace v1 {
@ -59,6 +255,16 @@ bool Mod::evaluate(ov::TensorVector& outputs, const ov::TensorVector& inputs) co
get_autob());
}
bool Mod::evaluate_lower(TensorVector& outputs) const {
OV_OP_SCOPE(v1_Mod_evaluate_lower);
return mod::evaluate_bound(this, outputs, true);
}
bool Mod::evaluate_upper(TensorVector& outputs) const {
OV_OP_SCOPE(v1_Mod_evaluate_upper);
return mod::evaluate_bound(this, outputs, false);
}
bool Mod::has_evaluate() const {
OV_OP_SCOPE(v1_Mod_has_evaluate);

View File

@ -910,32 +910,8 @@ void evaluate_nodes(std::map<RawNodeOutput, HostTensorPtr>& value_map,
}
std::shared_ptr<op::v0::Constant> get_constant_max_of_type(element::Type_t t) {
#define OPENVINO_TYPE_TO_MAX_CONST(t) \
case t: \
return ov::op::v0::Constant::create( \
t, \
{}, \
{std::numeric_limits<typename element_type_traits<t>::value_type>::max()}); \
break
switch (t) {
OPENVINO_TYPE_TO_MAX_CONST(element::boolean);
OPENVINO_TYPE_TO_MAX_CONST(element::bf16);
OPENVINO_TYPE_TO_MAX_CONST(element::f16);
OPENVINO_TYPE_TO_MAX_CONST(element::f32);
OPENVINO_TYPE_TO_MAX_CONST(element::f64);
OPENVINO_TYPE_TO_MAX_CONST(element::i8);
OPENVINO_TYPE_TO_MAX_CONST(element::i16);
OPENVINO_TYPE_TO_MAX_CONST(element::i32);
OPENVINO_TYPE_TO_MAX_CONST(element::i64);
OPENVINO_TYPE_TO_MAX_CONST(element::u1);
OPENVINO_TYPE_TO_MAX_CONST(element::u8);
OPENVINO_TYPE_TO_MAX_CONST(element::u16);
OPENVINO_TYPE_TO_MAX_CONST(element::u32);
OPENVINO_TYPE_TO_MAX_CONST(element::u64);
default:
return nullptr;
}
auto tensor = ov::util::make_tensor_of_max_value(t);
return tensor ? std::make_shared<op::v0::Constant>(tensor) : nullptr;
}
std::shared_ptr<op::v0::Constant> get_constant_min_of_type(element::Type_t t) {
@ -1385,6 +1361,48 @@ std::shared_ptr<Constant> get_constant_from_source(const Output<Node>& source) {
}
}
template <class T>
Tensor make_tensor_of_max_value(const element::Type_t et) {
Tensor t{et, Shape{}};
*t.data<T>() = std::numeric_limits<T>::max();
return t;
}
Tensor make_tensor_of_max_value(const element::Type_t et) {
switch (et) {
case element::boolean:
return make_tensor_of_max_value<ov::fundamental_type_for<element::boolean>>(et);
case element::bf16:
return make_tensor_of_max_value<ov::fundamental_type_for<element::bf16>>(et);
case element::f16:
return make_tensor_of_max_value<ov::fundamental_type_for<element::f16>>(et);
case element::f32:
return make_tensor_of_max_value<ov::fundamental_type_for<element::f32>>(et);
case element::f64:
return make_tensor_of_max_value<ov::fundamental_type_for<element::f64>>(et);
case element::i8:
return make_tensor_of_max_value<ov::fundamental_type_for<element::i8>>(et);
case element::i16:
return make_tensor_of_max_value<ov::fundamental_type_for<element::i16>>(et);
case element::i32:
return make_tensor_of_max_value<ov::fundamental_type_for<element::i32>>(et);
case element::i64:
return make_tensor_of_max_value<ov::fundamental_type_for<element::i64>>(et);
case element::u1:
return make_tensor_of_max_value<ov::fundamental_type_for<element::u1>>(et);
case element::u8:
return make_tensor_of_max_value<ov::fundamental_type_for<element::u8>>(et);
case element::u16:
return make_tensor_of_max_value<ov::fundamental_type_for<element::u16>>(et);
case element::u32:
return make_tensor_of_max_value<ov::fundamental_type_for<element::u32>>(et);
case element::u64:
return make_tensor_of_max_value<ov::fundamental_type_for<element::u64>>(et);
default:
return {};
}
}
std::vector<PartialShape> get_tensors_partial_shapes(const TensorVector& tensors) {
std::vector<PartialShape> shapes;
shapes.reserve(tensors.size());

View File

@ -5,7 +5,171 @@
#include "openvino/op/mod.hpp"
#include "arithmetic_ops.hpp"
#include "openvino/core/validation_util.hpp"
#include "openvino/op/broadcast.hpp"
#include "openvino/op/constant.hpp"
#include "openvino/op/shape_of.hpp"
#include "openvino/op/squeeze.hpp"
using Type = ::testing::Types<ov::op::v1::Mod>;
INSTANTIATE_TYPED_TEST_SUITE_P(type_prop_mod, ArithmeticOperator, Type);
using ov::op::v0::Constant;
using ov::op::v0::Parameter;
using ov::op::v0::Squeeze;
using ov::op::v3::Broadcast;
using ov::op::v3::ShapeOf;
class TypePropModV1Test : public TypePropOpTest<op::v1::Mod> {};
TEST_F(TypePropModV1Test, preserve_constant_data_on_inputs) {
const auto a = Constant::create(ov::element::i32, ov::Shape{4}, {4, 10, 22, 5});
const auto b = Constant::create(ov::element::i32, ov::Shape{4}, {3, 4, 8, 3});
const auto op = make_op(a, b);
const auto param = std::make_shared<Parameter>(ov::element::i32, ov::Shape{1});
auto bc = std::make_shared<Broadcast>(param, op, ov::op::BroadcastType::BIDIRECTIONAL);
const auto& output_shape = bc->get_output_partial_shape(0);
EXPECT_EQ(output_shape, ov::PartialShape({1, 2, 6, 2}));
}
TEST_F(TypePropModV1Test, preserve_partial_values_on_inputs) {
const auto a = std::make_shared<Parameter>(ov::element::i64, ov::PartialShape{{5, 6}, 22, {3, 7}, -1, {7, 9}});
const auto b = std::make_shared<Parameter>(ov::element::i64, ov::PartialShape{3, {12, 18}, {4, 6}, -1, {0, 4}});
const auto op = make_op(std::make_shared<ShapeOf>(a), std::make_shared<ShapeOf>(b));
const auto param = std::make_shared<Parameter>(ov::element::i64, ov::Shape{1});
auto bc = std::make_shared<Broadcast>(param, op, ov::op::BroadcastType::BIDIRECTIONAL);
const auto& output_shape = bc->get_output_partial_shape(0);
EXPECT_EQ(output_shape, ov::PartialShape({{0, 2}, {4, 10}, {0, 5}, -1, -1}));
}
TEST_F(TypePropModV1Test, preserve_partial_values_when_m_is_interval_scalar) {
const auto a = std::make_shared<Parameter>(ov::element::i64, ov::PartialShape{{5, 6}, 22, {3, 7}, -1, {7, 9}});
const auto b = std::make_shared<Parameter>(ov::element::i64, ov::PartialShape{{12, 18}});
const auto b_scalar = std::make_shared<Squeeze>(std::make_shared<ShapeOf>(b));
const auto op = make_op(std::make_shared<ShapeOf>(a), b_scalar);
const auto param = std::make_shared<Parameter>(ov::element::i64, ov::Shape{1});
auto bc = std::make_shared<Broadcast>(param, op, ov::op::BroadcastType::BIDIRECTIONAL);
const auto& output_shape = bc->get_output_partial_shape(0);
EXPECT_EQ(output_shape, ov::PartialShape({{5, 6}, {4, 10}, {3, 7}, -1, {7, 9}}));
}
TEST_F(TypePropModV1Test, preserve_partial_values_when_value_is_interval_scalar) {
const auto a = std::make_shared<Parameter>(ov::element::i64, ov::PartialShape{{3, 7}});
const auto b = std::make_shared<Parameter>(ov::element::i64, ov::PartialShape{3, {12, 18}, {4, 6}, -1, {0, 4}});
const auto a_scalar = std::make_shared<Squeeze>(std::make_shared<ShapeOf>(a));
const auto op = make_op(a_scalar, std::make_shared<ShapeOf>(b));
const auto param = std::make_shared<Parameter>(ov::element::i64, ov::Shape{1});
auto bc = std::make_shared<Broadcast>(param, op, ov::op::BroadcastType::BIDIRECTIONAL);
const auto& output_shape = bc->get_output_partial_shape(0);
EXPECT_EQ(output_shape, ov::PartialShape({{0, 2}, {3, 7}, {0, 5}, -1, -1}));
}
// test params as {a, b, exp_result}
using IntervalModuloParams = std::tuple<ov::Dimension, ov::Dimension, ov::Dimension>;
class SingleDimModV1Test : public TypePropModV1Test, public testing::WithParamInterface<IntervalModuloParams> {
protected:
void SetUp() override {
std::tie(a_dim, b_dim, exp_dim) = GetParam();
}
ov::Dimension a_dim, b_dim, exp_dim;
};
const auto v_and_m_static = testing::Values(IntervalModuloParams{{0, 0}, {1, 1}, {0, 0}},
IntervalModuloParams{{0, 0}, {9, 9}, {0, 0}},
IntervalModuloParams{{0, 0}, {1000, 1000}, {0, 0}},
IntervalModuloParams{{10, 10}, {3, 3}, {1, 1}},
IntervalModuloParams{{10, 10}, {6, 6}, {4, 4}},
IntervalModuloParams{{10, 10}, {5, 5}, {0, 0}},
IntervalModuloParams{{10, 10}, {15, 15}, {10, 10}});
const auto v_interval_m_static = testing::Values(IntervalModuloParams{{6, 7}, {4, 4}, {2, 3}},
IntervalModuloParams{{6, 8}, {4, 4}, {0, 3}}, // Result [0,2,3]
IntervalModuloParams{{6, 8}, {10, 10}, {6, 8}},
IntervalModuloParams{{6, 8}, {7, 7}, {0, 6}},
IntervalModuloParams{{4, 8}, {7, 7}, {0, 6}},
IntervalModuloParams{{15, 16}, {7, 7}, {1, 2}},
IntervalModuloParams{{5, 20}, {5, 5}, {0, 4}},
IntervalModuloParams{{5, 10}, {7, 7}, {0, 6}});
const auto v_static_m_interval = testing::Values(IntervalModuloParams{{0, 0}, {3, 13}, {0, 0}},
IntervalModuloParams{{10, 10}, {2, 4}, {0, 3}},
IntervalModuloParams{{10, 10}, {2, 6}, {0, 4}},
IntervalModuloParams{{10, 10}, {6, 9}, {1, 4}},
IntervalModuloParams{{10, 10}, {9, 11}, {0, 10}},
IntervalModuloParams{{10, 10}, {3, 11}, {0, 10}},
IntervalModuloParams{{10, 10}, {3, 10}, {0, 9}},
IntervalModuloParams{{10, 10}, {7, 8}, {2, 3}},
IntervalModuloParams{{100, 100}, {2, 20}, {0, 19}},
// can be estimated accurate as only two results are possible
IntervalModuloParams{{100, 100}, {15, 16}, {4, 10}},
// can not be estimated accurate as there are three results [10,4,15]
// Requires to calculate all possibilities and pick min, max
IntervalModuloParams{{100, 100}, {15, 17}, {0, 16}});
const auto v_and_m_intervals = testing::Values(IntervalModuloParams{{1, 10}, {2, 9}, {0, 8}},
IntervalModuloParams{{1, 10}, {6, 9}, {0, 8}},
IntervalModuloParams{{1, 10}, {2, 12}, {0, 10}},
IntervalModuloParams{{1, 10}, {6, 12}, {0, 10}},
IntervalModuloParams{{1, 10}, {11, 12}, {1, 10}},
IntervalModuloParams{{1, 10}, {11, 15}, {1, 10}},
IntervalModuloParams{{4, 10}, {10, 13}, {0, 10}},
IntervalModuloParams{{10, 20}, {3, 5}, {0, 4}},
IntervalModuloParams{{10, 10}, {3, 10}, {0, 9}},
IntervalModuloParams{{5, 20}, {5, 10}, {0, 9}},
IntervalModuloParams{{10, 100}, {3, 20}, {0, 19}},
IntervalModuloParams{{10, 100}, {2, 20}, {0, 19}},
IntervalModuloParams{{10, 100}, {51, 60}, {0, 59}});
// If input is infinite or m has 0 then output is undefined.
const auto v_and_m_special_values = testing::Values(IntervalModuloParams{{0, -1}, {5, 5}, {0, -1}},
IntervalModuloParams{{10, -1}, {4, 4}, {0, -1}},
// Evaluate low/up return [0, max]
// but evaluate both bounds return [0] as `m` has same bounds
IntervalModuloParams{{11, 11}, {0, 0}, {0, 0}},
IntervalModuloParams{{11, 11}, {0, 5}, {0, -1}},
IntervalModuloParams{{11, 20}, {0, 5}, {0, -1}},
IntervalModuloParams{{11, 20}, {0, -1}, {0, -1}},
IntervalModuloParams{{0, -1}, {0, -1}, {0, -1}});
INSTANTIATE_TEST_SUITE_P(v_and_m_static, SingleDimModV1Test, v_and_m_static);
INSTANTIATE_TEST_SUITE_P(value_interval_m_static, SingleDimModV1Test, v_interval_m_static);
INSTANTIATE_TEST_SUITE_P(value_static_m_interval, SingleDimModV1Test, v_static_m_interval);
INSTANTIATE_TEST_SUITE_P(value_and_m_as_intervals, SingleDimModV1Test, v_and_m_intervals);
INSTANTIATE_TEST_SUITE_P(value_and_m_special_values, SingleDimModV1Test, v_and_m_special_values);
TEST_P(SingleDimModV1Test, preserve_value_on_inputs_i64) {
constexpr auto et = ov::element::i64;
const auto a = std::make_shared<Parameter>(et, ov::PartialShape{a_dim});
const auto b = std::make_shared<Parameter>(et, ov::PartialShape{b_dim});
const auto op = make_op(std::make_shared<ShapeOf>(a), std::make_shared<ShapeOf>(b));
const auto param = std::make_shared<Parameter>(et, ov::Shape{1});
const auto bc = std::make_shared<Broadcast>(param, op, ov::op::BroadcastType::BIDIRECTIONAL);
const auto& output_shape = bc->get_output_partial_shape(0);
EXPECT_EQ(output_shape, ov::PartialShape({exp_dim}));
}
TEST_P(SingleDimModV1Test, preserve_value_on_inputs_i32) {
constexpr auto et = ov::element::i32;
const auto a = std::make_shared<Parameter>(et, ov::PartialShape{a_dim});
const auto b = std::make_shared<Parameter>(et, ov::PartialShape{b_dim});
const auto op = make_op(std::make_shared<ShapeOf>(a, et), std::make_shared<ShapeOf>(b, et));
const auto param = std::make_shared<Parameter>(et, ov::Shape{1});
const auto bc = std::make_shared<Broadcast>(param, op, ov::op::BroadcastType::BIDIRECTIONAL);
const auto& output_shape = bc->get_output_partial_shape(0);
EXPECT_EQ(output_shape, ov::PartialShape({exp_dim}));
}