[shape infer]BroadcastV3 and BroadcastV1 shape inference (#8976)

* Implement broadcastv3 shape infer

* Implement BroadcastV1 shape infer

* Use shape_inference in test case

* Fix myriadx test case failure

* Apply review comments

* Change file name

* Apply review comments

* Apply review comments

* Change broadcast bidirection logic to align with master change
This commit is contained in:
Mang Guo 2022-01-12 05:33:33 +08:00 committed by GitHub
parent dce2aa2c0e
commit 8b93880b37
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 587 additions and 18 deletions

View File

@ -50,6 +50,10 @@ public:
bool evaluate(const HostTensorVector& outputs, const HostTensorVector& inputs) const override;
OPENVINO_SUPPRESS_DEPRECATED_END
const BroadcastModeSpec& get_broadcast_spec() const {
return m_mode;
}
protected:
BroadcastModeSpec m_mode;

View File

@ -0,0 +1,301 @@
// Copyright (C) 2018-2021 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#pragma once
#include <ngraph/validation_util.hpp>
#include <openvino/op/broadcast.hpp>
#include <openvino/op/util/broadcast_base.hpp>
#include "ngraph/op/concat.hpp"
#include "openvino/core/axis_vector.hpp"
#include "utils.hpp"
namespace ov {
namespace op {
namespace util {
template <typename T>
void validate_target_shape_none(const ov::Node* op,
const T& arg_shape,
const AxisVector& axes_mapping_val,
const T& target_shape) {
if (arg_shape.rank().is_static() && target_shape.rank().is_static()) {
const auto target_rank_length = target_shape.size();
// axes_mapping needs to be in sorted order
NODE_VALIDATION_CHECK(op,
std::is_sorted(axes_mapping_val.begin(), axes_mapping_val.end()),
"Broadcast doesn't permit transposes. axes_mapping ",
axes_mapping_val,
" not in sorted order");
if (arg_shape.size() == 0 && axes_mapping_val.size() > 0) {
NODE_VALIDATION_CHECK(op,
target_shape[axes_mapping_val[0]].compatible(1),
"Broadcast target[axes_mapping[0]]. Expected 1. Got ",
target_shape[axes_mapping_val[0]]);
}
for (size_t i = 0; i < axes_mapping_val.size(); i++) {
NODE_VALIDATION_CHECK(op,
axes_mapping_val[i] < target_rank_length,
"Broadcast axes_mapping[",
i,
"]: ",
axes_mapping_val[i],
" exceeds target rank ",
target_rank_length);
if (arg_shape.size() > 0) {
NODE_VALIDATION_CHECK(
op,
target_shape[axes_mapping_val[i]].compatible(arg_shape[i]) || arg_shape[i].compatible(1),
"Broadcast target[axes_mapping[",
i,
"]]",
" Expected ",
arg_shape[i],
". Got ",
target_shape[axes_mapping_val[i]]);
}
}
}
}
template <typename T>
void validate_target_shape_numpy(const ov::Node* op, const T& arg_shape, const T& target_shape) {
if (arg_shape.rank().is_dynamic() || target_shape.rank().is_dynamic()) {
return;
}
const auto arg_rank_length = arg_shape.size();
const auto target_rank_length = target_shape.size();
const int64_t start_axis = target_rank_length - arg_rank_length;
NODE_VALIDATION_CHECK(op,
start_axis >= 0,
"Broadcast target_shape has smaller rank ",
target_rank_length,
" than arg shape ",
arg_rank_length);
for (auto i = start_axis; i < target_rank_length; i++) {
NODE_VALIDATION_CHECK(op,
arg_shape[i - start_axis].is_dynamic() || target_shape[i].is_dynamic() ||
arg_shape[i - start_axis].compatible(1) ||
arg_shape[i - start_axis].compatible(target_shape[i]),
"Input shape dimension equal ",
arg_shape[i - start_axis],
" cannot be broadcasted (numpy mode) to ",
target_shape[i],
". Allowed input dimension value would be 1",
target_shape[i] != 1 ? " or " : "",
target_shape[i] != 1 ? std::to_string(target_shape[i].get_length()) : "");
}
}
template <typename T>
void set_result_shape_pdpd(const ov::Node* op,
const T& arg0_shape,
const T& target_shape,
T& result_shape,
const ov::op::BroadcastModeSpec& broadcast_spec) {
using DimType = typename std::iterator_traits<typename T::iterator>::value_type;
if (arg0_shape.rank().is_dynamic() || target_shape.rank().is_dynamic()) {
result_shape = PartialShape::dynamic(target_shape.rank());
return;
}
result_shape = target_shape;
auto& start_axis = broadcast_spec.m_axis;
NODE_VALIDATION_CHECK(op, start_axis >= 0, "Broadcast start_axis must be greater than 0");
for (size_t i = start_axis; i < target_shape.size(); i++) {
const auto& arg_dim = arg0_shape[i - start_axis];
if (arg_dim == 1) {
result_shape[i] = target_shape[i];
} else if (target_shape[i] == 1) {
result_shape[i] = arg_dim;
} else {
NODE_VALIDATION_CHECK(op,
DimType::merge(result_shape[i], arg_dim, target_shape[i]),
"Broadcast incorrect target shape. Expecting either 1 or ",
arg_dim,
" . Got ",
target_shape[i]);
}
}
}
template <typename T>
void set_result_shape_bidirectional(const ov::Node* op, const T& arg_shape, T& target_shape, T& result_shape) {
using DimType = typename std::iterator_traits<typename T::iterator>::value_type;
if (arg_shape.rank().is_dynamic() || target_shape.rank().is_dynamic()) {
result_shape = PartialShape::dynamic();
return;
}
auto arg_shape_vec = arg_shape;
// Add left padding to shorter target or argument shape
const auto target_padded_rank = std::max(arg_shape_vec.size(), target_shape.size());
while (arg_shape_vec.size() < target_padded_rank) {
arg_shape_vec.insert(arg_shape_vec.begin(), 1);
}
while (target_shape.size() < target_padded_rank) {
target_shape.insert(target_shape.begin(), 1);
}
result_shape.resize(target_padded_rank);
for (size_t i = 0; i < target_shape.size(); ++i) {
if (arg_shape_vec[i] == 1) {
result_shape[i] = target_shape[i];
} else if (target_shape[i] == 1) {
result_shape[i] = arg_shape_vec[i];
} else {
NODE_VALIDATION_CHECK(op,
DimType::merge(result_shape[i], arg_shape_vec[i], target_shape[i]),
"Broadcast incorrect target shape. Expecting either 1 or ",
arg_shape_vec[i],
". Got ",
target_shape[i]);
}
}
}
template <class T>
void broadcase_base_shape_infer(
const ov::op::util::BroadcastBase* op,
const std::vector<T>& input_shapes,
std::vector<T>& output_shapes,
const std::map<size_t, std::shared_ptr<ngraph::runtime::HostTensor>>& constant_data = {}) {
// shape node should produce a one dimensional shape.
auto broadcast_shape_rank = input_shapes[1].rank();
NODE_VALIDATION_CHECK(op,
broadcast_shape_rank.compatible(1),
"Broadcast shape rank must be 1, but has ",
broadcast_shape_rank);
const auto& mode = op->get_broadcast_spec();
if (mode.m_type == BroadcastType::NONE) {
// axes_mapping node should produce a one dimensional shape.
auto axes_shape_rank = input_shapes[2].rank();
NODE_VALIDATION_CHECK(op,
axes_shape_rank.compatible(1),
"Broadcast axes rank must be 1, but has ",
axes_shape_rank);
}
auto& result_shape = output_shapes[0];
const auto& input_shape = input_shapes[0];
const auto& target_shape = input_shapes[1];
const bool is_target_shape_known = target_shape.is_static();
T output_shape;
bool output_shape_defined = get_data_as_shape<T>(1, op, output_shape, constant_data);
if (!output_shape_defined) {
if (auto concat = ov::as_type_ptr<ov::opset1::Concat>(op->get_input_node_shared_ptr(1))) {
const auto concat_inputs = concat->input_values();
if (concat->get_output_partial_shape(0).is_static() && concat->get_shape().size() == 1 &&
concat_inputs.size() == shape_size(concat->get_shape())) {
for (const auto& concat_input : concat_inputs) {
auto source_node_ptr = concat_input.get_node_shared_ptr();
if (auto source_const_ptr = ov::as_type_ptr<ov::opset1::Constant>(source_node_ptr)) {
output_shape.push_back(source_const_ptr->get_axis_vector_val()[0]);
} else {
output_shape.push_back(Dimension::dynamic());
}
}
output_shape_defined = true;
}
}
}
if (mode.m_type == BroadcastType::NONE) {
if (output_shape_defined) {
result_shape = output_shape;
} else if (is_target_shape_known) {
result_shape = PartialShape::dynamic(target_shape[0].get_length());
} else {
result_shape = PartialShape::dynamic();
}
// Validate axes_mapping
const auto& axes_shape = input_shapes[2];
if (input_shape.rank().is_static() && target_shape.rank().is_static() && axes_shape.is_static()) {
auto input_rank = (input_shape.size() == 0 && axes_shape[0].get_length() > 0) ? 1 : input_shape.size();
NODE_VALIDATION_CHECK(op,
axes_shape[0].get_length() == input_rank,
"Broadcast axes_mapping shape ",
axes_shape,
" doesn't match rank of input tensor ",
input_rank);
std::vector<int64_t> axes_mapping_val;
if (output_shape_defined && get_data_as_int64<T>(2, op, axes_mapping_val, constant_data)) {
AxisVector axes_mapping =
AxisVector(std::vector<size_t>(axes_mapping_val.begin(), axes_mapping_val.end()));
validate_target_shape_none(op, input_shape, axes_mapping, output_shape);
}
}
} else if (mode.m_type == BroadcastType::NUMPY) {
if (output_shape_defined) {
result_shape = output_shape;
validate_target_shape_numpy(op, input_shape, output_shape);
} else if (is_target_shape_known) {
result_shape = PartialShape::dynamic(target_shape[0].get_length());
} else {
result_shape = PartialShape::dynamic();
}
} else if (mode.m_type == BroadcastType::PDPD) {
if (output_shape_defined) {
set_result_shape_pdpd(op, input_shape, output_shape, result_shape, mode);
} else if (is_target_shape_known) {
result_shape = PartialShape::dynamic(target_shape[0].get_length());
} else {
result_shape = PartialShape::dynamic();
}
} else if (mode.m_type == BroadcastType::BIDIRECTIONAL) {
if (output_shape_defined) {
set_result_shape_bidirectional(op, input_shape, output_shape, result_shape);
} else if (input_shape.rank().is_static() && is_target_shape_known) {
auto output_rank = std::max(input_shape.size(), static_cast<size_t>(target_shape[0].get_length()));
result_shape = PartialShape::dynamic(output_rank);
} else {
result_shape = PartialShape::dynamic();
}
}
}
} // namespace util
namespace v3 {
template <class T>
void shape_infer(const ov::op::v3::Broadcast* op,
const std::vector<T>& input_shapes,
std::vector<T>& output_shapes,
const std::map<size_t, std::shared_ptr<ngraph::runtime::HostTensor>>& constant_data = {}) {
NODE_VALIDATION_CHECK(op, output_shapes.size() == 1);
auto& mode = op->get_broadcast_spec();
if (mode.m_type == BroadcastType::NONE) {
NODE_VALIDATION_CHECK(op,
input_shapes.size() == 3,
"axes_mapping input should be provided if explicit mode is used");
} else {
NODE_VALIDATION_CHECK(op,
input_shapes.size() == 2,
"axes_mapping input should not be provided for mode other than explicit");
}
broadcase_base_shape_infer(op, input_shapes, output_shapes, constant_data);
}
} // namespace v3
namespace v1 {
template <class T>
void shape_infer(const ov::op::v1::Broadcast* op,
const std::vector<T>& input_shapes,
std::vector<T>& output_shapes,
const std::map<size_t, std::shared_ptr<ngraph::runtime::HostTensor>>& constant_data = {}) {
NODE_VALIDATION_CHECK(op, output_shapes.size() == 1 && (input_shapes.size() == 2 || input_shapes.size() == 3));
broadcase_base_shape_infer(op, input_shapes, output_shapes, constant_data);
}
} // namespace v1
} // namespace op
} // namespace ov

View File

@ -4,6 +4,7 @@
#include "ngraph/op/broadcast.hpp"
#include <broadcast_shape_inference.hpp>
#include <ngraph/validation_util.hpp>
#include <numeric>
@ -141,25 +142,39 @@ void op::v3::Broadcast::validate_and_infer_types() {
"axes_mapping input should not be provided for mode other than explicit");
}
util::BroadcastBase::validate_and_infer_types();
auto result_shape = get_output_partial_shape(0);
if (m_mode.m_type == BroadcastType::BIDIRECTIONAL) {
if (get_input_partial_shape(0).rank().is_static() && get_input_partial_shape(1).is_static()) {
auto arg_shape = get_input_partial_shape(0);
PartialShape target_shape;
if (evaluate_as_partial_shape(input_value(1), target_shape)) {
result_shape = get_result_shape_bidirectional(this, arg_shape, target_shape);
}
}
const auto& shape_et = get_input_element_type(1);
NODE_VALIDATION_CHECK(this,
shape_et.is_integral_number(),
"Broadcast shape must be an integral number, but is: ",
shape_et);
if (m_mode.m_type == BroadcastType::NONE) {
// axes_mapping node should have integer data type. For now we only allow i64
const auto& axes_et = get_input_element_type(2);
NODE_VALIDATION_CHECK(this,
axes_et.is_integral_number(),
"Broadcast axes must be integral numbers, but are: ",
axes_et);
}
std::vector<ov::PartialShape> output_shapes = {ov::PartialShape()};
std::vector<ov::PartialShape> input_shapes;
const auto& arg_shape = get_input_partial_shape(0);
const auto& target_shape = get_input_partial_shape(1);
if (input_values().size() == 2) {
input_shapes = {arg_shape, target_shape};
} else {
const auto& axes_mapping = get_input_partial_shape(2);
input_shapes = {arg_shape, target_shape, axes_mapping};
}
shape_infer(this, input_shapes, output_shapes);
set_input_is_relevant_to_shape(0); // arg - Result element type
set_input_is_relevant_to_shape(1); // target_shape - Result shape
if (get_input_size() == 3) {
set_input_is_relevant_to_shape(2); // axes_mapping - Broadcast type
}
set_output_type(0, get_input_element_type(0), result_shape);
set_output_type(0, get_input_element_type(0), output_shapes[0]);
}
shared_ptr<Node> op::v3::Broadcast::clone_with_new_inputs(const OutputVector& new_args) const {
@ -253,10 +268,32 @@ void op::v1::Broadcast::validate_and_infer_types() {
util::BroadcastBase::m_mode = base_spec;
}
util::BroadcastBase::validate_and_infer_types();
const auto& shape_et = get_input_element_type(1);
NODE_VALIDATION_CHECK(this,
shape_et.is_integral_number(),
"Broadcast shape must be an integral number, but is: ",
shape_et);
if (m_mode.m_type == BroadcastType::NONE) {
// axes_mapping node should have integer data type. For now we only allow i64
const auto& axes_et = get_input_element_type(2);
NODE_VALIDATION_CHECK(this,
axes_et.is_integral_number(),
"Broadcast axes must be integral numbers, but are: ",
axes_et);
}
const auto& arg_shape = get_input_partial_shape(0);
const auto& target_shape = get_input_partial_shape(1);
const auto& axes_mapping = get_input_partial_shape(2);
std::vector<ov::PartialShape> output_shapes = {ov::PartialShape()};
std::vector<ov::PartialShape> input_shapes = {arg_shape, target_shape, axes_mapping};
shape_infer(this, input_shapes, output_shapes);
set_input_is_relevant_to_shape(0); // arg - Result element type
set_input_is_relevant_to_shape(1); // target_shape - Result shape
set_input_is_relevant_to_shape(2); // axes_mapping - Broadcast type
set_output_type(0, get_input_element_type(0), output_shapes[0]);
}
shared_ptr<Node> op::v1::Broadcast::clone_with_new_inputs(const OutputVector& new_args) const {

View File

@ -302,10 +302,15 @@ TYPED_TEST_P(BroadcastTests, broadcast_explicit_const_target_shape_static_rank_i
// const axes mapping
const auto axes_mapping_const = op::Constant::create(element::i64, Shape{4}, vector<int64_t>{0, 2, 1, 3});
bc = make_shared<TypeParam>(data, target_shape, axes_mapping_const, "EXPLICIT");
ASSERT_TRUE(bc->get_output_partial_shape(0).is_static());
ASSERT_EQ(bc->get_output_partial_shape(0).rank().get_length(), 4);
ASSERT_EQ(bc->get_shape(), (Shape{1, 1, 5, 10}));
try {
auto bc = make_shared<TypeParam>(data, target_shape, axes_mapping_const, "EXPLICIT");
FAIL() << "Broadcast: Broadcast axes_mapping shape doesn't match rank of input tensor";
} catch (const NodeValidationFailure& error) {
EXPECT_HAS_SUBSTRING(error.what(),
std::string("Broadcast axes_mapping shape {4} doesn't match rank of input tensor 3"));
} catch (...) {
FAIL() << "Deduced type check failed for unexpected reason";
}
}
TYPED_TEST_P(BroadcastTests, broadcast_explicit_static_input_shape) {

View File

@ -59,6 +59,7 @@
#include "detection_output_shape_inference.hpp"
#include "select_shape_inference.hpp"
#include "shuffle_channels_shape_inference.hpp"
#include "broadcast_shape_inference.hpp"
#include "static_shape.hpp"
#include "tile_shape_inference.hpp"
#include "utils.hpp"
@ -229,6 +230,10 @@ void shape_inference(ov::Node* op,
shape_infer(node, input_shapes, output_shapes);
} else if (auto node = ov::as_type<ov::opset1::ShuffleChannels>(op)) {
shape_infer(node, input_shapes, output_shapes);
} else if (auto node = ov::as_type<ov::opset4::Broadcast>(op)) {
shape_infer(node, input_shapes, output_shapes, constant_data);
} else if (auto node = ov::as_type<ov::opset1::Broadcast>(op)) {
shape_infer(node, input_shapes, output_shapes, constant_data);
} else {
ngraph::OutputVector new_inputs;
for (size_t i = 0; i < op->get_input_size(); ++i) {

View File

@ -0,0 +1,217 @@
// Copyright (C) 2018-2021 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include <gtest/gtest.h>
#include <broadcast_shape_inference.hpp>
#include <openvino/op/broadcast.hpp>
#include <openvino/op/ops.hpp>
#include <openvino/op/parameter.hpp>
#include <utils/shape_inference/shape_inference.hpp>
#include <utils/shape_inference/static_shape.hpp>
using namespace ov;
TEST(StaticShapeInferenceTest, BroadcastBidirectionalTest) {
auto input = std::make_shared<ov::op::v0::Parameter>(element::f32, PartialShape{-1, -1, -1});
auto target_shape = std::make_shared<ov::op::v0::Parameter>(element::i32, PartialShape{-1});
auto broadcast_v3 = std::make_shared<op::v3::Broadcast>(input, target_shape, op::BroadcastType::BIDIRECTIONAL);
int32_t target_shape_val[] = {1, 16, 50, 50};
std::map<size_t, std::shared_ptr<ngraph::runtime::HostTensor>> constant_data;
constant_data[1] =
std::make_shared<ngraph::runtime::HostTensor>(ngraph::element::Type_t::i32, ov::Shape{4}, target_shape_val);
std::vector<StaticShape> static_input_shapes = {StaticShape{16, 1, 1}, StaticShape{4}},
static_output_shapes = {StaticShape{}};
shape_inference(broadcast_v3.get(), static_input_shapes, static_output_shapes, constant_data);
ASSERT_EQ(static_output_shapes[0], StaticShape({1, 16, 50, 50}));
static_input_shapes = {StaticShape{16, 1, 1}, StaticShape{4}};
static_output_shapes = {StaticShape{}};
EXPECT_THROW(shape_inference(broadcast_v3.get(), static_input_shapes, static_output_shapes, {}), NodeValidationFailure);
}
TEST(StaticShapeInferenceTest, BroadcastBidirectionalConstantTest) {
auto input = std::make_shared<ov::op::v0::Parameter>(element::f32, PartialShape{-1, -1, -1, -1});
auto target_shape = std::make_shared<ov::op::v0::Constant>(element::i32, ov::Shape{3}, std::vector<int32_t>{16, 1, 40});
auto broadcast_v3 = std::make_shared<op::v3::Broadcast>(input, target_shape, op::BroadcastType::BIDIRECTIONAL);
std::vector<StaticShape> static_input_shapes = {StaticShape{1, 16, 50, 1}, StaticShape{3}},
static_output_shapes = {StaticShape{}};
shape_inference(broadcast_v3.get(), static_input_shapes, static_output_shapes, {});
ASSERT_EQ(static_output_shapes[0], StaticShape({1, 16, 50, 40}));
}
TEST(StaticShapeInferenceTest, BroadcastPDPDTest) {
auto input = std::make_shared<ov::op::v0::Parameter>(element::f32, PartialShape{-1, -1});
auto target_shape = std::make_shared<ov::op::v0::Parameter>(element::i32, PartialShape{-1});
auto broadcast_v3 =
std::make_shared<op::v3::Broadcast>(input, target_shape, op::BroadcastModeSpec(op::BroadcastType::PDPD, 1));
int32_t target_shape_val[] = {2, 3, 6};
std::map<size_t, std::shared_ptr<ngraph::runtime::HostTensor>> constant_data;
constant_data[1] =
std::make_shared<ngraph::runtime::HostTensor>(ngraph::element::Type_t::i32, ov::Shape{3}, target_shape_val);
std::vector<StaticShape> static_input_shapes = {StaticShape{3, 1}, StaticShape{3}},
static_output_shapes = {StaticShape{}};
shape_inference(broadcast_v3.get(), static_input_shapes, static_output_shapes, constant_data);
ASSERT_EQ(static_output_shapes[0], StaticShape({2, 3, 6}));
static_input_shapes = {StaticShape{3, 1}, StaticShape{3}};
static_output_shapes = {StaticShape{}};
EXPECT_THROW(shape_inference(broadcast_v3.get(), static_input_shapes, static_output_shapes, {}), NodeValidationFailure);
}
TEST(StaticShapeInferenceTest, BroadcastPDPDConstantTest) {
auto input = std::make_shared<ov::op::v0::Parameter>(element::f32, PartialShape{-1, -1});
auto target_shape = std::make_shared<ov::op::v0::Constant>(element::i32, ov::Shape{3}, std::vector<int32_t>{2, 3, 6});
auto broadcast_v3 =
std::make_shared<op::v3::Broadcast>(input, target_shape, op::BroadcastModeSpec(op::BroadcastType::PDPD, 1));
std::vector<StaticShape> static_input_shapes = {StaticShape{3, 1}, StaticShape{3}},
static_output_shapes = {StaticShape{}};
shape_inference(broadcast_v3.get(), static_input_shapes, static_output_shapes, {});
ASSERT_EQ(static_output_shapes[0], StaticShape({2, 3, 6}));
}
TEST(StaticShapeInferenceTest, BroadcastNumpyTest) {
auto input = std::make_shared<ov::op::v0::Parameter>(element::f32, PartialShape{-1, -1, -1});
auto target_shape = std::make_shared<ov::op::v0::Parameter>(element::i32, PartialShape{-1});
auto broadcast_v3 = std::make_shared<op::v3::Broadcast>(input, target_shape, op::BroadcastType::NUMPY);
int32_t target_shape_val[] = {1, 16, 50, 50};
std::map<size_t, std::shared_ptr<ngraph::runtime::HostTensor>> constant_data;
constant_data[1] =
std::make_shared<ngraph::runtime::HostTensor>(ngraph::element::Type_t::i32, ov::Shape{4}, target_shape_val);
std::vector<StaticShape> static_input_shapes = {StaticShape{16, 1, 1}, StaticShape{4}},
static_output_shapes = {StaticShape{}};
shape_inference(broadcast_v3.get(), static_input_shapes, static_output_shapes, constant_data);
ASSERT_EQ(static_output_shapes[0], StaticShape({1, 16, 50, 50}));
static_input_shapes = {StaticShape{16, 1, 1}, StaticShape{4}};
static_output_shapes = {StaticShape{}};
EXPECT_THROW(shape_inference(broadcast_v3.get(), static_input_shapes, static_output_shapes, {}), NodeValidationFailure);
}
TEST(StaticShapeInferenceTest, BroadcastNumpyConstantTest) {
auto input = std::make_shared<ov::op::v0::Parameter>(element::f32, PartialShape{-1, -1, -1});
auto target_shape =
std::make_shared<ov::op::v0::Constant>(element::i32, ov::Shape{4}, std::vector<int32_t>{1, 16, 50, 50});
auto broadcast_v3 = std::make_shared<op::v3::Broadcast>(input, target_shape, op::BroadcastType::NUMPY);
std::vector<StaticShape> static_input_shapes = {StaticShape{16, 1, 1}, StaticShape{4}},
static_output_shapes = {StaticShape{}};
shape_inference(broadcast_v3.get(), static_input_shapes, static_output_shapes, {});
ASSERT_EQ(static_output_shapes[0], StaticShape({1, 16, 50, 50}));
}
TEST(StaticShapeInferenceTest, BroadcastExplicitTest) {
auto input = std::make_shared<ov::op::v0::Parameter>(element::f32, PartialShape{-1});
auto target_shape = std::make_shared<ov::op::v0::Parameter>(element::i32, PartialShape{-1});
auto axes_mapping = std::make_shared<ov::op::v0::Parameter>(element::i32, PartialShape{-1});
auto broadcast_v3 =
std::make_shared<op::v3::Broadcast>(input, target_shape, axes_mapping, op::BroadcastType::EXPLICIT);
int32_t target_shape_val[] = {1, 16, 50, 50};
int32_t axes_mapping_val[] = {1};
std::map<size_t, std::shared_ptr<ngraph::runtime::HostTensor>> constant_data;
constant_data[1] =
std::make_shared<ngraph::runtime::HostTensor>(ngraph::element::Type_t::i32, ov::Shape{4}, target_shape_val);
constant_data[2] =
std::make_shared<ngraph::runtime::HostTensor>(ngraph::element::Type_t::i32, ov::Shape{1}, axes_mapping_val);
std::vector<StaticShape> static_input_shapes = {StaticShape{16}, StaticShape{4}, StaticShape{1}};
std::vector<StaticShape> static_output_shapes = {StaticShape{}};
shape_inference(broadcast_v3.get(), static_input_shapes, static_output_shapes, constant_data);
ASSERT_EQ(static_output_shapes[0], StaticShape({1, 16, 50, 50}));
constant_data.erase(1);
EXPECT_THROW(shape_inference(broadcast_v3.get(), static_input_shapes, static_output_shapes, constant_data),
NodeValidationFailure);
EXPECT_THROW(shape_inference(broadcast_v3.get(), static_input_shapes, static_output_shapes, {}), NodeValidationFailure);
}
TEST(StaticShapeInferenceTest, BroadcastExplicitConstantTest) {
auto input = std::make_shared<ov::op::v0::Parameter>(element::f32, PartialShape{-1});
auto target_shape =
std::make_shared<ov::op::v0::Constant>(element::i32, ov::Shape{4}, std::vector<int32_t>{1, 16, 50, 50});
auto axes_mapping = std::make_shared<ov::op::v0::Constant>(element::i32, ov::Shape{1}, std::vector<int32_t>{1});
auto broadcast_v3 =
std::make_shared<op::v3::Broadcast>(input, target_shape, axes_mapping, op::BroadcastType::EXPLICIT);
std::vector<StaticShape> static_input_shapes = {StaticShape{16}, StaticShape{4}, StaticShape{1}};
std::vector<StaticShape> static_output_shapes = {StaticShape{}};
shape_inference(broadcast_v3.get(), static_input_shapes, static_output_shapes, {});
ASSERT_EQ(static_output_shapes[0], StaticShape({1, 16, 50, 50}));
}
// BroadcastV1 test
TEST(StaticShapeInferenceTest, BroadcastV1PDPDTest) {
auto input = std::make_shared<ov::op::v0::Parameter>(element::f32, PartialShape{-1, -1});
auto target_shape = std::make_shared<ov::op::v0::Parameter>(element::i32, PartialShape{-1});
auto broadcast_v1 =
std::make_shared<op::v1::Broadcast>(input, target_shape, op::AutoBroadcastSpec(op::AutoBroadcastType::PDPD, 1));
int32_t target_shape_val[] = {2, 3, 6};
std::map<size_t, std::shared_ptr<ngraph::runtime::HostTensor>> constant_data;
constant_data[1] =
std::make_shared<ngraph::runtime::HostTensor>(ngraph::element::Type_t::i32, ov::Shape{3}, target_shape_val);
std::vector<StaticShape> static_input_shapes = {StaticShape{3, 1}, StaticShape{3}},
static_output_shapes = {StaticShape{}};
shape_inference(broadcast_v1.get(), static_input_shapes, static_output_shapes, constant_data);
ASSERT_EQ(static_output_shapes[0], StaticShape({2, 3, 6}));
static_input_shapes = {StaticShape{3, 1}, StaticShape{3}};
static_output_shapes = {StaticShape{}};
EXPECT_THROW(shape_inference(broadcast_v1.get(), static_input_shapes, static_output_shapes, {}), NodeValidationFailure);
}
TEST(StaticShapeInferenceTest, BroadcastV1NumpyTest) {
auto input = std::make_shared<ov::op::v0::Parameter>(element::f32, PartialShape{-1, -1});
auto target_shape = std::make_shared<ov::op::v0::Parameter>(element::i32, PartialShape{-1});
auto broadcast_v1 = std::make_shared<op::v1::Broadcast>(input, target_shape);
int32_t target_shape_val[] = {2, 3, 6};
std::map<size_t, std::shared_ptr<ngraph::runtime::HostTensor>> constant_data;
constant_data[1] =
std::make_shared<ngraph::runtime::HostTensor>(ngraph::element::Type_t::i32, ov::Shape{3}, target_shape_val);
std::vector<StaticShape> static_input_shapes = {StaticShape{3, 1}, StaticShape{3}},
static_output_shapes = {StaticShape{}};
shape_inference(broadcast_v1.get(), static_input_shapes, static_output_shapes, constant_data);
ASSERT_EQ(static_output_shapes[0], StaticShape({2, 3, 6}));
static_input_shapes = {StaticShape{3, 1}, StaticShape{3}};
static_output_shapes = {StaticShape{}};
EXPECT_THROW(shape_inference(broadcast_v1.get(), static_input_shapes, static_output_shapes, {}), NodeValidationFailure);
}
TEST(StaticShapeInferenceTest, BroadcastV1ExplicitTest) {
auto input = std::make_shared<ov::op::v0::Parameter>(element::f32, PartialShape{-1, -1});
auto target_shape = std::make_shared<ov::op::v0::Parameter>(element::i32, PartialShape{-1});
auto axes_mapping = std::make_shared<ov::op::v0::Parameter>(element::i32, PartialShape{-1});
auto broadcast_v1 = std::make_shared<op::v1::Broadcast>(input, target_shape, axes_mapping);
int32_t target_shape_val[] = {2, 3, 1};
int32_t axes_mapping_val[] = {1, 2};
std::map<size_t, std::shared_ptr<ngraph::runtime::HostTensor>> constant_data;
constant_data[1] =
std::make_shared<ngraph::runtime::HostTensor>(ngraph::element::Type_t::i32, ov::Shape{3}, target_shape_val);
constant_data[2] =
std::make_shared<ngraph::runtime::HostTensor>(ngraph::element::Type_t::i32, ov::Shape{2}, axes_mapping_val);
std::vector<StaticShape> static_input_shapes = {StaticShape{3, 1}, StaticShape{3}, StaticShape{2}},
static_output_shapes = {StaticShape{}};
shape_inference(broadcast_v1.get(), static_input_shapes, static_output_shapes, constant_data);
ASSERT_EQ(static_output_shapes[0], StaticShape({2, 3, 1}));
static_input_shapes = {StaticShape{3, 1}, StaticShape{3}, StaticShape{2}};
static_output_shapes = {StaticShape{}};
EXPECT_THROW(shape_inference(broadcast_v1.get(), static_input_shapes, static_output_shapes, {}), NodeValidationFailure);
}