[GPU] Fix shape infer for 0d broadcast (#14744)

This commit is contained in:
Vladimir Paramuzov 2022-12-21 09:39:20 +04:00 committed by GitHub
parent 5714fdfe6b
commit 04c0dbbf60
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 17 additions and 1 deletions

View File

@ -107,7 +107,7 @@ struct broadcast : public primitive_base<broadcast> {
target_shape(target_shape),
axes_mapping(axes_mapping),
broadcast_mode(broadcast_spec),
broadcast_sizes({}),
broadcast_sizes(target_shape.empty() ? tensor(1) : tensor(0)),
broadcast_axes({}) {}
/// @brief Constructs broadcast primitive / layer with dynamic target_shape.

View File

@ -25,6 +25,22 @@ const std::vector<InferenceEngine::Precision> inputTPrecisions = {
InferenceEngine::Precision::BOOL
};
// NUMPY MODE //////////////////////////////////////////
// 0D
std::vector<std::vector<size_t>> targetShapesNumpy0D = {
{},
};
INSTANTIATE_TEST_CASE_P(smoke_TestNumpyBroadcast0D,
BroadcastLayerTest,
::testing::Combine(::testing::ValuesIn(targetShapesNumpy0D),
::testing::Values(ngraph::AxisSet{}), // not used in numpy mode
::testing::Values(ngraph::op::BroadcastType::NUMPY),
::testing::Values(std::vector<size_t>{}),
::testing::ValuesIn(inputPrecisions),
::testing::Values(CommonTestUtils::DEVICE_GPU)),
BroadcastLayerTest::getTestCaseName);
// NUMPY MODE //////////////////////////////////////////
// 1D
std::vector<std::vector<size_t>> targetShapesNumpy1D = {