[GPU] Update Broadcast operation to use dynamic input (#14264)

* Refactor to merge extend_to_6d and get_agnostic_updated_params functions
This commit is contained in:
Kelvin Choi 2023-01-26 04:38:33 +09:00 committed by GitHub
parent 70a0c713d3
commit 8575ad690c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
14 changed files with 712 additions and 58 deletions

View File

@ -82,10 +82,13 @@ std::vector<layout> broadcast_inst::calc_output_layouts(broadcast_node const& /*
static_cast<void*>(target_shape.data()));
const_data.emplace(1, target_shape_tensor);
ov::op::v3::shape_infer(&op, input_shapes, output_shapes, const_data);
} else {
// Pattern shape is set as second input. Even though the input is scalar, the shape should be propagaterd as dynamic
auto output_rank = input_shapes[0].size();
output_shapes[0] = ShapeType::dynamic(std::max(output_rank, static_cast<size_t>(1)));
} else if (impl_param.input_layouts.size() >= 2) {
auto input1 = impl_param.get_input_layout(1);
int output_rank = input1.get<ShapeType>().size();
if (input1.is_static()) {
output_rank = input1.get_dim(0); // target shape rank is set as second input.
}
output_shapes[0] = ShapeType::dynamic(std::max(output_rank, static_cast<int>(1)));
}
format output_format = format::adjust_to_rank(input0_layout.format, output_shapes[0].size());
@ -95,6 +98,88 @@ std::vector<layout> broadcast_inst::calc_output_layouts(broadcast_node const& /*
template std::vector<layout> broadcast_inst::calc_output_layouts<ov::PartialShape>(broadcast_node const& node, const kernel_impl_params& impl_param);
std::vector<size_t> broadcast_inst::extend_input_shape_to_6d(kernel_impl_params const& orig_impl_param, int32_t input_idx) {
ov::PartialShape ps;
auto orig_input_layout = orig_impl_param.get_input_layout();
auto updated_param = orig_impl_param;
const auto& primitive = updated_param.typed_desc<broadcast>();
// Extend input dimensions with ones
auto i_layout = updated_param.input_layouts[0];
auto o_layout = updated_param.output_layouts[0];
auto input_shape = i_layout.get_shape();
auto output_shape = o_layout.get_shape();
if (primitive->axes_mapping.empty()) {
auto broadcastable = [&](layout a, layout b) {
auto dims_a = a.get_dims();
auto dims_b = b.get_dims();
size_t min_size = (dims_a.size() < dims_b.size()) ? dims_a.size(): dims_b.size();
for (size_t i = 0; i < min_size; i++) {
if (!(dims_a[i] == 1 || dims_b[i] == 1 || dims_a[i] == dims_b[i])) {
return false;
}
}
return true;
};
auto input_rank = input_shape.size();
auto output_rank = output_shape.size();
if (!broadcastable(i_layout, o_layout)) {
input_shape.insert(input_shape.begin(), output_rank - input_rank, 1ul);
}
} else {
// If axis_mapping is specified, then ones are inserted according to it.
ov::Shape tmp_shape;
int prev_axis = -1;
int next_axis = -1;
size_t currentRank = 0;
int axe_idx = 0;
for (auto& axis : primitive->axes_mapping) {
prev_axis = next_axis;
next_axis = static_cast<int>(axis);
int ones_count = std::max(next_axis - prev_axis - 1, 0);
tmp_shape.insert(tmp_shape.begin() + currentRank, ones_count, 1ul);
tmp_shape.push_back(input_shape[axe_idx]); // Consider the Broadcast kernel 'broadcast' input to output shape
currentRank += ones_count + 1;
axe_idx += 1;
}
// insert 1 to match with output shape
if (o_layout.get_rank() > tmp_shape.size()) {
tmp_shape.insert(tmp_shape.end(), o_layout.get_rank() - tmp_shape.size(), 1ul);
}
input_shape = tmp_shape;
}
ps = ov::PartialShape(input_shape);
if (ps.size() < 4) {
ps.insert(ps.end(), 4 - ps.size(), ov::Dimension(1));
}
layout l(ps, data_types::i32, format::get_default_format(ps.size()));
return l.transform(format::bfwzyx).to_shape();
}
std::vector<size_t> broadcast_inst::extend_output_shape_to_6d(kernel_impl_params const& orig_impl_param, int32_t output_idx) {
ov::PartialShape ps = orig_impl_param.get_output_layout(output_idx).get_partial_shape();
if (ps.size() < 4) {
ps.insert(ps.end(), 4 - ps.size(), ov::Dimension(1));
}
layout l(ps, data_types::i32, format::get_default_format(ps.size()));
return l.transform(format::bfwzyx).to_shape();
}
std::string broadcast_inst::to_string(broadcast_node const& node) {
auto desc = node.get_primitive();
auto node_info = node.desc_to_json();

View File

@ -111,6 +111,28 @@ std::vector<layout> gemm_inst::calc_output_layouts(gemm_node const& /*node*/, co
template std::vector<layout> gemm_inst::calc_output_layouts<ov::PartialShape>(gemm_node const& node, const kernel_impl_params& impl_param);
std::vector<size_t> gemm_inst::extend_input_shape_to_6d(kernel_impl_params const& orig_impl_param, int32_t input_idx) {
ov::PartialShape ps = orig_impl_param.get_input_layout(input_idx).get_partial_shape();
if (ps.size() < 4) {
ps.insert(ps.begin(), 4 - ps.size(), ov::Dimension(1));
}
layout l(ps, data_types::i32, format::get_default_format(ps.size()));
return l.transform(format::bfwzyx).to_shape();
}
std::vector<size_t> gemm_inst::extend_output_shape_to_6d(kernel_impl_params const& orig_impl_param, int32_t output_idx) {
ov::PartialShape ps = orig_impl_param.get_output_layout(output_idx).get_partial_shape();
if (ps.size() < 4) {
ps.insert(ps.begin(), 4 - ps.size(), ov::Dimension(1));
}
layout l(ps, data_types::i32, format::get_default_format(ps.size()));
return l.transform(format::bfwzyx).to_shape();
}
std::vector<layout> gemm_inst::transform_input_layouts(const std::shared_ptr<const gemm> primitive,
const std::vector<layout>& input_layouts,
const layout& output_layout) {

View File

@ -49,6 +49,104 @@ struct broadcast_impl : typed_primitive_impl_ocl<broadcast> {
}
}
// Extend input dimensions with ones
auto i_layout = impl_param.input_layouts[0];
auto o_layout = impl_param.output_layouts[0];
if (i_layout.is_static() && o_layout.is_static()) {
auto data_shape = i_layout.get_shape();
auto output_shape = o_layout.get_shape();
if (primitive->axes_mapping.empty()) {
auto broadcastable = [&](layout a, layout b) {
auto dims_a = a.get_dims();
auto dims_b = b.get_dims();
size_t min_size = (dims_a.size() < dims_b.size()) ? dims_a.size(): dims_b.size();
for (size_t i = 0; i < min_size; i++) {
if (!(dims_a[i] == 1 || dims_b[i] == 1 || dims_a[i] == dims_b[i])) {
return false;
}
}
return true;
};
auto input_rank = data_shape.size();
auto output_rank = output_shape.size();
if (!broadcastable(i_layout, o_layout)) {
data_shape.insert(data_shape.begin(), output_rank - input_rank, 1ul);
}
} else {
// If axis_mapping is specified, then ones are inserted according to it.
ov::Shape tmp_shape;
int prev_axis = -1;
int next_axis = -1;
size_t currentRank = 0;
int axe_idx = 0;
for (auto& axis : primitive->axes_mapping) {
prev_axis = next_axis;
next_axis = static_cast<int>(axis);
int ones_count = std::max(next_axis - prev_axis - 1, 0);
tmp_shape.insert(tmp_shape.begin() + currentRank, ones_count, 1ul);
tmp_shape.push_back(data_shape[axe_idx]); // Consider the Broadcast kernel 'broadcast' input to output shape
currentRank += ones_count + 1;
axe_idx += 1;
}
if (o_layout.get_rank() > tmp_shape.size()) {
tmp_shape.insert(tmp_shape.end(), o_layout.get_rank() - tmp_shape.size(), 1ul);
}
data_shape = tmp_shape;
}
layout new_layout = i_layout;
new_layout.format = format::adjust_to_rank(i_layout.format, data_shape.size());
new_layout.set_partial_shape(data_shape);
params.inputs[0] = convert_data_tensor(new_layout);
} else {
// dynamic input
if (primitive->axes_mapping.empty()) {
ov::PartialShape i_shape = i_layout.get_partial_shape();
ov::PartialShape o_shape = o_layout.get_partial_shape();
auto i_rank = i_shape.size();
auto o_rank = o_shape.size();
i_shape.insert(i_shape.begin(), o_rank - i_rank, 1ul);
layout new_layout = i_layout;
new_layout.format = format::adjust_to_rank(i_layout.format, i_shape.size());
new_layout.set_partial_shape(i_shape);
params.inputs[0] = convert_data_tensor(new_layout);
} else {
// insert 1 to extend dimensions by axes_mapping
ov::Shape tmp_shape;
size_t idx = 0;
for (auto& axis : primitive->axes_mapping) {
if (idx == axis) {
tmp_shape.insert(tmp_shape.begin() + idx, 1, -1);
idx += 1;
} else {
tmp_shape.insert(tmp_shape.begin() + idx, axis - idx, 1);
idx = axis;
tmp_shape.insert(tmp_shape.begin() + idx, 1, -1);
idx += 1;
}
}
// insert 1 to match with output shape
if (o_layout.get_rank() > tmp_shape.size()) {
tmp_shape.insert(tmp_shape.end(), o_layout.get_rank() - tmp_shape.size(), 1ul);
}
layout new_layout = i_layout;
new_layout.format = format::adjust_to_rank(i_layout.format, tmp_shape.size());
new_layout.set_partial_shape(tmp_shape);
params.inputs[0] = convert_data_tensor(new_layout);
}
}
return {params, optional_params};
}

View File

@ -37,6 +37,8 @@ public:
template<typename ShapeType>
static std::vector<layout> calc_output_layouts(broadcast_node const& /*node*/, const kernel_impl_params& impl_param);
static layout calc_output_layout(broadcast_node const& node, kernel_impl_params const& impl_param);
static std::vector<size_t> extend_input_shape_to_6d(kernel_impl_params const& orig_impl_param, int32_t input_idx);
static std::vector<size_t> extend_output_shape_to_6d(kernel_impl_params const& orig_impl_param, int32_t output_idx);
static std::string to_string(broadcast_node const& node);
typed_primitive_inst(network& network, broadcast_node const& node);
};

View File

@ -32,6 +32,8 @@ public:
template<typename ShapeType>
static std::vector<layout> calc_output_layouts(gemm_node const& /*node*/, const kernel_impl_params& impl_param);
static layout calc_output_layout(gemm_node const& node, kernel_impl_params const& impl_param);
static std::vector<size_t> extend_input_shape_to_6d(kernel_impl_params const& orig_impl_param, int32_t input_idx);
static std::vector<size_t> extend_output_shape_to_6d(kernel_impl_params const& orig_impl_param, int32_t output_idx);
static std::string to_string(gemm_node const& node);
static std::vector<layout> transform_input_layouts(const std::shared_ptr<const gemm> primitive,

View File

@ -406,6 +406,26 @@ public:
return std::move(orig_impl_param);
}
static std::vector<size_t> extend_input_shape_to_6d(kernel_impl_params const& orig_impl_param, int32_t input_idx) {
ov::PartialShape ps = orig_impl_param.get_input_layout(input_idx).get_partial_shape();
if (ps.size() < 4) {
ps.insert(ps.end(), 4 - ps.size(), ov::Dimension(1));
}
layout l(ps, data_types::i32, format::get_default_format(ps.size()));
return l.transform(format::bfwzyx).to_shape();
}
static std::vector<size_t> extend_output_shape_to_6d(kernel_impl_params const& orig_impl_param, int32_t output_idx) {
ov::PartialShape ps = orig_impl_param.get_output_layout(output_idx).get_partial_shape();
if (ps.size() < 4) {
ps.insert(ps.end(), 4 - ps.size(), ov::Dimension(1));
}
layout l(ps, data_types::i32, format::get_default_format(ps.size()));
return l.transform(format::bfwzyx).to_shape();
}
typed_primitive_inst_base(network& network, typed_node const& node)
: typed_primitive_inst_base(network, node, do_allocate_memory(node)) {}

View File

@ -44,6 +44,8 @@ struct primitive_type {
virtual layout calc_output_layout(const program_node& node, const kernel_impl_params& params) const = 0;
virtual std::vector<layout> calc_output_layouts(const program_node& node, const kernel_impl_params& impl_param) const = 0;
virtual kernel_impl_params get_fake_aligned_params(kernel_impl_params const& orig_impl_param) const = 0;
virtual std::vector<size_t> extend_input_shape_to_6d(kernel_impl_params const& orig_impl_param, int32_t input_idx) const = 0;
virtual std::vector<size_t> extend_output_shape_to_6d(kernel_impl_params const& orig_impl_param, int32_t output_idx) const = 0;
virtual std::string to_string(const program_node& node) const = 0;
};
} // namespace cldnn

View File

@ -89,6 +89,12 @@ struct primitive_type_base : primitive_type {
kernel_impl_params get_fake_aligned_params(kernel_impl_params const& orig_impl_param) const override {
return typed_primitive_inst<PType>::get_fake_aligned_params(orig_impl_param);
}
std::vector<size_t> extend_input_shape_to_6d(kernel_impl_params const& orig_impl_param, int32_t input_idx) const override {
return typed_primitive_inst<PType>::extend_input_shape_to_6d(orig_impl_param, input_idx);
}
std::vector<size_t> extend_output_shape_to_6d(kernel_impl_params const& orig_impl_param, int32_t output_idx) const override {
return typed_primitive_inst<PType>::extend_output_shape_to_6d(orig_impl_param, output_idx);
}
std::string to_string(const cldnn::program_node& node) const override {
OPENVINO_ASSERT(node.type() == this, "[GPU] primitive_type_base::to_string: primitive type mismatch");
return typed_primitive_inst<PType>::to_string(node);

View File

@ -35,6 +35,8 @@ public:
template<typename ShapeType>
static std::vector<layout> calc_output_layouts(shape_of_node const& /*node*/, const kernel_impl_params& impl_param);
static layout calc_output_layout(shape_of_node const& node, kernel_impl_params const& impl_param);
static std::vector<size_t> extend_input_shape_to_6d(kernel_impl_params const& orig_impl_param, int32_t input_idx);
static std::vector<size_t> extend_output_shape_to_6d(kernel_impl_params const& orig_impl_param, int32_t output_idx);
static std::string to_string(shape_of_node const& node);
typed_primitive_inst(network& network, shape_of_node const& node);

View File

@ -267,21 +267,6 @@ void primitive_inst::realloc_if_needed() {
void primitive_inst::update_impl() {
GPU_DEBUG_PROFILED_STAGE(instrumentation::pipeline_stage::update_implementation);
auto prev_impl_str = _impl != nullptr ? _impl->get_kernel_name() : "nullptr";
auto extend_to_6d = [this](ov::PartialShape ps) -> std::vector<size_t> {
// For shape_of we extend shape with 1-s to 6d rank to make kernel simpler
if (_node->is_type<shape_of>()) {
ps.insert(ps.end(), 6 - ps.size(), ov::Dimension(1));
return ps.to_shape();
}
if (ps.size() < 4) {
if (_node->is_type<gemm>())
ps.insert(ps.begin(), 4 - ps.size(), ov::Dimension(1));
else
ps.insert(ps.end(), 4 - ps.size(), ov::Dimension(1));
}
layout l(ps, data_types::i32, format::get_default_format(ps.size()));
return l.transform(format::bfwzyx).to_shape();
};
auto get_layout_key = [&](const kernel_impl_params& params) -> size_t {
size_t seed = 0;
@ -301,21 +286,23 @@ void primitive_inst::update_impl() {
return seed;
};
auto update_shape_info = [this, extend_to_6d, prev_impl_str](const kernel_impl_params& params) {
auto update_shape_info = [this, prev_impl_str](const kernel_impl_params& params) {
mem_lock<int32_t> lock(_shape_info_memory, _network.get_stream());
size_t offset = 0;
for (size_t i = 0; i < _node->get_dependencies().size(); i++) {
if (_node->get_dependency(i).get_output_layout().is_dynamic()) {
auto input_shape = extend_to_6d(params.get_input_layout(i).get_partial_shape());
auto input_shape = _node->type()->extend_input_shape_to_6d(params, i);
for (size_t j = 0; j < input_shape.size(); j++)
lock[offset++] = static_cast<int32_t>(input_shape[j]);
}
}
if (_node->get_output_layout().is_dynamic()) {
auto output_shape = extend_to_6d(params.get_output_layout().get_partial_shape());
for (size_t j = 0; j < output_shape.size(); j++)
lock[offset++] = static_cast<int32_t>(output_shape[j]);
for (size_t i = 0; i < _node->get_output_layouts().size(); i++) {
if (_node->get_output_layout(i).is_dynamic()) {
auto output_shape = _node->type()->extend_output_shape_to_6d(params, i);
for (size_t j = 0; j < output_shape.size(); j++)
lock[offset++] = static_cast<int32_t>(output_shape[j]);
}
}
std::stringstream s;
s << "shapes: ";
@ -365,6 +352,7 @@ void primitive_inst::update_impl() {
_impl = _dynamic_impl->clone();
_impl->update_dispatch_data(updated_params);
update_shape_info(updated_params);
} else {
_impl = _node->type()->choose_impl(*_node, updated_params);

View File

@ -47,6 +47,18 @@ std::vector<layout> shape_of_inst::calc_output_layouts(shape_of_node const& /*no
template std::vector<layout> shape_of_inst::calc_output_layouts<ov::PartialShape>(shape_of_node const& node, const kernel_impl_params& impl_param);
std::vector<size_t> shape_of_inst::extend_input_shape_to_6d(kernel_impl_params const& orig_impl_param, int32_t input_idx) {
ov::PartialShape ps = orig_impl_param.get_input_layout(input_idx).get_partial_shape();
ps.insert(ps.end(), 6 - ps.size(), ov::Dimension(1));
return ps.to_shape();
}
std::vector<size_t> shape_of_inst::extend_output_shape_to_6d(kernel_impl_params const& orig_impl_param, int32_t output_idx) {
ov::PartialShape ps = orig_impl_param.get_output_layout(output_idx).get_partial_shape();
ps.insert(ps.end(), 6 - ps.size(), ov::Dimension(1));
return ps.to_shape();
}
std::string shape_of_inst::to_string(shape_of_node const& node) {
auto node_info = node.desc_to_json();
auto desc = node.get_primitive();

View File

@ -33,26 +33,45 @@ BroadcastKernelBase::DispatchData BroadcastKernelBase::SetDefault(const broadcas
static std::string GetInputBlockND(const broadcast_params& params) {
const auto& input = params.inputs[0];
auto input_dims = input.LogicalDims();
std::reverse(input_dims.begin(), input_dims.end());
const int rank = static_cast<int>(input_dims.size());
std::vector<size_t> block_nd(rank + 1);
std::vector<std::string> block_nd_s(rank + 1);
block_nd[rank] = 1;
block_nd_s[rank] = "1";
for (int idx = (rank - 1); idx >= 0; idx--) {
block_nd[idx] = input_dims[idx] * block_nd[idx + 1];
block_nd_s[idx] = "(" + toCodeString(input.GetDims()[idx], rank - idx) + " * " + block_nd_s[idx + 1] + ")";
}
std::stringstream s;
for (int i = 0; i < (rank + 1); i++) {
if (i < rank) {
s << (input.is_dynamic() ? block_nd_s[i] : std::to_string(block_nd[i])) << ",";
} else {
s << (input.is_dynamic() ? block_nd_s[i] : std::to_string(block_nd[i]));
auto input_dims = input.LogicalDims();
std::reverse(input_dims.begin(), input_dims.end());
if (input.is_dynamic()) {
const int rank = static_cast<int>(input_dims.size());
std::vector<std::string> block_nd_s(rank + 1);
block_nd_s[rank] = "1";
for (int idx = (rank - 1); idx >= 0; idx--) {
int shape_info_idx = idx;
if (idx >= 2) {
shape_info_idx += (6 - rank);
}
block_nd_s[idx] = "(" + toCodeString(input.GetDims()[rank - idx - 1], shape_info_idx) + " * " + block_nd_s[idx + 1] + ")";
}
for (int i = 0; i < (rank + 1); i++) {
s << block_nd_s[i];
if (i < rank) {
s << ",";
}
}
} else {
const int rank = static_cast<int>(input_dims.size());
std::vector<size_t> block_nd(rank + 1);
block_nd[rank] = 1;
for (int idx = (rank - 1); idx >= 0; idx--) {
block_nd[idx] = input_dims[idx] * block_nd[idx + 1];
}
for (int i = 0; i < (rank + 1); i++) {
s << std::to_string(block_nd[i]);
if (i < rank) {
s << ",";
}
}
}
auto str_result = s.str();
return str_result;
}

View File

@ -50,23 +50,6 @@ static void CreateCommonBroadcastOp(Program& p, const std::shared_ptr<ngraph::No
if (axis_mapping.empty()) {
// If axis_mapping is not specified, then we prepend shape with neccesary count of 1-s
inputShape.insert(inputShape.begin(), output_rank - input_rank, 1ul);
} else {
// If axis_mapping is specified, then ones are inserted according to it.
ngraph::Shape tmp_shape;
int prev_axis = -1;
int next_axis = -1;
size_t currentRank = 0;
for (auto& axis : axis_mapping) {
prev_axis = next_axis;
next_axis = static_cast<int>(axis);
int ones_count = std::max(next_axis - prev_axis - 1, 0);
tmp_shape.insert(tmp_shape.begin() + currentRank, ones_count, 1ul);
tmp_shape.push_back(outputShape[axis]);
currentRank += ones_count + 1;
}
inputShape = tmp_shape;
}
auto targetShape = tensor_from_dims(inputShape);

View File

@ -0,0 +1,413 @@
// Copyright (C) 2018-2022 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include "shared_test_classes/single_layer/broadcast.hpp"
#include "shared_test_classes/base/ov_subgraph.hpp"
#include "ie_precision.hpp"
#include "ngraph_functions/builders.hpp"
#include <common_test_utils/ov_tensor_utils.hpp>
#include <string>
using namespace ngraph;
using namespace InferenceEngine;
using namespace ov::test;
namespace GPULayerTestsDefinitions {
typedef std::tuple<
std::vector<InputShape>, // Shapes
std::vector<int64_t>, // Target shapes
std::vector<int64_t>, // Axes mapping
ov::op::BroadcastType, // Broadcast mode
ov::element::Type_t, // Network precision
std::vector<bool>, // Const inputs
std::string // Device name
> BroadcastLayerTestParamsSet;
class BroadcastLayerGPUTest : public testing::WithParamInterface<BroadcastLayerTestParamsSet>,
virtual public SubgraphBaseTest {
public:
static std::string getTestCaseName(testing::TestParamInfo<BroadcastLayerTestParamsSet> obj) {
std::vector<ov::test::InputShape> inputShapes;
std::vector<int64_t> targetShapes, axesMapping;
ov::op::BroadcastType mode;
ov::element::Type_t netPrecision;
std::vector<bool> isConstInputs;
std::string deviceName;
std::tie(inputShapes, targetShapes, axesMapping, mode, netPrecision, isConstInputs, deviceName) = obj.param;
std::ostringstream result;
result << "IS=(";
for (const auto& shape : inputShapes) {
result << CommonTestUtils::partialShape2str({shape.first}) << "_";
}
result << ")_TS=(";
for (const auto& shape : inputShapes) {
for (const auto& item : shape.second) {
result << CommonTestUtils::vec2str(item) << "_";
}
}
result << "targetShape=" << CommonTestUtils::vec2str(targetShapes) << "_";
result << "axesMapping=" << CommonTestUtils::vec2str(axesMapping) << "_";
result << "mode=" << mode << "_";
result << "netPrec=" << netPrecision << "_";
result << "constIn=(" << (isConstInputs[0] ? "True" : "False") << "." << (isConstInputs[1] ? "True" : "False") << ")_";
result << "trgDevice=" << deviceName;
return result.str();
}
protected:
std::vector<int64_t> targetShape, axesMapping;
void SetUp() override {
std::vector<InputShape> inputShapes;
ov::op::BroadcastType mode;
ov::element::Type_t netPrecision;
std::vector<bool> isConstInput;
std::tie(inputShapes, targetShape, axesMapping, mode, netPrecision, isConstInput, targetDevice) = this->GetParam();
bool isTargetShapeConst = isConstInput[0];
bool isAxesMapConst = isConstInput[1];
const auto targetShapeRank = targetShape.size();
const auto axesMappingRank = axesMapping.size();
if (inputShapes.front().first.rank() != 0) {
inputDynamicShapes.push_back(inputShapes.front().first);
if (!isTargetShapeConst) {
inputDynamicShapes.push_back({ static_cast<int64_t>(targetShape.size()) });
}
if (!isAxesMapConst) {
inputDynamicShapes.push_back({ static_cast<int64_t>(axesMapping.size()) });
}
}
const size_t targetStaticShapeSize = inputShapes.front().second.size();
targetStaticShapes.resize(targetStaticShapeSize);
for (size_t i = 0lu; i < targetStaticShapeSize; ++i) {
targetStaticShapes[i].push_back(inputShapes.front().second[i]);
if (!isTargetShapeConst)
targetStaticShapes[i].push_back({ targetShape.size() });
if (!isAxesMapConst)
targetStaticShapes[i].push_back({ axesMapping.size() });
}
ov::ParameterVector functionParams;
if (inputDynamicShapes.empty()) {
functionParams.push_back(std::make_shared<ov::op::v0::Parameter>(netPrecision, targetStaticShapes.front().front()));
} else {
functionParams.push_back(std::make_shared<ov::op::v0::Parameter>(netPrecision, inputDynamicShapes.front()));
if (!isTargetShapeConst) {
functionParams.push_back(std::make_shared<ov::op::v0::Parameter>(ov::element::i64, inputDynamicShapes[1]));
functionParams.back()->set_friendly_name("targetShape");
}
if (!isAxesMapConst) {
functionParams.push_back(std::make_shared<ov::op::v0::Parameter>(ov::element::i64, inputDynamicShapes.back()));
functionParams.back()->set_friendly_name("axesMapping");
}
}
functionParams.front()->set_friendly_name("data");
auto paramOuts = helpers::convert2OutputVector(helpers::castOps2Nodes<ov::op::v0::Parameter>(functionParams));
std::shared_ptr<ov::op::v3::Broadcast> broadcastOp;
if (mode == ov::op::BroadcastType::EXPLICIT) {
std::shared_ptr<ov::Node> targetShapeOp;
std::shared_ptr<ov::Node> axesMappingOp;
if (isTargetShapeConst) {
targetShapeOp = ov::op::v0::Constant::create(ov::element::i64, {targetShapeRank}, targetShape);
} else {
targetShapeOp = functionParams[1];
}
if (isAxesMapConst) {
axesMappingOp = ov::op::v0::Constant::create(ov::element::i64, {axesMappingRank}, axesMapping);
} else {
axesMappingOp = functionParams.size() > 2 ? functionParams[2] : functionParams[1];
}
broadcastOp = std::make_shared<ov::op::v3::Broadcast>(paramOuts[0],
targetShapeOp,
axesMappingOp,
mode);
} else if (mode == ov::op::BroadcastType::NUMPY) {
if (isTargetShapeConst) {
auto targetShapeConst = ov::op::v0::Constant::create(ov::element::i64, {targetShapeRank}, targetShape);
broadcastOp = std::make_shared<ov::op::v3::Broadcast>(paramOuts[0],
targetShapeConst,
mode);
} else {
broadcastOp = std::make_shared<ov::op::v3::Broadcast>(paramOuts[0],
paramOuts[1],
mode);
}
}
auto makeFunction = [](ParameterVector &params, const std::shared_ptr<Node> &lastNode) {
ResultVector results;
for (int i = 0; i < lastNode->get_output_size(); i++)
results.push_back(std::make_shared<opset1::Result>(lastNode->output(i)));
return std::make_shared<Function>(results, params, "BroadcastLayerGPUTest");
};
function = makeFunction(functionParams, broadcastOp);
}
void generate_inputs(const std::vector<ngraph::Shape>& targetInputStaticShapes) override {
inputs.clear();
const auto& funcInputs = function->inputs();
for (size_t i = 0lu; i < funcInputs.size(); i++) {
const auto& funcInput = funcInputs[i];
ov::Tensor tensor;
if (funcInput.get_node()->get_friendly_name() == "targetShape") {
tensor = ov::Tensor{ov::element::i64, targetInputStaticShapes[i]};
auto data = tensor.data<ov::element_type_traits<ov::element::i64>::value_type>();
for (size_t i = 0lu; i < targetShape.size(); i++) {
data[i] = targetShape[i];
}
} else if (funcInput.get_node()->get_friendly_name() == "axesMapping") {
tensor = ov::Tensor{ov::element::i64, targetInputStaticShapes[i]};
auto data = tensor.data<ov::element_type_traits<ov::element::i64>::value_type>();
for (size_t i = 0lu; i < axesMapping.size(); i++) {
data[i] = axesMapping[i];
}
} else {
if (funcInput.get_element_type().is_real()) {
tensor = ov::test::utils::create_and_fill_tensor(
funcInput.get_element_type(), targetInputStaticShapes[i], 10, 0, 1000);
} else {
tensor = ov::test::utils::create_and_fill_tensor(funcInput.get_element_type(), targetInputStaticShapes[i]);
}
}
inputs.insert({funcInput.get_node_shared_ptr(), tensor});
}
}
};
TEST_P(BroadcastLayerGPUTest, CompareWithRefs) {
SKIP_IF_CURRENT_TEST_IS_DISABLED()
run();
}
namespace {
const std::vector<ov::element::Type_t> inputPrecisionsFloat = {
ov::element::f32,
};
const std::vector<ov::element::Type_t> inputPrecisionsInt = {
ov::element::i32,
};
const std::vector<std::vector<bool>> inputConstants = {
{true, true},
{false, true},
#if 0 // axes map input doesn't supported parameter input
{true, false},
{false, false},
#endif
};
// ==============================================================================
// 1D
const std::vector<std::vector<InputShape>> dynamicInputShapes1D_explicit = {
{
{ {-1}, {{7}} }
},
};
INSTANTIATE_TEST_CASE_P(smoke_broadcast_1d_explicit_compareWithRefs_dynamic,
BroadcastLayerGPUTest,
::testing::Combine(
::testing::ValuesIn(dynamicInputShapes1D_explicit),
::testing::ValuesIn(std::vector<std::vector<int64_t>>{{4, 7, 3}, {2, 7, 4, 3, 6}}),
::testing::Values(std::vector<int64_t>{1}),
::testing::Values(ov::op::BroadcastType::EXPLICIT),
::testing::ValuesIn(inputPrecisionsFloat),
::testing::ValuesIn(inputConstants),
::testing::Values(CommonTestUtils::DEVICE_GPU)),
BroadcastLayerGPUTest::getTestCaseName);
const std::vector<std::vector<InputShape>> dynamicInputShapes1D = {
{
{ {-1}, {{1}, {7}} }
},
};
INSTANTIATE_TEST_CASE_P(smoke_broadcast_1d_numpy_compareWithRefs_dynamic,
BroadcastLayerGPUTest,
::testing::Combine(
::testing::ValuesIn(dynamicInputShapes1D),
::testing::ValuesIn(std::vector<std::vector<int64_t>>{{7}, {2, 4, 7}, {2, 3, 4, 7}}),
::testing::Values(std::vector<int64_t>{}),
::testing::Values(ov::op::BroadcastType::NUMPY),
::testing::ValuesIn(inputPrecisionsInt),
::testing::ValuesIn(inputConstants),
::testing::Values(CommonTestUtils::DEVICE_GPU)),
BroadcastLayerGPUTest::getTestCaseName);
// ==============================================================================
// 2D
const std::vector<std::vector<InputShape>> dynamicInputShapes2D_explicit = {
{
{ {-1, -1}, {{3, 5}} }
}
};
INSTANTIATE_TEST_CASE_P(smoke_broadcast_2d_explicit_compareWithRefs_dynamic,
BroadcastLayerGPUTest,
::testing::Combine(
::testing::ValuesIn(dynamicInputShapes2D_explicit),
::testing::ValuesIn(std::vector<std::vector<int64_t>>{{3, 4, 5}, {3, 6, 5, 7}}),
::testing::Values(std::vector<int64_t>{0, 2}),
::testing::Values(ov::op::BroadcastType::EXPLICIT),
::testing::ValuesIn(inputPrecisionsInt),
::testing::ValuesIn(inputConstants),
::testing::Values(CommonTestUtils::DEVICE_GPU)),
BroadcastLayerGPUTest::getTestCaseName);
const std::vector<std::vector<InputShape>> dynamicInputShapes2D = {
{
{ {-1, -1}, {{3, 1}, {3, 5}} }
}
};
INSTANTIATE_TEST_CASE_P(smoke_broadcast_2d_numpy_compareWithRefs_dynamic,
BroadcastLayerGPUTest,
::testing::Combine(
::testing::ValuesIn(dynamicInputShapes2D),
::testing::ValuesIn(std::vector<std::vector<int64_t>>{{3, 5}, {2, 3, 5}}),
::testing::Values(std::vector<int64_t>{}),
::testing::Values(ov::op::BroadcastType::NUMPY),
::testing::ValuesIn(inputPrecisionsFloat),
::testing::ValuesIn(inputConstants),
::testing::Values(CommonTestUtils::DEVICE_GPU)),
BroadcastLayerGPUTest::getTestCaseName);
// ==============================================================================
// 3D
const std::vector<std::vector<InputShape>> dynamicInputShapes3D_explicit = {
{
{ {-1, -1, -1}, {{4, 5, 6}} }
},
};
INSTANTIATE_TEST_CASE_P(smoke_broadcast_3d_explicit_compareWithRefs_dynamic,
BroadcastLayerGPUTest,
::testing::Combine(
::testing::ValuesIn(dynamicInputShapes3D_explicit),
::testing::ValuesIn(std::vector<std::vector<int64_t>>{{4, 5, 6}, {4, 5, 6, 2, 3}}),
::testing::Values(std::vector<int64_t>{0, 1, 2}),
::testing::Values(ov::op::BroadcastType::EXPLICIT),
::testing::ValuesIn(inputPrecisionsFloat),
::testing::ValuesIn(inputConstants),
::testing::Values(CommonTestUtils::DEVICE_GPU)),
BroadcastLayerGPUTest::getTestCaseName);
const std::vector<std::vector<InputShape>> dynamicInputShapes3D = {
{
{ {-1, -1, -1}, {{4, 5, 1}, {1, 5, 1}} }
},
};
INSTANTIATE_TEST_CASE_P(smoke_broadcast_3d_numpy_compareWithRefs_dynamic,
BroadcastLayerGPUTest,
::testing::Combine(
::testing::ValuesIn(dynamicInputShapes3D),
::testing::ValuesIn(std::vector<std::vector<int64_t>>{{4, 5, 6}, {2, 4, 5, 1}}),
::testing::Values(std::vector<int64_t>{}),
::testing::Values(ov::op::BroadcastType::NUMPY),
::testing::ValuesIn(inputPrecisionsInt),
::testing::ValuesIn(inputConstants),
::testing::Values(CommonTestUtils::DEVICE_GPU)),
BroadcastLayerGPUTest::getTestCaseName);
// ==============================================================================
// 4D
const std::vector<std::vector<InputShape>> dynamicInputShapes4D_explicit = {
{
{ {-1, -1, -1, -1}, {{1, 16, 1, 7}} }
},
};
INSTANTIATE_TEST_CASE_P(smoke_broadcast_4d_explicit_compareWithRefs_dynamic,
BroadcastLayerGPUTest,
::testing::Combine(
::testing::ValuesIn(dynamicInputShapes4D_explicit),
::testing::ValuesIn(std::vector<std::vector<int64_t>>{{1, 16, 2, 1, 7}, {1, 16, 2, 1, 7, 3}}),
::testing::Values(std::vector<int64_t>{0, 1, 3, 4}),
::testing::Values(ov::op::BroadcastType::EXPLICIT),
::testing::ValuesIn(inputPrecisionsInt),
::testing::ValuesIn(inputConstants),
::testing::Values(CommonTestUtils::DEVICE_GPU)),
BroadcastLayerGPUTest::getTestCaseName);
const std::vector<std::vector<InputShape>> dynamicInputShapes4D = {
{
{ {-1, -1, -1, -1}, {{2, 1, 1, 3}, {1, 4, 1, 3}} }
},
};
INSTANTIATE_TEST_CASE_P(smoke_broadcast_4d_numpy_compareWithRefs_dynamic,
BroadcastLayerGPUTest,
::testing::Combine(
::testing::ValuesIn(dynamicInputShapes4D),
::testing::ValuesIn(std::vector<std::vector<int64_t>>{{2, 4, 1, 3}, {3, 2, 2, 4, 1, 3}}),
::testing::Values(std::vector<int64_t>{}),
::testing::Values(ov::op::BroadcastType::NUMPY),
::testing::ValuesIn(inputPrecisionsFloat),
::testing::ValuesIn(inputConstants),
::testing::Values(CommonTestUtils::DEVICE_GPU)),
BroadcastLayerGPUTest::getTestCaseName);
// ==============================================================================
// 5D
const std::vector<std::vector<InputShape>> dynamicInputShapes5D_explicit = {
{
{ {-1, -1, -1, -1, -1}, {{2, 3, 4, 5, 6}} }
},
};
INSTANTIATE_TEST_CASE_P(smoke_broadcast_5d_explicit_compareWithRefs_dynamic,
BroadcastLayerGPUTest,
::testing::Combine(
::testing::ValuesIn(dynamicInputShapes5D_explicit),
::testing::ValuesIn(std::vector<std::vector<int64_t>>{{2, 3, 4, 5, 6}}),
::testing::Values(std::vector<int64_t>{0, 1, 2, 3, 4}),
::testing::Values(ov::op::BroadcastType::EXPLICIT),
::testing::ValuesIn(inputPrecisionsInt),
::testing::ValuesIn(inputConstants),
::testing::Values(CommonTestUtils::DEVICE_GPU)),
BroadcastLayerGPUTest::getTestCaseName);
const std::vector<std::vector<InputShape>> dynamicInputShapes5D = {
{
{ {-1, -1, -1, -1, -1}, {{8, 1, 1, 7, 1}, {8, 4, 1, 7, 3}} }
},
};
INSTANTIATE_TEST_CASE_P(smoke_broadcast_5d_numpy_compareWithRefs_dynamic,
BroadcastLayerGPUTest,
::testing::Combine(
::testing::ValuesIn(dynamicInputShapes5D),
::testing::ValuesIn(std::vector<std::vector<int64_t>>{{8, 4, 1, 7, 3}, {8, 4, 5, 7, 3}}),
::testing::Values(std::vector<int64_t>{}),
::testing::Values(ov::op::BroadcastType::NUMPY),
::testing::ValuesIn(inputPrecisionsFloat),
::testing::ValuesIn(inputConstants),
::testing::Values(CommonTestUtils::DEVICE_GPU)),
BroadcastLayerGPUTest::getTestCaseName);
// ==============================================================================
// 6D
const std::vector<std::vector<InputShape>> dynamicInputShapes6D = {
{
{ {-1, -1, -1, -1, -1, -1}, {{8, 1, 1, 7, 1, 3}, {8, 4, 1, 7, 16, 3}} }
},
};
INSTANTIATE_TEST_CASE_P(smoke_broadcast_6d_numpy_compareWithRefs_dynamic,
BroadcastLayerGPUTest,
::testing::Combine(
::testing::ValuesIn(dynamicInputShapes6D),
::testing::ValuesIn(std::vector<std::vector<int64_t>>{{8, 4, 1, 7, 16, 3}, {8, 4, 5, 7, 16, 3}}),
::testing::Values(std::vector<int64_t>{}),
::testing::Values(ov::op::BroadcastType::NUMPY),
::testing::ValuesIn(inputPrecisionsInt),
::testing::ValuesIn(inputConstants),
::testing::Values(CommonTestUtils::DEVICE_GPU)),
BroadcastLayerGPUTest::getTestCaseName);
} // namespace
} // namespace GPULayerTestsDefinitions