Move BroadcastTransition to MOC (#19543)

* Move BroadcastTransition to MOC

Broadcast that could be eliminated by BroadcastElementwiseFusion are moved down the graph
(by BroadcastTransition registered in the plugins). That prevents BroadcastElementwiseFusion
to eliminate them.

Ticket: CVS-118899

* dont count const layers

* remove virtual inheritance
This commit is contained in:
Mateusz Tabaka 2023-09-08 09:05:54 +02:00 committed by GitHub
parent e2b553302b
commit a55b5381d3
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 117 additions and 5 deletions

View File

@ -15,6 +15,7 @@
#include "transformations/common_optimizations/batch_to_space_fusion.hpp"
#include "transformations/common_optimizations/binarize_weights.hpp"
#include "transformations/common_optimizations/broadcast_elementwise_fusion.hpp"
#include "transformations/common_optimizations/broadcast_transition.hpp"
#include "transformations/common_optimizations/clamp_fusion.hpp"
#include "transformations/common_optimizations/conv_mul_fusion.hpp"
#include "transformations/common_optimizations/conv_to_binary_conv.hpp"
@ -155,6 +156,7 @@ bool ov::pass::MOCTransformations::run_on_model(const std::shared_ptr<ov::Model>
REGISTER_PASS(manager, ConvertNmsGatherPathToUnsigned)
REGISTER_PASS(manager, StridedSliceOptimization, m_use_shapes)
REGISTER_PASS(manager, BroadcastElementwiseFusion)
REGISTER_PASS(manager, BroadcastTransition)
REGISTER_PASS(manager, PullThroughReduce)
// GRUCellFusion and SequenceFusion should be before NopElimination

View File

@ -19,7 +19,6 @@
// Common transformations
#include "transformations/common_optimizations/mark_precision_sensitive_shapeof_subgraphs.hpp"
#include "transformations/common_optimizations/add_fake_quantize_fusion.hpp"
#include "transformations/common_optimizations/broadcast_transition.hpp"
#include "transformations/fp16_compression/convert_compression_only_to_legacy.hpp"
#include "transformations/common_optimizations/convert_quantize_dequantize.hpp"
#include "transformations/common_optimizations/fq_mul_fusion.hpp"
@ -258,7 +257,6 @@ void Transformations::PreLpt(const std::vector<ov::element::Type>& defaultPrecis
type_to_fuse_map type_to_fuse = {{ov::opset10::Convert::get_type_info_static(), fuse_type_to_convert}};
CPU_REGISTER_PASS_COMMON(manager, ov::pass::AUGRUCellFusion);
CPU_REGISTER_PASS_COMMON(manager, ov::pass::BroadcastTransition);
CPU_REGISTER_PASS_COMMON(manager, ov::pass::CommonOptimizations);
CPU_REGISTER_PASS_COMMON(manager, ov::pass::WrapInterpolateIntoTransposes);
CPU_REGISTER_PASS_COMMON(manager, ov::pass::TransposeSinking);

View File

@ -0,0 +1,14 @@
// Copyright (C) 2023 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include "subgraph_tests/broadcast_eltwise_eliminated.hpp"
using namespace ov::test;
namespace {
INSTANTIATE_TEST_SUITE_P(smoke_BroadcastEltwise, BroadcastEltwiseEliminated,
::testing::Values(ov::test::utils::DEVICE_CPU),
BroadcastEltwiseEliminated::getTestCaseName);
} // namespace

View File

@ -54,7 +54,6 @@
#include "transformations/common_optimizations/wrap_interpolate_into_transposes.hpp"
#include "transformations/common_optimizations/transpose_sinking.hpp"
#include "transformations/common_optimizations/softmax_fusion.hpp"
#include "transformations/common_optimizations/broadcast_transition.hpp"
#include "transformations/common_optimizations/mvn_fusion.hpp"
#include "transformations/op_conversions/convert_depth_to_space.hpp"
@ -213,7 +212,6 @@ void TransformationsPipeline::apply(std::shared_ptr<ov::Model> func) {
manager.register_pass<ov::pass::MVNFusion>();
// decompose MVNs that sre not supported in GPU, so that they will be marked as precision sensitive in ConvertPrecision
manager.register_pass<ov::pass::MVN6Decomposition>();
manager.register_pass<ov::pass::BroadcastTransition>();
const bool keep_precision_sensitive_in_fp32_1 = true;
manager.register_pass<ov::pass::ConvertPrecision>(fp_convert_precision_map,

View File

@ -0,0 +1,14 @@
// Copyright (C) 2023 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include "subgraph_tests/broadcast_eltwise_eliminated.hpp"
using namespace ov::test;
namespace {
INSTANTIATE_TEST_SUITE_P(smoke_BroadcastEltwise, BroadcastEltwiseEliminated,
::testing::Values(ov::test::utils::DEVICE_GPU),
BroadcastEltwiseEliminated::getTestCaseName);
} // namespace

View File

@ -0,0 +1,17 @@
// Copyright (C) 2023 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#pragma once
#include "shared_test_classes/subgraph/broadcast_eltwise_eliminated.hpp"
namespace ov {
namespace test {
TEST_P(BroadcastEltwiseEliminated, CompareWithRefs){
run();
};
} // namespace test
} // namespace ov

View File

@ -0,0 +1,23 @@
// Copyright (C) 2023 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#pragma once
#include "shared_test_classes/base/ov_subgraph.hpp"
namespace ov {
namespace test {
class BroadcastEltwiseEliminated : public testing::WithParamInterface<const char*>,
public ov::test::SubgraphBaseTest {
public:
static std::string getTestCaseName(const testing::TestParamInfo<const char*> &obj);
protected:
void SetUp() override;
void TearDown() override;
};
} // namespace test
} // namespace ov

View File

@ -244,12 +244,13 @@ void SubgraphBaseTest::generate_inputs(const std::vector<ov::Shape>& targetInput
ASSERT_NE(it, inputMap.end());
for (size_t port = 0; port < nodePtr->get_input_size(); ++port) {
if (nodePtr->get_input_node_ptr(port)->shared_from_this() == inputNode->shared_from_this()) {
inputs.insert({param, it->second(nodePtr, port, param->get_element_type(), *itTargetShape++)});
inputs.insert({param, it->second(nodePtr, port, param->get_element_type(), *itTargetShape)});
break;
}
}
}
}
itTargetShape++;
}
}

View File

@ -0,0 +1,45 @@
// Copyright (C) 2023 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include "shared_test_classes/subgraph/broadcast_eltwise_eliminated.hpp"
namespace ov {
namespace test {
std::string BroadcastEltwiseEliminated::getTestCaseName(const testing::TestParamInfo<const char*> &obj) {
return "device=" + std::string(obj.param);
}
void BroadcastEltwiseEliminated::SetUp() {
targetDevice = GetParam();
ov::PartialShape shape{-1, 3, 10, 10};
InputShape input_shape = {shape, {Shape{1, 3, 10, 10}}};
init_input_shapes({input_shape});
const auto param = std::make_shared<ov::op::v0::Parameter>(ov::element::f32, shape);
const auto shapeof = std::make_shared<ov::op::v3::ShapeOf>(param);
const auto constant = op::v0::Constant::create(element::f32, Shape{1}, {9});
const auto bcast = std::make_shared<ov::op::v3::Broadcast>(constant, shapeof);
const auto mul = std::make_shared<ov::op::v1::Multiply>(param, bcast);
function = std::make_shared<ov::Model>(mul, ov::ParameterVector{param});
}
void BroadcastEltwiseEliminated::TearDown() {
const auto model = compiledModel.get_runtime_model();
int num_ops = 0;
for (const auto& node : model->get_ordered_ops()) {
const auto& rt_info = node->get_rt_info();
const auto layer_type = rt_info.find("layerType")->second.as<std::string>();
if (layer_type != "Reorder" && layer_type != "Const")
num_ops++;
EXPECT_NE(layer_type, "Broadcast");
}
ASSERT_EQ(num_ops, 3); // one Input, one Eltwise and one Output
}
} // namespace test
} // namespace ov