[shape infer]BroadcastV3 and BroadcastV1 shape inference (#8976)
* Implement broadcastv3 shape infer * Implement BroadcastV1 shape infer * Use shape_inference in test case * Fix myriadx test case failure * Apply review comments * Change file name * Apply review comments * Apply review comments * Change broadcast bidirection logic to align with master change
This commit is contained in:
parent
dce2aa2c0e
commit
8b93880b37
@ -50,6 +50,10 @@ public:
|
||||
bool evaluate(const HostTensorVector& outputs, const HostTensorVector& inputs) const override;
|
||||
OPENVINO_SUPPRESS_DEPRECATED_END
|
||||
|
||||
const BroadcastModeSpec& get_broadcast_spec() const {
|
||||
return m_mode;
|
||||
}
|
||||
|
||||
protected:
|
||||
BroadcastModeSpec m_mode;
|
||||
|
||||
|
301
src/core/shape_inference/include/broadcast_shape_inference.hpp
Normal file
301
src/core/shape_inference/include/broadcast_shape_inference.hpp
Normal file
@ -0,0 +1,301 @@
|
||||
// Copyright (C) 2018-2021 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
#pragma once
|
||||
|
||||
#include <ngraph/validation_util.hpp>
|
||||
#include <openvino/op/broadcast.hpp>
|
||||
#include <openvino/op/util/broadcast_base.hpp>
|
||||
|
||||
#include "ngraph/op/concat.hpp"
|
||||
#include "openvino/core/axis_vector.hpp"
|
||||
#include "utils.hpp"
|
||||
|
||||
namespace ov {
|
||||
namespace op {
|
||||
namespace util {
|
||||
|
||||
template <typename T>
|
||||
void validate_target_shape_none(const ov::Node* op,
|
||||
const T& arg_shape,
|
||||
const AxisVector& axes_mapping_val,
|
||||
const T& target_shape) {
|
||||
if (arg_shape.rank().is_static() && target_shape.rank().is_static()) {
|
||||
const auto target_rank_length = target_shape.size();
|
||||
// axes_mapping needs to be in sorted order
|
||||
NODE_VALIDATION_CHECK(op,
|
||||
std::is_sorted(axes_mapping_val.begin(), axes_mapping_val.end()),
|
||||
"Broadcast doesn't permit transposes. axes_mapping ",
|
||||
axes_mapping_val,
|
||||
" not in sorted order");
|
||||
|
||||
if (arg_shape.size() == 0 && axes_mapping_val.size() > 0) {
|
||||
NODE_VALIDATION_CHECK(op,
|
||||
target_shape[axes_mapping_val[0]].compatible(1),
|
||||
"Broadcast target[axes_mapping[0]]. Expected 1. Got ",
|
||||
target_shape[axes_mapping_val[0]]);
|
||||
}
|
||||
|
||||
for (size_t i = 0; i < axes_mapping_val.size(); i++) {
|
||||
NODE_VALIDATION_CHECK(op,
|
||||
axes_mapping_val[i] < target_rank_length,
|
||||
"Broadcast axes_mapping[",
|
||||
i,
|
||||
"]: ",
|
||||
axes_mapping_val[i],
|
||||
" exceeds target rank ",
|
||||
target_rank_length);
|
||||
|
||||
if (arg_shape.size() > 0) {
|
||||
NODE_VALIDATION_CHECK(
|
||||
op,
|
||||
target_shape[axes_mapping_val[i]].compatible(arg_shape[i]) || arg_shape[i].compatible(1),
|
||||
"Broadcast target[axes_mapping[",
|
||||
i,
|
||||
"]]",
|
||||
" Expected ",
|
||||
arg_shape[i],
|
||||
". Got ",
|
||||
target_shape[axes_mapping_val[i]]);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void validate_target_shape_numpy(const ov::Node* op, const T& arg_shape, const T& target_shape) {
|
||||
if (arg_shape.rank().is_dynamic() || target_shape.rank().is_dynamic()) {
|
||||
return;
|
||||
}
|
||||
const auto arg_rank_length = arg_shape.size();
|
||||
const auto target_rank_length = target_shape.size();
|
||||
const int64_t start_axis = target_rank_length - arg_rank_length;
|
||||
NODE_VALIDATION_CHECK(op,
|
||||
start_axis >= 0,
|
||||
"Broadcast target_shape has smaller rank ",
|
||||
target_rank_length,
|
||||
" than arg shape ",
|
||||
arg_rank_length);
|
||||
for (auto i = start_axis; i < target_rank_length; i++) {
|
||||
NODE_VALIDATION_CHECK(op,
|
||||
arg_shape[i - start_axis].is_dynamic() || target_shape[i].is_dynamic() ||
|
||||
arg_shape[i - start_axis].compatible(1) ||
|
||||
arg_shape[i - start_axis].compatible(target_shape[i]),
|
||||
"Input shape dimension equal ",
|
||||
arg_shape[i - start_axis],
|
||||
" cannot be broadcasted (numpy mode) to ",
|
||||
target_shape[i],
|
||||
". Allowed input dimension value would be 1",
|
||||
target_shape[i] != 1 ? " or " : "",
|
||||
target_shape[i] != 1 ? std::to_string(target_shape[i].get_length()) : "");
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void set_result_shape_pdpd(const ov::Node* op,
|
||||
const T& arg0_shape,
|
||||
const T& target_shape,
|
||||
T& result_shape,
|
||||
const ov::op::BroadcastModeSpec& broadcast_spec) {
|
||||
using DimType = typename std::iterator_traits<typename T::iterator>::value_type;
|
||||
if (arg0_shape.rank().is_dynamic() || target_shape.rank().is_dynamic()) {
|
||||
result_shape = PartialShape::dynamic(target_shape.rank());
|
||||
return;
|
||||
}
|
||||
result_shape = target_shape;
|
||||
auto& start_axis = broadcast_spec.m_axis;
|
||||
|
||||
NODE_VALIDATION_CHECK(op, start_axis >= 0, "Broadcast start_axis must be greater than 0");
|
||||
|
||||
for (size_t i = start_axis; i < target_shape.size(); i++) {
|
||||
const auto& arg_dim = arg0_shape[i - start_axis];
|
||||
if (arg_dim == 1) {
|
||||
result_shape[i] = target_shape[i];
|
||||
} else if (target_shape[i] == 1) {
|
||||
result_shape[i] = arg_dim;
|
||||
} else {
|
||||
NODE_VALIDATION_CHECK(op,
|
||||
DimType::merge(result_shape[i], arg_dim, target_shape[i]),
|
||||
"Broadcast incorrect target shape. Expecting either 1 or ",
|
||||
arg_dim,
|
||||
" . Got ",
|
||||
target_shape[i]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void set_result_shape_bidirectional(const ov::Node* op, const T& arg_shape, T& target_shape, T& result_shape) {
|
||||
using DimType = typename std::iterator_traits<typename T::iterator>::value_type;
|
||||
if (arg_shape.rank().is_dynamic() || target_shape.rank().is_dynamic()) {
|
||||
result_shape = PartialShape::dynamic();
|
||||
return;
|
||||
}
|
||||
auto arg_shape_vec = arg_shape;
|
||||
|
||||
// Add left padding to shorter target or argument shape
|
||||
const auto target_padded_rank = std::max(arg_shape_vec.size(), target_shape.size());
|
||||
while (arg_shape_vec.size() < target_padded_rank) {
|
||||
arg_shape_vec.insert(arg_shape_vec.begin(), 1);
|
||||
}
|
||||
while (target_shape.size() < target_padded_rank) {
|
||||
target_shape.insert(target_shape.begin(), 1);
|
||||
}
|
||||
|
||||
result_shape.resize(target_padded_rank);
|
||||
for (size_t i = 0; i < target_shape.size(); ++i) {
|
||||
if (arg_shape_vec[i] == 1) {
|
||||
result_shape[i] = target_shape[i];
|
||||
} else if (target_shape[i] == 1) {
|
||||
result_shape[i] = arg_shape_vec[i];
|
||||
} else {
|
||||
NODE_VALIDATION_CHECK(op,
|
||||
DimType::merge(result_shape[i], arg_shape_vec[i], target_shape[i]),
|
||||
"Broadcast incorrect target shape. Expecting either 1 or ",
|
||||
arg_shape_vec[i],
|
||||
". Got ",
|
||||
target_shape[i]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <class T>
|
||||
void broadcase_base_shape_infer(
|
||||
const ov::op::util::BroadcastBase* op,
|
||||
const std::vector<T>& input_shapes,
|
||||
std::vector<T>& output_shapes,
|
||||
const std::map<size_t, std::shared_ptr<ngraph::runtime::HostTensor>>& constant_data = {}) {
|
||||
|
||||
// shape node should produce a one dimensional shape.
|
||||
auto broadcast_shape_rank = input_shapes[1].rank();
|
||||
NODE_VALIDATION_CHECK(op,
|
||||
broadcast_shape_rank.compatible(1),
|
||||
"Broadcast shape rank must be 1, but has ",
|
||||
broadcast_shape_rank);
|
||||
|
||||
const auto& mode = op->get_broadcast_spec();
|
||||
if (mode.m_type == BroadcastType::NONE) {
|
||||
// axes_mapping node should produce a one dimensional shape.
|
||||
auto axes_shape_rank = input_shapes[2].rank();
|
||||
NODE_VALIDATION_CHECK(op,
|
||||
axes_shape_rank.compatible(1),
|
||||
"Broadcast axes rank must be 1, but has ",
|
||||
axes_shape_rank);
|
||||
}
|
||||
|
||||
auto& result_shape = output_shapes[0];
|
||||
const auto& input_shape = input_shapes[0];
|
||||
const auto& target_shape = input_shapes[1];
|
||||
const bool is_target_shape_known = target_shape.is_static();
|
||||
|
||||
T output_shape;
|
||||
bool output_shape_defined = get_data_as_shape<T>(1, op, output_shape, constant_data);
|
||||
|
||||
if (!output_shape_defined) {
|
||||
if (auto concat = ov::as_type_ptr<ov::opset1::Concat>(op->get_input_node_shared_ptr(1))) {
|
||||
const auto concat_inputs = concat->input_values();
|
||||
if (concat->get_output_partial_shape(0).is_static() && concat->get_shape().size() == 1 &&
|
||||
concat_inputs.size() == shape_size(concat->get_shape())) {
|
||||
for (const auto& concat_input : concat_inputs) {
|
||||
auto source_node_ptr = concat_input.get_node_shared_ptr();
|
||||
if (auto source_const_ptr = ov::as_type_ptr<ov::opset1::Constant>(source_node_ptr)) {
|
||||
output_shape.push_back(source_const_ptr->get_axis_vector_val()[0]);
|
||||
} else {
|
||||
output_shape.push_back(Dimension::dynamic());
|
||||
}
|
||||
}
|
||||
output_shape_defined = true;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (mode.m_type == BroadcastType::NONE) {
|
||||
if (output_shape_defined) {
|
||||
result_shape = output_shape;
|
||||
} else if (is_target_shape_known) {
|
||||
result_shape = PartialShape::dynamic(target_shape[0].get_length());
|
||||
} else {
|
||||
result_shape = PartialShape::dynamic();
|
||||
}
|
||||
// Validate axes_mapping
|
||||
const auto& axes_shape = input_shapes[2];
|
||||
if (input_shape.rank().is_static() && target_shape.rank().is_static() && axes_shape.is_static()) {
|
||||
auto input_rank = (input_shape.size() == 0 && axes_shape[0].get_length() > 0) ? 1 : input_shape.size();
|
||||
NODE_VALIDATION_CHECK(op,
|
||||
axes_shape[0].get_length() == input_rank,
|
||||
"Broadcast axes_mapping shape ",
|
||||
axes_shape,
|
||||
" doesn't match rank of input tensor ",
|
||||
input_rank);
|
||||
std::vector<int64_t> axes_mapping_val;
|
||||
if (output_shape_defined && get_data_as_int64<T>(2, op, axes_mapping_val, constant_data)) {
|
||||
AxisVector axes_mapping =
|
||||
AxisVector(std::vector<size_t>(axes_mapping_val.begin(), axes_mapping_val.end()));
|
||||
validate_target_shape_none(op, input_shape, axes_mapping, output_shape);
|
||||
}
|
||||
}
|
||||
} else if (mode.m_type == BroadcastType::NUMPY) {
|
||||
if (output_shape_defined) {
|
||||
result_shape = output_shape;
|
||||
validate_target_shape_numpy(op, input_shape, output_shape);
|
||||
} else if (is_target_shape_known) {
|
||||
result_shape = PartialShape::dynamic(target_shape[0].get_length());
|
||||
} else {
|
||||
result_shape = PartialShape::dynamic();
|
||||
}
|
||||
} else if (mode.m_type == BroadcastType::PDPD) {
|
||||
if (output_shape_defined) {
|
||||
set_result_shape_pdpd(op, input_shape, output_shape, result_shape, mode);
|
||||
} else if (is_target_shape_known) {
|
||||
result_shape = PartialShape::dynamic(target_shape[0].get_length());
|
||||
} else {
|
||||
result_shape = PartialShape::dynamic();
|
||||
}
|
||||
} else if (mode.m_type == BroadcastType::BIDIRECTIONAL) {
|
||||
if (output_shape_defined) {
|
||||
set_result_shape_bidirectional(op, input_shape, output_shape, result_shape);
|
||||
} else if (input_shape.rank().is_static() && is_target_shape_known) {
|
||||
auto output_rank = std::max(input_shape.size(), static_cast<size_t>(target_shape[0].get_length()));
|
||||
result_shape = PartialShape::dynamic(output_rank);
|
||||
} else {
|
||||
result_shape = PartialShape::dynamic();
|
||||
}
|
||||
}
|
||||
}
|
||||
} // namespace util
|
||||
|
||||
namespace v3 {
|
||||
template <class T>
|
||||
void shape_infer(const ov::op::v3::Broadcast* op,
|
||||
const std::vector<T>& input_shapes,
|
||||
std::vector<T>& output_shapes,
|
||||
const std::map<size_t, std::shared_ptr<ngraph::runtime::HostTensor>>& constant_data = {}) {
|
||||
NODE_VALIDATION_CHECK(op, output_shapes.size() == 1);
|
||||
auto& mode = op->get_broadcast_spec();
|
||||
if (mode.m_type == BroadcastType::NONE) {
|
||||
NODE_VALIDATION_CHECK(op,
|
||||
input_shapes.size() == 3,
|
||||
"axes_mapping input should be provided if explicit mode is used");
|
||||
} else {
|
||||
NODE_VALIDATION_CHECK(op,
|
||||
input_shapes.size() == 2,
|
||||
"axes_mapping input should not be provided for mode other than explicit");
|
||||
}
|
||||
broadcase_base_shape_infer(op, input_shapes, output_shapes, constant_data);
|
||||
}
|
||||
} // namespace v3
|
||||
|
||||
namespace v1 {
|
||||
template <class T>
|
||||
void shape_infer(const ov::op::v1::Broadcast* op,
|
||||
const std::vector<T>& input_shapes,
|
||||
std::vector<T>& output_shapes,
|
||||
const std::map<size_t, std::shared_ptr<ngraph::runtime::HostTensor>>& constant_data = {}) {
|
||||
NODE_VALIDATION_CHECK(op, output_shapes.size() == 1 && (input_shapes.size() == 2 || input_shapes.size() == 3));
|
||||
|
||||
broadcase_base_shape_infer(op, input_shapes, output_shapes, constant_data);
|
||||
}
|
||||
} // namespace v1
|
||||
|
||||
} // namespace op
|
||||
} // namespace ov
|
@ -4,6 +4,7 @@
|
||||
|
||||
#include "ngraph/op/broadcast.hpp"
|
||||
|
||||
#include <broadcast_shape_inference.hpp>
|
||||
#include <ngraph/validation_util.hpp>
|
||||
#include <numeric>
|
||||
|
||||
@ -141,25 +142,39 @@ void op::v3::Broadcast::validate_and_infer_types() {
|
||||
"axes_mapping input should not be provided for mode other than explicit");
|
||||
}
|
||||
|
||||
util::BroadcastBase::validate_and_infer_types();
|
||||
const auto& shape_et = get_input_element_type(1);
|
||||
NODE_VALIDATION_CHECK(this,
|
||||
shape_et.is_integral_number(),
|
||||
"Broadcast shape must be an integral number, but is: ",
|
||||
shape_et);
|
||||
if (m_mode.m_type == BroadcastType::NONE) {
|
||||
// axes_mapping node should have integer data type. For now we only allow i64
|
||||
const auto& axes_et = get_input_element_type(2);
|
||||
NODE_VALIDATION_CHECK(this,
|
||||
axes_et.is_integral_number(),
|
||||
"Broadcast axes must be integral numbers, but are: ",
|
||||
axes_et);
|
||||
}
|
||||
|
||||
auto result_shape = get_output_partial_shape(0);
|
||||
if (m_mode.m_type == BroadcastType::BIDIRECTIONAL) {
|
||||
if (get_input_partial_shape(0).rank().is_static() && get_input_partial_shape(1).is_static()) {
|
||||
auto arg_shape = get_input_partial_shape(0);
|
||||
std::vector<ov::PartialShape> output_shapes = {ov::PartialShape()};
|
||||
std::vector<ov::PartialShape> input_shapes;
|
||||
const auto& arg_shape = get_input_partial_shape(0);
|
||||
const auto& target_shape = get_input_partial_shape(1);
|
||||
if (input_values().size() == 2) {
|
||||
input_shapes = {arg_shape, target_shape};
|
||||
} else {
|
||||
const auto& axes_mapping = get_input_partial_shape(2);
|
||||
input_shapes = {arg_shape, target_shape, axes_mapping};
|
||||
}
|
||||
|
||||
shape_infer(this, input_shapes, output_shapes);
|
||||
|
||||
PartialShape target_shape;
|
||||
if (evaluate_as_partial_shape(input_value(1), target_shape)) {
|
||||
result_shape = get_result_shape_bidirectional(this, arg_shape, target_shape);
|
||||
}
|
||||
}
|
||||
}
|
||||
set_input_is_relevant_to_shape(0); // arg - Result element type
|
||||
set_input_is_relevant_to_shape(1); // target_shape - Result shape
|
||||
if (get_input_size() == 3) {
|
||||
set_input_is_relevant_to_shape(2); // axes_mapping - Broadcast type
|
||||
}
|
||||
set_output_type(0, get_input_element_type(0), result_shape);
|
||||
set_output_type(0, get_input_element_type(0), output_shapes[0]);
|
||||
}
|
||||
|
||||
shared_ptr<Node> op::v3::Broadcast::clone_with_new_inputs(const OutputVector& new_args) const {
|
||||
@ -253,10 +268,32 @@ void op::v1::Broadcast::validate_and_infer_types() {
|
||||
util::BroadcastBase::m_mode = base_spec;
|
||||
}
|
||||
|
||||
util::BroadcastBase::validate_and_infer_types();
|
||||
const auto& shape_et = get_input_element_type(1);
|
||||
NODE_VALIDATION_CHECK(this,
|
||||
shape_et.is_integral_number(),
|
||||
"Broadcast shape must be an integral number, but is: ",
|
||||
shape_et);
|
||||
if (m_mode.m_type == BroadcastType::NONE) {
|
||||
// axes_mapping node should have integer data type. For now we only allow i64
|
||||
const auto& axes_et = get_input_element_type(2);
|
||||
NODE_VALIDATION_CHECK(this,
|
||||
axes_et.is_integral_number(),
|
||||
"Broadcast axes must be integral numbers, but are: ",
|
||||
axes_et);
|
||||
}
|
||||
|
||||
const auto& arg_shape = get_input_partial_shape(0);
|
||||
const auto& target_shape = get_input_partial_shape(1);
|
||||
const auto& axes_mapping = get_input_partial_shape(2);
|
||||
|
||||
std::vector<ov::PartialShape> output_shapes = {ov::PartialShape()};
|
||||
std::vector<ov::PartialShape> input_shapes = {arg_shape, target_shape, axes_mapping};
|
||||
shape_infer(this, input_shapes, output_shapes);
|
||||
|
||||
set_input_is_relevant_to_shape(0); // arg - Result element type
|
||||
set_input_is_relevant_to_shape(1); // target_shape - Result shape
|
||||
set_input_is_relevant_to_shape(2); // axes_mapping - Broadcast type
|
||||
set_output_type(0, get_input_element_type(0), output_shapes[0]);
|
||||
}
|
||||
|
||||
shared_ptr<Node> op::v1::Broadcast::clone_with_new_inputs(const OutputVector& new_args) const {
|
||||
|
@ -302,10 +302,15 @@ TYPED_TEST_P(BroadcastTests, broadcast_explicit_const_target_shape_static_rank_i
|
||||
|
||||
// const axes mapping
|
||||
const auto axes_mapping_const = op::Constant::create(element::i64, Shape{4}, vector<int64_t>{0, 2, 1, 3});
|
||||
bc = make_shared<TypeParam>(data, target_shape, axes_mapping_const, "EXPLICIT");
|
||||
ASSERT_TRUE(bc->get_output_partial_shape(0).is_static());
|
||||
ASSERT_EQ(bc->get_output_partial_shape(0).rank().get_length(), 4);
|
||||
ASSERT_EQ(bc->get_shape(), (Shape{1, 1, 5, 10}));
|
||||
try {
|
||||
auto bc = make_shared<TypeParam>(data, target_shape, axes_mapping_const, "EXPLICIT");
|
||||
FAIL() << "Broadcast: Broadcast axes_mapping shape doesn't match rank of input tensor";
|
||||
} catch (const NodeValidationFailure& error) {
|
||||
EXPECT_HAS_SUBSTRING(error.what(),
|
||||
std::string("Broadcast axes_mapping shape {4} doesn't match rank of input tensor 3"));
|
||||
} catch (...) {
|
||||
FAIL() << "Deduced type check failed for unexpected reason";
|
||||
}
|
||||
}
|
||||
|
||||
TYPED_TEST_P(BroadcastTests, broadcast_explicit_static_input_shape) {
|
||||
|
@ -59,6 +59,7 @@
|
||||
#include "detection_output_shape_inference.hpp"
|
||||
#include "select_shape_inference.hpp"
|
||||
#include "shuffle_channels_shape_inference.hpp"
|
||||
#include "broadcast_shape_inference.hpp"
|
||||
#include "static_shape.hpp"
|
||||
#include "tile_shape_inference.hpp"
|
||||
#include "utils.hpp"
|
||||
@ -229,6 +230,10 @@ void shape_inference(ov::Node* op,
|
||||
shape_infer(node, input_shapes, output_shapes);
|
||||
} else if (auto node = ov::as_type<ov::opset1::ShuffleChannels>(op)) {
|
||||
shape_infer(node, input_shapes, output_shapes);
|
||||
} else if (auto node = ov::as_type<ov::opset4::Broadcast>(op)) {
|
||||
shape_infer(node, input_shapes, output_shapes, constant_data);
|
||||
} else if (auto node = ov::as_type<ov::opset1::Broadcast>(op)) {
|
||||
shape_infer(node, input_shapes, output_shapes, constant_data);
|
||||
} else {
|
||||
ngraph::OutputVector new_inputs;
|
||||
for (size_t i = 0; i < op->get_input_size(); ++i) {
|
||||
|
@ -0,0 +1,217 @@
|
||||
// Copyright (C) 2018-2021 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#include <gtest/gtest.h>
|
||||
|
||||
#include <broadcast_shape_inference.hpp>
|
||||
#include <openvino/op/broadcast.hpp>
|
||||
#include <openvino/op/ops.hpp>
|
||||
#include <openvino/op/parameter.hpp>
|
||||
#include <utils/shape_inference/shape_inference.hpp>
|
||||
#include <utils/shape_inference/static_shape.hpp>
|
||||
|
||||
using namespace ov;
|
||||
|
||||
TEST(StaticShapeInferenceTest, BroadcastBidirectionalTest) {
|
||||
auto input = std::make_shared<ov::op::v0::Parameter>(element::f32, PartialShape{-1, -1, -1});
|
||||
auto target_shape = std::make_shared<ov::op::v0::Parameter>(element::i32, PartialShape{-1});
|
||||
auto broadcast_v3 = std::make_shared<op::v3::Broadcast>(input, target_shape, op::BroadcastType::BIDIRECTIONAL);
|
||||
|
||||
int32_t target_shape_val[] = {1, 16, 50, 50};
|
||||
std::map<size_t, std::shared_ptr<ngraph::runtime::HostTensor>> constant_data;
|
||||
constant_data[1] =
|
||||
std::make_shared<ngraph::runtime::HostTensor>(ngraph::element::Type_t::i32, ov::Shape{4}, target_shape_val);
|
||||
|
||||
std::vector<StaticShape> static_input_shapes = {StaticShape{16, 1, 1}, StaticShape{4}},
|
||||
static_output_shapes = {StaticShape{}};
|
||||
shape_inference(broadcast_v3.get(), static_input_shapes, static_output_shapes, constant_data);
|
||||
ASSERT_EQ(static_output_shapes[0], StaticShape({1, 16, 50, 50}));
|
||||
|
||||
static_input_shapes = {StaticShape{16, 1, 1}, StaticShape{4}};
|
||||
static_output_shapes = {StaticShape{}};
|
||||
EXPECT_THROW(shape_inference(broadcast_v3.get(), static_input_shapes, static_output_shapes, {}), NodeValidationFailure);
|
||||
}
|
||||
|
||||
TEST(StaticShapeInferenceTest, BroadcastBidirectionalConstantTest) {
|
||||
auto input = std::make_shared<ov::op::v0::Parameter>(element::f32, PartialShape{-1, -1, -1, -1});
|
||||
auto target_shape = std::make_shared<ov::op::v0::Constant>(element::i32, ov::Shape{3}, std::vector<int32_t>{16, 1, 40});
|
||||
auto broadcast_v3 = std::make_shared<op::v3::Broadcast>(input, target_shape, op::BroadcastType::BIDIRECTIONAL);
|
||||
|
||||
std::vector<StaticShape> static_input_shapes = {StaticShape{1, 16, 50, 1}, StaticShape{3}},
|
||||
static_output_shapes = {StaticShape{}};
|
||||
shape_inference(broadcast_v3.get(), static_input_shapes, static_output_shapes, {});
|
||||
ASSERT_EQ(static_output_shapes[0], StaticShape({1, 16, 50, 40}));
|
||||
}
|
||||
|
||||
TEST(StaticShapeInferenceTest, BroadcastPDPDTest) {
|
||||
auto input = std::make_shared<ov::op::v0::Parameter>(element::f32, PartialShape{-1, -1});
|
||||
auto target_shape = std::make_shared<ov::op::v0::Parameter>(element::i32, PartialShape{-1});
|
||||
auto broadcast_v3 =
|
||||
std::make_shared<op::v3::Broadcast>(input, target_shape, op::BroadcastModeSpec(op::BroadcastType::PDPD, 1));
|
||||
|
||||
int32_t target_shape_val[] = {2, 3, 6};
|
||||
std::map<size_t, std::shared_ptr<ngraph::runtime::HostTensor>> constant_data;
|
||||
constant_data[1] =
|
||||
std::make_shared<ngraph::runtime::HostTensor>(ngraph::element::Type_t::i32, ov::Shape{3}, target_shape_val);
|
||||
|
||||
std::vector<StaticShape> static_input_shapes = {StaticShape{3, 1}, StaticShape{3}},
|
||||
static_output_shapes = {StaticShape{}};
|
||||
shape_inference(broadcast_v3.get(), static_input_shapes, static_output_shapes, constant_data);
|
||||
ASSERT_EQ(static_output_shapes[0], StaticShape({2, 3, 6}));
|
||||
|
||||
static_input_shapes = {StaticShape{3, 1}, StaticShape{3}};
|
||||
static_output_shapes = {StaticShape{}};
|
||||
EXPECT_THROW(shape_inference(broadcast_v3.get(), static_input_shapes, static_output_shapes, {}), NodeValidationFailure);
|
||||
}
|
||||
|
||||
TEST(StaticShapeInferenceTest, BroadcastPDPDConstantTest) {
|
||||
auto input = std::make_shared<ov::op::v0::Parameter>(element::f32, PartialShape{-1, -1});
|
||||
auto target_shape = std::make_shared<ov::op::v0::Constant>(element::i32, ov::Shape{3}, std::vector<int32_t>{2, 3, 6});
|
||||
auto broadcast_v3 =
|
||||
std::make_shared<op::v3::Broadcast>(input, target_shape, op::BroadcastModeSpec(op::BroadcastType::PDPD, 1));
|
||||
|
||||
std::vector<StaticShape> static_input_shapes = {StaticShape{3, 1}, StaticShape{3}},
|
||||
static_output_shapes = {StaticShape{}};
|
||||
shape_inference(broadcast_v3.get(), static_input_shapes, static_output_shapes, {});
|
||||
ASSERT_EQ(static_output_shapes[0], StaticShape({2, 3, 6}));
|
||||
}
|
||||
|
||||
TEST(StaticShapeInferenceTest, BroadcastNumpyTest) {
|
||||
auto input = std::make_shared<ov::op::v0::Parameter>(element::f32, PartialShape{-1, -1, -1});
|
||||
auto target_shape = std::make_shared<ov::op::v0::Parameter>(element::i32, PartialShape{-1});
|
||||
auto broadcast_v3 = std::make_shared<op::v3::Broadcast>(input, target_shape, op::BroadcastType::NUMPY);
|
||||
|
||||
int32_t target_shape_val[] = {1, 16, 50, 50};
|
||||
std::map<size_t, std::shared_ptr<ngraph::runtime::HostTensor>> constant_data;
|
||||
constant_data[1] =
|
||||
std::make_shared<ngraph::runtime::HostTensor>(ngraph::element::Type_t::i32, ov::Shape{4}, target_shape_val);
|
||||
|
||||
std::vector<StaticShape> static_input_shapes = {StaticShape{16, 1, 1}, StaticShape{4}},
|
||||
static_output_shapes = {StaticShape{}};
|
||||
shape_inference(broadcast_v3.get(), static_input_shapes, static_output_shapes, constant_data);
|
||||
ASSERT_EQ(static_output_shapes[0], StaticShape({1, 16, 50, 50}));
|
||||
|
||||
static_input_shapes = {StaticShape{16, 1, 1}, StaticShape{4}};
|
||||
static_output_shapes = {StaticShape{}};
|
||||
EXPECT_THROW(shape_inference(broadcast_v3.get(), static_input_shapes, static_output_shapes, {}), NodeValidationFailure);
|
||||
}
|
||||
|
||||
TEST(StaticShapeInferenceTest, BroadcastNumpyConstantTest) {
|
||||
auto input = std::make_shared<ov::op::v0::Parameter>(element::f32, PartialShape{-1, -1, -1});
|
||||
auto target_shape =
|
||||
std::make_shared<ov::op::v0::Constant>(element::i32, ov::Shape{4}, std::vector<int32_t>{1, 16, 50, 50});
|
||||
auto broadcast_v3 = std::make_shared<op::v3::Broadcast>(input, target_shape, op::BroadcastType::NUMPY);
|
||||
|
||||
std::vector<StaticShape> static_input_shapes = {StaticShape{16, 1, 1}, StaticShape{4}},
|
||||
static_output_shapes = {StaticShape{}};
|
||||
shape_inference(broadcast_v3.get(), static_input_shapes, static_output_shapes, {});
|
||||
ASSERT_EQ(static_output_shapes[0], StaticShape({1, 16, 50, 50}));
|
||||
}
|
||||
|
||||
TEST(StaticShapeInferenceTest, BroadcastExplicitTest) {
|
||||
auto input = std::make_shared<ov::op::v0::Parameter>(element::f32, PartialShape{-1});
|
||||
auto target_shape = std::make_shared<ov::op::v0::Parameter>(element::i32, PartialShape{-1});
|
||||
auto axes_mapping = std::make_shared<ov::op::v0::Parameter>(element::i32, PartialShape{-1});
|
||||
auto broadcast_v3 =
|
||||
std::make_shared<op::v3::Broadcast>(input, target_shape, axes_mapping, op::BroadcastType::EXPLICIT);
|
||||
|
||||
int32_t target_shape_val[] = {1, 16, 50, 50};
|
||||
int32_t axes_mapping_val[] = {1};
|
||||
std::map<size_t, std::shared_ptr<ngraph::runtime::HostTensor>> constant_data;
|
||||
constant_data[1] =
|
||||
std::make_shared<ngraph::runtime::HostTensor>(ngraph::element::Type_t::i32, ov::Shape{4}, target_shape_val);
|
||||
constant_data[2] =
|
||||
std::make_shared<ngraph::runtime::HostTensor>(ngraph::element::Type_t::i32, ov::Shape{1}, axes_mapping_val);
|
||||
|
||||
std::vector<StaticShape> static_input_shapes = {StaticShape{16}, StaticShape{4}, StaticShape{1}};
|
||||
std::vector<StaticShape> static_output_shapes = {StaticShape{}};
|
||||
shape_inference(broadcast_v3.get(), static_input_shapes, static_output_shapes, constant_data);
|
||||
ASSERT_EQ(static_output_shapes[0], StaticShape({1, 16, 50, 50}));
|
||||
|
||||
constant_data.erase(1);
|
||||
EXPECT_THROW(shape_inference(broadcast_v3.get(), static_input_shapes, static_output_shapes, constant_data),
|
||||
NodeValidationFailure);
|
||||
EXPECT_THROW(shape_inference(broadcast_v3.get(), static_input_shapes, static_output_shapes, {}), NodeValidationFailure);
|
||||
}
|
||||
|
||||
TEST(StaticShapeInferenceTest, BroadcastExplicitConstantTest) {
|
||||
auto input = std::make_shared<ov::op::v0::Parameter>(element::f32, PartialShape{-1});
|
||||
auto target_shape =
|
||||
std::make_shared<ov::op::v0::Constant>(element::i32, ov::Shape{4}, std::vector<int32_t>{1, 16, 50, 50});
|
||||
auto axes_mapping = std::make_shared<ov::op::v0::Constant>(element::i32, ov::Shape{1}, std::vector<int32_t>{1});
|
||||
auto broadcast_v3 =
|
||||
std::make_shared<op::v3::Broadcast>(input, target_shape, axes_mapping, op::BroadcastType::EXPLICIT);
|
||||
|
||||
std::vector<StaticShape> static_input_shapes = {StaticShape{16}, StaticShape{4}, StaticShape{1}};
|
||||
std::vector<StaticShape> static_output_shapes = {StaticShape{}};
|
||||
shape_inference(broadcast_v3.get(), static_input_shapes, static_output_shapes, {});
|
||||
ASSERT_EQ(static_output_shapes[0], StaticShape({1, 16, 50, 50}));
|
||||
}
|
||||
|
||||
// BroadcastV1 test
|
||||
|
||||
TEST(StaticShapeInferenceTest, BroadcastV1PDPDTest) {
|
||||
auto input = std::make_shared<ov::op::v0::Parameter>(element::f32, PartialShape{-1, -1});
|
||||
auto target_shape = std::make_shared<ov::op::v0::Parameter>(element::i32, PartialShape{-1});
|
||||
auto broadcast_v1 =
|
||||
std::make_shared<op::v1::Broadcast>(input, target_shape, op::AutoBroadcastSpec(op::AutoBroadcastType::PDPD, 1));
|
||||
|
||||
int32_t target_shape_val[] = {2, 3, 6};
|
||||
std::map<size_t, std::shared_ptr<ngraph::runtime::HostTensor>> constant_data;
|
||||
constant_data[1] =
|
||||
std::make_shared<ngraph::runtime::HostTensor>(ngraph::element::Type_t::i32, ov::Shape{3}, target_shape_val);
|
||||
|
||||
std::vector<StaticShape> static_input_shapes = {StaticShape{3, 1}, StaticShape{3}},
|
||||
static_output_shapes = {StaticShape{}};
|
||||
shape_inference(broadcast_v1.get(), static_input_shapes, static_output_shapes, constant_data);
|
||||
ASSERT_EQ(static_output_shapes[0], StaticShape({2, 3, 6}));
|
||||
|
||||
static_input_shapes = {StaticShape{3, 1}, StaticShape{3}};
|
||||
static_output_shapes = {StaticShape{}};
|
||||
EXPECT_THROW(shape_inference(broadcast_v1.get(), static_input_shapes, static_output_shapes, {}), NodeValidationFailure);
|
||||
}
|
||||
|
||||
TEST(StaticShapeInferenceTest, BroadcastV1NumpyTest) {
|
||||
auto input = std::make_shared<ov::op::v0::Parameter>(element::f32, PartialShape{-1, -1});
|
||||
auto target_shape = std::make_shared<ov::op::v0::Parameter>(element::i32, PartialShape{-1});
|
||||
auto broadcast_v1 = std::make_shared<op::v1::Broadcast>(input, target_shape);
|
||||
|
||||
int32_t target_shape_val[] = {2, 3, 6};
|
||||
std::map<size_t, std::shared_ptr<ngraph::runtime::HostTensor>> constant_data;
|
||||
constant_data[1] =
|
||||
std::make_shared<ngraph::runtime::HostTensor>(ngraph::element::Type_t::i32, ov::Shape{3}, target_shape_val);
|
||||
|
||||
std::vector<StaticShape> static_input_shapes = {StaticShape{3, 1}, StaticShape{3}},
|
||||
static_output_shapes = {StaticShape{}};
|
||||
shape_inference(broadcast_v1.get(), static_input_shapes, static_output_shapes, constant_data);
|
||||
ASSERT_EQ(static_output_shapes[0], StaticShape({2, 3, 6}));
|
||||
|
||||
static_input_shapes = {StaticShape{3, 1}, StaticShape{3}};
|
||||
static_output_shapes = {StaticShape{}};
|
||||
EXPECT_THROW(shape_inference(broadcast_v1.get(), static_input_shapes, static_output_shapes, {}), NodeValidationFailure);
|
||||
}
|
||||
|
||||
TEST(StaticShapeInferenceTest, BroadcastV1ExplicitTest) {
|
||||
auto input = std::make_shared<ov::op::v0::Parameter>(element::f32, PartialShape{-1, -1});
|
||||
auto target_shape = std::make_shared<ov::op::v0::Parameter>(element::i32, PartialShape{-1});
|
||||
auto axes_mapping = std::make_shared<ov::op::v0::Parameter>(element::i32, PartialShape{-1});
|
||||
auto broadcast_v1 = std::make_shared<op::v1::Broadcast>(input, target_shape, axes_mapping);
|
||||
|
||||
int32_t target_shape_val[] = {2, 3, 1};
|
||||
int32_t axes_mapping_val[] = {1, 2};
|
||||
std::map<size_t, std::shared_ptr<ngraph::runtime::HostTensor>> constant_data;
|
||||
constant_data[1] =
|
||||
std::make_shared<ngraph::runtime::HostTensor>(ngraph::element::Type_t::i32, ov::Shape{3}, target_shape_val);
|
||||
constant_data[2] =
|
||||
std::make_shared<ngraph::runtime::HostTensor>(ngraph::element::Type_t::i32, ov::Shape{2}, axes_mapping_val);
|
||||
|
||||
std::vector<StaticShape> static_input_shapes = {StaticShape{3, 1}, StaticShape{3}, StaticShape{2}},
|
||||
static_output_shapes = {StaticShape{}};
|
||||
shape_inference(broadcast_v1.get(), static_input_shapes, static_output_shapes, constant_data);
|
||||
ASSERT_EQ(static_output_shapes[0], StaticShape({2, 3, 1}));
|
||||
|
||||
static_input_shapes = {StaticShape{3, 1}, StaticShape{3}, StaticShape{2}};
|
||||
static_output_shapes = {StaticShape{}};
|
||||
EXPECT_THROW(shape_inference(broadcast_v1.get(), static_input_shapes, static_output_shapes, {}), NodeValidationFailure);
|
||||
}
|
Loading…
Reference in New Issue
Block a user