[GPU] Broadcast shape infer support (#12588)

* [GPU] Align broadcast with nGraph operation

* [GPU] Broadcast shape infer support

Co-authored-by: Sergey Shlyapnikov <sergey.shlyapnikov@intel.com>
This commit is contained in:
Roman Lyamin
2022-08-24 18:27:15 +04:00
committed by GitHub
parent 3595f195f1
commit 19fd77e3d8
12 changed files with 286 additions and 19 deletions

View File

@@ -69,7 +69,7 @@ void validate_target_shape_numpy(const ov::Node* op, const T& arg_shape, const T
} }
const auto arg_rank_length = arg_shape.size(); const auto arg_rank_length = arg_shape.size();
const auto target_rank_length = target_shape.size(); const auto target_rank_length = target_shape.size();
const int64_t start_axis = target_rank_length - arg_rank_length; const auto start_axis = target_rank_length - arg_rank_length;
NODE_VALIDATION_CHECK(op, NODE_VALIDATION_CHECK(op,
start_axis >= 0, start_axis >= 0,
"Broadcast target_shape has smaller rank ", "Broadcast target_shape has smaller rank ",
@@ -220,7 +220,7 @@ void broadcase_base_shape_infer(
// Validate axes_mapping // Validate axes_mapping
const auto& axes_shape = input_shapes[2]; const auto& axes_shape = input_shapes[2];
if (input_shape.rank().is_static() && target_shape.rank().is_static() && axes_shape.is_static()) { if (input_shape.rank().is_static() && target_shape.rank().is_static() && axes_shape.is_static()) {
auto input_rank = (input_shape.size() == 0 && axes_shape[0].get_length() > 0) ? 1 : input_shape.size(); int64_t input_rank = (input_shape.size() == 0 && axes_shape[0].get_length() > 0) ? 1 : input_shape.size();
NODE_VALIDATION_CHECK(op, NODE_VALIDATION_CHECK(op,
axes_shape[0].get_length() == input_rank, axes_shape[0].get_length() == input_rank,
"Broadcast axes_mapping shape ", "Broadcast axes_mapping shape ",

View File

@@ -5,6 +5,8 @@
/////////////////////////////////////////////////////////////////////////////////////////////////// ///////////////////////////////////////////////////////////////////////////////////////////////////
#pragma once #pragma once
#include "openvino/op/broadcast.hpp"
#include "primitive.hpp" #include "primitive.hpp"
#include <vector> #include <vector>
@@ -81,6 +83,56 @@ struct broadcast : public primitive_base<broadcast> {
broadcast_sizes(broadcast_sizes), broadcast_sizes(broadcast_sizes),
broadcast_axes(broadcast_axes) {} broadcast_axes(broadcast_axes) {}
/// @brief Constructs broadcast primitive / layer with static target_shape.
///
/// @param id An identifier of new primitive.
/// @param input An identifier of primitive which is an input for newly created
/// broadcast primitive.
/// @param target_shape The shape of the output tensor.
/// @param axes_mapping The axis positions (0-based) in the result that correspond
/// to input axes. 'Arg' tensor is broadcast along the
/// remaining axes.
/// E.g., Input Shape - [3, 4], Target Shape - [3, 5, 4, 4]
/// axes_mapping - [0, 2] => Broadcast along axes 1 and 3.
/// axes_mapping - [0, 3] => Broadcast along axes 1 and 2.
/// @param broadcast_spec Broadcast specification to use for determining broadcast
/// axes. 'axes_mapping' should not be provided if mode other
/// than explicit (none) is used.
broadcast(const primitive_id& id,
const primitive_id& input,
const ov::Shape& target_shape,
const ngraph::AxisSet& axes_mapping,
const ov::op::BroadcastModeSpec& broadcast_spec = ov::op::BroadcastType::EXPLICIT,
const primitive_id& ext_prim_id = "",
const padding& output_padding = padding())
: primitive_base(id, {input}, ext_prim_id, output_padding),
target_shape(target_shape),
axes_mapping(axes_mapping),
broadcast_mode(broadcast_spec),
broadcast_sizes({}),
broadcast_axes({}) {}
/// @brief Constructs broadcast primitive / layer with dynamic target_shape.
broadcast(const primitive_id& id,
const primitive_id& input,
const primitive_id& target_shape_id,
const ngraph::AxisSet& axes_mapping,
const ov::op::BroadcastModeSpec& broadcast_spec = ov::op::BroadcastType::EXPLICIT,
const primitive_id& ext_prim_id = "",
const padding& output_padding = padding())
: primitive_base(id, {input, target_shape_id}, ext_prim_id, output_padding),
target_shape({}),
axes_mapping(axes_mapping),
broadcast_mode(broadcast_spec),
broadcast_sizes({}),
broadcast_axes({}) {}
/// @brief The shape of the output tensor.
ov::Shape target_shape;
/// @brief The axis positions (0-based) in the result that correspond to input axes.
ov::AxisSet axes_mapping;
/// @brief Broadcast mode to use for determining broadcast axes.
ov::op::BroadcastModeSpec broadcast_mode;
/// @brief Expected sizes of output from broadcast primitive. /// @brief Expected sizes of output from broadcast primitive.
tensor broadcast_sizes; tensor broadcast_sizes;
/// @brief Array of axes positions from output shape (0-based, from left to right) /// @brief Array of axes positions from output shape (0-based, from left to right)

View File

@@ -3,6 +3,7 @@
// //
#include "broadcast_inst.h" #include "broadcast_inst.h"
#include "broadcast_shape_inference.hpp"
#include "intel_gpu/runtime/error_handler.hpp" #include "intel_gpu/runtime/error_handler.hpp"
#include "json_object.h" #include "json_object.h"
@@ -23,7 +24,71 @@ layout broadcast_inst::calc_output_layout(broadcast_node const& node, kernel_imp
auto input_layout = impl_param.get_input_layout(); auto input_layout = impl_param.get_input_layout();
auto desc = impl_param.typed_desc<broadcast>(); auto desc = impl_param.typed_desc<broadcast>();
return {input_layout.data_type, input_layout.format, desc->broadcast_sizes}; if (!desc->target_shape.empty()) {
std::vector<tensor::value_type> dims_converted(desc->target_shape.begin(), desc->target_shape.end());
for (size_t i = dims_converted.size(); i < 4; i++)
dims_converted.push_back(1); // extend shape to 4d
return { input_layout.data_type,
input_layout.format,
tensor(format::get_default_format(dims_converted.size()), dims_converted) };
} else {
return { input_layout.data_type, input_layout.format, desc->broadcast_sizes };
}
}
template<typename ShapeType>
std::vector<layout> broadcast_inst::calc_output_layouts(broadcast_node const& /*node*/, const kernel_impl_params& impl_param) {
auto desc = impl_param.typed_desc<broadcast>();
auto input0_layout = impl_param.get_input_layout(0);
auto output_type = input0_layout.data_type;
if (impl_param.has_fused_primitives()) {
output_type = impl_param.get_fused_output_layout().data_type;
}
ov::op::v3::Broadcast op;
op.set_broadcast_spec(desc->broadcast_mode);
bool third_input_needed = desc->broadcast_mode == ov::op::BroadcastType::EXPLICIT;
auto target_shape = desc->target_shape;
ShapeType pattern_shape = impl_param.input_layouts.size() == 2 ? impl_param.get_input_layout(1).get<ShapeType>()
: ShapeType(ov::Shape{ target_shape.size() });
std::vector<ShapeType> output_shapes = {ShapeType{}};
std::vector<ShapeType> input_shapes = {
input0_layout.get<ShapeType>(),
pattern_shape
};
auto axes_mapping = desc->axes_mapping.to_vector();
ShapeType axes_mapping_shape = ov::Shape{axes_mapping.size()};
std::map<size_t, ngraph::HostTensorPtr> const_data;
if (third_input_needed) {
input_shapes.emplace_back(axes_mapping_shape);
auto axes_mapping_tensor = make_host_tensor({axes_mapping_shape, data_types::i64, format::bfyx},
static_cast<void*>(axes_mapping.data()));
const_data.emplace(2, axes_mapping_tensor);
}
auto& constant_mem = impl_param.memory_deps;
if (constant_mem.count(1)) {
auto target_shape_mem = constant_mem.at(1);
cldnn::mem_lock<uint8_t, mem_lock_type::read> target_shape_lock(target_shape_mem, impl_param.prog.get_stream());
const_data.emplace(1, make_host_tensor(target_shape_mem->get_layout(), target_shape_lock.data()));
ov::op::v3::shape_infer(&op, input_shapes, output_shapes, const_data);
} else {
auto target_shape_tensor = make_host_tensor({pattern_shape, data_types::i64, format::bfyx},
static_cast<void*>(target_shape.data()));
const_data.emplace(1, target_shape_tensor);
ov::op::v3::shape_infer(&op, input_shapes, output_shapes, const_data);
}
format output_format = format::adjust_to_rank(input0_layout.format, output_shapes[0].size());
return { layout{output_shapes[0], output_type, output_format} };
} }
std::string broadcast_inst::to_string(broadcast_node const& node) { std::string broadcast_inst::to_string(broadcast_node const& node) {

View File

@@ -24,6 +24,7 @@ public:
support_padding_all(true); support_padding_all(true);
} }
program_node& input() const { return get_dependency(0); } program_node& input() const { return get_dependency(0); }
std::vector<size_t> get_shape_infer_dependencies() const override { return {1}; }
}; };
using broadcast_node = typed_program_node<broadcast>; using broadcast_node = typed_program_node<broadcast>;
@@ -33,6 +34,8 @@ class typed_primitive_inst<broadcast> : public typed_primitive_inst_base<broadca
using parent = typed_primitive_inst_base<broadcast>; using parent = typed_primitive_inst_base<broadcast>;
public: public:
template<typename ShapeType>
static std::vector<layout> calc_output_layouts(broadcast_node const& /*node*/, const kernel_impl_params& impl_param);
static layout calc_output_layout(broadcast_node const& node, kernel_impl_params const& impl_param); static layout calc_output_layout(broadcast_node const& node, kernel_impl_params const& impl_param);
static std::string to_string(broadcast_node const& node); static std::string to_string(broadcast_node const& node);
typed_primitive_inst(network& network, broadcast_node const& node); typed_primitive_inst(network& network, broadcast_node const& node);

View File

@@ -67,7 +67,7 @@ std::vector<layout> reshape_inst::calc_output_layouts(reshape_node const& node,
op.set_special_zero(prim->special_zero); op.set_special_zero(prim->special_zero);
ShapeType pattern_shape = impl_param.input_layouts.size() == 2 ? impl_param.get_input_layout(1).get<ShapeType>() ShapeType pattern_shape = impl_param.input_layouts.size() == 2 ? impl_param.get_input_layout(1).get<ShapeType>()
: ShapeType(ov::Shape{ prim->output_pattern.size() }); : ShapeType(ov::Shape{ prim->output_pattern.size() });
std::vector<ShapeType> output_shapes = {ShapeType()}; std::vector<ShapeType> output_shapes = {ShapeType()};
std::vector<ShapeType> input_shapes = { std::vector<ShapeType> input_shapes = {
input_layout.get<ShapeType>(), input_layout.get<ShapeType>(),

View File

@@ -2,6 +2,7 @@
// SPDX-License-Identifier: Apache-2.0 // SPDX-License-Identifier: Apache-2.0
// //
#include "openvino/core/except.hpp"
#include "intel_gpu/plugin/program.hpp" #include "intel_gpu/plugin/program.hpp"
#include "intel_gpu/plugin/common_utils.hpp" #include "intel_gpu/plugin/common_utils.hpp"
@@ -79,10 +80,26 @@ static void CreateCommonBroadcastOp(Program& p, const std::shared_ptr<ngraph::No
inputPrimitive = reshapeName; inputPrimitive = reshapeName;
} }
ov::op::BroadcastModeSpec mode = ov::op::BroadcastType::NONE;
if (auto broadcast_v3 = std::dynamic_pointer_cast<ngraph::op::v3::Broadcast>(op)) {
mode = broadcast_v3->get_broadcast_spec();
} else if (auto broadcast_v1 = std::dynamic_pointer_cast<ngraph::op::v1::Broadcast>(op)) {
switch (broadcast_v1->get_broadcast_spec().m_type) {
case ov::op::AutoBroadcastType::NONE: mode = ov::op::BroadcastType::NONE; break;
case ov::op::AutoBroadcastType::NUMPY: mode = ov::op::BroadcastType::NUMPY; break;
case ov::op::AutoBroadcastType::PDPD: mode = ov::op::BroadcastType::PDPD; break;
default:
throw ov::Exception("[GPU] Can't match Broadcast v1 mode with v3 version");
}
} else {
throw ov::Exception("[GPU] Can't cast Broadcast operation to any supported version");
}
auto broadcastPrim = cldnn::broadcast(layerName, auto broadcastPrim = cldnn::broadcast(layerName,
inputPrimitive, inputPrimitive,
tensor_from_dims(op->get_output_shape(0)), outputShape,
{}, axis_mapping,
mode,
op->get_friendly_name()); op->get_friendly_name());
p.AddPrimitive(broadcastPrim); p.AddPrimitive(broadcastPrim);

View File

@@ -0,0 +1,129 @@
// Copyright (C) 2022 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include "test_utils.h"
#include <intel_gpu/primitives/input_layout.hpp>
#include <intel_gpu/primitives/broadcast.hpp>
#include <intel_gpu/primitives/data.hpp>
#include "broadcast_inst.h"
#include "program_wrapper.h"
#include <cmath>
#include <algorithm>
using namespace cldnn;
using namespace ::tests;
namespace shape_infer_tests {
struct broadcast_test_params {
layout data_layout;
layout target_shape_layout;
ov::Shape target_shape_data;
ngraph::AxisSet axes_mapping_data;
ov::op::BroadcastModeSpec mode;
layout expected_layout;
};
class broadcast_test_two_inputs : public testing::TestWithParam<broadcast_test_params> { };
TEST_P(broadcast_test_two_inputs, shape_infer) {
auto p = GetParam();
auto& engine = get_test_engine();
auto data_layout_prim = std::make_shared<input_layout>("data", p.data_layout);
auto target_shape_layout_prim = std::make_shared<input_layout>("target_shape", p.target_shape_layout);
auto broadcast_prim = std::make_shared<broadcast>("output", "data", "target_shape", p.axes_mapping_data, p.mode);
cldnn::program prog(engine);
auto target_shape_mem = engine.allocate_memory(p.target_shape_layout);
set_values(target_shape_mem, p.target_shape_data);
auto& data_node = prog.get_or_create(data_layout_prim);
auto& target_shape_node = prog.get_or_create(target_shape_layout_prim);
auto& broadcast_node = prog.get_or_create(broadcast_prim);
program_wrapper::add_connection(prog, data_node, broadcast_node);
program_wrapper::add_connection(prog, target_shape_node, broadcast_node);
auto params = broadcast_node.get_kernel_impl_params();
params->memory_deps = {{1, target_shape_mem}};
auto res = broadcast_inst::calc_output_layouts<ov::PartialShape>(broadcast_node, *params);
ASSERT_EQ(res.size(), 1);
ASSERT_EQ(res[0], p.expected_layout);
}
INSTANTIATE_TEST_SUITE_P(smoke, broadcast_test_two_inputs,
testing::ValuesIn(std::vector<broadcast_test_params>{
{
layout{ov::PartialShape{16, 1, 1, 1}, data_types::f32, format::bfyx},
layout{ov::PartialShape{5}, data_types::i64, format::bfzyx}, {1, 16, 50, 50, 50},
{}, ov::op::BroadcastType::NUMPY,
layout{ov::PartialShape{1, 16, 50, 50, 50}, data_types::f32, format::bfzyx}
},
{
layout{ov::PartialShape::dynamic(4), data_types::f32, format::bfyx},
layout{ov::PartialShape{5}, data_types::i64, format::bfzyx}, {1, 16, 50, 50, 50},
{}, ov::op::BroadcastType::NUMPY,
layout{ov::PartialShape{1, 16, 50, 50, 50}, data_types::f32, format::bfzyx}
},
{
layout{ov::PartialShape{16}, data_types::f32, format::bfyx},
layout{ov::PartialShape{4}, data_types::i64, format::bfyx}, {1, 16, 50, 50},
{1}, ov::op::BroadcastType::EXPLICIT,
layout{ov::PartialShape{1, 16, 50, 50}, data_types::f32, format::bfyx}
}
}));
class broadcast_test_single_input : public testing::TestWithParam<broadcast_test_params> { };
TEST_P(broadcast_test_single_input, shape_infer) {
auto p = GetParam();
auto& engine = get_test_engine();
auto data_layout_prim = std::make_shared<input_layout>("data", p.data_layout);
auto broadcast_prim = std::make_shared<broadcast>("output", "data", p.target_shape_data, p.axes_mapping_data, p.mode);
cldnn::program prog(engine);
auto& data_node = prog.get_or_create(data_layout_prim);
auto& broadcast_node = prog.get_or_create(broadcast_prim);
program_wrapper::add_connection(prog, data_node, broadcast_node);
auto params = broadcast_node.get_kernel_impl_params();
auto res = broadcast_inst::calc_output_layouts<ov::PartialShape>(broadcast_node, *params);
ASSERT_EQ(res.size(), 1);
ASSERT_EQ(res[0], p.expected_layout);
}
INSTANTIATE_TEST_SUITE_P(smoke, broadcast_test_single_input,
testing::ValuesIn(std::vector<broadcast_test_params>{
{
layout{ov::PartialShape{16, 1, 1, 1}, data_types::f32, format::bfyx},
layout{ov::PartialShape{5}, data_types::i64, format::bfzyx}, {1, 16, 50, 50, 50},
{}, ov::op::BroadcastType::NUMPY,
layout{ov::PartialShape{1, 16, 50, 50, 50}, data_types::f32, format::bfzyx}
},
{
layout{ov::PartialShape::dynamic(4), data_types::f32, format::bfyx},
layout{ov::PartialShape{5}, data_types::i64, format::bfzyx}, {1, 16, 50, 50, 50},
{}, ov::op::BroadcastType::NUMPY,
layout{ov::PartialShape{1, 16, 50, 50, 50}, data_types::f32, format::bfzyx}
},
{
layout{ov::PartialShape{16}, data_types::f32, format::bfyx},
layout{ov::PartialShape{4}, data_types::i64, format::bfyx}, {1, 16, 50, 50},
{1}, ov::op::BroadcastType::EXPLICIT,
layout{ov::PartialShape{1, 16, 50, 50}, data_types::f32, format::bfyx}
}
}));
} // shape_infer_tests

View File

@@ -1,4 +1,4 @@
// Copyright (C) 2018-2022 Intel Corporation // Copyright (C) 2022 Intel Corporation
// SPDX-License-Identifier: Apache-2.0 // SPDX-License-Identifier: Apache-2.0
// //

View File

@@ -1,4 +1,4 @@
// Copyright (C) 2018-2022 Intel Corporation // Copyright (C) 2022 Intel Corporation
// SPDX-License-Identifier: Apache-2.0 // SPDX-License-Identifier: Apache-2.0
// //

View File

@@ -1,4 +1,4 @@
// Copyright (C) 2018-2022 Intel Corporation // Copyright (C) 2022 Intel Corporation
// SPDX-License-Identifier: Apache-2.0 // SPDX-License-Identifier: Apache-2.0
// //

View File

@@ -1,4 +1,4 @@
// Copyright (C) 2018-2022 Intel Corporation // Copyright (C) 2022 Intel Corporation
// SPDX-License-Identifier: Apache-2.0 // SPDX-License-Identifier: Apache-2.0
// //
@@ -32,6 +32,7 @@ struct strided_slice_test_params {
std::vector<int64_t> end_mask; std::vector<int64_t> end_mask;
std::vector<int64_t> new_axis_mask; std::vector<int64_t> new_axis_mask;
std::vector<int64_t> shrink_axis_mask; std::vector<int64_t> shrink_axis_mask;
std::vector<int64_t> ellipsis_mask;
layout expected_layout; layout expected_layout;
}; };
@@ -55,7 +56,7 @@ TEST_P(strided_slice_test, shape_infer) {
p.end_mask, p.end_mask,
p.new_axis_mask, p.new_axis_mask,
p.shrink_axis_mask, p.shrink_axis_mask,
std::vector<int64_t>{}, p.ellipsis_mask,
ov::Shape{}); ov::Shape{});
cldnn::program prog(engine); cldnn::program prog(engine);
@@ -91,7 +92,7 @@ INSTANTIATE_TEST_SUITE_P(smoke, strided_slice_test,
layout{ov::PartialShape{3}, data_types::i64, format::bfyx}, {0, 0, 0}, layout{ov::PartialShape{3}, data_types::i64, format::bfyx}, {0, 0, 0},
layout{ov::PartialShape{3}, data_types::i64, format::bfyx}, {0, 1, 0}, layout{ov::PartialShape{3}, data_types::i64, format::bfyx}, {0, 1, 0},
layout{ov::PartialShape{3}, data_types::i64, format::bfyx}, {1, 1, 1}, layout{ov::PartialShape{3}, data_types::i64, format::bfyx}, {1, 1, 1},
{1, 0, 1}, {1, 0, 1}, {0, 0, 0}, {0, 0, 0}, {1, 0, 1}, {1, 0, 1}, {0, 0, 0}, {0, 0, 0}, {0, 0, 0},
layout{ov::PartialShape{1, 1, 1024}, data_types::i64, format::bfyx} layout{ov::PartialShape{1, 1, 1024}, data_types::i64, format::bfyx}
}, },
})); }));

View File

@@ -894,7 +894,7 @@ TEST(broadcast_gpu, basic_error_wrong_b_axes_size) {
topology topology; topology topology;
topology.add(input_layout("input", input->get_layout())); topology.add(input_layout("input", input->get_layout()));
topology.add(broadcast("output", "input", {2, 3, 4, 5}, {0, 1, 2, 3, 4})); topology.add(broadcast("output", "input", tensor{2, 3, 4, 5}, {0, 1, 2, 3, 4}));
std::string msg_to_find = "Incorrect parameters configuration: broadcast_axes size should be less or equal 4."; std::string msg_to_find = "Incorrect parameters configuration: broadcast_axes size should be less or equal 4.";
EXPECT_ANY_THROW(check_exception_massage(engine, topology, msg_to_find)); EXPECT_ANY_THROW(check_exception_massage(engine, topology, msg_to_find));
@@ -906,7 +906,7 @@ TEST(broadcast_gpu, basic_error_wrong_b_axis_value) {
topology topology; topology topology;
topology.add(input_layout("input", input->get_layout())); topology.add(input_layout("input", input->get_layout()));
topology.add(broadcast("output", "input", {2, 3, 4, 5}, {0, 4})); topology.add(broadcast("output", "input", tensor{2, 3, 4, 5}, {0, 4}));
std::string msg_to_find = "Incorrect parameters configuration: broadcast_axes index should be within broadcast_sizes range."; std::string msg_to_find = "Incorrect parameters configuration: broadcast_axes index should be within broadcast_sizes range.";
EXPECT_ANY_THROW(check_exception_massage(engine, topology, msg_to_find)); EXPECT_ANY_THROW(check_exception_massage(engine, topology, msg_to_find));
@@ -918,7 +918,7 @@ TEST(broadcast_gpu, basic_error_duplicate_b_axis_values) {
topology topology; topology topology;
topology.add(input_layout("input", input->get_layout())); topology.add(input_layout("input", input->get_layout()));
topology.add(broadcast("output", "input", {2, 3, 4, 5}, {0, 1, 1})); topology.add(broadcast("output", "input", tensor{2, 3, 4, 5}, {0, 1, 1}));
std::string msg_to_find = "Incorrect parameters configuration: Duplicate axes numbers was found in broadcast_axes."; std::string msg_to_find = "Incorrect parameters configuration: Duplicate axes numbers was found in broadcast_axes.";
EXPECT_ANY_THROW(check_exception_massage(engine, topology, msg_to_find)); EXPECT_ANY_THROW(check_exception_massage(engine, topology, msg_to_find));
@@ -930,7 +930,7 @@ TEST(broadcast_gpu, basic_error_wrong_input_dimension_0) {
topology topology; topology topology;
topology.add(input_layout("input", input->get_layout())); topology.add(input_layout("input", input->get_layout()));
topology.add(broadcast("output", "input", {2, 3, 4, 5}, {1})); topology.add(broadcast("output", "input", tensor{2, 3, 4, 5}, {1}));
std::string msg_to_find = "Input size on dimension number 0(=2) is not equal to: (=1)"; std::string msg_to_find = "Input size on dimension number 0(=2) is not equal to: (=1)";
EXPECT_ANY_THROW(check_exception_massage(engine, topology, msg_to_find)); EXPECT_ANY_THROW(check_exception_massage(engine, topology, msg_to_find));
@@ -942,7 +942,7 @@ TEST(broadcast_gpu, basic_error_not_dividable_2x3x4x5_to_3x3x4x5) {
topology topology; topology topology;
topology.add(input_layout("input", input->get_layout())); topology.add(input_layout("input", input->get_layout()));
topology.add(broadcast("output", "input", {3, 3, 4, 5}, {})); topology.add(broadcast("output", "input", tensor{3, 3, 4, 5}, {}));
std::string msg_to_find = "Invalid broadcast size: not dividable by input size"; std::string msg_to_find = "Invalid broadcast size: not dividable by input size";
EXPECT_ANY_THROW(check_exception_massage(engine, topology, msg_to_find)); EXPECT_ANY_THROW(check_exception_massage(engine, topology, msg_to_find));
@@ -954,7 +954,7 @@ TEST(broadcast_gpu, basic_error_not_dividable_3_to_2x3x4x5_w_b_axes_0x1x3) {
topology topology; topology topology;
topology.add(input_layout("input", input->get_layout())); topology.add(input_layout("input", input->get_layout()));
topology.add(broadcast("output", "input", {2, 3, 4, 5}, {0, 1, 3})); topology.add(broadcast("output", "input", tensor{2, 3, 4, 5}, {0, 1, 3}));
std::string msg_to_find = "Invalid broadcast size: not dividable by input size"; std::string msg_to_find = "Invalid broadcast size: not dividable by input size";
EXPECT_ANY_THROW(check_exception_massage(engine, topology, msg_to_find)); EXPECT_ANY_THROW(check_exception_massage(engine, topology, msg_to_find));
@@ -966,7 +966,7 @@ TEST(broadcast_gpu, basic_error_not_dividable_4x5_to_3x4x5_w_b_axes_1) {
topology topology; topology topology;
topology.add(input_layout("input", input->get_layout())); topology.add(input_layout("input", input->get_layout()));
topology.add(broadcast("output", "input", {2, 3, 4, 5}, {1})); topology.add(broadcast("output", "input", tensor{2, 3, 4, 5}, {1}));
std::string msg_to_find = "Invalid broadcast size: not dividable by input size"; std::string msg_to_find = "Invalid broadcast size: not dividable by input size";
EXPECT_ANY_THROW(check_exception_massage(engine, topology, msg_to_find)); EXPECT_ANY_THROW(check_exception_massage(engine, topology, msg_to_find));