[Transformations] BroadcastTransition transformation (#16861)
This commit is contained in:
parent
70d80a750f
commit
31efdfd00d
@ -0,0 +1,28 @@
|
||||
// Copyright (C) 2018-2023 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <memory>
|
||||
#include <openvino/pass/graph_rewrite.hpp>
|
||||
#include <openvino/pass/pattern/matcher.hpp>
|
||||
#include <transformations_visibility.hpp>
|
||||
|
||||
namespace ov {
|
||||
namespace pass {
|
||||
|
||||
class TRANSFORMATIONS_API BroadcastTransition;
|
||||
|
||||
} // namespace pass
|
||||
} // namespace ov
|
||||
|
||||
/**
|
||||
* @ingroup ie_transformation_common_api
|
||||
* @brief BroadcastTransition transformation moves broadcast through binary eltwise operation
|
||||
*/
|
||||
class ov::pass::BroadcastTransition : public ov::pass::MatcherPass {
|
||||
public:
|
||||
OPENVINO_RTTI("BroadcastTransition", "0");
|
||||
BroadcastTransition();
|
||||
};
|
@ -54,13 +54,13 @@ bool can_eliminate_broadcast(const ngraph::Output<ngraph::Node>& eltwise,
|
||||
// input_shape will be broadcast
|
||||
return false;
|
||||
}
|
||||
} else if (input_shape[i_dim].is_dynamic() && broadcast_shape[i_dim].is_static() &&
|
||||
broadcast_shape[i_dim].get_length() != 1) {
|
||||
} else if (input_shape[i_dim].is_dynamic() && broadcast_shape[b_dim].is_static() &&
|
||||
broadcast_shape[b_dim].get_length() != 1) {
|
||||
return false;
|
||||
} else if (broadcast_shape[i_dim].is_dynamic() && input_shape[i_dim].is_static() &&
|
||||
} else if (broadcast_shape[b_dim].is_dynamic() && input_shape[i_dim].is_static() &&
|
||||
input_shape[i_dim].get_length() == 1) {
|
||||
return false;
|
||||
} else if (broadcast_shape[i_dim].is_dynamic() && input_shape[i_dim].is_dynamic()) {
|
||||
} else if (broadcast_shape[b_dim].is_dynamic() && input_shape[i_dim].is_dynamic()) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
@ -0,0 +1,87 @@
|
||||
// Copyright (C) 2018-2023 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#include "transformations/common_optimizations/broadcast_transition.hpp"
|
||||
|
||||
#include <memory>
|
||||
#include <openvino/opsets/opset1.hpp>
|
||||
#include <openvino/opsets/opset10.hpp>
|
||||
#include <openvino/pass/pattern/op/or.hpp>
|
||||
#include <openvino/pass/pattern/op/wrap_type.hpp>
|
||||
#include <vector>
|
||||
|
||||
#include "itt.hpp"
|
||||
#include "transformations/utils/utils.hpp"
|
||||
|
||||
ov::pass::BroadcastTransition::BroadcastTransition() {
|
||||
MATCHER_SCOPE(BroadcastTransition);
|
||||
auto bcast_m = pass::pattern::wrap_type<opset1::Broadcast, opset10::Broadcast>(pass::pattern::consumers_count(1));
|
||||
auto eltwise_input_m = pass::pattern::any_input(pass::pattern::has_static_rank());
|
||||
auto eltwise_1 = pass::pattern::wrap_type<op::util::BinaryElementwiseArithmetic>({eltwise_input_m, bcast_m});
|
||||
auto eltwise_2 = pass::pattern::wrap_type<op::util::BinaryElementwiseArithmetic>({bcast_m, eltwise_input_m});
|
||||
auto eltwise_m = std::make_shared<pass::pattern::op::Or>(OutputVector{eltwise_1, eltwise_2});
|
||||
|
||||
ov::matcher_pass_callback callback = [=](ov::pass::pattern::Matcher& m) {
|
||||
const auto& pattern_map = m.get_pattern_value_map();
|
||||
const auto eltwise = ov::as_type_ptr<ov::op::util::BinaryElementwiseArithmetic>(m.get_match_root());
|
||||
if (eltwise->get_autob().m_type != ov::op::AutoBroadcastType::NUMPY) {
|
||||
return false;
|
||||
}
|
||||
|
||||
const auto bcast = ov::as_type_ptr<ov::opset10::Broadcast>(pattern_map.at(bcast_m).get_node_shared_ptr());
|
||||
const auto& bcast_type = bcast->get_broadcast_spec().m_type;
|
||||
if (bcast_type != ov::op::BroadcastType::NUMPY && bcast_type != ov::op::BroadcastType::BIDIRECTIONAL) {
|
||||
return false;
|
||||
}
|
||||
|
||||
const auto& eltwise_input = pattern_map.at(eltwise_input_m);
|
||||
const auto& bcast_data = bcast->input_value(0);
|
||||
// inputs order mustn't be changed because an eltwise might be not commutative
|
||||
ov::OutputVector new_inputs{
|
||||
eltwise->get_input_node_ptr(0) == eltwise_input.get_node() ? eltwise_input : bcast_data,
|
||||
eltwise->get_input_node_ptr(1) == bcast.get() ? bcast_data : eltwise_input};
|
||||
const auto new_eltwise = eltwise->clone_with_new_inputs(new_inputs);
|
||||
ov::copy_runtime_info(eltwise, new_eltwise);
|
||||
|
||||
auto target_shape = bcast->input_value(1);
|
||||
const auto& target_shape_et = target_shape.get_element_type();
|
||||
|
||||
std::shared_ptr<ov::Node> data_shape_path;
|
||||
if (target_shape_et == ov::element::i32 || target_shape_et == ov::element::i64) {
|
||||
data_shape_path = ov::op::util::make_try_fold<opset10::ShapeOf>(new_eltwise, target_shape_et);
|
||||
ov::copy_runtime_info(eltwise, data_shape_path);
|
||||
} else {
|
||||
auto shapeof = ov::op::util::make_try_fold<opset10::ShapeOf>(new_eltwise);
|
||||
data_shape_path = ov::op::util::make_try_fold<ov::opset10::Convert>(shapeof, target_shape_et);
|
||||
ov::copy_runtime_info(eltwise, {shapeof, data_shape_path});
|
||||
}
|
||||
|
||||
const size_t target_shape_rank = target_shape.get_partial_shape()[0].get_length();
|
||||
const size_t input_rank = new_eltwise->get_output_partial_shape(0).size();
|
||||
if (input_rank != target_shape_rank) {
|
||||
auto align_rank = [&](const ov::Output<ov::Node>& out, const size_t count) {
|
||||
const auto constant = ov::opset10::Constant::create(target_shape_et, {count}, {1});
|
||||
const auto res = ov::op::util::make_try_fold<ov::opset10::Concat>(ov::OutputVector{constant, out}, 0);
|
||||
ov::copy_runtime_info(out.get_node_shared_ptr(), {constant, res});
|
||||
return res;
|
||||
};
|
||||
if (input_rank < target_shape_rank) {
|
||||
data_shape_path = align_rank(data_shape_path, target_shape_rank - input_rank);
|
||||
} else {
|
||||
target_shape = align_rank(target_shape, input_rank - target_shape_rank);
|
||||
}
|
||||
}
|
||||
const auto new_target_shape = ov::op::util::make_try_fold<opset10::Maximum>(data_shape_path, target_shape);
|
||||
ov::copy_runtime_info(eltwise, new_target_shape);
|
||||
|
||||
const auto new_bcast = std::make_shared<ov::opset10::Broadcast>(new_eltwise, new_target_shape);
|
||||
new_bcast->set_friendly_name(eltwise->get_friendly_name());
|
||||
ov::copy_runtime_info(eltwise, {new_eltwise, new_bcast});
|
||||
ov::replace_node(eltwise, new_bcast);
|
||||
return true;
|
||||
};
|
||||
|
||||
auto m = std::make_shared<ov::pass::pattern::Matcher>(eltwise_m, matcher_name);
|
||||
register_matcher(m, callback);
|
||||
}
|
@ -333,3 +333,16 @@ TEST_F(TransformationTestsF, BroadcastElementwiseFusionWithShapeOfNeg) {
|
||||
manager.register_pass<ov::pass::BroadcastElementwiseFusion>();
|
||||
}
|
||||
}
|
||||
|
||||
TEST_F(TransformationTestsF, BroadcastElementwiseFusionDynShapesDifferentRanks) {
|
||||
{
|
||||
auto input = std::make_shared<ngraph::opset5::Parameter>(ov::element::f32, ov::PartialShape{-1, -1, -1, -1});
|
||||
auto target_shape = std::make_shared<ngraph::opset5::Parameter>(ov::element::i32, ov::PartialShape{2});
|
||||
auto constant = ngraph::opset5::Constant::create(ov::element::f32, {}, {1.f});
|
||||
auto broadcast = std::make_shared<ngraph::opset5::Broadcast>(constant, target_shape);
|
||||
auto elementwise = std::make_shared<ngraph::opset5::Add>(input, broadcast);
|
||||
function = std::make_shared<ov::Model>(ov::NodeVector{elementwise}, ov::ParameterVector{input, target_shape});
|
||||
|
||||
manager.register_pass<ov::pass::BroadcastElementwiseFusion>();
|
||||
}
|
||||
}
|
||||
|
@ -0,0 +1,326 @@
|
||||
// Copyright (C) 2018-2023 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#include <gtest/gtest.h>
|
||||
|
||||
#include <memory>
|
||||
#include <openvino/opsets/opset10.hpp>
|
||||
#include <string>
|
||||
#include <transformations/common_optimizations/broadcast_transition.hpp>
|
||||
|
||||
#include "common_test_utils/ngraph_test_utils.hpp"
|
||||
|
||||
using namespace testing;
|
||||
|
||||
std::shared_ptr<ov::Node> getOperation(
|
||||
const ov::Output<ov::Node>& in1,
|
||||
const ov::Output<ov::Node>& in2,
|
||||
const std::string& operation_type,
|
||||
const ov::op::AutoBroadcastType& eltwise_bcast_type = ov::op::AutoBroadcastType::NUMPY) {
|
||||
if (operation_type == "Add") {
|
||||
return std::make_shared<ov::opset10::Add>(in1, in2, eltwise_bcast_type);
|
||||
} else if (operation_type == "Multiply") {
|
||||
return std::make_shared<ov::opset10::Multiply>(in1, in2, eltwise_bcast_type);
|
||||
} else if (operation_type == "Subtract") {
|
||||
return std::make_shared<ov::opset10::Subtract>(in1, in2, eltwise_bcast_type);
|
||||
} else {
|
||||
throw std::runtime_error("Unexpected operation type");
|
||||
}
|
||||
}
|
||||
|
||||
std::shared_ptr<ov::Model> getOriginal(
|
||||
const ov::element::Type& precision,
|
||||
const ov::PartialShape& input_shape,
|
||||
const ov::Shape& target_shape,
|
||||
const ov::op::BroadcastType& bcast_mode,
|
||||
const std::string& operation_type,
|
||||
const size_t idx,
|
||||
const ov::op::AutoBroadcastType& eltwise_bcast_type = ov::op::AutoBroadcastType::NUMPY) {
|
||||
const auto input = std::make_shared<ov::opset10::Parameter>(precision, input_shape);
|
||||
const auto data_constant = ov::opset10::Constant::create(precision, {}, {1.f});
|
||||
const auto target_shape_node = ov::opset10::Constant::create(ov::element::i32, {target_shape.size()}, target_shape);
|
||||
const auto bcast = std::make_shared<ov::opset10::Broadcast>(data_constant, target_shape_node, bcast_mode);
|
||||
|
||||
const auto fst_in = idx == 0 ? bcast->output(0) : input->output(0);
|
||||
const auto sec_in = idx == 1 ? bcast->output(0) : input->output(0);
|
||||
const auto operation = getOperation(fst_in, sec_in, operation_type, eltwise_bcast_type);
|
||||
return std::make_shared<ov::Model>(operation, ov::ParameterVector{input});
|
||||
}
|
||||
|
||||
std::shared_ptr<ov::Model> getReference(const ov::element::Type& precision,
|
||||
const ov::PartialShape& input_shape,
|
||||
const ov::Shape& original_target_shape,
|
||||
const std::string& operation_type,
|
||||
const size_t idx) {
|
||||
const auto input = std::make_shared<ov::opset10::Parameter>(precision, input_shape);
|
||||
const auto data_constant = ov::opset10::Constant::create(precision, {}, {1.f});
|
||||
|
||||
const auto fst_in = idx == 0 ? data_constant->output(0) : input->output(0);
|
||||
const auto sec_in = idx == 1 ? data_constant->output(0) : input->output(0);
|
||||
const auto operation = getOperation(fst_in, sec_in, operation_type, ov::op::AutoBroadcastType::NUMPY);
|
||||
|
||||
const auto target_shape = [&]() {
|
||||
auto new_shape = original_target_shape;
|
||||
auto op_shape = operation->get_shape();
|
||||
while (new_shape.size() < op_shape.size())
|
||||
new_shape.insert(new_shape.begin(), 1);
|
||||
while (op_shape.size() < new_shape.size())
|
||||
op_shape.insert(op_shape.begin(), 1);
|
||||
|
||||
for (size_t i = 0; i < new_shape.size(); ++i) {
|
||||
new_shape[i] = std::max(new_shape[i], op_shape[i]);
|
||||
}
|
||||
return new_shape;
|
||||
}();
|
||||
|
||||
const auto target_shape_node = ov::opset10::Constant::create(ov::element::i32, {target_shape.size()}, target_shape);
|
||||
const auto bcast = std::make_shared<ov::opset10::Broadcast>(operation, target_shape_node);
|
||||
return std::make_shared<ov::Model>(bcast, ov::ParameterVector{input});
|
||||
}
|
||||
|
||||
using BroadcastTransitionParams = std::tuple<ov::element::Type, // precision
|
||||
ov::Shape, // input shape
|
||||
ov::Shape, // target shape
|
||||
ov::op::BroadcastType, // broadcast mode
|
||||
std::string, // operation type
|
||||
size_t // broadcast input index
|
||||
>;
|
||||
|
||||
class StaticBroadcastTransitionTests : public testing::WithParamInterface<BroadcastTransitionParams>,
|
||||
public TransformationTestsF {
|
||||
public:
|
||||
StaticBroadcastTransitionTests() : TransformationTestsF() {
|
||||
comparator.enable(FunctionsComparator::ATTRIBUTES);
|
||||
}
|
||||
|
||||
static std::string getTestCaseName(const testing::TestParamInfo<BroadcastTransitionParams>& obj) {
|
||||
ov::element::Type precision;
|
||||
ov::Shape input_shape;
|
||||
ov::Shape target_shape;
|
||||
ov::op::BroadcastType bcast_mode;
|
||||
std::string operation_type;
|
||||
size_t idx;
|
||||
std::tie(precision, input_shape, target_shape, bcast_mode, operation_type, idx) = obj.param;
|
||||
|
||||
std::ostringstream result;
|
||||
result << operation_type << "_prc=" << precision << "_IS=" << input_shape << "_TS=" << target_shape
|
||||
<< "_bcast_idx=" << idx << "_bcast_type=" << bcast_mode;
|
||||
return result.str();
|
||||
}
|
||||
|
||||
protected:
|
||||
void SetUp() override {
|
||||
TransformationTestsF::SetUp();
|
||||
ov::element::Type precision;
|
||||
ov::Shape input_shape;
|
||||
ov::Shape target_shape;
|
||||
ov::op::BroadcastType bcast_mode;
|
||||
std::string operation_type;
|
||||
size_t idx;
|
||||
std::tie(precision, input_shape, target_shape, bcast_mode, operation_type, idx) = GetParam();
|
||||
|
||||
manager.register_pass<ov::pass::BroadcastTransition>();
|
||||
model = getOriginal(precision, input_shape, target_shape, bcast_mode, operation_type, idx);
|
||||
model_ref = getReference(precision, input_shape, target_shape, operation_type, idx);
|
||||
}
|
||||
};
|
||||
|
||||
TEST_P(StaticBroadcastTransitionTests, BroadcastTransition) {}
|
||||
|
||||
namespace BroadcastTransitionTestsInstantiation {
|
||||
std::vector<ov::Shape> input_shapes = {
|
||||
{1, 3, 16, 16},
|
||||
{1, 3, 1, 16},
|
||||
{16, 16},
|
||||
};
|
||||
|
||||
std::vector<ov::Shape> target_shapes = {
|
||||
{1, 3, 16, 1},
|
||||
{16, 16},
|
||||
};
|
||||
|
||||
std::vector<ov::op::BroadcastType> bcast_modes = {ov::op::BroadcastType::NUMPY, ov::op::BroadcastType::BIDIRECTIONAL};
|
||||
|
||||
std::vector<std::string> operation_types = {"Add", "Multiply", "Subtract"};
|
||||
std::vector<size_t> bcast_input_idx = {0, 1};
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(TransformationTestsF,
|
||||
StaticBroadcastTransitionTests,
|
||||
::testing::Combine(::testing::Values(ov::element::f32),
|
||||
::testing::ValuesIn(input_shapes),
|
||||
::testing::ValuesIn(target_shapes),
|
||||
::testing::ValuesIn(bcast_modes),
|
||||
::testing::ValuesIn(operation_types),
|
||||
::testing::ValuesIn(bcast_input_idx)),
|
||||
StaticBroadcastTransitionTests::getTestCaseName);
|
||||
} // namespace BroadcastTransitionTestsInstantiation
|
||||
|
||||
TEST_F(TransformationTestsF, BroadcastTransitionTests_Dynamic_U32TargetShapePrecision) {
|
||||
const auto data_precision = ov::element::f32;
|
||||
const auto shape_precision = ov::element::u32;
|
||||
{
|
||||
const auto input = std::make_shared<ov::opset10::Parameter>(data_precision, ov::PartialShape::dynamic(4));
|
||||
const auto target_shape = std::make_shared<ov::opset10::Parameter>(shape_precision, ov::PartialShape{4});
|
||||
|
||||
const auto data_constant = ov::opset10::Constant::create(data_precision, {}, {1.f});
|
||||
const auto bcast = std::make_shared<ov::opset10::Broadcast>(data_constant, target_shape);
|
||||
const auto operation = getOperation(input, bcast, "Add");
|
||||
model = std::make_shared<ov::Model>(operation, ov::ParameterVector{input, target_shape});
|
||||
}
|
||||
manager.register_pass<ov::pass::BroadcastTransition>();
|
||||
{
|
||||
const auto input = std::make_shared<ov::opset10::Parameter>(data_precision, ov::PartialShape::dynamic(4));
|
||||
const auto target_shape = std::make_shared<ov::opset10::Parameter>(shape_precision, ov::PartialShape{4});
|
||||
|
||||
const auto data_constant = ov::opset10::Constant::create(data_precision, {}, {1.f});
|
||||
const auto operation = getOperation(input, data_constant, "Add");
|
||||
const auto shapeof = std::make_shared<ov::opset10::ShapeOf>(operation);
|
||||
const auto convert = std::make_shared<ov::opset10::Convert>(shapeof, shape_precision);
|
||||
const auto max = std::make_shared<ov::opset10::Maximum>(convert, target_shape);
|
||||
const auto bcast = std::make_shared<ov::opset10::Broadcast>(operation, max);
|
||||
model_ref = std::make_shared<ov::Model>(bcast, ov::ParameterVector{input, target_shape});
|
||||
}
|
||||
}
|
||||
|
||||
TEST_F(TransformationTestsF, BroadcastTransitionTests_Dynamic_EqualRanks) {
|
||||
const auto data_precision = ov::element::f32;
|
||||
const auto shape_precision = ov::element::i32;
|
||||
{
|
||||
const auto input = std::make_shared<ov::opset10::Parameter>(data_precision, ov::PartialShape::dynamic(4));
|
||||
const auto target_shape = std::make_shared<ov::opset10::Parameter>(shape_precision, ov::PartialShape{4});
|
||||
|
||||
const auto data_constant = ov::opset10::Constant::create(data_precision, {}, {1.f});
|
||||
const auto bcast = std::make_shared<ov::opset10::Broadcast>(data_constant, target_shape);
|
||||
const auto operation = getOperation(input, bcast, "Add");
|
||||
model = std::make_shared<ov::Model>(operation, ov::ParameterVector{input, target_shape});
|
||||
}
|
||||
manager.register_pass<ov::pass::BroadcastTransition>();
|
||||
{
|
||||
const auto input = std::make_shared<ov::opset10::Parameter>(data_precision, ov::PartialShape::dynamic(4));
|
||||
const auto target_shape = std::make_shared<ov::opset10::Parameter>(shape_precision, ov::PartialShape{4});
|
||||
|
||||
const auto data_constant = ov::opset10::Constant::create(data_precision, {}, {1.f});
|
||||
const auto operation = getOperation(input, data_constant, "Add");
|
||||
const auto shapeof = std::make_shared<ov::opset10::ShapeOf>(operation, shape_precision);
|
||||
const auto max = std::make_shared<ov::opset10::Maximum>(shapeof, target_shape);
|
||||
const auto bcast = std::make_shared<ov::opset10::Broadcast>(operation, max);
|
||||
model_ref = std::make_shared<ov::Model>(bcast, ov::ParameterVector{input, target_shape});
|
||||
}
|
||||
}
|
||||
|
||||
TEST_F(TransformationTestsF, BroadcastTransitionTests_Dynamic_DataRankLessThanTarget) {
|
||||
const auto data_precision = ov::element::f32;
|
||||
const auto shape_precision = ov::element::i32;
|
||||
{
|
||||
const auto input = std::make_shared<ov::opset10::Parameter>(data_precision, ov::PartialShape::dynamic(2));
|
||||
const auto target_shape = std::make_shared<ov::opset10::Parameter>(shape_precision, ov::PartialShape{4});
|
||||
|
||||
const auto data_constant = ov::opset10::Constant::create(data_precision, {}, {1.f});
|
||||
const auto bcast = std::make_shared<ov::opset10::Broadcast>(data_constant, target_shape);
|
||||
const auto operation = getOperation(input, bcast, "Add");
|
||||
model = std::make_shared<ov::Model>(operation, ov::ParameterVector{input, target_shape});
|
||||
}
|
||||
manager.register_pass<ov::pass::BroadcastTransition>();
|
||||
{
|
||||
const auto input = std::make_shared<ov::opset10::Parameter>(data_precision, ov::PartialShape::dynamic(2));
|
||||
const auto target_shape = std::make_shared<ov::opset10::Parameter>(shape_precision, ov::PartialShape{4});
|
||||
|
||||
const auto data_constant = ov::opset10::Constant::create(data_precision, {}, {1.f});
|
||||
const auto operation = getOperation(input, data_constant, "Add");
|
||||
const auto shapeof = std::make_shared<ov::opset10::ShapeOf>(operation, shape_precision);
|
||||
const auto constant = ov::opset10::Constant::create(shape_precision, {2}, {1});
|
||||
const auto concat = std::make_shared<ov::opset10::Concat>(ov::OutputVector{constant, shapeof}, 0);
|
||||
const auto max = std::make_shared<ov::opset10::Maximum>(concat, target_shape);
|
||||
const auto bcast = std::make_shared<ov::opset10::Broadcast>(operation, max);
|
||||
model_ref = std::make_shared<ov::Model>(bcast, ov::ParameterVector{input, target_shape});
|
||||
}
|
||||
}
|
||||
|
||||
TEST_F(TransformationTestsF, BroadcastTransitionTests_Dynamic_DataRankGreaterThanTarget) {
|
||||
const auto data_precision = ov::element::f32;
|
||||
const auto shape_precision = ov::element::i32;
|
||||
{
|
||||
const auto input = std::make_shared<ov::opset10::Parameter>(data_precision, ov::PartialShape::dynamic(4));
|
||||
const auto target_shape = std::make_shared<ov::opset10::Parameter>(shape_precision, ov::PartialShape{2});
|
||||
|
||||
const auto data_constant = ov::opset10::Constant::create(data_precision, {}, {1.f});
|
||||
const auto bcast = std::make_shared<ov::opset10::Broadcast>(data_constant, target_shape);
|
||||
const auto operation = getOperation(input, bcast, "Add");
|
||||
model = std::make_shared<ov::Model>(operation, ov::ParameterVector{input, target_shape});
|
||||
}
|
||||
manager.register_pass<ov::pass::BroadcastTransition>();
|
||||
{
|
||||
const auto input = std::make_shared<ov::opset10::Parameter>(data_precision, ov::PartialShape::dynamic(4));
|
||||
const auto target_shape = std::make_shared<ov::opset10::Parameter>(shape_precision, ov::PartialShape{2});
|
||||
|
||||
const auto data_constant = ov::opset10::Constant::create(data_precision, {}, {1.f});
|
||||
const auto operation = getOperation(input, data_constant, "Add");
|
||||
const auto shapeof = std::make_shared<ov::opset10::ShapeOf>(operation, shape_precision);
|
||||
const auto constant = ov::opset10::Constant::create(shape_precision, {2}, {1});
|
||||
const auto concat = std::make_shared<ov::opset10::Concat>(ov::OutputVector{constant, target_shape}, 0);
|
||||
const auto max = std::make_shared<ov::opset10::Maximum>(shapeof, concat);
|
||||
const auto bcast = std::make_shared<ov::opset10::Broadcast>(operation, max);
|
||||
model_ref = std::make_shared<ov::Model>(bcast, ov::ParameterVector{input, target_shape});
|
||||
}
|
||||
}
|
||||
|
||||
TEST_F(TransformationTestsF, BroadcastTransitionTests_Negative_ExplicitEltwiseBcast) {
|
||||
model = getOriginal(ov::element::f32,
|
||||
ov::PartialShape{1, 3, 16, 16},
|
||||
ov::Shape{1, 3, 16, 16},
|
||||
ov::op::BroadcastType::NUMPY,
|
||||
"Add",
|
||||
0,
|
||||
ov::op::AutoBroadcastType::EXPLICIT);
|
||||
manager.register_pass<ov::pass::BroadcastTransition>();
|
||||
}
|
||||
|
||||
TEST_F(TransformationTestsF, BroadcastTransitionTests_Negative_PDPDEltwiseBcast) {
|
||||
model = getOriginal(ov::element::f32,
|
||||
ov::PartialShape{1, 3, 16, 16},
|
||||
ov::Shape{1, 3, 16, 16},
|
||||
ov::op::BroadcastType::NUMPY,
|
||||
"Add",
|
||||
0,
|
||||
ov::op::AutoBroadcastType::PDPD);
|
||||
manager.register_pass<ov::pass::BroadcastTransition>();
|
||||
}
|
||||
|
||||
TEST_F(TransformationTestsF, BroadcastTransitionTests_Negative_PDPDBcastType) {
|
||||
const auto input = std::make_shared<ov::opset10::Parameter>(ov::element::f32, ov::PartialShape{1, 3, 16, 16});
|
||||
|
||||
const auto data_constant = ov::opset10::Constant::create(ov::element::f32, {1, 1, 1}, {1.f});
|
||||
const auto target_shape_node = ov::opset10::Constant::create(ov::element::i32, {3}, {1, 16, 16});
|
||||
const ov::op::BroadcastModeSpec pdpd_spec(ov::op::BroadcastType::PDPD);
|
||||
const auto bcast = std::make_shared<ov::opset10::Broadcast>(data_constant, target_shape_node, pdpd_spec);
|
||||
const auto add = std::make_shared<ov::opset10::Add>(input, bcast);
|
||||
|
||||
model = std::make_shared<ov::Model>(add, ov::ParameterVector{input});
|
||||
manager.register_pass<ov::pass::BroadcastTransition>();
|
||||
}
|
||||
|
||||
TEST_F(TransformationTestsF, BroadcastTransitionTests_Negative_WithAxesMapping) {
|
||||
const auto input = std::make_shared<ov::opset10::Parameter>(ov::element::f32, ov::PartialShape{1, 3, 16, 16});
|
||||
const auto data_constant = ov::opset10::Constant::create(ov::element::f32, {16, 16}, {1.f});
|
||||
|
||||
const auto target_shape_node = ov::opset10::Constant::create(ov::element::i32, {3}, {1, 16, 16});
|
||||
const auto axes_node = ov::opset10::Constant::create(ov::element::i32, {2}, {1, 2});
|
||||
const auto bcast = std::make_shared<ov::opset10::Broadcast>(data_constant, target_shape_node, axes_node);
|
||||
const auto add = std::make_shared<ov::opset10::Add>(input, bcast);
|
||||
|
||||
model = std::make_shared<ov::Model>(add, ov::ParameterVector{input});
|
||||
manager.register_pass<ov::pass::BroadcastTransition>();
|
||||
}
|
||||
|
||||
TEST_F(TransformationTestsF, BroadcastTransitionTests_Negative_DynamicRank) {
|
||||
const auto input = std::make_shared<ov::opset10::Parameter>(ov::element::f32, ov::PartialShape::dynamic());
|
||||
const auto data_constant = ov::opset10::Constant::create(ov::element::f32, {}, {1.f});
|
||||
|
||||
const auto target_shape_input = std::make_shared<ov::opset10::Parameter>(ov::element::i32, ov::PartialShape{-1});
|
||||
const auto bcast = std::make_shared<ov::opset10::Broadcast>(data_constant, target_shape_input);
|
||||
const auto add = std::make_shared<ov::opset10::Add>(input, bcast);
|
||||
|
||||
model = std::make_shared<ov::Model>(add, ov::ParameterVector{input, target_shape_input});
|
||||
manager.register_pass<ov::pass::BroadcastTransition>();
|
||||
}
|
@ -181,6 +181,20 @@ TYPED_TEST_P(BroadcastTests, broadcast_axes_wrong_rank) {
|
||||
}
|
||||
}
|
||||
|
||||
TYPED_TEST_P(BroadcastTests, broadcast_target_shape_wrong_rank) {
|
||||
auto arg = make_shared<op::Parameter>(element::f32, Shape{2, 4});
|
||||
auto bc_shape = make_shared<op::Parameter>(element::i64, Shape{});
|
||||
|
||||
try {
|
||||
auto bc = make_shared<TypeParam>(arg, bc_shape);
|
||||
FAIL() << "Broadcast: axes target shape rank not detected";
|
||||
} catch (const NodeValidationFailure& error) {
|
||||
EXPECT_HAS_SUBSTRING(error.what(), "Broadcast shape rank must be 1, but has");
|
||||
} catch (...) {
|
||||
FAIL() << "Deduced type check failed for unexpected reason";
|
||||
}
|
||||
}
|
||||
|
||||
TYPED_TEST_P(BroadcastTests, broadcast_fully_dynamic_target_shape) {
|
||||
auto arg = make_shared<op::Parameter>(element::f32, Shape{2, 4});
|
||||
auto bc_shape = make_shared<op::Parameter>(element::i64, PartialShape::dynamic());
|
||||
@ -559,6 +573,7 @@ REGISTER_TYPED_TEST_SUITE_P(BroadcastTests,
|
||||
broadcast_fail_axes_map,
|
||||
broadcast_fail_axes_map_shape,
|
||||
broadcast_axes_wrong_rank,
|
||||
broadcast_target_shape_wrong_rank,
|
||||
broadcast_fully_dynamic_target_shape,
|
||||
broadcast_dynamic_values_of_target_shape,
|
||||
broadcast_broadcast_shape_et_wrong,
|
||||
|
@ -4,6 +4,7 @@
|
||||
|
||||
#include "tile_broadcast_utils.h"
|
||||
|
||||
#include "cpu_convert.h"
|
||||
#include "cpu_memcpy.h"
|
||||
#include "ie_parallel.hpp"
|
||||
#include <memory_desc/cpu_memory_desc_utils.h>
|
||||
@ -250,7 +251,10 @@ void TileBroadcastCommon::optimizedExecute(const MemoryPtr& srcMemory, const Mem
|
||||
auto srcData = reinterpret_cast<const char *>(srcMemory->GetPtr());
|
||||
auto dstData = reinterpret_cast<char *>(dstMemory->GetPtr());
|
||||
|
||||
if (optimizedParams.srcStrides[5] == 0) {
|
||||
if (srcMemory->getStaticDims() == dstMemory->getStaticDims()) {
|
||||
const auto prc = dstMemory->getDesc().getPrecision();
|
||||
cpu_convert(srcData, dstData, prc, prc, optimizedParams.copySize / prc.size());
|
||||
} else if (optimizedParams.srcStrides[5] == 0) {
|
||||
if (optimizedParams.dstStrides[0] == optimizedParams.dims[5] * optimizedParams.dstStrides[5]) {
|
||||
size_t data_size = optimizedParams.dstStrides[5];
|
||||
size_t elt_cnt = optimizedParams.dims[5];
|
||||
|
@ -18,6 +18,7 @@
|
||||
|
||||
// Common transformations
|
||||
#include "transformations/common_optimizations/add_fake_quantize_fusion.hpp"
|
||||
#include "transformations/common_optimizations/broadcast_transition.hpp"
|
||||
#include "transformations/common_optimizations/convert_compression_only_to_legacy.hpp"
|
||||
#include "transformations/common_optimizations/convert_quantize_dequantize.hpp"
|
||||
#include "transformations/common_optimizations/fq_mul_fusion.hpp"
|
||||
@ -225,6 +226,7 @@ void Transformations::PreLpt(const std::vector<ov::element::Type>& defaultPrecis
|
||||
type_to_fuse_map type_to_fuse = {{ov::opset10::Convert::get_type_info_static(), fuse_type_to_convert}};
|
||||
|
||||
CPU_REGISTER_PASS_COMMON(manager, ov::pass::AUGRUCellFusion);
|
||||
CPU_REGISTER_PASS_COMMON(manager, ov::pass::BroadcastTransition);
|
||||
CPU_REGISTER_PASS_COMMON(manager, ov::pass::CommonOptimizations);
|
||||
CPU_REGISTER_PASS_COMMON(manager, ov::pass::WrapInterpolateIntoTransposes);
|
||||
CPU_REGISTER_PASS_COMMON(manager, ov::pass::TransposeSinking);
|
||||
|
@ -0,0 +1,108 @@
|
||||
// Copyright (C) 2023 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#include <openvino/opsets/opset1.hpp>
|
||||
#include <common_test_utils/ov_tensor_utils.hpp>
|
||||
|
||||
#include "ngraph_functions/builders.hpp"
|
||||
#include "ngraph_functions/utils/ngraph_helpers.hpp"
|
||||
#include "shared_test_classes/base/layer_test_utils.hpp"
|
||||
#include "shared_test_classes/base/ov_subgraph.hpp"
|
||||
#include "test_utils/cpu_test_utils.hpp"
|
||||
|
||||
using namespace ngraph;
|
||||
using namespace ov::test;
|
||||
using namespace CPUTestUtils;
|
||||
using namespace InferenceEngine;
|
||||
|
||||
namespace SubgraphTestsDefinitions {
|
||||
using BroadcastEltwiseParams = std::tuple<
|
||||
ElementType, // input precision
|
||||
InputShape, // input shape
|
||||
ov::Shape // target broadcast shape
|
||||
>;
|
||||
|
||||
class BroadcastEltwise : virtual public SubgraphBaseTest,
|
||||
public CPUTestsBase,
|
||||
public testing::WithParamInterface<BroadcastEltwiseParams> {
|
||||
public:
|
||||
static std::string getTestCaseName(const testing::TestParamInfo<BroadcastEltwiseParams>& obj) {
|
||||
ElementType input_precision;
|
||||
InputShape input_shape;
|
||||
ov::Shape target_shape;
|
||||
std::tie(input_precision, input_shape, target_shape) = obj.param;
|
||||
|
||||
std::ostringstream result;
|
||||
result << "precision=" << input_precision << "IS=(" << CommonTestUtils::partialShape2str({input_shape.first}) << ")_TS=(";
|
||||
for (const auto& item : input_shape.second) {
|
||||
result << CommonTestUtils::vec2str(item) << "_";
|
||||
}
|
||||
result << ")_target_shape=" << CommonTestUtils::vec2str(target_shape);
|
||||
return result.str();
|
||||
}
|
||||
|
||||
protected:
|
||||
void SetUp() override {
|
||||
ElementType input_precision;
|
||||
InputShape input_shape;
|
||||
std::tie(input_precision, input_shape, target_shape) = GetParam();
|
||||
targetDevice = CommonTestUtils::DEVICE_CPU;
|
||||
|
||||
std::vector<InputShape> input_shapes{input_shape, {{}, {{target_shape.size()}}}};
|
||||
init_input_shapes(input_shapes);
|
||||
|
||||
ov::element::TypeVector input_precisions{input_precision, ov::element::i64};
|
||||
const auto params = ngraph::builder::makeDynamicParams(input_precisions, inputDynamicShapes);
|
||||
const auto bcast_data = ov::opset10::Constant::create(input_precision, {}, {1.f});
|
||||
const auto bcast = std::make_shared<ov::opset10::Broadcast>(bcast_data, params[1]);
|
||||
const auto add = std::make_shared<ov::opset10::Add>(params[0], bcast);
|
||||
function = std::make_shared<ov::Model>(add, params);
|
||||
}
|
||||
|
||||
void generate_inputs(const std::vector<ov::Shape>& targetInputStaticShapes) override {
|
||||
inputs.clear();
|
||||
const auto& funcInputs = function->inputs();
|
||||
auto data_tensor = ov::test::utils::create_and_fill_tensor(funcInputs[0].get_element_type(), targetInputStaticShapes[0]);
|
||||
inputs.insert({funcInputs[0].get_node_shared_ptr(), data_tensor});
|
||||
|
||||
auto shape_tensor = ov::Tensor{ov::element::i64, targetInputStaticShapes[1]};
|
||||
auto data = shape_tensor.data<ov::element_type_traits<ov::element::i64>::value_type>();
|
||||
for (size_t i = 0; i < target_shape.size(); i++) {
|
||||
data[i] = target_shape[i];
|
||||
}
|
||||
inputs.insert({funcInputs[1].get_node_shared_ptr(), shape_tensor});
|
||||
}
|
||||
|
||||
ov::Shape target_shape;
|
||||
};
|
||||
|
||||
TEST_P(BroadcastEltwise, smoke_CompareWithRefs) {
|
||||
run();
|
||||
|
||||
const auto model = compiledModel.get_runtime_model();
|
||||
const auto last_node = model->get_result()->get_input_node_shared_ptr(0);
|
||||
const auto& rt_info = last_node->get_rt_info();
|
||||
const auto layerType = rt_info.find("layerType")->second.as<std::string>();
|
||||
EXPECT_EQ(layerType, "Broadcast");
|
||||
}
|
||||
|
||||
namespace {
|
||||
const std::vector<InputShape> input_shapes = {
|
||||
{{-1, -1, -1, -1}, {{1, 3, 16, 16}}},
|
||||
{{-1, -1}, {{16, 16}}},
|
||||
};
|
||||
|
||||
const std::vector<ov::Shape> target_shapes = {
|
||||
{1, 3, 16, 1},
|
||||
{16, 16},
|
||||
};
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(smoke_BroadcastEltwise,
|
||||
BroadcastEltwise,
|
||||
::testing::Combine(::testing::Values(ov::element::f32),
|
||||
::testing::ValuesIn(input_shapes),
|
||||
::testing::ValuesIn(target_shapes)),
|
||||
BroadcastEltwise::getTestCaseName);
|
||||
} // namespace
|
||||
} // namespace SubgraphTestsDefinitions
|
Loading…
Reference in New Issue
Block a user