From 8b93880b37a55623da6865cfc8683b80f767e57d Mon Sep 17 00:00:00 2001 From: Mang Guo Date: Wed, 12 Jan 2022 05:33:33 +0800 Subject: [PATCH] [shape infer]BroadcastV3 and BroadcastV1 shape inference (#8976) * Implement broadcastv3 shape infer * Implement BroadcastV1 shape infer * Use shape_inference in test case * Fix myriadx test case failure * Apply review comments * Change file name * Apply review comments * Apply review comments * Change broadcast bidirection logic to align with master change --- .../openvino/op/util/broadcast_base.hpp | 4 + .../include/broadcast_shape_inference.hpp | 301 ++++++++++++++++++ src/core/src/op/broadcast.cpp | 65 +++- src/core/tests/type_prop/broadcast.cpp | 13 +- .../utils/shape_inference/shape_inference.cpp | 5 + .../broadcast_shape_inference.cpp | 217 +++++++++++++ 6 files changed, 587 insertions(+), 18 deletions(-) create mode 100644 src/core/shape_inference/include/broadcast_shape_inference.hpp create mode 100644 src/tests/unit/cpu/shape_inference_test/broadcast_shape_inference.cpp diff --git a/src/core/include/openvino/op/util/broadcast_base.hpp b/src/core/include/openvino/op/util/broadcast_base.hpp index f6d52e91106..e2e0ba87d90 100644 --- a/src/core/include/openvino/op/util/broadcast_base.hpp +++ b/src/core/include/openvino/op/util/broadcast_base.hpp @@ -50,6 +50,10 @@ public: bool evaluate(const HostTensorVector& outputs, const HostTensorVector& inputs) const override; OPENVINO_SUPPRESS_DEPRECATED_END + const BroadcastModeSpec& get_broadcast_spec() const { + return m_mode; + } + protected: BroadcastModeSpec m_mode; diff --git a/src/core/shape_inference/include/broadcast_shape_inference.hpp b/src/core/shape_inference/include/broadcast_shape_inference.hpp new file mode 100644 index 00000000000..8ea755cc817 --- /dev/null +++ b/src/core/shape_inference/include/broadcast_shape_inference.hpp @@ -0,0 +1,301 @@ +// Copyright (C) 2018-2021 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// +#pragma once + +#include +#include +#include + +#include "ngraph/op/concat.hpp" +#include "openvino/core/axis_vector.hpp" +#include "utils.hpp" + +namespace ov { +namespace op { +namespace util { + +template +void validate_target_shape_none(const ov::Node* op, + const T& arg_shape, + const AxisVector& axes_mapping_val, + const T& target_shape) { + if (arg_shape.rank().is_static() && target_shape.rank().is_static()) { + const auto target_rank_length = target_shape.size(); + // axes_mapping needs to be in sorted order + NODE_VALIDATION_CHECK(op, + std::is_sorted(axes_mapping_val.begin(), axes_mapping_val.end()), + "Broadcast doesn't permit transposes. axes_mapping ", + axes_mapping_val, + " not in sorted order"); + + if (arg_shape.size() == 0 && axes_mapping_val.size() > 0) { + NODE_VALIDATION_CHECK(op, + target_shape[axes_mapping_val[0]].compatible(1), + "Broadcast target[axes_mapping[0]]. Expected 1. Got ", + target_shape[axes_mapping_val[0]]); + } + + for (size_t i = 0; i < axes_mapping_val.size(); i++) { + NODE_VALIDATION_CHECK(op, + axes_mapping_val[i] < target_rank_length, + "Broadcast axes_mapping[", + i, + "]: ", + axes_mapping_val[i], + " exceeds target rank ", + target_rank_length); + + if (arg_shape.size() > 0) { + NODE_VALIDATION_CHECK( + op, + target_shape[axes_mapping_val[i]].compatible(arg_shape[i]) || arg_shape[i].compatible(1), + "Broadcast target[axes_mapping[", + i, + "]]", + " Expected ", + arg_shape[i], + ". Got ", + target_shape[axes_mapping_val[i]]); + } + } + } +} + +template +void validate_target_shape_numpy(const ov::Node* op, const T& arg_shape, const T& target_shape) { + if (arg_shape.rank().is_dynamic() || target_shape.rank().is_dynamic()) { + return; + } + const auto arg_rank_length = arg_shape.size(); + const auto target_rank_length = target_shape.size(); + const int64_t start_axis = target_rank_length - arg_rank_length; + NODE_VALIDATION_CHECK(op, + start_axis >= 0, + "Broadcast target_shape has smaller rank ", + target_rank_length, + " than arg shape ", + arg_rank_length); + for (auto i = start_axis; i < target_rank_length; i++) { + NODE_VALIDATION_CHECK(op, + arg_shape[i - start_axis].is_dynamic() || target_shape[i].is_dynamic() || + arg_shape[i - start_axis].compatible(1) || + arg_shape[i - start_axis].compatible(target_shape[i]), + "Input shape dimension equal ", + arg_shape[i - start_axis], + " cannot be broadcasted (numpy mode) to ", + target_shape[i], + ". Allowed input dimension value would be 1", + target_shape[i] != 1 ? " or " : "", + target_shape[i] != 1 ? std::to_string(target_shape[i].get_length()) : ""); + } +} + +template +void set_result_shape_pdpd(const ov::Node* op, + const T& arg0_shape, + const T& target_shape, + T& result_shape, + const ov::op::BroadcastModeSpec& broadcast_spec) { + using DimType = typename std::iterator_traits::value_type; + if (arg0_shape.rank().is_dynamic() || target_shape.rank().is_dynamic()) { + result_shape = PartialShape::dynamic(target_shape.rank()); + return; + } + result_shape = target_shape; + auto& start_axis = broadcast_spec.m_axis; + + NODE_VALIDATION_CHECK(op, start_axis >= 0, "Broadcast start_axis must be greater than 0"); + + for (size_t i = start_axis; i < target_shape.size(); i++) { + const auto& arg_dim = arg0_shape[i - start_axis]; + if (arg_dim == 1) { + result_shape[i] = target_shape[i]; + } else if (target_shape[i] == 1) { + result_shape[i] = arg_dim; + } else { + NODE_VALIDATION_CHECK(op, + DimType::merge(result_shape[i], arg_dim, target_shape[i]), + "Broadcast incorrect target shape. Expecting either 1 or ", + arg_dim, + " . Got ", + target_shape[i]); + } + } +} + +template +void set_result_shape_bidirectional(const ov::Node* op, const T& arg_shape, T& target_shape, T& result_shape) { + using DimType = typename std::iterator_traits::value_type; + if (arg_shape.rank().is_dynamic() || target_shape.rank().is_dynamic()) { + result_shape = PartialShape::dynamic(); + return; + } + auto arg_shape_vec = arg_shape; + + // Add left padding to shorter target or argument shape + const auto target_padded_rank = std::max(arg_shape_vec.size(), target_shape.size()); + while (arg_shape_vec.size() < target_padded_rank) { + arg_shape_vec.insert(arg_shape_vec.begin(), 1); + } + while (target_shape.size() < target_padded_rank) { + target_shape.insert(target_shape.begin(), 1); + } + + result_shape.resize(target_padded_rank); + for (size_t i = 0; i < target_shape.size(); ++i) { + if (arg_shape_vec[i] == 1) { + result_shape[i] = target_shape[i]; + } else if (target_shape[i] == 1) { + result_shape[i] = arg_shape_vec[i]; + } else { + NODE_VALIDATION_CHECK(op, + DimType::merge(result_shape[i], arg_shape_vec[i], target_shape[i]), + "Broadcast incorrect target shape. Expecting either 1 or ", + arg_shape_vec[i], + ". Got ", + target_shape[i]); + } + } +} + +template +void broadcase_base_shape_infer( + const ov::op::util::BroadcastBase* op, + const std::vector& input_shapes, + std::vector& output_shapes, + const std::map>& constant_data = {}) { + + // shape node should produce a one dimensional shape. + auto broadcast_shape_rank = input_shapes[1].rank(); + NODE_VALIDATION_CHECK(op, + broadcast_shape_rank.compatible(1), + "Broadcast shape rank must be 1, but has ", + broadcast_shape_rank); + + const auto& mode = op->get_broadcast_spec(); + if (mode.m_type == BroadcastType::NONE) { + // axes_mapping node should produce a one dimensional shape. + auto axes_shape_rank = input_shapes[2].rank(); + NODE_VALIDATION_CHECK(op, + axes_shape_rank.compatible(1), + "Broadcast axes rank must be 1, but has ", + axes_shape_rank); + } + + auto& result_shape = output_shapes[0]; + const auto& input_shape = input_shapes[0]; + const auto& target_shape = input_shapes[1]; + const bool is_target_shape_known = target_shape.is_static(); + + T output_shape; + bool output_shape_defined = get_data_as_shape(1, op, output_shape, constant_data); + + if (!output_shape_defined) { + if (auto concat = ov::as_type_ptr(op->get_input_node_shared_ptr(1))) { + const auto concat_inputs = concat->input_values(); + if (concat->get_output_partial_shape(0).is_static() && concat->get_shape().size() == 1 && + concat_inputs.size() == shape_size(concat->get_shape())) { + for (const auto& concat_input : concat_inputs) { + auto source_node_ptr = concat_input.get_node_shared_ptr(); + if (auto source_const_ptr = ov::as_type_ptr(source_node_ptr)) { + output_shape.push_back(source_const_ptr->get_axis_vector_val()[0]); + } else { + output_shape.push_back(Dimension::dynamic()); + } + } + output_shape_defined = true; + } + } + } + + if (mode.m_type == BroadcastType::NONE) { + if (output_shape_defined) { + result_shape = output_shape; + } else if (is_target_shape_known) { + result_shape = PartialShape::dynamic(target_shape[0].get_length()); + } else { + result_shape = PartialShape::dynamic(); + } + // Validate axes_mapping + const auto& axes_shape = input_shapes[2]; + if (input_shape.rank().is_static() && target_shape.rank().is_static() && axes_shape.is_static()) { + auto input_rank = (input_shape.size() == 0 && axes_shape[0].get_length() > 0) ? 1 : input_shape.size(); + NODE_VALIDATION_CHECK(op, + axes_shape[0].get_length() == input_rank, + "Broadcast axes_mapping shape ", + axes_shape, + " doesn't match rank of input tensor ", + input_rank); + std::vector axes_mapping_val; + if (output_shape_defined && get_data_as_int64(2, op, axes_mapping_val, constant_data)) { + AxisVector axes_mapping = + AxisVector(std::vector(axes_mapping_val.begin(), axes_mapping_val.end())); + validate_target_shape_none(op, input_shape, axes_mapping, output_shape); + } + } + } else if (mode.m_type == BroadcastType::NUMPY) { + if (output_shape_defined) { + result_shape = output_shape; + validate_target_shape_numpy(op, input_shape, output_shape); + } else if (is_target_shape_known) { + result_shape = PartialShape::dynamic(target_shape[0].get_length()); + } else { + result_shape = PartialShape::dynamic(); + } + } else if (mode.m_type == BroadcastType::PDPD) { + if (output_shape_defined) { + set_result_shape_pdpd(op, input_shape, output_shape, result_shape, mode); + } else if (is_target_shape_known) { + result_shape = PartialShape::dynamic(target_shape[0].get_length()); + } else { + result_shape = PartialShape::dynamic(); + } + } else if (mode.m_type == BroadcastType::BIDIRECTIONAL) { + if (output_shape_defined) { + set_result_shape_bidirectional(op, input_shape, output_shape, result_shape); + } else if (input_shape.rank().is_static() && is_target_shape_known) { + auto output_rank = std::max(input_shape.size(), static_cast(target_shape[0].get_length())); + result_shape = PartialShape::dynamic(output_rank); + } else { + result_shape = PartialShape::dynamic(); + } + } +} +} // namespace util + +namespace v3 { +template +void shape_infer(const ov::op::v3::Broadcast* op, + const std::vector& input_shapes, + std::vector& output_shapes, + const std::map>& constant_data = {}) { + NODE_VALIDATION_CHECK(op, output_shapes.size() == 1); + auto& mode = op->get_broadcast_spec(); + if (mode.m_type == BroadcastType::NONE) { + NODE_VALIDATION_CHECK(op, + input_shapes.size() == 3, + "axes_mapping input should be provided if explicit mode is used"); + } else { + NODE_VALIDATION_CHECK(op, + input_shapes.size() == 2, + "axes_mapping input should not be provided for mode other than explicit"); + } + broadcase_base_shape_infer(op, input_shapes, output_shapes, constant_data); +} +} // namespace v3 + +namespace v1 { +template +void shape_infer(const ov::op::v1::Broadcast* op, + const std::vector& input_shapes, + std::vector& output_shapes, + const std::map>& constant_data = {}) { + NODE_VALIDATION_CHECK(op, output_shapes.size() == 1 && (input_shapes.size() == 2 || input_shapes.size() == 3)); + + broadcase_base_shape_infer(op, input_shapes, output_shapes, constant_data); +} +} // namespace v1 + +} // namespace op +} // namespace ov diff --git a/src/core/src/op/broadcast.cpp b/src/core/src/op/broadcast.cpp index 9a8d5a6fdff..3e7cc7e8fac 100644 --- a/src/core/src/op/broadcast.cpp +++ b/src/core/src/op/broadcast.cpp @@ -4,6 +4,7 @@ #include "ngraph/op/broadcast.hpp" +#include #include #include @@ -141,25 +142,39 @@ void op::v3::Broadcast::validate_and_infer_types() { "axes_mapping input should not be provided for mode other than explicit"); } - util::BroadcastBase::validate_and_infer_types(); - - auto result_shape = get_output_partial_shape(0); - if (m_mode.m_type == BroadcastType::BIDIRECTIONAL) { - if (get_input_partial_shape(0).rank().is_static() && get_input_partial_shape(1).is_static()) { - auto arg_shape = get_input_partial_shape(0); - - PartialShape target_shape; - if (evaluate_as_partial_shape(input_value(1), target_shape)) { - result_shape = get_result_shape_bidirectional(this, arg_shape, target_shape); - } - } + const auto& shape_et = get_input_element_type(1); + NODE_VALIDATION_CHECK(this, + shape_et.is_integral_number(), + "Broadcast shape must be an integral number, but is: ", + shape_et); + if (m_mode.m_type == BroadcastType::NONE) { + // axes_mapping node should have integer data type. For now we only allow i64 + const auto& axes_et = get_input_element_type(2); + NODE_VALIDATION_CHECK(this, + axes_et.is_integral_number(), + "Broadcast axes must be integral numbers, but are: ", + axes_et); } + + std::vector output_shapes = {ov::PartialShape()}; + std::vector input_shapes; + const auto& arg_shape = get_input_partial_shape(0); + const auto& target_shape = get_input_partial_shape(1); + if (input_values().size() == 2) { + input_shapes = {arg_shape, target_shape}; + } else { + const auto& axes_mapping = get_input_partial_shape(2); + input_shapes = {arg_shape, target_shape, axes_mapping}; + } + + shape_infer(this, input_shapes, output_shapes); + set_input_is_relevant_to_shape(0); // arg - Result element type set_input_is_relevant_to_shape(1); // target_shape - Result shape if (get_input_size() == 3) { set_input_is_relevant_to_shape(2); // axes_mapping - Broadcast type } - set_output_type(0, get_input_element_type(0), result_shape); + set_output_type(0, get_input_element_type(0), output_shapes[0]); } shared_ptr op::v3::Broadcast::clone_with_new_inputs(const OutputVector& new_args) const { @@ -253,10 +268,32 @@ void op::v1::Broadcast::validate_and_infer_types() { util::BroadcastBase::m_mode = base_spec; } - util::BroadcastBase::validate_and_infer_types(); + const auto& shape_et = get_input_element_type(1); + NODE_VALIDATION_CHECK(this, + shape_et.is_integral_number(), + "Broadcast shape must be an integral number, but is: ", + shape_et); + if (m_mode.m_type == BroadcastType::NONE) { + // axes_mapping node should have integer data type. For now we only allow i64 + const auto& axes_et = get_input_element_type(2); + NODE_VALIDATION_CHECK(this, + axes_et.is_integral_number(), + "Broadcast axes must be integral numbers, but are: ", + axes_et); + } + + const auto& arg_shape = get_input_partial_shape(0); + const auto& target_shape = get_input_partial_shape(1); + const auto& axes_mapping = get_input_partial_shape(2); + + std::vector output_shapes = {ov::PartialShape()}; + std::vector input_shapes = {arg_shape, target_shape, axes_mapping}; + shape_infer(this, input_shapes, output_shapes); + set_input_is_relevant_to_shape(0); // arg - Result element type set_input_is_relevant_to_shape(1); // target_shape - Result shape set_input_is_relevant_to_shape(2); // axes_mapping - Broadcast type + set_output_type(0, get_input_element_type(0), output_shapes[0]); } shared_ptr op::v1::Broadcast::clone_with_new_inputs(const OutputVector& new_args) const { diff --git a/src/core/tests/type_prop/broadcast.cpp b/src/core/tests/type_prop/broadcast.cpp index 51f2d855ef1..23eb9006972 100644 --- a/src/core/tests/type_prop/broadcast.cpp +++ b/src/core/tests/type_prop/broadcast.cpp @@ -302,10 +302,15 @@ TYPED_TEST_P(BroadcastTests, broadcast_explicit_const_target_shape_static_rank_i // const axes mapping const auto axes_mapping_const = op::Constant::create(element::i64, Shape{4}, vector{0, 2, 1, 3}); - bc = make_shared(data, target_shape, axes_mapping_const, "EXPLICIT"); - ASSERT_TRUE(bc->get_output_partial_shape(0).is_static()); - ASSERT_EQ(bc->get_output_partial_shape(0).rank().get_length(), 4); - ASSERT_EQ(bc->get_shape(), (Shape{1, 1, 5, 10})); + try { + auto bc = make_shared(data, target_shape, axes_mapping_const, "EXPLICIT"); + FAIL() << "Broadcast: Broadcast axes_mapping shape doesn't match rank of input tensor"; + } catch (const NodeValidationFailure& error) { + EXPECT_HAS_SUBSTRING(error.what(), + std::string("Broadcast axes_mapping shape {4} doesn't match rank of input tensor 3")); + } catch (...) { + FAIL() << "Deduced type check failed for unexpected reason"; + } } TYPED_TEST_P(BroadcastTests, broadcast_explicit_static_input_shape) { diff --git a/src/plugins/intel_cpu/src/utils/shape_inference/shape_inference.cpp b/src/plugins/intel_cpu/src/utils/shape_inference/shape_inference.cpp index 273b8265e3e..2595fd39961 100644 --- a/src/plugins/intel_cpu/src/utils/shape_inference/shape_inference.cpp +++ b/src/plugins/intel_cpu/src/utils/shape_inference/shape_inference.cpp @@ -59,6 +59,7 @@ #include "detection_output_shape_inference.hpp" #include "select_shape_inference.hpp" #include "shuffle_channels_shape_inference.hpp" +#include "broadcast_shape_inference.hpp" #include "static_shape.hpp" #include "tile_shape_inference.hpp" #include "utils.hpp" @@ -229,6 +230,10 @@ void shape_inference(ov::Node* op, shape_infer(node, input_shapes, output_shapes); } else if (auto node = ov::as_type(op)) { shape_infer(node, input_shapes, output_shapes); + } else if (auto node = ov::as_type(op)) { + shape_infer(node, input_shapes, output_shapes, constant_data); + } else if (auto node = ov::as_type(op)) { + shape_infer(node, input_shapes, output_shapes, constant_data); } else { ngraph::OutputVector new_inputs; for (size_t i = 0; i < op->get_input_size(); ++i) { diff --git a/src/tests/unit/cpu/shape_inference_test/broadcast_shape_inference.cpp b/src/tests/unit/cpu/shape_inference_test/broadcast_shape_inference.cpp new file mode 100644 index 00000000000..014cb3897e5 --- /dev/null +++ b/src/tests/unit/cpu/shape_inference_test/broadcast_shape_inference.cpp @@ -0,0 +1,217 @@ +// Copyright (C) 2018-2021 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include + +#include +#include +#include +#include +#include +#include + +using namespace ov; + +TEST(StaticShapeInferenceTest, BroadcastBidirectionalTest) { + auto input = std::make_shared(element::f32, PartialShape{-1, -1, -1}); + auto target_shape = std::make_shared(element::i32, PartialShape{-1}); + auto broadcast_v3 = std::make_shared(input, target_shape, op::BroadcastType::BIDIRECTIONAL); + + int32_t target_shape_val[] = {1, 16, 50, 50}; + std::map> constant_data; + constant_data[1] = + std::make_shared(ngraph::element::Type_t::i32, ov::Shape{4}, target_shape_val); + + std::vector static_input_shapes = {StaticShape{16, 1, 1}, StaticShape{4}}, + static_output_shapes = {StaticShape{}}; + shape_inference(broadcast_v3.get(), static_input_shapes, static_output_shapes, constant_data); + ASSERT_EQ(static_output_shapes[0], StaticShape({1, 16, 50, 50})); + + static_input_shapes = {StaticShape{16, 1, 1}, StaticShape{4}}; + static_output_shapes = {StaticShape{}}; + EXPECT_THROW(shape_inference(broadcast_v3.get(), static_input_shapes, static_output_shapes, {}), NodeValidationFailure); +} + +TEST(StaticShapeInferenceTest, BroadcastBidirectionalConstantTest) { + auto input = std::make_shared(element::f32, PartialShape{-1, -1, -1, -1}); + auto target_shape = std::make_shared(element::i32, ov::Shape{3}, std::vector{16, 1, 40}); + auto broadcast_v3 = std::make_shared(input, target_shape, op::BroadcastType::BIDIRECTIONAL); + + std::vector static_input_shapes = {StaticShape{1, 16, 50, 1}, StaticShape{3}}, + static_output_shapes = {StaticShape{}}; + shape_inference(broadcast_v3.get(), static_input_shapes, static_output_shapes, {}); + ASSERT_EQ(static_output_shapes[0], StaticShape({1, 16, 50, 40})); +} + +TEST(StaticShapeInferenceTest, BroadcastPDPDTest) { + auto input = std::make_shared(element::f32, PartialShape{-1, -1}); + auto target_shape = std::make_shared(element::i32, PartialShape{-1}); + auto broadcast_v3 = + std::make_shared(input, target_shape, op::BroadcastModeSpec(op::BroadcastType::PDPD, 1)); + + int32_t target_shape_val[] = {2, 3, 6}; + std::map> constant_data; + constant_data[1] = + std::make_shared(ngraph::element::Type_t::i32, ov::Shape{3}, target_shape_val); + + std::vector static_input_shapes = {StaticShape{3, 1}, StaticShape{3}}, + static_output_shapes = {StaticShape{}}; + shape_inference(broadcast_v3.get(), static_input_shapes, static_output_shapes, constant_data); + ASSERT_EQ(static_output_shapes[0], StaticShape({2, 3, 6})); + + static_input_shapes = {StaticShape{3, 1}, StaticShape{3}}; + static_output_shapes = {StaticShape{}}; + EXPECT_THROW(shape_inference(broadcast_v3.get(), static_input_shapes, static_output_shapes, {}), NodeValidationFailure); +} + +TEST(StaticShapeInferenceTest, BroadcastPDPDConstantTest) { + auto input = std::make_shared(element::f32, PartialShape{-1, -1}); + auto target_shape = std::make_shared(element::i32, ov::Shape{3}, std::vector{2, 3, 6}); + auto broadcast_v3 = + std::make_shared(input, target_shape, op::BroadcastModeSpec(op::BroadcastType::PDPD, 1)); + + std::vector static_input_shapes = {StaticShape{3, 1}, StaticShape{3}}, + static_output_shapes = {StaticShape{}}; + shape_inference(broadcast_v3.get(), static_input_shapes, static_output_shapes, {}); + ASSERT_EQ(static_output_shapes[0], StaticShape({2, 3, 6})); +} + +TEST(StaticShapeInferenceTest, BroadcastNumpyTest) { + auto input = std::make_shared(element::f32, PartialShape{-1, -1, -1}); + auto target_shape = std::make_shared(element::i32, PartialShape{-1}); + auto broadcast_v3 = std::make_shared(input, target_shape, op::BroadcastType::NUMPY); + + int32_t target_shape_val[] = {1, 16, 50, 50}; + std::map> constant_data; + constant_data[1] = + std::make_shared(ngraph::element::Type_t::i32, ov::Shape{4}, target_shape_val); + + std::vector static_input_shapes = {StaticShape{16, 1, 1}, StaticShape{4}}, + static_output_shapes = {StaticShape{}}; + shape_inference(broadcast_v3.get(), static_input_shapes, static_output_shapes, constant_data); + ASSERT_EQ(static_output_shapes[0], StaticShape({1, 16, 50, 50})); + + static_input_shapes = {StaticShape{16, 1, 1}, StaticShape{4}}; + static_output_shapes = {StaticShape{}}; + EXPECT_THROW(shape_inference(broadcast_v3.get(), static_input_shapes, static_output_shapes, {}), NodeValidationFailure); +} + +TEST(StaticShapeInferenceTest, BroadcastNumpyConstantTest) { + auto input = std::make_shared(element::f32, PartialShape{-1, -1, -1}); + auto target_shape = + std::make_shared(element::i32, ov::Shape{4}, std::vector{1, 16, 50, 50}); + auto broadcast_v3 = std::make_shared(input, target_shape, op::BroadcastType::NUMPY); + + std::vector static_input_shapes = {StaticShape{16, 1, 1}, StaticShape{4}}, + static_output_shapes = {StaticShape{}}; + shape_inference(broadcast_v3.get(), static_input_shapes, static_output_shapes, {}); + ASSERT_EQ(static_output_shapes[0], StaticShape({1, 16, 50, 50})); +} + +TEST(StaticShapeInferenceTest, BroadcastExplicitTest) { + auto input = std::make_shared(element::f32, PartialShape{-1}); + auto target_shape = std::make_shared(element::i32, PartialShape{-1}); + auto axes_mapping = std::make_shared(element::i32, PartialShape{-1}); + auto broadcast_v3 = + std::make_shared(input, target_shape, axes_mapping, op::BroadcastType::EXPLICIT); + + int32_t target_shape_val[] = {1, 16, 50, 50}; + int32_t axes_mapping_val[] = {1}; + std::map> constant_data; + constant_data[1] = + std::make_shared(ngraph::element::Type_t::i32, ov::Shape{4}, target_shape_val); + constant_data[2] = + std::make_shared(ngraph::element::Type_t::i32, ov::Shape{1}, axes_mapping_val); + + std::vector static_input_shapes = {StaticShape{16}, StaticShape{4}, StaticShape{1}}; + std::vector static_output_shapes = {StaticShape{}}; + shape_inference(broadcast_v3.get(), static_input_shapes, static_output_shapes, constant_data); + ASSERT_EQ(static_output_shapes[0], StaticShape({1, 16, 50, 50})); + + constant_data.erase(1); + EXPECT_THROW(shape_inference(broadcast_v3.get(), static_input_shapes, static_output_shapes, constant_data), + NodeValidationFailure); + EXPECT_THROW(shape_inference(broadcast_v3.get(), static_input_shapes, static_output_shapes, {}), NodeValidationFailure); +} + +TEST(StaticShapeInferenceTest, BroadcastExplicitConstantTest) { + auto input = std::make_shared(element::f32, PartialShape{-1}); + auto target_shape = + std::make_shared(element::i32, ov::Shape{4}, std::vector{1, 16, 50, 50}); + auto axes_mapping = std::make_shared(element::i32, ov::Shape{1}, std::vector{1}); + auto broadcast_v3 = + std::make_shared(input, target_shape, axes_mapping, op::BroadcastType::EXPLICIT); + + std::vector static_input_shapes = {StaticShape{16}, StaticShape{4}, StaticShape{1}}; + std::vector static_output_shapes = {StaticShape{}}; + shape_inference(broadcast_v3.get(), static_input_shapes, static_output_shapes, {}); + ASSERT_EQ(static_output_shapes[0], StaticShape({1, 16, 50, 50})); +} + +// BroadcastV1 test + +TEST(StaticShapeInferenceTest, BroadcastV1PDPDTest) { + auto input = std::make_shared(element::f32, PartialShape{-1, -1}); + auto target_shape = std::make_shared(element::i32, PartialShape{-1}); + auto broadcast_v1 = + std::make_shared(input, target_shape, op::AutoBroadcastSpec(op::AutoBroadcastType::PDPD, 1)); + + int32_t target_shape_val[] = {2, 3, 6}; + std::map> constant_data; + constant_data[1] = + std::make_shared(ngraph::element::Type_t::i32, ov::Shape{3}, target_shape_val); + + std::vector static_input_shapes = {StaticShape{3, 1}, StaticShape{3}}, + static_output_shapes = {StaticShape{}}; + shape_inference(broadcast_v1.get(), static_input_shapes, static_output_shapes, constant_data); + ASSERT_EQ(static_output_shapes[0], StaticShape({2, 3, 6})); + + static_input_shapes = {StaticShape{3, 1}, StaticShape{3}}; + static_output_shapes = {StaticShape{}}; + EXPECT_THROW(shape_inference(broadcast_v1.get(), static_input_shapes, static_output_shapes, {}), NodeValidationFailure); +} + +TEST(StaticShapeInferenceTest, BroadcastV1NumpyTest) { + auto input = std::make_shared(element::f32, PartialShape{-1, -1}); + auto target_shape = std::make_shared(element::i32, PartialShape{-1}); + auto broadcast_v1 = std::make_shared(input, target_shape); + + int32_t target_shape_val[] = {2, 3, 6}; + std::map> constant_data; + constant_data[1] = + std::make_shared(ngraph::element::Type_t::i32, ov::Shape{3}, target_shape_val); + + std::vector static_input_shapes = {StaticShape{3, 1}, StaticShape{3}}, + static_output_shapes = {StaticShape{}}; + shape_inference(broadcast_v1.get(), static_input_shapes, static_output_shapes, constant_data); + ASSERT_EQ(static_output_shapes[0], StaticShape({2, 3, 6})); + + static_input_shapes = {StaticShape{3, 1}, StaticShape{3}}; + static_output_shapes = {StaticShape{}}; + EXPECT_THROW(shape_inference(broadcast_v1.get(), static_input_shapes, static_output_shapes, {}), NodeValidationFailure); +} + +TEST(StaticShapeInferenceTest, BroadcastV1ExplicitTest) { + auto input = std::make_shared(element::f32, PartialShape{-1, -1}); + auto target_shape = std::make_shared(element::i32, PartialShape{-1}); + auto axes_mapping = std::make_shared(element::i32, PartialShape{-1}); + auto broadcast_v1 = std::make_shared(input, target_shape, axes_mapping); + + int32_t target_shape_val[] = {2, 3, 1}; + int32_t axes_mapping_val[] = {1, 2}; + std::map> constant_data; + constant_data[1] = + std::make_shared(ngraph::element::Type_t::i32, ov::Shape{3}, target_shape_val); + constant_data[2] = + std::make_shared(ngraph::element::Type_t::i32, ov::Shape{2}, axes_mapping_val); + + std::vector static_input_shapes = {StaticShape{3, 1}, StaticShape{3}, StaticShape{2}}, + static_output_shapes = {StaticShape{}}; + shape_inference(broadcast_v1.get(), static_input_shapes, static_output_shapes, constant_data); + ASSERT_EQ(static_output_shapes[0], StaticShape({2, 3, 1})); + + static_input_shapes = {StaticShape{3, 1}, StaticShape{3}, StaticShape{2}}; + static_output_shapes = {StaticShape{}}; + EXPECT_THROW(shape_inference(broadcast_v1.get(), static_input_shapes, static_output_shapes, {}), NodeValidationFailure); +} \ No newline at end of file