From 31efdfd00d67feed86a5ae40af458659ff7997a3 Mon Sep 17 00:00:00 2001 From: Vladislav Golubev Date: Fri, 21 Apr 2023 09:35:04 +0200 Subject: [PATCH] [Transformations] BroadcastTransition transformation (#16861) --- .../broadcast_transition.hpp | 28 ++ .../broadcast_elementwise_fusion.cpp | 8 +- .../broadcast_transition.cpp | 87 +++++ .../broadcast_elementwise_fusion_test.cpp | 13 + .../broadcast_transition_test.cpp | 326 ++++++++++++++++++ src/core/tests/type_prop/broadcast.cpp | 15 + .../src/nodes/common/tile_broadcast_utils.cpp | 6 +- .../transformation_pipeline.cpp | 2 + .../subgraph_tests/src/broadcast_eltwise.cpp | 108 ++++++ 9 files changed, 588 insertions(+), 5 deletions(-) create mode 100644 src/common/transformations/include/transformations/common_optimizations/broadcast_transition.hpp create mode 100644 src/common/transformations/src/transformations/common_optimizations/broadcast_transition.cpp create mode 100644 src/common/transformations/tests/common_optimizations/broadcast_transition_test.cpp create mode 100644 src/plugins/intel_cpu/tests/functional/subgraph_tests/src/broadcast_eltwise.cpp diff --git a/src/common/transformations/include/transformations/common_optimizations/broadcast_transition.hpp b/src/common/transformations/include/transformations/common_optimizations/broadcast_transition.hpp new file mode 100644 index 00000000000..77f1641ab76 --- /dev/null +++ b/src/common/transformations/include/transformations/common_optimizations/broadcast_transition.hpp @@ -0,0 +1,28 @@ +// Copyright (C) 2018-2023 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#include +#include +#include +#include + +namespace ov { +namespace pass { + +class TRANSFORMATIONS_API BroadcastTransition; + +} // namespace pass +} // namespace ov + +/** + * @ingroup ie_transformation_common_api + * @brief BroadcastTransition transformation moves broadcast through binary eltwise operation + */ +class ov::pass::BroadcastTransition : public ov::pass::MatcherPass { +public: + OPENVINO_RTTI("BroadcastTransition", "0"); + BroadcastTransition(); +}; diff --git a/src/common/transformations/src/transformations/common_optimizations/broadcast_elementwise_fusion.cpp b/src/common/transformations/src/transformations/common_optimizations/broadcast_elementwise_fusion.cpp index 875f34a74c7..dbdcd38803f 100644 --- a/src/common/transformations/src/transformations/common_optimizations/broadcast_elementwise_fusion.cpp +++ b/src/common/transformations/src/transformations/common_optimizations/broadcast_elementwise_fusion.cpp @@ -54,13 +54,13 @@ bool can_eliminate_broadcast(const ngraph::Output& eltwise, // input_shape will be broadcast return false; } - } else if (input_shape[i_dim].is_dynamic() && broadcast_shape[i_dim].is_static() && - broadcast_shape[i_dim].get_length() != 1) { + } else if (input_shape[i_dim].is_dynamic() && broadcast_shape[b_dim].is_static() && + broadcast_shape[b_dim].get_length() != 1) { return false; - } else if (broadcast_shape[i_dim].is_dynamic() && input_shape[i_dim].is_static() && + } else if (broadcast_shape[b_dim].is_dynamic() && input_shape[i_dim].is_static() && input_shape[i_dim].get_length() == 1) { return false; - } else if (broadcast_shape[i_dim].is_dynamic() && input_shape[i_dim].is_dynamic()) { + } else if (broadcast_shape[b_dim].is_dynamic() && input_shape[i_dim].is_dynamic()) { return false; } } diff --git a/src/common/transformations/src/transformations/common_optimizations/broadcast_transition.cpp b/src/common/transformations/src/transformations/common_optimizations/broadcast_transition.cpp new file mode 100644 index 00000000000..5e2a51dc350 --- /dev/null +++ b/src/common/transformations/src/transformations/common_optimizations/broadcast_transition.cpp @@ -0,0 +1,87 @@ +// Copyright (C) 2018-2023 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include "transformations/common_optimizations/broadcast_transition.hpp" + +#include +#include +#include +#include +#include +#include + +#include "itt.hpp" +#include "transformations/utils/utils.hpp" + +ov::pass::BroadcastTransition::BroadcastTransition() { + MATCHER_SCOPE(BroadcastTransition); + auto bcast_m = pass::pattern::wrap_type(pass::pattern::consumers_count(1)); + auto eltwise_input_m = pass::pattern::any_input(pass::pattern::has_static_rank()); + auto eltwise_1 = pass::pattern::wrap_type({eltwise_input_m, bcast_m}); + auto eltwise_2 = pass::pattern::wrap_type({bcast_m, eltwise_input_m}); + auto eltwise_m = std::make_shared(OutputVector{eltwise_1, eltwise_2}); + + ov::matcher_pass_callback callback = [=](ov::pass::pattern::Matcher& m) { + const auto& pattern_map = m.get_pattern_value_map(); + const auto eltwise = ov::as_type_ptr(m.get_match_root()); + if (eltwise->get_autob().m_type != ov::op::AutoBroadcastType::NUMPY) { + return false; + } + + const auto bcast = ov::as_type_ptr(pattern_map.at(bcast_m).get_node_shared_ptr()); + const auto& bcast_type = bcast->get_broadcast_spec().m_type; + if (bcast_type != ov::op::BroadcastType::NUMPY && bcast_type != ov::op::BroadcastType::BIDIRECTIONAL) { + return false; + } + + const auto& eltwise_input = pattern_map.at(eltwise_input_m); + const auto& bcast_data = bcast->input_value(0); + // inputs order mustn't be changed because an eltwise might be not commutative + ov::OutputVector new_inputs{ + eltwise->get_input_node_ptr(0) == eltwise_input.get_node() ? eltwise_input : bcast_data, + eltwise->get_input_node_ptr(1) == bcast.get() ? bcast_data : eltwise_input}; + const auto new_eltwise = eltwise->clone_with_new_inputs(new_inputs); + ov::copy_runtime_info(eltwise, new_eltwise); + + auto target_shape = bcast->input_value(1); + const auto& target_shape_et = target_shape.get_element_type(); + + std::shared_ptr data_shape_path; + if (target_shape_et == ov::element::i32 || target_shape_et == ov::element::i64) { + data_shape_path = ov::op::util::make_try_fold(new_eltwise, target_shape_et); + ov::copy_runtime_info(eltwise, data_shape_path); + } else { + auto shapeof = ov::op::util::make_try_fold(new_eltwise); + data_shape_path = ov::op::util::make_try_fold(shapeof, target_shape_et); + ov::copy_runtime_info(eltwise, {shapeof, data_shape_path}); + } + + const size_t target_shape_rank = target_shape.get_partial_shape()[0].get_length(); + const size_t input_rank = new_eltwise->get_output_partial_shape(0).size(); + if (input_rank != target_shape_rank) { + auto align_rank = [&](const ov::Output& out, const size_t count) { + const auto constant = ov::opset10::Constant::create(target_shape_et, {count}, {1}); + const auto res = ov::op::util::make_try_fold(ov::OutputVector{constant, out}, 0); + ov::copy_runtime_info(out.get_node_shared_ptr(), {constant, res}); + return res; + }; + if (input_rank < target_shape_rank) { + data_shape_path = align_rank(data_shape_path, target_shape_rank - input_rank); + } else { + target_shape = align_rank(target_shape, input_rank - target_shape_rank); + } + } + const auto new_target_shape = ov::op::util::make_try_fold(data_shape_path, target_shape); + ov::copy_runtime_info(eltwise, new_target_shape); + + const auto new_bcast = std::make_shared(new_eltwise, new_target_shape); + new_bcast->set_friendly_name(eltwise->get_friendly_name()); + ov::copy_runtime_info(eltwise, {new_eltwise, new_bcast}); + ov::replace_node(eltwise, new_bcast); + return true; + }; + + auto m = std::make_shared(eltwise_m, matcher_name); + register_matcher(m, callback); +} diff --git a/src/common/transformations/tests/common_optimizations/broadcast_elementwise_fusion_test.cpp b/src/common/transformations/tests/common_optimizations/broadcast_elementwise_fusion_test.cpp index ced067a80f4..7dc2c42a5ac 100644 --- a/src/common/transformations/tests/common_optimizations/broadcast_elementwise_fusion_test.cpp +++ b/src/common/transformations/tests/common_optimizations/broadcast_elementwise_fusion_test.cpp @@ -333,3 +333,16 @@ TEST_F(TransformationTestsF, BroadcastElementwiseFusionWithShapeOfNeg) { manager.register_pass(); } } + +TEST_F(TransformationTestsF, BroadcastElementwiseFusionDynShapesDifferentRanks) { + { + auto input = std::make_shared(ov::element::f32, ov::PartialShape{-1, -1, -1, -1}); + auto target_shape = std::make_shared(ov::element::i32, ov::PartialShape{2}); + auto constant = ngraph::opset5::Constant::create(ov::element::f32, {}, {1.f}); + auto broadcast = std::make_shared(constant, target_shape); + auto elementwise = std::make_shared(input, broadcast); + function = std::make_shared(ov::NodeVector{elementwise}, ov::ParameterVector{input, target_shape}); + + manager.register_pass(); + } +} diff --git a/src/common/transformations/tests/common_optimizations/broadcast_transition_test.cpp b/src/common/transformations/tests/common_optimizations/broadcast_transition_test.cpp new file mode 100644 index 00000000000..f2b39c54331 --- /dev/null +++ b/src/common/transformations/tests/common_optimizations/broadcast_transition_test.cpp @@ -0,0 +1,326 @@ +// Copyright (C) 2018-2023 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include + +#include +#include +#include +#include + +#include "common_test_utils/ngraph_test_utils.hpp" + +using namespace testing; + +std::shared_ptr getOperation( + const ov::Output& in1, + const ov::Output& in2, + const std::string& operation_type, + const ov::op::AutoBroadcastType& eltwise_bcast_type = ov::op::AutoBroadcastType::NUMPY) { + if (operation_type == "Add") { + return std::make_shared(in1, in2, eltwise_bcast_type); + } else if (operation_type == "Multiply") { + return std::make_shared(in1, in2, eltwise_bcast_type); + } else if (operation_type == "Subtract") { + return std::make_shared(in1, in2, eltwise_bcast_type); + } else { + throw std::runtime_error("Unexpected operation type"); + } +} + +std::shared_ptr getOriginal( + const ov::element::Type& precision, + const ov::PartialShape& input_shape, + const ov::Shape& target_shape, + const ov::op::BroadcastType& bcast_mode, + const std::string& operation_type, + const size_t idx, + const ov::op::AutoBroadcastType& eltwise_bcast_type = ov::op::AutoBroadcastType::NUMPY) { + const auto input = std::make_shared(precision, input_shape); + const auto data_constant = ov::opset10::Constant::create(precision, {}, {1.f}); + const auto target_shape_node = ov::opset10::Constant::create(ov::element::i32, {target_shape.size()}, target_shape); + const auto bcast = std::make_shared(data_constant, target_shape_node, bcast_mode); + + const auto fst_in = idx == 0 ? bcast->output(0) : input->output(0); + const auto sec_in = idx == 1 ? bcast->output(0) : input->output(0); + const auto operation = getOperation(fst_in, sec_in, operation_type, eltwise_bcast_type); + return std::make_shared(operation, ov::ParameterVector{input}); +} + +std::shared_ptr getReference(const ov::element::Type& precision, + const ov::PartialShape& input_shape, + const ov::Shape& original_target_shape, + const std::string& operation_type, + const size_t idx) { + const auto input = std::make_shared(precision, input_shape); + const auto data_constant = ov::opset10::Constant::create(precision, {}, {1.f}); + + const auto fst_in = idx == 0 ? data_constant->output(0) : input->output(0); + const auto sec_in = idx == 1 ? data_constant->output(0) : input->output(0); + const auto operation = getOperation(fst_in, sec_in, operation_type, ov::op::AutoBroadcastType::NUMPY); + + const auto target_shape = [&]() { + auto new_shape = original_target_shape; + auto op_shape = operation->get_shape(); + while (new_shape.size() < op_shape.size()) + new_shape.insert(new_shape.begin(), 1); + while (op_shape.size() < new_shape.size()) + op_shape.insert(op_shape.begin(), 1); + + for (size_t i = 0; i < new_shape.size(); ++i) { + new_shape[i] = std::max(new_shape[i], op_shape[i]); + } + return new_shape; + }(); + + const auto target_shape_node = ov::opset10::Constant::create(ov::element::i32, {target_shape.size()}, target_shape); + const auto bcast = std::make_shared(operation, target_shape_node); + return std::make_shared(bcast, ov::ParameterVector{input}); +} + +using BroadcastTransitionParams = std::tuple; + +class StaticBroadcastTransitionTests : public testing::WithParamInterface, + public TransformationTestsF { +public: + StaticBroadcastTransitionTests() : TransformationTestsF() { + comparator.enable(FunctionsComparator::ATTRIBUTES); + } + + static std::string getTestCaseName(const testing::TestParamInfo& obj) { + ov::element::Type precision; + ov::Shape input_shape; + ov::Shape target_shape; + ov::op::BroadcastType bcast_mode; + std::string operation_type; + size_t idx; + std::tie(precision, input_shape, target_shape, bcast_mode, operation_type, idx) = obj.param; + + std::ostringstream result; + result << operation_type << "_prc=" << precision << "_IS=" << input_shape << "_TS=" << target_shape + << "_bcast_idx=" << idx << "_bcast_type=" << bcast_mode; + return result.str(); + } + +protected: + void SetUp() override { + TransformationTestsF::SetUp(); + ov::element::Type precision; + ov::Shape input_shape; + ov::Shape target_shape; + ov::op::BroadcastType bcast_mode; + std::string operation_type; + size_t idx; + std::tie(precision, input_shape, target_shape, bcast_mode, operation_type, idx) = GetParam(); + + manager.register_pass(); + model = getOriginal(precision, input_shape, target_shape, bcast_mode, operation_type, idx); + model_ref = getReference(precision, input_shape, target_shape, operation_type, idx); + } +}; + +TEST_P(StaticBroadcastTransitionTests, BroadcastTransition) {} + +namespace BroadcastTransitionTestsInstantiation { +std::vector input_shapes = { + {1, 3, 16, 16}, + {1, 3, 1, 16}, + {16, 16}, +}; + +std::vector target_shapes = { + {1, 3, 16, 1}, + {16, 16}, +}; + +std::vector bcast_modes = {ov::op::BroadcastType::NUMPY, ov::op::BroadcastType::BIDIRECTIONAL}; + +std::vector operation_types = {"Add", "Multiply", "Subtract"}; +std::vector bcast_input_idx = {0, 1}; + +INSTANTIATE_TEST_SUITE_P(TransformationTestsF, + StaticBroadcastTransitionTests, + ::testing::Combine(::testing::Values(ov::element::f32), + ::testing::ValuesIn(input_shapes), + ::testing::ValuesIn(target_shapes), + ::testing::ValuesIn(bcast_modes), + ::testing::ValuesIn(operation_types), + ::testing::ValuesIn(bcast_input_idx)), + StaticBroadcastTransitionTests::getTestCaseName); +} // namespace BroadcastTransitionTestsInstantiation + +TEST_F(TransformationTestsF, BroadcastTransitionTests_Dynamic_U32TargetShapePrecision) { + const auto data_precision = ov::element::f32; + const auto shape_precision = ov::element::u32; + { + const auto input = std::make_shared(data_precision, ov::PartialShape::dynamic(4)); + const auto target_shape = std::make_shared(shape_precision, ov::PartialShape{4}); + + const auto data_constant = ov::opset10::Constant::create(data_precision, {}, {1.f}); + const auto bcast = std::make_shared(data_constant, target_shape); + const auto operation = getOperation(input, bcast, "Add"); + model = std::make_shared(operation, ov::ParameterVector{input, target_shape}); + } + manager.register_pass(); + { + const auto input = std::make_shared(data_precision, ov::PartialShape::dynamic(4)); + const auto target_shape = std::make_shared(shape_precision, ov::PartialShape{4}); + + const auto data_constant = ov::opset10::Constant::create(data_precision, {}, {1.f}); + const auto operation = getOperation(input, data_constant, "Add"); + const auto shapeof = std::make_shared(operation); + const auto convert = std::make_shared(shapeof, shape_precision); + const auto max = std::make_shared(convert, target_shape); + const auto bcast = std::make_shared(operation, max); + model_ref = std::make_shared(bcast, ov::ParameterVector{input, target_shape}); + } +} + +TEST_F(TransformationTestsF, BroadcastTransitionTests_Dynamic_EqualRanks) { + const auto data_precision = ov::element::f32; + const auto shape_precision = ov::element::i32; + { + const auto input = std::make_shared(data_precision, ov::PartialShape::dynamic(4)); + const auto target_shape = std::make_shared(shape_precision, ov::PartialShape{4}); + + const auto data_constant = ov::opset10::Constant::create(data_precision, {}, {1.f}); + const auto bcast = std::make_shared(data_constant, target_shape); + const auto operation = getOperation(input, bcast, "Add"); + model = std::make_shared(operation, ov::ParameterVector{input, target_shape}); + } + manager.register_pass(); + { + const auto input = std::make_shared(data_precision, ov::PartialShape::dynamic(4)); + const auto target_shape = std::make_shared(shape_precision, ov::PartialShape{4}); + + const auto data_constant = ov::opset10::Constant::create(data_precision, {}, {1.f}); + const auto operation = getOperation(input, data_constant, "Add"); + const auto shapeof = std::make_shared(operation, shape_precision); + const auto max = std::make_shared(shapeof, target_shape); + const auto bcast = std::make_shared(operation, max); + model_ref = std::make_shared(bcast, ov::ParameterVector{input, target_shape}); + } +} + +TEST_F(TransformationTestsF, BroadcastTransitionTests_Dynamic_DataRankLessThanTarget) { + const auto data_precision = ov::element::f32; + const auto shape_precision = ov::element::i32; + { + const auto input = std::make_shared(data_precision, ov::PartialShape::dynamic(2)); + const auto target_shape = std::make_shared(shape_precision, ov::PartialShape{4}); + + const auto data_constant = ov::opset10::Constant::create(data_precision, {}, {1.f}); + const auto bcast = std::make_shared(data_constant, target_shape); + const auto operation = getOperation(input, bcast, "Add"); + model = std::make_shared(operation, ov::ParameterVector{input, target_shape}); + } + manager.register_pass(); + { + const auto input = std::make_shared(data_precision, ov::PartialShape::dynamic(2)); + const auto target_shape = std::make_shared(shape_precision, ov::PartialShape{4}); + + const auto data_constant = ov::opset10::Constant::create(data_precision, {}, {1.f}); + const auto operation = getOperation(input, data_constant, "Add"); + const auto shapeof = std::make_shared(operation, shape_precision); + const auto constant = ov::opset10::Constant::create(shape_precision, {2}, {1}); + const auto concat = std::make_shared(ov::OutputVector{constant, shapeof}, 0); + const auto max = std::make_shared(concat, target_shape); + const auto bcast = std::make_shared(operation, max); + model_ref = std::make_shared(bcast, ov::ParameterVector{input, target_shape}); + } +} + +TEST_F(TransformationTestsF, BroadcastTransitionTests_Dynamic_DataRankGreaterThanTarget) { + const auto data_precision = ov::element::f32; + const auto shape_precision = ov::element::i32; + { + const auto input = std::make_shared(data_precision, ov::PartialShape::dynamic(4)); + const auto target_shape = std::make_shared(shape_precision, ov::PartialShape{2}); + + const auto data_constant = ov::opset10::Constant::create(data_precision, {}, {1.f}); + const auto bcast = std::make_shared(data_constant, target_shape); + const auto operation = getOperation(input, bcast, "Add"); + model = std::make_shared(operation, ov::ParameterVector{input, target_shape}); + } + manager.register_pass(); + { + const auto input = std::make_shared(data_precision, ov::PartialShape::dynamic(4)); + const auto target_shape = std::make_shared(shape_precision, ov::PartialShape{2}); + + const auto data_constant = ov::opset10::Constant::create(data_precision, {}, {1.f}); + const auto operation = getOperation(input, data_constant, "Add"); + const auto shapeof = std::make_shared(operation, shape_precision); + const auto constant = ov::opset10::Constant::create(shape_precision, {2}, {1}); + const auto concat = std::make_shared(ov::OutputVector{constant, target_shape}, 0); + const auto max = std::make_shared(shapeof, concat); + const auto bcast = std::make_shared(operation, max); + model_ref = std::make_shared(bcast, ov::ParameterVector{input, target_shape}); + } +} + +TEST_F(TransformationTestsF, BroadcastTransitionTests_Negative_ExplicitEltwiseBcast) { + model = getOriginal(ov::element::f32, + ov::PartialShape{1, 3, 16, 16}, + ov::Shape{1, 3, 16, 16}, + ov::op::BroadcastType::NUMPY, + "Add", + 0, + ov::op::AutoBroadcastType::EXPLICIT); + manager.register_pass(); +} + +TEST_F(TransformationTestsF, BroadcastTransitionTests_Negative_PDPDEltwiseBcast) { + model = getOriginal(ov::element::f32, + ov::PartialShape{1, 3, 16, 16}, + ov::Shape{1, 3, 16, 16}, + ov::op::BroadcastType::NUMPY, + "Add", + 0, + ov::op::AutoBroadcastType::PDPD); + manager.register_pass(); +} + +TEST_F(TransformationTestsF, BroadcastTransitionTests_Negative_PDPDBcastType) { + const auto input = std::make_shared(ov::element::f32, ov::PartialShape{1, 3, 16, 16}); + + const auto data_constant = ov::opset10::Constant::create(ov::element::f32, {1, 1, 1}, {1.f}); + const auto target_shape_node = ov::opset10::Constant::create(ov::element::i32, {3}, {1, 16, 16}); + const ov::op::BroadcastModeSpec pdpd_spec(ov::op::BroadcastType::PDPD); + const auto bcast = std::make_shared(data_constant, target_shape_node, pdpd_spec); + const auto add = std::make_shared(input, bcast); + + model = std::make_shared(add, ov::ParameterVector{input}); + manager.register_pass(); +} + +TEST_F(TransformationTestsF, BroadcastTransitionTests_Negative_WithAxesMapping) { + const auto input = std::make_shared(ov::element::f32, ov::PartialShape{1, 3, 16, 16}); + const auto data_constant = ov::opset10::Constant::create(ov::element::f32, {16, 16}, {1.f}); + + const auto target_shape_node = ov::opset10::Constant::create(ov::element::i32, {3}, {1, 16, 16}); + const auto axes_node = ov::opset10::Constant::create(ov::element::i32, {2}, {1, 2}); + const auto bcast = std::make_shared(data_constant, target_shape_node, axes_node); + const auto add = std::make_shared(input, bcast); + + model = std::make_shared(add, ov::ParameterVector{input}); + manager.register_pass(); +} + +TEST_F(TransformationTestsF, BroadcastTransitionTests_Negative_DynamicRank) { + const auto input = std::make_shared(ov::element::f32, ov::PartialShape::dynamic()); + const auto data_constant = ov::opset10::Constant::create(ov::element::f32, {}, {1.f}); + + const auto target_shape_input = std::make_shared(ov::element::i32, ov::PartialShape{-1}); + const auto bcast = std::make_shared(data_constant, target_shape_input); + const auto add = std::make_shared(input, bcast); + + model = std::make_shared(add, ov::ParameterVector{input, target_shape_input}); + manager.register_pass(); +} diff --git a/src/core/tests/type_prop/broadcast.cpp b/src/core/tests/type_prop/broadcast.cpp index c218ebdc8e3..5b26ef60a41 100644 --- a/src/core/tests/type_prop/broadcast.cpp +++ b/src/core/tests/type_prop/broadcast.cpp @@ -181,6 +181,20 @@ TYPED_TEST_P(BroadcastTests, broadcast_axes_wrong_rank) { } } +TYPED_TEST_P(BroadcastTests, broadcast_target_shape_wrong_rank) { + auto arg = make_shared(element::f32, Shape{2, 4}); + auto bc_shape = make_shared(element::i64, Shape{}); + + try { + auto bc = make_shared(arg, bc_shape); + FAIL() << "Broadcast: axes target shape rank not detected"; + } catch (const NodeValidationFailure& error) { + EXPECT_HAS_SUBSTRING(error.what(), "Broadcast shape rank must be 1, but has"); + } catch (...) { + FAIL() << "Deduced type check failed for unexpected reason"; + } +} + TYPED_TEST_P(BroadcastTests, broadcast_fully_dynamic_target_shape) { auto arg = make_shared(element::f32, Shape{2, 4}); auto bc_shape = make_shared(element::i64, PartialShape::dynamic()); @@ -559,6 +573,7 @@ REGISTER_TYPED_TEST_SUITE_P(BroadcastTests, broadcast_fail_axes_map, broadcast_fail_axes_map_shape, broadcast_axes_wrong_rank, + broadcast_target_shape_wrong_rank, broadcast_fully_dynamic_target_shape, broadcast_dynamic_values_of_target_shape, broadcast_broadcast_shape_et_wrong, diff --git a/src/plugins/intel_cpu/src/nodes/common/tile_broadcast_utils.cpp b/src/plugins/intel_cpu/src/nodes/common/tile_broadcast_utils.cpp index 74bb34309e5..ef5874567e8 100644 --- a/src/plugins/intel_cpu/src/nodes/common/tile_broadcast_utils.cpp +++ b/src/plugins/intel_cpu/src/nodes/common/tile_broadcast_utils.cpp @@ -4,6 +4,7 @@ #include "tile_broadcast_utils.h" +#include "cpu_convert.h" #include "cpu_memcpy.h" #include "ie_parallel.hpp" #include @@ -250,7 +251,10 @@ void TileBroadcastCommon::optimizedExecute(const MemoryPtr& srcMemory, const Mem auto srcData = reinterpret_cast(srcMemory->GetPtr()); auto dstData = reinterpret_cast(dstMemory->GetPtr()); - if (optimizedParams.srcStrides[5] == 0) { + if (srcMemory->getStaticDims() == dstMemory->getStaticDims()) { + const auto prc = dstMemory->getDesc().getPrecision(); + cpu_convert(srcData, dstData, prc, prc, optimizedParams.copySize / prc.size()); + } else if (optimizedParams.srcStrides[5] == 0) { if (optimizedParams.dstStrides[0] == optimizedParams.dims[5] * optimizedParams.dstStrides[5]) { size_t data_size = optimizedParams.dstStrides[5]; size_t elt_cnt = optimizedParams.dims[5]; diff --git a/src/plugins/intel_cpu/src/transformations/transformation_pipeline.cpp b/src/plugins/intel_cpu/src/transformations/transformation_pipeline.cpp index 16146b31744..3f976007f00 100644 --- a/src/plugins/intel_cpu/src/transformations/transformation_pipeline.cpp +++ b/src/plugins/intel_cpu/src/transformations/transformation_pipeline.cpp @@ -18,6 +18,7 @@ // Common transformations #include "transformations/common_optimizations/add_fake_quantize_fusion.hpp" +#include "transformations/common_optimizations/broadcast_transition.hpp" #include "transformations/common_optimizations/convert_compression_only_to_legacy.hpp" #include "transformations/common_optimizations/convert_quantize_dequantize.hpp" #include "transformations/common_optimizations/fq_mul_fusion.hpp" @@ -225,6 +226,7 @@ void Transformations::PreLpt(const std::vector& defaultPrecis type_to_fuse_map type_to_fuse = {{ov::opset10::Convert::get_type_info_static(), fuse_type_to_convert}}; CPU_REGISTER_PASS_COMMON(manager, ov::pass::AUGRUCellFusion); + CPU_REGISTER_PASS_COMMON(manager, ov::pass::BroadcastTransition); CPU_REGISTER_PASS_COMMON(manager, ov::pass::CommonOptimizations); CPU_REGISTER_PASS_COMMON(manager, ov::pass::WrapInterpolateIntoTransposes); CPU_REGISTER_PASS_COMMON(manager, ov::pass::TransposeSinking); diff --git a/src/plugins/intel_cpu/tests/functional/subgraph_tests/src/broadcast_eltwise.cpp b/src/plugins/intel_cpu/tests/functional/subgraph_tests/src/broadcast_eltwise.cpp new file mode 100644 index 00000000000..56644e91526 --- /dev/null +++ b/src/plugins/intel_cpu/tests/functional/subgraph_tests/src/broadcast_eltwise.cpp @@ -0,0 +1,108 @@ +// Copyright (C) 2023 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include +#include + +#include "ngraph_functions/builders.hpp" +#include "ngraph_functions/utils/ngraph_helpers.hpp" +#include "shared_test_classes/base/layer_test_utils.hpp" +#include "shared_test_classes/base/ov_subgraph.hpp" +#include "test_utils/cpu_test_utils.hpp" + +using namespace ngraph; +using namespace ov::test; +using namespace CPUTestUtils; +using namespace InferenceEngine; + +namespace SubgraphTestsDefinitions { +using BroadcastEltwiseParams = std::tuple< + ElementType, // input precision + InputShape, // input shape + ov::Shape // target broadcast shape +>; + +class BroadcastEltwise : virtual public SubgraphBaseTest, + public CPUTestsBase, + public testing::WithParamInterface { +public: + static std::string getTestCaseName(const testing::TestParamInfo& obj) { + ElementType input_precision; + InputShape input_shape; + ov::Shape target_shape; + std::tie(input_precision, input_shape, target_shape) = obj.param; + + std::ostringstream result; + result << "precision=" << input_precision << "IS=(" << CommonTestUtils::partialShape2str({input_shape.first}) << ")_TS=("; + for (const auto& item : input_shape.second) { + result << CommonTestUtils::vec2str(item) << "_"; + } + result << ")_target_shape=" << CommonTestUtils::vec2str(target_shape); + return result.str(); + } + +protected: + void SetUp() override { + ElementType input_precision; + InputShape input_shape; + std::tie(input_precision, input_shape, target_shape) = GetParam(); + targetDevice = CommonTestUtils::DEVICE_CPU; + + std::vector input_shapes{input_shape, {{}, {{target_shape.size()}}}}; + init_input_shapes(input_shapes); + + ov::element::TypeVector input_precisions{input_precision, ov::element::i64}; + const auto params = ngraph::builder::makeDynamicParams(input_precisions, inputDynamicShapes); + const auto bcast_data = ov::opset10::Constant::create(input_precision, {}, {1.f}); + const auto bcast = std::make_shared(bcast_data, params[1]); + const auto add = std::make_shared(params[0], bcast); + function = std::make_shared(add, params); + } + + void generate_inputs(const std::vector& targetInputStaticShapes) override { + inputs.clear(); + const auto& funcInputs = function->inputs(); + auto data_tensor = ov::test::utils::create_and_fill_tensor(funcInputs[0].get_element_type(), targetInputStaticShapes[0]); + inputs.insert({funcInputs[0].get_node_shared_ptr(), data_tensor}); + + auto shape_tensor = ov::Tensor{ov::element::i64, targetInputStaticShapes[1]}; + auto data = shape_tensor.data::value_type>(); + for (size_t i = 0; i < target_shape.size(); i++) { + data[i] = target_shape[i]; + } + inputs.insert({funcInputs[1].get_node_shared_ptr(), shape_tensor}); + } + + ov::Shape target_shape; +}; + +TEST_P(BroadcastEltwise, smoke_CompareWithRefs) { + run(); + + const auto model = compiledModel.get_runtime_model(); + const auto last_node = model->get_result()->get_input_node_shared_ptr(0); + const auto& rt_info = last_node->get_rt_info(); + const auto layerType = rt_info.find("layerType")->second.as(); + EXPECT_EQ(layerType, "Broadcast"); +} + +namespace { +const std::vector input_shapes = { + {{-1, -1, -1, -1}, {{1, 3, 16, 16}}}, + {{-1, -1}, {{16, 16}}}, +}; + +const std::vector target_shapes = { + {1, 3, 16, 1}, + {16, 16}, +}; + +INSTANTIATE_TEST_SUITE_P(smoke_BroadcastEltwise, + BroadcastEltwise, + ::testing::Combine(::testing::Values(ov::element::f32), + ::testing::ValuesIn(input_shapes), + ::testing::ValuesIn(target_shapes)), + BroadcastEltwise::getTestCaseName); +} // namespace +} // namespace SubgraphTestsDefinitions \ No newline at end of file