Revise Broadcast reference implementation (#2715)

* change tile reference implementation

* remove tile tests from interpreter manifest

* add repeats parameter to tile

* improve tile reference implementation

* add repeats parameter to tile reference call in tile evaluate method

* style apply

* include <numeric>

* add unnamed namespace to helper functions. Change stdio.h to cstdio. Change input_rank to be constant int

* add const reference to parameter repeats in tile reference function

* change createPitches function to use partial_sum instead of accumulate

* change a little bit createPitches function

* style-apply

* fix function naming

* style-apply

* fix calling functions name bug

* Add description of create_pitches function

* first version with debug logs

* reduce footprint

* single layer tests

* added more tests

* fixed handling bool type

* styles applied

* fix tile

* [ONLY DEBUG] print error scenario message

* fixed problem with e2e tests

* fixed casting of start_axis for numpy mode

Co-authored-by: pszmel <piotr.szmelczynski@intel.com>
This commit is contained in:
Mateusz Bencer 2020-11-10 08:42:26 +01:00 committed by GitHub
parent 8d4f8c4edd
commit eeafc8e7dc
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
14 changed files with 420 additions and 301 deletions

View File

@ -41,7 +41,11 @@ void ngraph::pass::ConvertBroadcast3::convert_broadcast3() {
} else if (broadcast_type == op::BroadcastType::BIDIRECTIONAL) {
auto constant_one = std::make_shared<ngraph::opset1::Constant>(input.get_element_type(), Shape({1}), std::vector<int>{1});
auto broadcast_ones = std::make_shared<ngraph::opset1::Broadcast>(constant_one, target_shape, op::AutoBroadcastType::NUMPY);
last_node = std::make_shared<ngraph::opset1::Multiply>(input, broadcast_ones);
if (input.get_element_type() == element::boolean) {
last_node = std::make_shared<ngraph::opset1::LogicalOr>(input, broadcast_ones);
} else {
last_node = std::make_shared<ngraph::opset1::Multiply>(input, broadcast_ones);
}
ngraph::copy_runtime_info(broadcast, {last_node, broadcast_ones, constant_one});
}

View File

@ -0,0 +1,174 @@
// Copyright (C) 2019 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include <vector>
#include "single_layer_tests/broadcast.hpp"
#include "common_test_utils/test_constants.hpp"
using namespace LayerTestsDefinitions;
namespace {
const std::vector<InferenceEngine::Precision> inputPrecisions = {
InferenceEngine::Precision::FP32,
InferenceEngine::Precision::I32,
InferenceEngine::Precision::BOOL
};
// NUMPY MODE
std::vector<std::vector<size_t>> inShapesNumpy = {
{3, 1},
{1, 4, 1}
};
std::vector<std::vector<size_t>> targetShapesNumpy = {
{2, 3, 6},
{1, 4, 4}
};
const auto numpyBroadcastParams1 = ::testing::Combine(
::testing::Values(targetShapesNumpy[0]),
::testing::Values(ngraph::AxisSet{}), //not used in numpy mode
::testing::Values(ngraph::op::BroadcastType::NUMPY),
::testing::Values(inShapesNumpy[0]),
::testing::ValuesIn(inputPrecisions),
::testing::Values(CommonTestUtils::DEVICE_CPU)
);
INSTANTIATE_TEST_CASE_P(
TestNumpyBroadcast1,
BroadcastLayerTest,
numpyBroadcastParams1,
BroadcastLayerTest::getTestCaseName
);
const auto numpyBroadcastParams2 = ::testing::Combine(
::testing::Values(targetShapesNumpy[1]),
::testing::Values(ngraph::AxisSet{}), //not used in numpy mode
::testing::Values(ngraph::op::BroadcastType::NUMPY),
::testing::Values(inShapesNumpy[1]),
::testing::ValuesIn(inputPrecisions),
::testing::Values(CommonTestUtils::DEVICE_CPU)
);
INSTANTIATE_TEST_CASE_P(
TestNumpyBroadcast2,
BroadcastLayerTest,
numpyBroadcastParams2,
BroadcastLayerTest::getTestCaseName
);
// BIDIRECTIONAL MODE
std::vector<std::vector<size_t>> inShapesBidi = {
{4, 1},
{1, 4, 1},
{4, 1, 1}
};
std::vector<std::vector<size_t>> targetShapesBidi = {
{2, 1, 4},
{1, 4, 4},
{1, 1, 2, 2}
};
const auto bidirectionalBroadcastParams1 = ::testing::Combine(
::testing::Values(targetShapesBidi[0]),
::testing::Values(ngraph::AxisSet{}), //not used in bidirectional mode
::testing::Values(ngraph::op::BroadcastType::BIDIRECTIONAL),
::testing::Values(inShapesBidi[0]),
::testing::ValuesIn(inputPrecisions),
::testing::Values(CommonTestUtils::DEVICE_CPU)
);
INSTANTIATE_TEST_CASE_P(
TestBidirectionalBroadcast1,
BroadcastLayerTest,
bidirectionalBroadcastParams1,
BroadcastLayerTest::getTestCaseName
);
const auto bidirectionalBroadcastParams2 = ::testing::Combine(
::testing::Values(targetShapesBidi[1]),
::testing::Values(ngraph::AxisSet{}), //not used in bidirectional mode
::testing::Values(ngraph::op::BroadcastType::BIDIRECTIONAL),
::testing::Values(inShapesBidi[1]),
::testing::ValuesIn(inputPrecisions),
::testing::Values(CommonTestUtils::DEVICE_CPU)
);
INSTANTIATE_TEST_CASE_P(
TestBidirectionalBroadcast2,
BroadcastLayerTest,
bidirectionalBroadcastParams2,
BroadcastLayerTest::getTestCaseName
);
const auto bidirectionalBroadcastParams3 = ::testing::Combine(
::testing::Values(targetShapesBidi[2]),
::testing::Values(ngraph::AxisSet{}), //not used in bidirectional mode
::testing::Values(ngraph::op::BroadcastType::BIDIRECTIONAL),
::testing::Values(inShapesBidi[2]),
::testing::ValuesIn(inputPrecisions),
::testing::Values(CommonTestUtils::DEVICE_CPU)
);
INSTANTIATE_TEST_CASE_P(
TestBidirectionalBroadcast3,
BroadcastLayerTest,
bidirectionalBroadcastParams3,
BroadcastLayerTest::getTestCaseName
);
// EXPLICIT MODE
std::vector<std::vector<size_t>> inShapesExplicit = {
{3, 1},
{2, 4}
};
std::vector<std::vector<size_t>> targetShapesExplicit = {
{2, 3, 1},
{2, 3, 4}
};
std::vector<ngraph::AxisSet> axes = {
{1, 2},
{0, 2}
};
const auto explicitBroadcastParams1 = ::testing::Combine(
::testing::Values(targetShapesExplicit[0]),
::testing::Values(axes[0]),
::testing::Values(ngraph::op::BroadcastType::EXPLICIT),
::testing::Values(inShapesExplicit[0]),
::testing::ValuesIn(inputPrecisions),
::testing::Values(CommonTestUtils::DEVICE_CPU)
);
INSTANTIATE_TEST_CASE_P(
TestExplicitBroadcast1,
BroadcastLayerTest,
explicitBroadcastParams1,
BroadcastLayerTest::getTestCaseName
);
const auto explicitBroadcastParams2 = ::testing::Combine(
::testing::Values(targetShapesExplicit[1]),
::testing::Values(axes[1]),
::testing::Values(ngraph::op::BroadcastType::EXPLICIT),
::testing::Values(inShapesExplicit[1]),
::testing::ValuesIn(inputPrecisions),
::testing::Values(CommonTestUtils::DEVICE_CPU)
);
INSTANTIATE_TEST_CASE_P(
TestExplicitBroadcast2,
BroadcastLayerTest,
explicitBroadcastParams2,
BroadcastLayerTest::getTestCaseName
);
} // namespace

View File

@ -0,0 +1,35 @@
// Copyright (C) 2019 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#pragma once
#include <tuple>
#include <string>
#include <vector>
#include <memory>
#include "functional_test_utils/layer_test_utils.hpp"
#include "ngraph_functions/builders.hpp"
#include "ngraph_functions/utils/ngraph_helpers.hpp"
namespace LayerTestsDefinitions {
using BroadcastParamsTuple = typename std::tuple<
InferenceEngine::SizeVector, // target shape
ngraph::AxisSet, // axes mapping
ngraph::op::BroadcastType, // broadcast mode
InferenceEngine::SizeVector, // Input shape
InferenceEngine::Precision, // Network precision
std::string>; // Device name
class BroadcastLayerTest : public testing::WithParamInterface<BroadcastParamsTuple>,
virtual public LayerTestsUtils::LayerTestsCommon {
public:
static std::string getTestCaseName(const testing::TestParamInfo<BroadcastParamsTuple> &obj);
protected:
void SetUp() override;
};
} // namespace LayerTestsDefinitions

View File

@ -0,0 +1,49 @@
// Copyright (C) 2020 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include "single_layer_tests/broadcast.hpp"
namespace LayerTestsDefinitions {
std::string BroadcastLayerTest::getTestCaseName(const testing::TestParamInfo<BroadcastParamsTuple>& obj) {
InferenceEngine::SizeVector targetShape;
ngraph::AxisSet axesMapping;
ngraph::op::BroadcastType mode;
InferenceEngine::SizeVector inputShape;
InferenceEngine::Precision networkPrecision;
std::string deviceName;
std::tie(targetShape, axesMapping, mode, inputShape, networkPrecision, deviceName) = obj.param;
std::ostringstream result;
result << "targetShape=" << CommonTestUtils::vec2str(targetShape) << "_";
result << "axesMapping=" << CommonTestUtils::set2str(axesMapping) << "_";
result << "mode=" << mode << "_";
result << "inShape=" << CommonTestUtils::vec2str(inputShape) << "_";
result << "inNPrec=" << networkPrecision << "_";
result << "trgDev=" << deviceName;
return result.str();
}
void BroadcastLayerTest::SetUp() {
InferenceEngine::SizeVector targetShape;
ngraph::AxisSet axesMapping;
ngraph::op::BroadcastType mode;
InferenceEngine::SizeVector inputShape;
InferenceEngine::Precision networkPrecision;
std::tie(targetShape, axesMapping, mode, inputShape, networkPrecision, targetDevice) = this->GetParam();
auto ngPrc = FuncTestUtils::PrecisionUtils::convertIE2nGraphPrc(networkPrecision);
auto target_shape_const = ngraph::opset3::Constant::create(ngraph::element::i64, {targetShape.size()}, targetShape);
auto params = ngraph::builder::makeParams(ngPrc, {inputShape});
auto broadcast = ngraph::builder::makeBroadcast(params[0], target_shape_const, mode, axesMapping);
ngraph::ResultVector results{std::make_shared<ngraph::opset4::Result>(broadcast)};
function = std::make_shared<ngraph::Function>(results, params, "BroadcastInference");
}
TEST_P(BroadcastLayerTest, CompareWithRefs) {
Run();
}
} // namespace LayerTestsDefinitions

View File

@ -70,6 +70,11 @@ std::shared_ptr<Node> makeConstant(const element::Type &type, const std::vector<
std::shared_ptr<ngraph::Node> makeInputLayer(const element::Type& type, ngraph::helpers::InputLayerType inputType,
const std::vector<size_t>& shape);
std::shared_ptr<ngraph::Node> makeBroadcast(const ngraph::Output<Node> &in,
const ngraph::Output<Node> &target_shape,
const ngraph::op::BroadcastType& mode,
const ngraph::AxisSet& axis_set = {});
std::shared_ptr<ngraph::Node> makeConvolution(const ngraph::Output<Node> &in,
const element::Type &type,
const std::vector<size_t> &filterSize,

View File

@ -0,0 +1,29 @@
// Copyright (C) 2020 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include <vector>
#include <memory>
#include "ngraph_functions/builders.hpp"
namespace ngraph {
namespace builder {
std::shared_ptr<ngraph::Node> makeBroadcast(const ngraph::Output<Node> &in,
const ngraph::Output<Node> &target_shape,
const ngraph::op::BroadcastType& mode,
const ngraph::AxisSet& axisSet) {
if (mode == ngraph::op::BroadcastType::NONE) {
auto axisSetConst = ngraph::opset5::Constant::create(ngraph::element::i64, {axisSet.size()}, axisSet.to_vector());
return std::make_shared<ngraph::opset5::Broadcast>(in,
target_shape,
axisSetConst,
mode);
} else { // numpy/bidiractional modes
return std::make_shared<ngraph::opset5::Broadcast>(in,
target_shape,
mode);
}
}
} // namespace builder
} // namespace ngraph

View File

@ -71,7 +71,6 @@ namespace ngraph
const std::pair<bool, AxisSet> pair_broadcast_axes,
const Shape output_shape) const;
template <element::Type_t ET>
bool evaluate(const HostTensorPtr& arg0,
const HostTensorPtr& out,
const AxisSet& broadcast_axes) const;

View File

@ -1,218 +0,0 @@
//*****************************************************************************
// Copyright 2017-2020 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#pragma once
#include <cmath>
#include <utility>
#include "ngraph/runtime/reference/broadcast.hpp"
#include "ngraph/shape_util.hpp"
#include "ngraph/util.hpp"
namespace ngraph
{
namespace runtime
{
namespace opt_kernel
{
template <typename T>
void broadcast_2d(
const T* in, T* out, const Shape& in_shape, const Shape& out_shape, size_t out_axis)
{
size_t index[2];
size_t& in_index = index[out_axis];
auto out_strides = row_major_strides(out_shape);
for (index[0] = 0; index[0] < out_shape[0]; ++index[0])
{
for (index[1] = 0; index[1] < out_shape[1]; ++index[1])
{
// clang-format off
out[index[0] * out_strides[0] +
index[1]] =
in[in_index];
// clang-format on
}
}
}
// #define PARALLEL
template <typename T>
void broadcast_3d(
const T* in, T* out, const Shape& in_shape, const Shape& out_shape, size_t out_axis)
{
size_t index[3];
size_t& in_index = index[out_axis];
auto out_strides = row_major_strides(out_shape);
for (index[0] = 0; index[0] < out_shape[0]; ++index[0])
{
for (index[1] = 0; index[1] < out_shape[1]; ++index[1])
{
for (index[2] = 0; index[2] < out_shape[2]; ++index[2])
{
// clang-format off
out[index[0] * out_strides[0] +
index[1] * out_strides[1] +
index[2]] =
in[in_index];
// clang-format on
}
}
}
}
template <typename T>
void broadcast_4d(
const T* in, T* out, const Shape& in_shape, const Shape& out_shape, size_t out_axis)
{
size_t index[4];
size_t& in_index = index[out_axis];
auto out_strides = row_major_strides(out_shape);
for (index[0] = 0; index[0] < out_shape[0]; ++index[0])
{
for (index[1] = 0; index[1] < out_shape[1]; ++index[1])
{
for (index[2] = 0; index[2] < out_shape[2]; ++index[2])
{
for (index[3] = 0; index[3] < out_shape[3]; ++index[3])
{
// clang-format off
out[index[0] * out_strides[0] +
index[1] * out_strides[1] +
index[2] * out_strides[2] +
index[3]] =
in[in_index];
// clang-format on
}
}
}
}
}
template <typename T>
void broadcast_5d(
const T* in, T* out, const Shape& in_shape, const Shape& out_shape, size_t out_axis)
{
size_t index[5];
size_t& in_index = index[out_axis];
auto out_strides = row_major_strides(out_shape);
for (index[0] = 0; index[0] < out_shape[0]; ++index[0])
{
for (index[1] = 0; index[1] < out_shape[1]; ++index[1])
{
for (index[2] = 0; index[2] < out_shape[2]; ++index[2])
{
for (index[3] = 0; index[3] < out_shape[3]; ++index[3])
{
for (index[4] = 0; index[4] < out_shape[4]; ++index[4])
{
// clang-format off
out[index[0] * out_strides[0] +
index[1] * out_strides[1] +
index[2] * out_strides[2] +
index[3] * out_strides[3] +
index[4]] =
in[in_index];
// clang-format on
}
}
}
}
}
}
template <typename T>
void broadcast_6d(
const T* in, T* out, const Shape& in_shape, const Shape& out_shape, size_t out_axis)
{
size_t index[6];
size_t& in_index = index[out_axis];
auto out_strides = row_major_strides(out_shape);
for (index[0] = 0; index[0] < out_shape[0]; ++index[0])
{
for (index[1] = 0; index[1] < out_shape[1]; ++index[1])
{
for (index[2] = 0; index[2] < out_shape[2]; ++index[2])
{
for (index[3] = 0; index[3] < out_shape[3]; ++index[3])
{
for (index[4] = 0; index[4] < out_shape[4]; ++index[4])
{
for (index[5] = 0; index[5] < out_shape[5]; ++index[5])
{
// clang-format off
out[index[0] * out_strides[0] +
index[1] * out_strides[1] +
index[2] * out_strides[2] +
index[3] * out_strides[3] +
index[4] * out_strides[4] +
index[5]] =
in[in_index];
// clang-format on
}
}
}
}
}
}
}
template <typename T>
void broadcast(const T* in,
T* out,
const Shape& in_shape,
const Shape& out_shape,
const AxisSet& broadcast_axes)
{
if (is_scalar(in_shape))
{
for (size_t i = 0; i < shape_size(out_shape); ++i)
{
out[i] = in[0];
}
}
else if (in_shape.size() == 1)
{
size_t output_axis = 0;
for (size_t i = 0; i < out_shape.size(); i++)
{
if (broadcast_axes.count(i) == 0)
{
output_axis = i;
break;
}
}
switch (out_shape.size())
{
case 2: broadcast_2d<T>(in, out, in_shape, out_shape, output_axis); break;
case 3: broadcast_3d<T>(in, out, in_shape, out_shape, output_axis); break;
case 4: broadcast_4d<T>(in, out, in_shape, out_shape, output_axis); break;
case 5: broadcast_5d<T>(in, out, in_shape, out_shape, output_axis); break;
case 6: broadcast_6d<T>(in, out, in_shape, out_shape, output_axis); break;
default:
runtime::reference::broadcast<T>(
in, out, in_shape, out_shape, broadcast_axes);
break;
}
}
else
{
runtime::reference::broadcast<T>(in, out, in_shape, out_shape, broadcast_axes);
}
}
}
}
}

View File

@ -16,10 +16,8 @@
#pragma once
#include <cmath>
#include "ngraph/coordinate_transform.hpp"
#include "ngraph/shape_util.hpp"
#include "ngraph/axis_set.hpp"
#include "ngraph/shape.hpp"
namespace ngraph
{
@ -27,42 +25,12 @@ namespace ngraph
{
namespace reference
{
template <typename T>
void broadcast(const T* arg,
T* out,
void broadcast(const char* arg,
char* out,
const Shape& in_shape,
const Shape& out_shape,
const AxisSet& broadcast_axes)
{
// Remove all broadcast axes from in_shape
Shape adjusted_in_shape;
for (auto length : in_shape)
{
if (length != 1)
{
adjusted_in_shape.push_back(length);
}
}
// Remove 1s from out_shape
AxisSet adjusted_axes(broadcast_axes);
for (uint64_t axis = 0; axis < out_shape.size(); ++axis)
{
auto length = out_shape.at(axis);
if (length == 1)
{
adjusted_axes.insert(axis);
}
}
CoordinateTransform input_transform(adjusted_in_shape);
CoordinateTransform output_transform(out_shape);
for (const Coordinate& output_coord : output_transform)
{
Coordinate input_coord = reduce(output_coord, adjusted_axes, false);
out[output_transform.index(output_coord)] =
arg[input_transform.index(input_coord)];
}
}
const AxisSet& broadcast_axes,
size_t elem_size);
}
}
}

View File

@ -197,11 +197,12 @@ namespace ngraph
if (!broadcast_axes.empty())
{
arg0_broadcast_vec.reserve(shape_size(arg0_br_target_shape));
broadcast(arg0_update,
arg0_broadcast_vec.data(),
broadcast(reinterpret_cast<const char*>(arg0_update),
reinterpret_cast<char*>(arg0_broadcast_vec.data()),
wip_arg0_shape,
arg0_br_target_shape,
broadcast_axes);
broadcast_axes,
sizeof(T));
arg0_update = arg0_broadcast_vec.data();
wip_arg0_shape = arg0_br_target_shape;
@ -216,11 +217,12 @@ namespace ngraph
if (!broadcast_axes.empty())
{
arg1_broadcast_vec.reserve(shape_size(arg1_br_target_shape));
broadcast(arg1_update,
arg1_broadcast_vec.data(),
broadcast(reinterpret_cast<const char*>(arg1_update),
reinterpret_cast<char*>(arg1_broadcast_vec.data()),
wip_arg1_shape,
arg1_br_target_shape,
broadcast_axes);
broadcast_axes,
sizeof(T));
arg1_update = arg1_broadcast_vec.data();
wip_arg1_shape = arg1_br_target_shape;

View File

@ -0,0 +1,55 @@
//*****************************************************************************
// Copyright 2017-2020 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#include "ngraph/runtime/reference/broadcast.hpp"
#include "ngraph/runtime/reference/tile.hpp"
namespace ngraph
{
namespace runtime
{
namespace reference
{
void broadcast(const char* arg,
char* out,
const Shape& in_shape,
const Shape& out_shape,
const AxisSet& broadcast_axes,
size_t elem_size)
{
const auto output_rank = std::max(in_shape.size(), out_shape.size());
Shape adjusted_in_shape = in_shape;
for (const auto& axis : broadcast_axes)
{
if (adjusted_in_shape.size() < output_rank)
{
adjusted_in_shape.insert(adjusted_in_shape.begin() + axis, 1);
}
}
Shape adjusted_out_shape = out_shape;
adjusted_out_shape.insert(
adjusted_out_shape.begin(), output_rank - adjusted_out_shape.size(), 1);
std::vector<int64_t> repeats(output_rank);
for (size_t i = 0; i < repeats.size(); ++i)
{
repeats[i] = adjusted_out_shape[i] / adjusted_in_shape[i];
}
return tile(arg, out, adjusted_in_shape, adjusted_out_shape, elem_size, repeats);
}
}
}
}

View File

@ -92,7 +92,7 @@ void op::util::BroadcastBase::validate_target_shape_numpy(const PartialShape& ar
return;
}
const auto arg_rank_length = arg_shape.rank().get_length();
auto start_axis = target_shape.size() - arg_rank_length;
const int64_t start_axis = target_shape.size() - arg_rank_length;
NODE_VALIDATION_CHECK(this,
start_axis >= 0,
"Broadcast target_shape has smaller rank ",
@ -357,18 +357,17 @@ std::pair<bool, AxisSet> op::util::BroadcastBase::get_broadcast_axes() const
return std::make_pair(axes_known, broadcast_axes);
}
template <element::Type_t ET>
bool op::util::BroadcastBase::evaluate(const HostTensorPtr& arg0,
const HostTensorPtr& out,
const AxisSet& broadcast_axes) const
{
OV_ITT_SCOPED_TASK(itt::domains::nGraphOp, "op::util::BroadcastBase::evaluate<ET>");
using T = typename element_type_traits<ET>::value_type;
runtime::reference::broadcast<T>((arg0->get_data_ptr<ET>()),
(out->get_data_ptr<ET>()),
arg0->get_shape(),
out->get_shape(),
broadcast_axes);
runtime::reference::broadcast(arg0->get_data_ptr<const char>(),
out->get_data_ptr<char>(),
arg0->get_shape(),
out->get_shape(),
broadcast_axes,
arg0->get_element_type().size());
return true;
}
@ -475,37 +474,11 @@ bool op::util::BroadcastBase::evaluate_broadcast(const HostTensorPtr& arg0,
// broadcast_axes not known deterministically
return false;
}
bool rc = true;
Shape in_shape = arg0->get_shape();
out->set_shape(output_shape);
out->set_element_type(arg0->get_element_type());
switch (arg0->get_element_type())
{
TYPE_CASE(boolean)(arg0, out, pair_broadcast_axes.second);
break;
TYPE_CASE(i8)(arg0, out, pair_broadcast_axes.second);
break;
TYPE_CASE(i16)(arg0, out, pair_broadcast_axes.second);
break;
TYPE_CASE(i32)(arg0, out, pair_broadcast_axes.second);
break;
TYPE_CASE(i64)(arg0, out, pair_broadcast_axes.second);
break;
TYPE_CASE(u8)(arg0, out, pair_broadcast_axes.second);
break;
TYPE_CASE(u16)(arg0, out, pair_broadcast_axes.second);
break;
TYPE_CASE(u32)(arg0, out, pair_broadcast_axes.second);
break;
TYPE_CASE(u64)(arg0, out, pair_broadcast_axes.second);
break;
TYPE_CASE(f16)(arg0, out, pair_broadcast_axes.second);
break;
TYPE_CASE(f32)(arg0, out, pair_broadcast_axes.second);
break;
default: rc = false; break;
}
return rc;
return evaluate(arg0, out, pair_broadcast_axes.second);
}
Shape op::util::BroadcastBase::get_target_shape(const HostTensorPtr& input1) const

View File

@ -247,6 +247,43 @@ TEST(eval, evaluate_broadcast_v3_bidirectional)
ASSERT_EQ(result_val, expec);
}
TEST(eval, evaluate_broadcast_v3_bidirectional_target_rank_smaller_than_input)
{
Shape shape_a{1, 1, 1, 1, 1, 1, 1, 1};
auto A = make_shared<op::Parameter>(element::f32, shape_a);
auto target_shape = op::Constant::create<int64_t>(element::i64, Shape{4}, {1, 3, 1, 1});
auto bcast_v3 =
make_shared<op::v3::Broadcast>(A, target_shape, op::BroadcastType::BIDIRECTIONAL);
auto fun = make_shared<Function>(OutputVector{bcast_v3}, ParameterVector{A});
auto result = make_shared<HostTensor>();
ASSERT_TRUE(fun->evaluate({result}, {make_host_tensor<element::Type_t::f32>(shape_a, {1.0f})}));
EXPECT_EQ(result->get_element_type(), element::f32);
EXPECT_EQ(result->get_partial_shape(), (PartialShape{1, 1, 1, 1, 1, 3, 1, 1}));
auto result_val = read_vector<float>(result);
vector<float> expec{1.0f, 1.0f, 1.0f};
ASSERT_EQ(result_val, expec);
}
TEST(eval, evaluate_broadcast_v3_bidirectional_target_rank_smaller_than_input_2)
{
Shape shape_a{1, 3, 1};
auto A = make_shared<op::Parameter>(element::f32, shape_a);
auto target_shape = op::Constant::create<int32_t>(element::i32, Shape{2}, {3, 1});
auto bcast_v3 =
make_shared<op::v3::Broadcast>(A, target_shape, op::BroadcastType::BIDIRECTIONAL);
auto fun = make_shared<Function>(OutputVector{bcast_v3}, ParameterVector{A});
auto result = make_shared<HostTensor>();
ASSERT_TRUE(fun->evaluate(
{result}, {make_host_tensor<element::Type_t::f32>(Shape{1, 3, 1}, {1.0f, 2.0f, 3.0f})}));
EXPECT_EQ(result->get_element_type(), element::f32);
EXPECT_EQ(result->get_partial_shape(), (PartialShape{1, 3, 1}));
auto result_val = read_vector<float>(result);
vector<float> expec{1.0f, 2.0f, 3.0f};
ASSERT_EQ(result_val, expec);
}
TEST(eval, evaluate_broadcast_v3_bidirectional_dyn)
{
Shape shape_a{4, 1};

View File

@ -39,7 +39,14 @@ namespace opset1_downgrade
{
const auto const_filled_with_ones = make_shared<op::v1::Broadcast>(
op::Constant::create(data->get_element_type(), {}, {1}), target_shape);
replacement_node = make_shared<op::v1::Multiply>(data, const_filled_with_ones);
if (const_filled_with_ones->get_element_type() == element::boolean)
{
replacement_node = make_shared<op::v1::LogicalOr>(data, const_filled_with_ones);
}
else
{
replacement_node = make_shared<op::v1::Multiply>(data, const_filled_with_ones);
}
break;
}
case op::BroadcastType::EXPLICIT: