Move BroadcastTransition to MOC (#19543)
* Move BroadcastTransition to MOC Broadcast that could be eliminated by BroadcastElementwiseFusion are moved down the graph (by BroadcastTransition registered in the plugins). That prevents BroadcastElementwiseFusion to eliminate them. Ticket: CVS-118899 * dont count const layers * remove virtual inheritance
This commit is contained in:
parent
e2b553302b
commit
a55b5381d3
@ -15,6 +15,7 @@
|
||||
#include "transformations/common_optimizations/batch_to_space_fusion.hpp"
|
||||
#include "transformations/common_optimizations/binarize_weights.hpp"
|
||||
#include "transformations/common_optimizations/broadcast_elementwise_fusion.hpp"
|
||||
#include "transformations/common_optimizations/broadcast_transition.hpp"
|
||||
#include "transformations/common_optimizations/clamp_fusion.hpp"
|
||||
#include "transformations/common_optimizations/conv_mul_fusion.hpp"
|
||||
#include "transformations/common_optimizations/conv_to_binary_conv.hpp"
|
||||
@ -155,6 +156,7 @@ bool ov::pass::MOCTransformations::run_on_model(const std::shared_ptr<ov::Model>
|
||||
REGISTER_PASS(manager, ConvertNmsGatherPathToUnsigned)
|
||||
REGISTER_PASS(manager, StridedSliceOptimization, m_use_shapes)
|
||||
REGISTER_PASS(manager, BroadcastElementwiseFusion)
|
||||
REGISTER_PASS(manager, BroadcastTransition)
|
||||
REGISTER_PASS(manager, PullThroughReduce)
|
||||
|
||||
// GRUCellFusion and SequenceFusion should be before NopElimination
|
||||
|
@ -19,7 +19,6 @@
|
||||
// Common transformations
|
||||
#include "transformations/common_optimizations/mark_precision_sensitive_shapeof_subgraphs.hpp"
|
||||
#include "transformations/common_optimizations/add_fake_quantize_fusion.hpp"
|
||||
#include "transformations/common_optimizations/broadcast_transition.hpp"
|
||||
#include "transformations/fp16_compression/convert_compression_only_to_legacy.hpp"
|
||||
#include "transformations/common_optimizations/convert_quantize_dequantize.hpp"
|
||||
#include "transformations/common_optimizations/fq_mul_fusion.hpp"
|
||||
@ -258,7 +257,6 @@ void Transformations::PreLpt(const std::vector<ov::element::Type>& defaultPrecis
|
||||
type_to_fuse_map type_to_fuse = {{ov::opset10::Convert::get_type_info_static(), fuse_type_to_convert}};
|
||||
|
||||
CPU_REGISTER_PASS_COMMON(manager, ov::pass::AUGRUCellFusion);
|
||||
CPU_REGISTER_PASS_COMMON(manager, ov::pass::BroadcastTransition);
|
||||
CPU_REGISTER_PASS_COMMON(manager, ov::pass::CommonOptimizations);
|
||||
CPU_REGISTER_PASS_COMMON(manager, ov::pass::WrapInterpolateIntoTransposes);
|
||||
CPU_REGISTER_PASS_COMMON(manager, ov::pass::TransposeSinking);
|
||||
|
@ -0,0 +1,14 @@
|
||||
// Copyright (C) 2023 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#include "subgraph_tests/broadcast_eltwise_eliminated.hpp"
|
||||
|
||||
using namespace ov::test;
|
||||
|
||||
namespace {
|
||||
INSTANTIATE_TEST_SUITE_P(smoke_BroadcastEltwise, BroadcastEltwiseEliminated,
|
||||
::testing::Values(ov::test::utils::DEVICE_CPU),
|
||||
BroadcastEltwiseEliminated::getTestCaseName);
|
||||
|
||||
} // namespace
|
@ -54,7 +54,6 @@
|
||||
#include "transformations/common_optimizations/wrap_interpolate_into_transposes.hpp"
|
||||
#include "transformations/common_optimizations/transpose_sinking.hpp"
|
||||
#include "transformations/common_optimizations/softmax_fusion.hpp"
|
||||
#include "transformations/common_optimizations/broadcast_transition.hpp"
|
||||
#include "transformations/common_optimizations/mvn_fusion.hpp"
|
||||
|
||||
#include "transformations/op_conversions/convert_depth_to_space.hpp"
|
||||
@ -213,7 +212,6 @@ void TransformationsPipeline::apply(std::shared_ptr<ov::Model> func) {
|
||||
manager.register_pass<ov::pass::MVNFusion>();
|
||||
// decompose MVNs that sre not supported in GPU, so that they will be marked as precision sensitive in ConvertPrecision
|
||||
manager.register_pass<ov::pass::MVN6Decomposition>();
|
||||
manager.register_pass<ov::pass::BroadcastTransition>();
|
||||
|
||||
const bool keep_precision_sensitive_in_fp32_1 = true;
|
||||
manager.register_pass<ov::pass::ConvertPrecision>(fp_convert_precision_map,
|
||||
|
@ -0,0 +1,14 @@
|
||||
// Copyright (C) 2023 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#include "subgraph_tests/broadcast_eltwise_eliminated.hpp"
|
||||
|
||||
using namespace ov::test;
|
||||
|
||||
namespace {
|
||||
INSTANTIATE_TEST_SUITE_P(smoke_BroadcastEltwise, BroadcastEltwiseEliminated,
|
||||
::testing::Values(ov::test::utils::DEVICE_GPU),
|
||||
BroadcastEltwiseEliminated::getTestCaseName);
|
||||
|
||||
} // namespace
|
@ -0,0 +1,17 @@
|
||||
// Copyright (C) 2023 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "shared_test_classes/subgraph/broadcast_eltwise_eliminated.hpp"
|
||||
|
||||
namespace ov {
|
||||
namespace test {
|
||||
|
||||
TEST_P(BroadcastEltwiseEliminated, CompareWithRefs){
|
||||
run();
|
||||
};
|
||||
|
||||
} // namespace test
|
||||
} // namespace ov
|
@ -0,0 +1,23 @@
|
||||
// Copyright (C) 2023 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "shared_test_classes/base/ov_subgraph.hpp"
|
||||
|
||||
namespace ov {
|
||||
namespace test {
|
||||
|
||||
class BroadcastEltwiseEliminated : public testing::WithParamInterface<const char*>,
|
||||
public ov::test::SubgraphBaseTest {
|
||||
public:
|
||||
static std::string getTestCaseName(const testing::TestParamInfo<const char*> &obj);
|
||||
|
||||
protected:
|
||||
void SetUp() override;
|
||||
void TearDown() override;
|
||||
};
|
||||
|
||||
} // namespace test
|
||||
} // namespace ov
|
@ -244,12 +244,13 @@ void SubgraphBaseTest::generate_inputs(const std::vector<ov::Shape>& targetInput
|
||||
ASSERT_NE(it, inputMap.end());
|
||||
for (size_t port = 0; port < nodePtr->get_input_size(); ++port) {
|
||||
if (nodePtr->get_input_node_ptr(port)->shared_from_this() == inputNode->shared_from_this()) {
|
||||
inputs.insert({param, it->second(nodePtr, port, param->get_element_type(), *itTargetShape++)});
|
||||
inputs.insert({param, it->second(nodePtr, port, param->get_element_type(), *itTargetShape)});
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
itTargetShape++;
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -0,0 +1,45 @@
|
||||
// Copyright (C) 2023 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#include "shared_test_classes/subgraph/broadcast_eltwise_eliminated.hpp"
|
||||
|
||||
namespace ov {
|
||||
namespace test {
|
||||
|
||||
std::string BroadcastEltwiseEliminated::getTestCaseName(const testing::TestParamInfo<const char*> &obj) {
|
||||
return "device=" + std::string(obj.param);
|
||||
}
|
||||
|
||||
void BroadcastEltwiseEliminated::SetUp() {
|
||||
targetDevice = GetParam();
|
||||
|
||||
ov::PartialShape shape{-1, 3, 10, 10};
|
||||
|
||||
InputShape input_shape = {shape, {Shape{1, 3, 10, 10}}};
|
||||
init_input_shapes({input_shape});
|
||||
|
||||
const auto param = std::make_shared<ov::op::v0::Parameter>(ov::element::f32, shape);
|
||||
const auto shapeof = std::make_shared<ov::op::v3::ShapeOf>(param);
|
||||
const auto constant = op::v0::Constant::create(element::f32, Shape{1}, {9});
|
||||
const auto bcast = std::make_shared<ov::op::v3::Broadcast>(constant, shapeof);
|
||||
const auto mul = std::make_shared<ov::op::v1::Multiply>(param, bcast);
|
||||
function = std::make_shared<ov::Model>(mul, ov::ParameterVector{param});
|
||||
}
|
||||
|
||||
void BroadcastEltwiseEliminated::TearDown() {
|
||||
const auto model = compiledModel.get_runtime_model();
|
||||
|
||||
int num_ops = 0;
|
||||
for (const auto& node : model->get_ordered_ops()) {
|
||||
const auto& rt_info = node->get_rt_info();
|
||||
const auto layer_type = rt_info.find("layerType")->second.as<std::string>();
|
||||
if (layer_type != "Reorder" && layer_type != "Const")
|
||||
num_ops++;
|
||||
EXPECT_NE(layer_type, "Broadcast");
|
||||
}
|
||||
ASSERT_EQ(num_ops, 3); // one Input, one Eltwise and one Output
|
||||
}
|
||||
|
||||
} // namespace test
|
||||
} // namespace ov
|
Loading…
Reference in New Issue
Block a user