[Transforamtions] NonZero horizontal fusion: review leftovers (#16639)

* Review comments applied

* codestyle

* review comments applied
This commit is contained in:
Vladislav Golubev 2023-04-11 13:42:43 +02:00 committed by GitHub
parent ca2265395d
commit 296c2d6603
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 31 additions and 29 deletions

View File

@ -12,17 +12,17 @@
namespace ov {
namespace pass {
class TRANSFORMATIONS_API NonZeroFusion;
class TRANSFORMATIONS_API NonZeroHorizontalFusion;
} // namespace pass
} // namespace ov
/**
* @ingroup ie_transformation_common_api
* @brief NonZeroFusion transformation makes horizontal fusion for equal NonZero layers
* @brief NonZeroHorizontalFusion transformation makes horizontal fusion for equal NonZero layers
*/
class ov::pass::NonZeroFusion : public ov::pass::MatcherPass {
class ov::pass::NonZeroHorizontalFusion : public ov::pass::MatcherPass {
public:
OPENVINO_RTTI("NonZeroFusion", "0");
NonZeroFusion();
OPENVINO_RTTI("NonZeroHorizontalFusion", "0");
NonZeroHorizontalFusion();
};

View File

@ -39,7 +39,7 @@
#include <transformations/common_optimizations/mul_fake_quantize_fusion.hpp>
#include <transformations/common_optimizations/mvn_fusion.hpp>
#include <transformations/common_optimizations/nearest_neighbor_upsampling_fusion.hpp>
#include <transformations/common_optimizations/nonzero_fusion.hpp>
#include <transformations/common_optimizations/nonzero_horizontal_fusion.hpp>
#include <transformations/common_optimizations/nop_elimination.hpp>
#include <transformations/common_optimizations/normalize_l2_fusion.hpp>
#include <transformations/common_optimizations/optimize_strided_slice.hpp>
@ -203,7 +203,7 @@ bool ov::pass::MOCTransformations::run_on_model(const std::shared_ptr<ngraph::Fu
ADD_MATCHER(common_fusions, PReluFusion)
ADD_MATCHER(common_fusions, DepthToSpaceFusion)
ADD_MATCHER(common_fusions, ShuffleChannelsFusion, !m_use_shapes)
ADD_MATCHER(common_fusions, NonZeroFusion)
ADD_MATCHER(common_fusions, NonZeroHorizontalFusion)
common_fusions->set_name("ov::pass::CommonFusions");
REGISTER_PASS(manager, BinarizeWeights)

View File

@ -2,7 +2,7 @@
// SPDX-License-Identifier: Apache-2.0
//
#include "transformations/common_optimizations/nonzero_fusion.hpp"
#include "transformations/common_optimizations/nonzero_horizontal_fusion.hpp"
#include <memory>
#include <openvino/opsets/opset10.hpp>
@ -12,8 +12,8 @@
#include "itt.hpp"
#include "transformations/utils/utils.hpp"
ov::pass::NonZeroFusion::NonZeroFusion() {
MATCHER_SCOPE(NonZeroFusion);
ov::pass::NonZeroHorizontalFusion::NonZeroHorizontalFusion() {
MATCHER_SCOPE(NonZeroHorizontalFusion);
auto input_m = pass::pattern::any_input(ov::pass::pattern::consumers_more_than(1));
auto nonzero_m = pass::pattern::wrap_type<ov::opset10::NonZero>({input_m});
@ -24,8 +24,9 @@ ov::pass::NonZeroFusion::NonZeroFusion() {
bool status = false;
auto replace_if_nodes_match = [&](const ov::Input<ov::Node>& in) {
auto cur_nonzero = ov::as_type_ptr<ov::opset10::NonZero>(in.get_node()->shared_from_this());
if (cur_nonzero && cur_nonzero->get_output_type() == out_prc) {
auto in_node = in.get_node();
auto cur_nonzero = ov::as_type<ov::opset10::NonZero>(in_node);
if (in_node != nonzero.get() && cur_nonzero && cur_nonzero->get_output_type() == out_prc) {
status |= ov::replace_output_update_name(cur_nonzero->output(0), nonzero->output(0));
}
};

View File

@ -7,7 +7,7 @@
#include <memory>
#include <openvino/opsets/opset10.hpp>
#include <string>
#include <transformations/common_optimizations/nonzero_fusion.hpp>
#include <transformations/common_optimizations/nonzero_horizontal_fusion.hpp>
#include "common_test_utils/ngraph_test_utils.hpp"
@ -15,9 +15,9 @@ using namespace testing;
enum NonZeroType { I32, I64, NONE };
struct NonZeroFusionBuilder {
NonZeroFusionBuilder() = default;
NonZeroFusionBuilder(const std::vector<NonZeroType>& props) : branch_props(props) {}
struct NonZeroHorizontalFusionBuilder {
NonZeroHorizontalFusionBuilder() = default;
NonZeroHorizontalFusionBuilder(const std::vector<NonZeroType>& props) : branch_props(props) {}
std::shared_ptr<ov::Model> getOriginal() {
const auto input = std::make_shared<ov::opset10::Parameter>(ov::element::f32, ov::PartialShape::dynamic(4));
@ -71,14 +71,15 @@ struct NonZeroFusionBuilder {
std::vector<NonZeroType> branch_props;
};
class NonZeroFusionTests : public testing::WithParamInterface<std::vector<NonZeroType>>, public TransformationTestsF {
class NonZeroHorizontalFusionTests : public testing::WithParamInterface<std::vector<NonZeroType>>,
public TransformationTestsF {
public:
NonZeroFusionTests() : TransformationTestsF() {
NonZeroHorizontalFusionTests() : TransformationTestsF() {
comparator.enable(FunctionsComparator::CONSUMERS_COUNT);
}
static std::string getTestCaseName(testing::TestParamInfo<std::vector<NonZeroType>> obj) {
const std::vector<NonZeroType> testValues = obj.param;
static std::string getTestCaseName(const testing::TestParamInfo<std::vector<NonZeroType>>& obj) {
const std::vector<NonZeroType>& testValues = obj.param;
std::ostringstream result;
result << "branch_props_{";
for (const auto& value : testValues) {
@ -101,20 +102,20 @@ public:
protected:
void SetUp() override {
TransformationTestsF::SetUp();
const auto branch_props = GetParam();
builder = NonZeroFusionBuilder(branch_props);
manager.register_pass<ov::pass::NonZeroFusion>();
const auto& branch_props = GetParam();
builder = NonZeroHorizontalFusionBuilder(branch_props);
manager.register_pass<ov::pass::NonZeroHorizontalFusion>();
}
NonZeroFusionBuilder builder;
NonZeroHorizontalFusionBuilder builder;
};
TEST_P(NonZeroFusionTests, NonZeroFusion) {
TEST_P(NonZeroHorizontalFusionTests, NonZeroHorizontalFusion) {
model = builder.getOriginal();
model_ref = builder.getReference();
}
namespace NonZeroFusionTestsInstantiation {
namespace NonZeroHorizontalFusionTestsInstantiation {
std::vector<std::vector<NonZeroType>> test_params{std::vector<NonZeroType>(5, I32),
std::vector<NonZeroType>(5, I64),
std::vector<NonZeroType>(2, NONE),
@ -123,8 +124,8 @@ std::vector<std::vector<NonZeroType>> test_params{std::vector<NonZeroType>(5, I3
{NONE, I64, NONE, I64, I32}};
INSTANTIATE_TEST_SUITE_P(TransformationTestsF,
NonZeroFusionTests,
NonZeroHorizontalFusionTests,
::testing::ValuesIn(test_params),
NonZeroFusionTests::getTestCaseName);
NonZeroHorizontalFusionTests::getTestCaseName);
} // namespace NonZeroFusionTestsInstantiation
} // namespace NonZeroHorizontalFusionTestsInstantiation