[shape infer]BroadcastV3 and BroadcastV1 shape inference (#8976)
* Implement broadcastv3 shape infer * Implement BroadcastV1 shape infer * Use shape_inference in test case * Fix myriadx test case failure * Apply review comments * Change file name * Apply review comments * Apply review comments * Change broadcast bidirection logic to align with master change
This commit is contained in:
parent
dce2aa2c0e
commit
8b93880b37
@ -50,6 +50,10 @@ public:
|
|||||||
bool evaluate(const HostTensorVector& outputs, const HostTensorVector& inputs) const override;
|
bool evaluate(const HostTensorVector& outputs, const HostTensorVector& inputs) const override;
|
||||||
OPENVINO_SUPPRESS_DEPRECATED_END
|
OPENVINO_SUPPRESS_DEPRECATED_END
|
||||||
|
|
||||||
|
const BroadcastModeSpec& get_broadcast_spec() const {
|
||||||
|
return m_mode;
|
||||||
|
}
|
||||||
|
|
||||||
protected:
|
protected:
|
||||||
BroadcastModeSpec m_mode;
|
BroadcastModeSpec m_mode;
|
||||||
|
|
||||||
|
301
src/core/shape_inference/include/broadcast_shape_inference.hpp
Normal file
301
src/core/shape_inference/include/broadcast_shape_inference.hpp
Normal file
@ -0,0 +1,301 @@
|
|||||||
|
// Copyright (C) 2018-2021 Intel Corporation
|
||||||
|
// SPDX-License-Identifier: Apache-2.0
|
||||||
|
//
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include <ngraph/validation_util.hpp>
|
||||||
|
#include <openvino/op/broadcast.hpp>
|
||||||
|
#include <openvino/op/util/broadcast_base.hpp>
|
||||||
|
|
||||||
|
#include "ngraph/op/concat.hpp"
|
||||||
|
#include "openvino/core/axis_vector.hpp"
|
||||||
|
#include "utils.hpp"
|
||||||
|
|
||||||
|
namespace ov {
|
||||||
|
namespace op {
|
||||||
|
namespace util {
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
void validate_target_shape_none(const ov::Node* op,
|
||||||
|
const T& arg_shape,
|
||||||
|
const AxisVector& axes_mapping_val,
|
||||||
|
const T& target_shape) {
|
||||||
|
if (arg_shape.rank().is_static() && target_shape.rank().is_static()) {
|
||||||
|
const auto target_rank_length = target_shape.size();
|
||||||
|
// axes_mapping needs to be in sorted order
|
||||||
|
NODE_VALIDATION_CHECK(op,
|
||||||
|
std::is_sorted(axes_mapping_val.begin(), axes_mapping_val.end()),
|
||||||
|
"Broadcast doesn't permit transposes. axes_mapping ",
|
||||||
|
axes_mapping_val,
|
||||||
|
" not in sorted order");
|
||||||
|
|
||||||
|
if (arg_shape.size() == 0 && axes_mapping_val.size() > 0) {
|
||||||
|
NODE_VALIDATION_CHECK(op,
|
||||||
|
target_shape[axes_mapping_val[0]].compatible(1),
|
||||||
|
"Broadcast target[axes_mapping[0]]. Expected 1. Got ",
|
||||||
|
target_shape[axes_mapping_val[0]]);
|
||||||
|
}
|
||||||
|
|
||||||
|
for (size_t i = 0; i < axes_mapping_val.size(); i++) {
|
||||||
|
NODE_VALIDATION_CHECK(op,
|
||||||
|
axes_mapping_val[i] < target_rank_length,
|
||||||
|
"Broadcast axes_mapping[",
|
||||||
|
i,
|
||||||
|
"]: ",
|
||||||
|
axes_mapping_val[i],
|
||||||
|
" exceeds target rank ",
|
||||||
|
target_rank_length);
|
||||||
|
|
||||||
|
if (arg_shape.size() > 0) {
|
||||||
|
NODE_VALIDATION_CHECK(
|
||||||
|
op,
|
||||||
|
target_shape[axes_mapping_val[i]].compatible(arg_shape[i]) || arg_shape[i].compatible(1),
|
||||||
|
"Broadcast target[axes_mapping[",
|
||||||
|
i,
|
||||||
|
"]]",
|
||||||
|
" Expected ",
|
||||||
|
arg_shape[i],
|
||||||
|
". Got ",
|
||||||
|
target_shape[axes_mapping_val[i]]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
void validate_target_shape_numpy(const ov::Node* op, const T& arg_shape, const T& target_shape) {
|
||||||
|
if (arg_shape.rank().is_dynamic() || target_shape.rank().is_dynamic()) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
const auto arg_rank_length = arg_shape.size();
|
||||||
|
const auto target_rank_length = target_shape.size();
|
||||||
|
const int64_t start_axis = target_rank_length - arg_rank_length;
|
||||||
|
NODE_VALIDATION_CHECK(op,
|
||||||
|
start_axis >= 0,
|
||||||
|
"Broadcast target_shape has smaller rank ",
|
||||||
|
target_rank_length,
|
||||||
|
" than arg shape ",
|
||||||
|
arg_rank_length);
|
||||||
|
for (auto i = start_axis; i < target_rank_length; i++) {
|
||||||
|
NODE_VALIDATION_CHECK(op,
|
||||||
|
arg_shape[i - start_axis].is_dynamic() || target_shape[i].is_dynamic() ||
|
||||||
|
arg_shape[i - start_axis].compatible(1) ||
|
||||||
|
arg_shape[i - start_axis].compatible(target_shape[i]),
|
||||||
|
"Input shape dimension equal ",
|
||||||
|
arg_shape[i - start_axis],
|
||||||
|
" cannot be broadcasted (numpy mode) to ",
|
||||||
|
target_shape[i],
|
||||||
|
". Allowed input dimension value would be 1",
|
||||||
|
target_shape[i] != 1 ? " or " : "",
|
||||||
|
target_shape[i] != 1 ? std::to_string(target_shape[i].get_length()) : "");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
void set_result_shape_pdpd(const ov::Node* op,
|
||||||
|
const T& arg0_shape,
|
||||||
|
const T& target_shape,
|
||||||
|
T& result_shape,
|
||||||
|
const ov::op::BroadcastModeSpec& broadcast_spec) {
|
||||||
|
using DimType = typename std::iterator_traits<typename T::iterator>::value_type;
|
||||||
|
if (arg0_shape.rank().is_dynamic() || target_shape.rank().is_dynamic()) {
|
||||||
|
result_shape = PartialShape::dynamic(target_shape.rank());
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
result_shape = target_shape;
|
||||||
|
auto& start_axis = broadcast_spec.m_axis;
|
||||||
|
|
||||||
|
NODE_VALIDATION_CHECK(op, start_axis >= 0, "Broadcast start_axis must be greater than 0");
|
||||||
|
|
||||||
|
for (size_t i = start_axis; i < target_shape.size(); i++) {
|
||||||
|
const auto& arg_dim = arg0_shape[i - start_axis];
|
||||||
|
if (arg_dim == 1) {
|
||||||
|
result_shape[i] = target_shape[i];
|
||||||
|
} else if (target_shape[i] == 1) {
|
||||||
|
result_shape[i] = arg_dim;
|
||||||
|
} else {
|
||||||
|
NODE_VALIDATION_CHECK(op,
|
||||||
|
DimType::merge(result_shape[i], arg_dim, target_shape[i]),
|
||||||
|
"Broadcast incorrect target shape. Expecting either 1 or ",
|
||||||
|
arg_dim,
|
||||||
|
" . Got ",
|
||||||
|
target_shape[i]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
void set_result_shape_bidirectional(const ov::Node* op, const T& arg_shape, T& target_shape, T& result_shape) {
|
||||||
|
using DimType = typename std::iterator_traits<typename T::iterator>::value_type;
|
||||||
|
if (arg_shape.rank().is_dynamic() || target_shape.rank().is_dynamic()) {
|
||||||
|
result_shape = PartialShape::dynamic();
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
auto arg_shape_vec = arg_shape;
|
||||||
|
|
||||||
|
// Add left padding to shorter target or argument shape
|
||||||
|
const auto target_padded_rank = std::max(arg_shape_vec.size(), target_shape.size());
|
||||||
|
while (arg_shape_vec.size() < target_padded_rank) {
|
||||||
|
arg_shape_vec.insert(arg_shape_vec.begin(), 1);
|
||||||
|
}
|
||||||
|
while (target_shape.size() < target_padded_rank) {
|
||||||
|
target_shape.insert(target_shape.begin(), 1);
|
||||||
|
}
|
||||||
|
|
||||||
|
result_shape.resize(target_padded_rank);
|
||||||
|
for (size_t i = 0; i < target_shape.size(); ++i) {
|
||||||
|
if (arg_shape_vec[i] == 1) {
|
||||||
|
result_shape[i] = target_shape[i];
|
||||||
|
} else if (target_shape[i] == 1) {
|
||||||
|
result_shape[i] = arg_shape_vec[i];
|
||||||
|
} else {
|
||||||
|
NODE_VALIDATION_CHECK(op,
|
||||||
|
DimType::merge(result_shape[i], arg_shape_vec[i], target_shape[i]),
|
||||||
|
"Broadcast incorrect target shape. Expecting either 1 or ",
|
||||||
|
arg_shape_vec[i],
|
||||||
|
". Got ",
|
||||||
|
target_shape[i]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <class T>
|
||||||
|
void broadcase_base_shape_infer(
|
||||||
|
const ov::op::util::BroadcastBase* op,
|
||||||
|
const std::vector<T>& input_shapes,
|
||||||
|
std::vector<T>& output_shapes,
|
||||||
|
const std::map<size_t, std::shared_ptr<ngraph::runtime::HostTensor>>& constant_data = {}) {
|
||||||
|
|
||||||
|
// shape node should produce a one dimensional shape.
|
||||||
|
auto broadcast_shape_rank = input_shapes[1].rank();
|
||||||
|
NODE_VALIDATION_CHECK(op,
|
||||||
|
broadcast_shape_rank.compatible(1),
|
||||||
|
"Broadcast shape rank must be 1, but has ",
|
||||||
|
broadcast_shape_rank);
|
||||||
|
|
||||||
|
const auto& mode = op->get_broadcast_spec();
|
||||||
|
if (mode.m_type == BroadcastType::NONE) {
|
||||||
|
// axes_mapping node should produce a one dimensional shape.
|
||||||
|
auto axes_shape_rank = input_shapes[2].rank();
|
||||||
|
NODE_VALIDATION_CHECK(op,
|
||||||
|
axes_shape_rank.compatible(1),
|
||||||
|
"Broadcast axes rank must be 1, but has ",
|
||||||
|
axes_shape_rank);
|
||||||
|
}
|
||||||
|
|
||||||
|
auto& result_shape = output_shapes[0];
|
||||||
|
const auto& input_shape = input_shapes[0];
|
||||||
|
const auto& target_shape = input_shapes[1];
|
||||||
|
const bool is_target_shape_known = target_shape.is_static();
|
||||||
|
|
||||||
|
T output_shape;
|
||||||
|
bool output_shape_defined = get_data_as_shape<T>(1, op, output_shape, constant_data);
|
||||||
|
|
||||||
|
if (!output_shape_defined) {
|
||||||
|
if (auto concat = ov::as_type_ptr<ov::opset1::Concat>(op->get_input_node_shared_ptr(1))) {
|
||||||
|
const auto concat_inputs = concat->input_values();
|
||||||
|
if (concat->get_output_partial_shape(0).is_static() && concat->get_shape().size() == 1 &&
|
||||||
|
concat_inputs.size() == shape_size(concat->get_shape())) {
|
||||||
|
for (const auto& concat_input : concat_inputs) {
|
||||||
|
auto source_node_ptr = concat_input.get_node_shared_ptr();
|
||||||
|
if (auto source_const_ptr = ov::as_type_ptr<ov::opset1::Constant>(source_node_ptr)) {
|
||||||
|
output_shape.push_back(source_const_ptr->get_axis_vector_val()[0]);
|
||||||
|
} else {
|
||||||
|
output_shape.push_back(Dimension::dynamic());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
output_shape_defined = true;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (mode.m_type == BroadcastType::NONE) {
|
||||||
|
if (output_shape_defined) {
|
||||||
|
result_shape = output_shape;
|
||||||
|
} else if (is_target_shape_known) {
|
||||||
|
result_shape = PartialShape::dynamic(target_shape[0].get_length());
|
||||||
|
} else {
|
||||||
|
result_shape = PartialShape::dynamic();
|
||||||
|
}
|
||||||
|
// Validate axes_mapping
|
||||||
|
const auto& axes_shape = input_shapes[2];
|
||||||
|
if (input_shape.rank().is_static() && target_shape.rank().is_static() && axes_shape.is_static()) {
|
||||||
|
auto input_rank = (input_shape.size() == 0 && axes_shape[0].get_length() > 0) ? 1 : input_shape.size();
|
||||||
|
NODE_VALIDATION_CHECK(op,
|
||||||
|
axes_shape[0].get_length() == input_rank,
|
||||||
|
"Broadcast axes_mapping shape ",
|
||||||
|
axes_shape,
|
||||||
|
" doesn't match rank of input tensor ",
|
||||||
|
input_rank);
|
||||||
|
std::vector<int64_t> axes_mapping_val;
|
||||||
|
if (output_shape_defined && get_data_as_int64<T>(2, op, axes_mapping_val, constant_data)) {
|
||||||
|
AxisVector axes_mapping =
|
||||||
|
AxisVector(std::vector<size_t>(axes_mapping_val.begin(), axes_mapping_val.end()));
|
||||||
|
validate_target_shape_none(op, input_shape, axes_mapping, output_shape);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else if (mode.m_type == BroadcastType::NUMPY) {
|
||||||
|
if (output_shape_defined) {
|
||||||
|
result_shape = output_shape;
|
||||||
|
validate_target_shape_numpy(op, input_shape, output_shape);
|
||||||
|
} else if (is_target_shape_known) {
|
||||||
|
result_shape = PartialShape::dynamic(target_shape[0].get_length());
|
||||||
|
} else {
|
||||||
|
result_shape = PartialShape::dynamic();
|
||||||
|
}
|
||||||
|
} else if (mode.m_type == BroadcastType::PDPD) {
|
||||||
|
if (output_shape_defined) {
|
||||||
|
set_result_shape_pdpd(op, input_shape, output_shape, result_shape, mode);
|
||||||
|
} else if (is_target_shape_known) {
|
||||||
|
result_shape = PartialShape::dynamic(target_shape[0].get_length());
|
||||||
|
} else {
|
||||||
|
result_shape = PartialShape::dynamic();
|
||||||
|
}
|
||||||
|
} else if (mode.m_type == BroadcastType::BIDIRECTIONAL) {
|
||||||
|
if (output_shape_defined) {
|
||||||
|
set_result_shape_bidirectional(op, input_shape, output_shape, result_shape);
|
||||||
|
} else if (input_shape.rank().is_static() && is_target_shape_known) {
|
||||||
|
auto output_rank = std::max(input_shape.size(), static_cast<size_t>(target_shape[0].get_length()));
|
||||||
|
result_shape = PartialShape::dynamic(output_rank);
|
||||||
|
} else {
|
||||||
|
result_shape = PartialShape::dynamic();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} // namespace util
|
||||||
|
|
||||||
|
namespace v3 {
|
||||||
|
template <class T>
|
||||||
|
void shape_infer(const ov::op::v3::Broadcast* op,
|
||||||
|
const std::vector<T>& input_shapes,
|
||||||
|
std::vector<T>& output_shapes,
|
||||||
|
const std::map<size_t, std::shared_ptr<ngraph::runtime::HostTensor>>& constant_data = {}) {
|
||||||
|
NODE_VALIDATION_CHECK(op, output_shapes.size() == 1);
|
||||||
|
auto& mode = op->get_broadcast_spec();
|
||||||
|
if (mode.m_type == BroadcastType::NONE) {
|
||||||
|
NODE_VALIDATION_CHECK(op,
|
||||||
|
input_shapes.size() == 3,
|
||||||
|
"axes_mapping input should be provided if explicit mode is used");
|
||||||
|
} else {
|
||||||
|
NODE_VALIDATION_CHECK(op,
|
||||||
|
input_shapes.size() == 2,
|
||||||
|
"axes_mapping input should not be provided for mode other than explicit");
|
||||||
|
}
|
||||||
|
broadcase_base_shape_infer(op, input_shapes, output_shapes, constant_data);
|
||||||
|
}
|
||||||
|
} // namespace v3
|
||||||
|
|
||||||
|
namespace v1 {
|
||||||
|
template <class T>
|
||||||
|
void shape_infer(const ov::op::v1::Broadcast* op,
|
||||||
|
const std::vector<T>& input_shapes,
|
||||||
|
std::vector<T>& output_shapes,
|
||||||
|
const std::map<size_t, std::shared_ptr<ngraph::runtime::HostTensor>>& constant_data = {}) {
|
||||||
|
NODE_VALIDATION_CHECK(op, output_shapes.size() == 1 && (input_shapes.size() == 2 || input_shapes.size() == 3));
|
||||||
|
|
||||||
|
broadcase_base_shape_infer(op, input_shapes, output_shapes, constant_data);
|
||||||
|
}
|
||||||
|
} // namespace v1
|
||||||
|
|
||||||
|
} // namespace op
|
||||||
|
} // namespace ov
|
@ -4,6 +4,7 @@
|
|||||||
|
|
||||||
#include "ngraph/op/broadcast.hpp"
|
#include "ngraph/op/broadcast.hpp"
|
||||||
|
|
||||||
|
#include <broadcast_shape_inference.hpp>
|
||||||
#include <ngraph/validation_util.hpp>
|
#include <ngraph/validation_util.hpp>
|
||||||
#include <numeric>
|
#include <numeric>
|
||||||
|
|
||||||
@ -141,25 +142,39 @@ void op::v3::Broadcast::validate_and_infer_types() {
|
|||||||
"axes_mapping input should not be provided for mode other than explicit");
|
"axes_mapping input should not be provided for mode other than explicit");
|
||||||
}
|
}
|
||||||
|
|
||||||
util::BroadcastBase::validate_and_infer_types();
|
const auto& shape_et = get_input_element_type(1);
|
||||||
|
NODE_VALIDATION_CHECK(this,
|
||||||
auto result_shape = get_output_partial_shape(0);
|
shape_et.is_integral_number(),
|
||||||
if (m_mode.m_type == BroadcastType::BIDIRECTIONAL) {
|
"Broadcast shape must be an integral number, but is: ",
|
||||||
if (get_input_partial_shape(0).rank().is_static() && get_input_partial_shape(1).is_static()) {
|
shape_et);
|
||||||
auto arg_shape = get_input_partial_shape(0);
|
if (m_mode.m_type == BroadcastType::NONE) {
|
||||||
|
// axes_mapping node should have integer data type. For now we only allow i64
|
||||||
PartialShape target_shape;
|
const auto& axes_et = get_input_element_type(2);
|
||||||
if (evaluate_as_partial_shape(input_value(1), target_shape)) {
|
NODE_VALIDATION_CHECK(this,
|
||||||
result_shape = get_result_shape_bidirectional(this, arg_shape, target_shape);
|
axes_et.is_integral_number(),
|
||||||
}
|
"Broadcast axes must be integral numbers, but are: ",
|
||||||
}
|
axes_et);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
std::vector<ov::PartialShape> output_shapes = {ov::PartialShape()};
|
||||||
|
std::vector<ov::PartialShape> input_shapes;
|
||||||
|
const auto& arg_shape = get_input_partial_shape(0);
|
||||||
|
const auto& target_shape = get_input_partial_shape(1);
|
||||||
|
if (input_values().size() == 2) {
|
||||||
|
input_shapes = {arg_shape, target_shape};
|
||||||
|
} else {
|
||||||
|
const auto& axes_mapping = get_input_partial_shape(2);
|
||||||
|
input_shapes = {arg_shape, target_shape, axes_mapping};
|
||||||
|
}
|
||||||
|
|
||||||
|
shape_infer(this, input_shapes, output_shapes);
|
||||||
|
|
||||||
set_input_is_relevant_to_shape(0); // arg - Result element type
|
set_input_is_relevant_to_shape(0); // arg - Result element type
|
||||||
set_input_is_relevant_to_shape(1); // target_shape - Result shape
|
set_input_is_relevant_to_shape(1); // target_shape - Result shape
|
||||||
if (get_input_size() == 3) {
|
if (get_input_size() == 3) {
|
||||||
set_input_is_relevant_to_shape(2); // axes_mapping - Broadcast type
|
set_input_is_relevant_to_shape(2); // axes_mapping - Broadcast type
|
||||||
}
|
}
|
||||||
set_output_type(0, get_input_element_type(0), result_shape);
|
set_output_type(0, get_input_element_type(0), output_shapes[0]);
|
||||||
}
|
}
|
||||||
|
|
||||||
shared_ptr<Node> op::v3::Broadcast::clone_with_new_inputs(const OutputVector& new_args) const {
|
shared_ptr<Node> op::v3::Broadcast::clone_with_new_inputs(const OutputVector& new_args) const {
|
||||||
@ -253,10 +268,32 @@ void op::v1::Broadcast::validate_and_infer_types() {
|
|||||||
util::BroadcastBase::m_mode = base_spec;
|
util::BroadcastBase::m_mode = base_spec;
|
||||||
}
|
}
|
||||||
|
|
||||||
util::BroadcastBase::validate_and_infer_types();
|
const auto& shape_et = get_input_element_type(1);
|
||||||
|
NODE_VALIDATION_CHECK(this,
|
||||||
|
shape_et.is_integral_number(),
|
||||||
|
"Broadcast shape must be an integral number, but is: ",
|
||||||
|
shape_et);
|
||||||
|
if (m_mode.m_type == BroadcastType::NONE) {
|
||||||
|
// axes_mapping node should have integer data type. For now we only allow i64
|
||||||
|
const auto& axes_et = get_input_element_type(2);
|
||||||
|
NODE_VALIDATION_CHECK(this,
|
||||||
|
axes_et.is_integral_number(),
|
||||||
|
"Broadcast axes must be integral numbers, but are: ",
|
||||||
|
axes_et);
|
||||||
|
}
|
||||||
|
|
||||||
|
const auto& arg_shape = get_input_partial_shape(0);
|
||||||
|
const auto& target_shape = get_input_partial_shape(1);
|
||||||
|
const auto& axes_mapping = get_input_partial_shape(2);
|
||||||
|
|
||||||
|
std::vector<ov::PartialShape> output_shapes = {ov::PartialShape()};
|
||||||
|
std::vector<ov::PartialShape> input_shapes = {arg_shape, target_shape, axes_mapping};
|
||||||
|
shape_infer(this, input_shapes, output_shapes);
|
||||||
|
|
||||||
set_input_is_relevant_to_shape(0); // arg - Result element type
|
set_input_is_relevant_to_shape(0); // arg - Result element type
|
||||||
set_input_is_relevant_to_shape(1); // target_shape - Result shape
|
set_input_is_relevant_to_shape(1); // target_shape - Result shape
|
||||||
set_input_is_relevant_to_shape(2); // axes_mapping - Broadcast type
|
set_input_is_relevant_to_shape(2); // axes_mapping - Broadcast type
|
||||||
|
set_output_type(0, get_input_element_type(0), output_shapes[0]);
|
||||||
}
|
}
|
||||||
|
|
||||||
shared_ptr<Node> op::v1::Broadcast::clone_with_new_inputs(const OutputVector& new_args) const {
|
shared_ptr<Node> op::v1::Broadcast::clone_with_new_inputs(const OutputVector& new_args) const {
|
||||||
|
@ -302,10 +302,15 @@ TYPED_TEST_P(BroadcastTests, broadcast_explicit_const_target_shape_static_rank_i
|
|||||||
|
|
||||||
// const axes mapping
|
// const axes mapping
|
||||||
const auto axes_mapping_const = op::Constant::create(element::i64, Shape{4}, vector<int64_t>{0, 2, 1, 3});
|
const auto axes_mapping_const = op::Constant::create(element::i64, Shape{4}, vector<int64_t>{0, 2, 1, 3});
|
||||||
bc = make_shared<TypeParam>(data, target_shape, axes_mapping_const, "EXPLICIT");
|
try {
|
||||||
ASSERT_TRUE(bc->get_output_partial_shape(0).is_static());
|
auto bc = make_shared<TypeParam>(data, target_shape, axes_mapping_const, "EXPLICIT");
|
||||||
ASSERT_EQ(bc->get_output_partial_shape(0).rank().get_length(), 4);
|
FAIL() << "Broadcast: Broadcast axes_mapping shape doesn't match rank of input tensor";
|
||||||
ASSERT_EQ(bc->get_shape(), (Shape{1, 1, 5, 10}));
|
} catch (const NodeValidationFailure& error) {
|
||||||
|
EXPECT_HAS_SUBSTRING(error.what(),
|
||||||
|
std::string("Broadcast axes_mapping shape {4} doesn't match rank of input tensor 3"));
|
||||||
|
} catch (...) {
|
||||||
|
FAIL() << "Deduced type check failed for unexpected reason";
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
TYPED_TEST_P(BroadcastTests, broadcast_explicit_static_input_shape) {
|
TYPED_TEST_P(BroadcastTests, broadcast_explicit_static_input_shape) {
|
||||||
|
@ -59,6 +59,7 @@
|
|||||||
#include "detection_output_shape_inference.hpp"
|
#include "detection_output_shape_inference.hpp"
|
||||||
#include "select_shape_inference.hpp"
|
#include "select_shape_inference.hpp"
|
||||||
#include "shuffle_channels_shape_inference.hpp"
|
#include "shuffle_channels_shape_inference.hpp"
|
||||||
|
#include "broadcast_shape_inference.hpp"
|
||||||
#include "static_shape.hpp"
|
#include "static_shape.hpp"
|
||||||
#include "tile_shape_inference.hpp"
|
#include "tile_shape_inference.hpp"
|
||||||
#include "utils.hpp"
|
#include "utils.hpp"
|
||||||
@ -229,6 +230,10 @@ void shape_inference(ov::Node* op,
|
|||||||
shape_infer(node, input_shapes, output_shapes);
|
shape_infer(node, input_shapes, output_shapes);
|
||||||
} else if (auto node = ov::as_type<ov::opset1::ShuffleChannels>(op)) {
|
} else if (auto node = ov::as_type<ov::opset1::ShuffleChannels>(op)) {
|
||||||
shape_infer(node, input_shapes, output_shapes);
|
shape_infer(node, input_shapes, output_shapes);
|
||||||
|
} else if (auto node = ov::as_type<ov::opset4::Broadcast>(op)) {
|
||||||
|
shape_infer(node, input_shapes, output_shapes, constant_data);
|
||||||
|
} else if (auto node = ov::as_type<ov::opset1::Broadcast>(op)) {
|
||||||
|
shape_infer(node, input_shapes, output_shapes, constant_data);
|
||||||
} else {
|
} else {
|
||||||
ngraph::OutputVector new_inputs;
|
ngraph::OutputVector new_inputs;
|
||||||
for (size_t i = 0; i < op->get_input_size(); ++i) {
|
for (size_t i = 0; i < op->get_input_size(); ++i) {
|
||||||
|
@ -0,0 +1,217 @@
|
|||||||
|
// Copyright (C) 2018-2021 Intel Corporation
|
||||||
|
// SPDX-License-Identifier: Apache-2.0
|
||||||
|
//
|
||||||
|
|
||||||
|
#include <gtest/gtest.h>
|
||||||
|
|
||||||
|
#include <broadcast_shape_inference.hpp>
|
||||||
|
#include <openvino/op/broadcast.hpp>
|
||||||
|
#include <openvino/op/ops.hpp>
|
||||||
|
#include <openvino/op/parameter.hpp>
|
||||||
|
#include <utils/shape_inference/shape_inference.hpp>
|
||||||
|
#include <utils/shape_inference/static_shape.hpp>
|
||||||
|
|
||||||
|
using namespace ov;
|
||||||
|
|
||||||
|
TEST(StaticShapeInferenceTest, BroadcastBidirectionalTest) {
|
||||||
|
auto input = std::make_shared<ov::op::v0::Parameter>(element::f32, PartialShape{-1, -1, -1});
|
||||||
|
auto target_shape = std::make_shared<ov::op::v0::Parameter>(element::i32, PartialShape{-1});
|
||||||
|
auto broadcast_v3 = std::make_shared<op::v3::Broadcast>(input, target_shape, op::BroadcastType::BIDIRECTIONAL);
|
||||||
|
|
||||||
|
int32_t target_shape_val[] = {1, 16, 50, 50};
|
||||||
|
std::map<size_t, std::shared_ptr<ngraph::runtime::HostTensor>> constant_data;
|
||||||
|
constant_data[1] =
|
||||||
|
std::make_shared<ngraph::runtime::HostTensor>(ngraph::element::Type_t::i32, ov::Shape{4}, target_shape_val);
|
||||||
|
|
||||||
|
std::vector<StaticShape> static_input_shapes = {StaticShape{16, 1, 1}, StaticShape{4}},
|
||||||
|
static_output_shapes = {StaticShape{}};
|
||||||
|
shape_inference(broadcast_v3.get(), static_input_shapes, static_output_shapes, constant_data);
|
||||||
|
ASSERT_EQ(static_output_shapes[0], StaticShape({1, 16, 50, 50}));
|
||||||
|
|
||||||
|
static_input_shapes = {StaticShape{16, 1, 1}, StaticShape{4}};
|
||||||
|
static_output_shapes = {StaticShape{}};
|
||||||
|
EXPECT_THROW(shape_inference(broadcast_v3.get(), static_input_shapes, static_output_shapes, {}), NodeValidationFailure);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(StaticShapeInferenceTest, BroadcastBidirectionalConstantTest) {
|
||||||
|
auto input = std::make_shared<ov::op::v0::Parameter>(element::f32, PartialShape{-1, -1, -1, -1});
|
||||||
|
auto target_shape = std::make_shared<ov::op::v0::Constant>(element::i32, ov::Shape{3}, std::vector<int32_t>{16, 1, 40});
|
||||||
|
auto broadcast_v3 = std::make_shared<op::v3::Broadcast>(input, target_shape, op::BroadcastType::BIDIRECTIONAL);
|
||||||
|
|
||||||
|
std::vector<StaticShape> static_input_shapes = {StaticShape{1, 16, 50, 1}, StaticShape{3}},
|
||||||
|
static_output_shapes = {StaticShape{}};
|
||||||
|
shape_inference(broadcast_v3.get(), static_input_shapes, static_output_shapes, {});
|
||||||
|
ASSERT_EQ(static_output_shapes[0], StaticShape({1, 16, 50, 40}));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(StaticShapeInferenceTest, BroadcastPDPDTest) {
|
||||||
|
auto input = std::make_shared<ov::op::v0::Parameter>(element::f32, PartialShape{-1, -1});
|
||||||
|
auto target_shape = std::make_shared<ov::op::v0::Parameter>(element::i32, PartialShape{-1});
|
||||||
|
auto broadcast_v3 =
|
||||||
|
std::make_shared<op::v3::Broadcast>(input, target_shape, op::BroadcastModeSpec(op::BroadcastType::PDPD, 1));
|
||||||
|
|
||||||
|
int32_t target_shape_val[] = {2, 3, 6};
|
||||||
|
std::map<size_t, std::shared_ptr<ngraph::runtime::HostTensor>> constant_data;
|
||||||
|
constant_data[1] =
|
||||||
|
std::make_shared<ngraph::runtime::HostTensor>(ngraph::element::Type_t::i32, ov::Shape{3}, target_shape_val);
|
||||||
|
|
||||||
|
std::vector<StaticShape> static_input_shapes = {StaticShape{3, 1}, StaticShape{3}},
|
||||||
|
static_output_shapes = {StaticShape{}};
|
||||||
|
shape_inference(broadcast_v3.get(), static_input_shapes, static_output_shapes, constant_data);
|
||||||
|
ASSERT_EQ(static_output_shapes[0], StaticShape({2, 3, 6}));
|
||||||
|
|
||||||
|
static_input_shapes = {StaticShape{3, 1}, StaticShape{3}};
|
||||||
|
static_output_shapes = {StaticShape{}};
|
||||||
|
EXPECT_THROW(shape_inference(broadcast_v3.get(), static_input_shapes, static_output_shapes, {}), NodeValidationFailure);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(StaticShapeInferenceTest, BroadcastPDPDConstantTest) {
|
||||||
|
auto input = std::make_shared<ov::op::v0::Parameter>(element::f32, PartialShape{-1, -1});
|
||||||
|
auto target_shape = std::make_shared<ov::op::v0::Constant>(element::i32, ov::Shape{3}, std::vector<int32_t>{2, 3, 6});
|
||||||
|
auto broadcast_v3 =
|
||||||
|
std::make_shared<op::v3::Broadcast>(input, target_shape, op::BroadcastModeSpec(op::BroadcastType::PDPD, 1));
|
||||||
|
|
||||||
|
std::vector<StaticShape> static_input_shapes = {StaticShape{3, 1}, StaticShape{3}},
|
||||||
|
static_output_shapes = {StaticShape{}};
|
||||||
|
shape_inference(broadcast_v3.get(), static_input_shapes, static_output_shapes, {});
|
||||||
|
ASSERT_EQ(static_output_shapes[0], StaticShape({2, 3, 6}));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(StaticShapeInferenceTest, BroadcastNumpyTest) {
|
||||||
|
auto input = std::make_shared<ov::op::v0::Parameter>(element::f32, PartialShape{-1, -1, -1});
|
||||||
|
auto target_shape = std::make_shared<ov::op::v0::Parameter>(element::i32, PartialShape{-1});
|
||||||
|
auto broadcast_v3 = std::make_shared<op::v3::Broadcast>(input, target_shape, op::BroadcastType::NUMPY);
|
||||||
|
|
||||||
|
int32_t target_shape_val[] = {1, 16, 50, 50};
|
||||||
|
std::map<size_t, std::shared_ptr<ngraph::runtime::HostTensor>> constant_data;
|
||||||
|
constant_data[1] =
|
||||||
|
std::make_shared<ngraph::runtime::HostTensor>(ngraph::element::Type_t::i32, ov::Shape{4}, target_shape_val);
|
||||||
|
|
||||||
|
std::vector<StaticShape> static_input_shapes = {StaticShape{16, 1, 1}, StaticShape{4}},
|
||||||
|
static_output_shapes = {StaticShape{}};
|
||||||
|
shape_inference(broadcast_v3.get(), static_input_shapes, static_output_shapes, constant_data);
|
||||||
|
ASSERT_EQ(static_output_shapes[0], StaticShape({1, 16, 50, 50}));
|
||||||
|
|
||||||
|
static_input_shapes = {StaticShape{16, 1, 1}, StaticShape{4}};
|
||||||
|
static_output_shapes = {StaticShape{}};
|
||||||
|
EXPECT_THROW(shape_inference(broadcast_v3.get(), static_input_shapes, static_output_shapes, {}), NodeValidationFailure);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(StaticShapeInferenceTest, BroadcastNumpyConstantTest) {
|
||||||
|
auto input = std::make_shared<ov::op::v0::Parameter>(element::f32, PartialShape{-1, -1, -1});
|
||||||
|
auto target_shape =
|
||||||
|
std::make_shared<ov::op::v0::Constant>(element::i32, ov::Shape{4}, std::vector<int32_t>{1, 16, 50, 50});
|
||||||
|
auto broadcast_v3 = std::make_shared<op::v3::Broadcast>(input, target_shape, op::BroadcastType::NUMPY);
|
||||||
|
|
||||||
|
std::vector<StaticShape> static_input_shapes = {StaticShape{16, 1, 1}, StaticShape{4}},
|
||||||
|
static_output_shapes = {StaticShape{}};
|
||||||
|
shape_inference(broadcast_v3.get(), static_input_shapes, static_output_shapes, {});
|
||||||
|
ASSERT_EQ(static_output_shapes[0], StaticShape({1, 16, 50, 50}));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(StaticShapeInferenceTest, BroadcastExplicitTest) {
|
||||||
|
auto input = std::make_shared<ov::op::v0::Parameter>(element::f32, PartialShape{-1});
|
||||||
|
auto target_shape = std::make_shared<ov::op::v0::Parameter>(element::i32, PartialShape{-1});
|
||||||
|
auto axes_mapping = std::make_shared<ov::op::v0::Parameter>(element::i32, PartialShape{-1});
|
||||||
|
auto broadcast_v3 =
|
||||||
|
std::make_shared<op::v3::Broadcast>(input, target_shape, axes_mapping, op::BroadcastType::EXPLICIT);
|
||||||
|
|
||||||
|
int32_t target_shape_val[] = {1, 16, 50, 50};
|
||||||
|
int32_t axes_mapping_val[] = {1};
|
||||||
|
std::map<size_t, std::shared_ptr<ngraph::runtime::HostTensor>> constant_data;
|
||||||
|
constant_data[1] =
|
||||||
|
std::make_shared<ngraph::runtime::HostTensor>(ngraph::element::Type_t::i32, ov::Shape{4}, target_shape_val);
|
||||||
|
constant_data[2] =
|
||||||
|
std::make_shared<ngraph::runtime::HostTensor>(ngraph::element::Type_t::i32, ov::Shape{1}, axes_mapping_val);
|
||||||
|
|
||||||
|
std::vector<StaticShape> static_input_shapes = {StaticShape{16}, StaticShape{4}, StaticShape{1}};
|
||||||
|
std::vector<StaticShape> static_output_shapes = {StaticShape{}};
|
||||||
|
shape_inference(broadcast_v3.get(), static_input_shapes, static_output_shapes, constant_data);
|
||||||
|
ASSERT_EQ(static_output_shapes[0], StaticShape({1, 16, 50, 50}));
|
||||||
|
|
||||||
|
constant_data.erase(1);
|
||||||
|
EXPECT_THROW(shape_inference(broadcast_v3.get(), static_input_shapes, static_output_shapes, constant_data),
|
||||||
|
NodeValidationFailure);
|
||||||
|
EXPECT_THROW(shape_inference(broadcast_v3.get(), static_input_shapes, static_output_shapes, {}), NodeValidationFailure);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(StaticShapeInferenceTest, BroadcastExplicitConstantTest) {
|
||||||
|
auto input = std::make_shared<ov::op::v0::Parameter>(element::f32, PartialShape{-1});
|
||||||
|
auto target_shape =
|
||||||
|
std::make_shared<ov::op::v0::Constant>(element::i32, ov::Shape{4}, std::vector<int32_t>{1, 16, 50, 50});
|
||||||
|
auto axes_mapping = std::make_shared<ov::op::v0::Constant>(element::i32, ov::Shape{1}, std::vector<int32_t>{1});
|
||||||
|
auto broadcast_v3 =
|
||||||
|
std::make_shared<op::v3::Broadcast>(input, target_shape, axes_mapping, op::BroadcastType::EXPLICIT);
|
||||||
|
|
||||||
|
std::vector<StaticShape> static_input_shapes = {StaticShape{16}, StaticShape{4}, StaticShape{1}};
|
||||||
|
std::vector<StaticShape> static_output_shapes = {StaticShape{}};
|
||||||
|
shape_inference(broadcast_v3.get(), static_input_shapes, static_output_shapes, {});
|
||||||
|
ASSERT_EQ(static_output_shapes[0], StaticShape({1, 16, 50, 50}));
|
||||||
|
}
|
||||||
|
|
||||||
|
// BroadcastV1 test
|
||||||
|
|
||||||
|
TEST(StaticShapeInferenceTest, BroadcastV1PDPDTest) {
|
||||||
|
auto input = std::make_shared<ov::op::v0::Parameter>(element::f32, PartialShape{-1, -1});
|
||||||
|
auto target_shape = std::make_shared<ov::op::v0::Parameter>(element::i32, PartialShape{-1});
|
||||||
|
auto broadcast_v1 =
|
||||||
|
std::make_shared<op::v1::Broadcast>(input, target_shape, op::AutoBroadcastSpec(op::AutoBroadcastType::PDPD, 1));
|
||||||
|
|
||||||
|
int32_t target_shape_val[] = {2, 3, 6};
|
||||||
|
std::map<size_t, std::shared_ptr<ngraph::runtime::HostTensor>> constant_data;
|
||||||
|
constant_data[1] =
|
||||||
|
std::make_shared<ngraph::runtime::HostTensor>(ngraph::element::Type_t::i32, ov::Shape{3}, target_shape_val);
|
||||||
|
|
||||||
|
std::vector<StaticShape> static_input_shapes = {StaticShape{3, 1}, StaticShape{3}},
|
||||||
|
static_output_shapes = {StaticShape{}};
|
||||||
|
shape_inference(broadcast_v1.get(), static_input_shapes, static_output_shapes, constant_data);
|
||||||
|
ASSERT_EQ(static_output_shapes[0], StaticShape({2, 3, 6}));
|
||||||
|
|
||||||
|
static_input_shapes = {StaticShape{3, 1}, StaticShape{3}};
|
||||||
|
static_output_shapes = {StaticShape{}};
|
||||||
|
EXPECT_THROW(shape_inference(broadcast_v1.get(), static_input_shapes, static_output_shapes, {}), NodeValidationFailure);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(StaticShapeInferenceTest, BroadcastV1NumpyTest) {
|
||||||
|
auto input = std::make_shared<ov::op::v0::Parameter>(element::f32, PartialShape{-1, -1});
|
||||||
|
auto target_shape = std::make_shared<ov::op::v0::Parameter>(element::i32, PartialShape{-1});
|
||||||
|
auto broadcast_v1 = std::make_shared<op::v1::Broadcast>(input, target_shape);
|
||||||
|
|
||||||
|
int32_t target_shape_val[] = {2, 3, 6};
|
||||||
|
std::map<size_t, std::shared_ptr<ngraph::runtime::HostTensor>> constant_data;
|
||||||
|
constant_data[1] =
|
||||||
|
std::make_shared<ngraph::runtime::HostTensor>(ngraph::element::Type_t::i32, ov::Shape{3}, target_shape_val);
|
||||||
|
|
||||||
|
std::vector<StaticShape> static_input_shapes = {StaticShape{3, 1}, StaticShape{3}},
|
||||||
|
static_output_shapes = {StaticShape{}};
|
||||||
|
shape_inference(broadcast_v1.get(), static_input_shapes, static_output_shapes, constant_data);
|
||||||
|
ASSERT_EQ(static_output_shapes[0], StaticShape({2, 3, 6}));
|
||||||
|
|
||||||
|
static_input_shapes = {StaticShape{3, 1}, StaticShape{3}};
|
||||||
|
static_output_shapes = {StaticShape{}};
|
||||||
|
EXPECT_THROW(shape_inference(broadcast_v1.get(), static_input_shapes, static_output_shapes, {}), NodeValidationFailure);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(StaticShapeInferenceTest, BroadcastV1ExplicitTest) {
|
||||||
|
auto input = std::make_shared<ov::op::v0::Parameter>(element::f32, PartialShape{-1, -1});
|
||||||
|
auto target_shape = std::make_shared<ov::op::v0::Parameter>(element::i32, PartialShape{-1});
|
||||||
|
auto axes_mapping = std::make_shared<ov::op::v0::Parameter>(element::i32, PartialShape{-1});
|
||||||
|
auto broadcast_v1 = std::make_shared<op::v1::Broadcast>(input, target_shape, axes_mapping);
|
||||||
|
|
||||||
|
int32_t target_shape_val[] = {2, 3, 1};
|
||||||
|
int32_t axes_mapping_val[] = {1, 2};
|
||||||
|
std::map<size_t, std::shared_ptr<ngraph::runtime::HostTensor>> constant_data;
|
||||||
|
constant_data[1] =
|
||||||
|
std::make_shared<ngraph::runtime::HostTensor>(ngraph::element::Type_t::i32, ov::Shape{3}, target_shape_val);
|
||||||
|
constant_data[2] =
|
||||||
|
std::make_shared<ngraph::runtime::HostTensor>(ngraph::element::Type_t::i32, ov::Shape{2}, axes_mapping_val);
|
||||||
|
|
||||||
|
std::vector<StaticShape> static_input_shapes = {StaticShape{3, 1}, StaticShape{3}, StaticShape{2}},
|
||||||
|
static_output_shapes = {StaticShape{}};
|
||||||
|
shape_inference(broadcast_v1.get(), static_input_shapes, static_output_shapes, constant_data);
|
||||||
|
ASSERT_EQ(static_output_shapes[0], StaticShape({2, 3, 1}));
|
||||||
|
|
||||||
|
static_input_shapes = {StaticShape{3, 1}, StaticShape{3}, StaticShape{2}};
|
||||||
|
static_output_shapes = {StaticShape{}};
|
||||||
|
EXPECT_THROW(shape_inference(broadcast_v1.get(), static_input_shapes, static_output_shapes, {}), NodeValidationFailure);
|
||||||
|
}
|
Loading…
Reference in New Issue
Block a user