[Transforamtions] NonZero horizontal fusion: review leftovers (#16639)
* Review comments applied * codestyle * review comments applied
This commit is contained in:
parent
ca2265395d
commit
296c2d6603
@ -12,17 +12,17 @@
|
||||
namespace ov {
|
||||
namespace pass {
|
||||
|
||||
class TRANSFORMATIONS_API NonZeroFusion;
|
||||
class TRANSFORMATIONS_API NonZeroHorizontalFusion;
|
||||
|
||||
} // namespace pass
|
||||
} // namespace ov
|
||||
|
||||
/**
|
||||
* @ingroup ie_transformation_common_api
|
||||
* @brief NonZeroFusion transformation makes horizontal fusion for equal NonZero layers
|
||||
* @brief NonZeroHorizontalFusion transformation makes horizontal fusion for equal NonZero layers
|
||||
*/
|
||||
class ov::pass::NonZeroFusion : public ov::pass::MatcherPass {
|
||||
class ov::pass::NonZeroHorizontalFusion : public ov::pass::MatcherPass {
|
||||
public:
|
||||
OPENVINO_RTTI("NonZeroFusion", "0");
|
||||
NonZeroFusion();
|
||||
OPENVINO_RTTI("NonZeroHorizontalFusion", "0");
|
||||
NonZeroHorizontalFusion();
|
||||
};
|
@ -39,7 +39,7 @@
|
||||
#include <transformations/common_optimizations/mul_fake_quantize_fusion.hpp>
|
||||
#include <transformations/common_optimizations/mvn_fusion.hpp>
|
||||
#include <transformations/common_optimizations/nearest_neighbor_upsampling_fusion.hpp>
|
||||
#include <transformations/common_optimizations/nonzero_fusion.hpp>
|
||||
#include <transformations/common_optimizations/nonzero_horizontal_fusion.hpp>
|
||||
#include <transformations/common_optimizations/nop_elimination.hpp>
|
||||
#include <transformations/common_optimizations/normalize_l2_fusion.hpp>
|
||||
#include <transformations/common_optimizations/optimize_strided_slice.hpp>
|
||||
@ -203,7 +203,7 @@ bool ov::pass::MOCTransformations::run_on_model(const std::shared_ptr<ngraph::Fu
|
||||
ADD_MATCHER(common_fusions, PReluFusion)
|
||||
ADD_MATCHER(common_fusions, DepthToSpaceFusion)
|
||||
ADD_MATCHER(common_fusions, ShuffleChannelsFusion, !m_use_shapes)
|
||||
ADD_MATCHER(common_fusions, NonZeroFusion)
|
||||
ADD_MATCHER(common_fusions, NonZeroHorizontalFusion)
|
||||
common_fusions->set_name("ov::pass::CommonFusions");
|
||||
|
||||
REGISTER_PASS(manager, BinarizeWeights)
|
||||
|
@ -2,7 +2,7 @@
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#include "transformations/common_optimizations/nonzero_fusion.hpp"
|
||||
#include "transformations/common_optimizations/nonzero_horizontal_fusion.hpp"
|
||||
|
||||
#include <memory>
|
||||
#include <openvino/opsets/opset10.hpp>
|
||||
@ -12,8 +12,8 @@
|
||||
#include "itt.hpp"
|
||||
#include "transformations/utils/utils.hpp"
|
||||
|
||||
ov::pass::NonZeroFusion::NonZeroFusion() {
|
||||
MATCHER_SCOPE(NonZeroFusion);
|
||||
ov::pass::NonZeroHorizontalFusion::NonZeroHorizontalFusion() {
|
||||
MATCHER_SCOPE(NonZeroHorizontalFusion);
|
||||
auto input_m = pass::pattern::any_input(ov::pass::pattern::consumers_more_than(1));
|
||||
auto nonzero_m = pass::pattern::wrap_type<ov::opset10::NonZero>({input_m});
|
||||
|
||||
@ -24,8 +24,9 @@ ov::pass::NonZeroFusion::NonZeroFusion() {
|
||||
|
||||
bool status = false;
|
||||
auto replace_if_nodes_match = [&](const ov::Input<ov::Node>& in) {
|
||||
auto cur_nonzero = ov::as_type_ptr<ov::opset10::NonZero>(in.get_node()->shared_from_this());
|
||||
if (cur_nonzero && cur_nonzero->get_output_type() == out_prc) {
|
||||
auto in_node = in.get_node();
|
||||
auto cur_nonzero = ov::as_type<ov::opset10::NonZero>(in_node);
|
||||
if (in_node != nonzero.get() && cur_nonzero && cur_nonzero->get_output_type() == out_prc) {
|
||||
status |= ov::replace_output_update_name(cur_nonzero->output(0), nonzero->output(0));
|
||||
}
|
||||
};
|
@ -7,7 +7,7 @@
|
||||
#include <memory>
|
||||
#include <openvino/opsets/opset10.hpp>
|
||||
#include <string>
|
||||
#include <transformations/common_optimizations/nonzero_fusion.hpp>
|
||||
#include <transformations/common_optimizations/nonzero_horizontal_fusion.hpp>
|
||||
|
||||
#include "common_test_utils/ngraph_test_utils.hpp"
|
||||
|
||||
@ -15,9 +15,9 @@ using namespace testing;
|
||||
|
||||
enum NonZeroType { I32, I64, NONE };
|
||||
|
||||
struct NonZeroFusionBuilder {
|
||||
NonZeroFusionBuilder() = default;
|
||||
NonZeroFusionBuilder(const std::vector<NonZeroType>& props) : branch_props(props) {}
|
||||
struct NonZeroHorizontalFusionBuilder {
|
||||
NonZeroHorizontalFusionBuilder() = default;
|
||||
NonZeroHorizontalFusionBuilder(const std::vector<NonZeroType>& props) : branch_props(props) {}
|
||||
|
||||
std::shared_ptr<ov::Model> getOriginal() {
|
||||
const auto input = std::make_shared<ov::opset10::Parameter>(ov::element::f32, ov::PartialShape::dynamic(4));
|
||||
@ -71,14 +71,15 @@ struct NonZeroFusionBuilder {
|
||||
std::vector<NonZeroType> branch_props;
|
||||
};
|
||||
|
||||
class NonZeroFusionTests : public testing::WithParamInterface<std::vector<NonZeroType>>, public TransformationTestsF {
|
||||
class NonZeroHorizontalFusionTests : public testing::WithParamInterface<std::vector<NonZeroType>>,
|
||||
public TransformationTestsF {
|
||||
public:
|
||||
NonZeroFusionTests() : TransformationTestsF() {
|
||||
NonZeroHorizontalFusionTests() : TransformationTestsF() {
|
||||
comparator.enable(FunctionsComparator::CONSUMERS_COUNT);
|
||||
}
|
||||
|
||||
static std::string getTestCaseName(testing::TestParamInfo<std::vector<NonZeroType>> obj) {
|
||||
const std::vector<NonZeroType> testValues = obj.param;
|
||||
static std::string getTestCaseName(const testing::TestParamInfo<std::vector<NonZeroType>>& obj) {
|
||||
const std::vector<NonZeroType>& testValues = obj.param;
|
||||
std::ostringstream result;
|
||||
result << "branch_props_{";
|
||||
for (const auto& value : testValues) {
|
||||
@ -101,20 +102,20 @@ public:
|
||||
protected:
|
||||
void SetUp() override {
|
||||
TransformationTestsF::SetUp();
|
||||
const auto branch_props = GetParam();
|
||||
builder = NonZeroFusionBuilder(branch_props);
|
||||
manager.register_pass<ov::pass::NonZeroFusion>();
|
||||
const auto& branch_props = GetParam();
|
||||
builder = NonZeroHorizontalFusionBuilder(branch_props);
|
||||
manager.register_pass<ov::pass::NonZeroHorizontalFusion>();
|
||||
}
|
||||
|
||||
NonZeroFusionBuilder builder;
|
||||
NonZeroHorizontalFusionBuilder builder;
|
||||
};
|
||||
|
||||
TEST_P(NonZeroFusionTests, NonZeroFusion) {
|
||||
TEST_P(NonZeroHorizontalFusionTests, NonZeroHorizontalFusion) {
|
||||
model = builder.getOriginal();
|
||||
model_ref = builder.getReference();
|
||||
}
|
||||
|
||||
namespace NonZeroFusionTestsInstantiation {
|
||||
namespace NonZeroHorizontalFusionTestsInstantiation {
|
||||
std::vector<std::vector<NonZeroType>> test_params{std::vector<NonZeroType>(5, I32),
|
||||
std::vector<NonZeroType>(5, I64),
|
||||
std::vector<NonZeroType>(2, NONE),
|
||||
@ -123,8 +124,8 @@ std::vector<std::vector<NonZeroType>> test_params{std::vector<NonZeroType>(5, I3
|
||||
{NONE, I64, NONE, I64, I32}};
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(TransformationTestsF,
|
||||
NonZeroFusionTests,
|
||||
NonZeroHorizontalFusionTests,
|
||||
::testing::ValuesIn(test_params),
|
||||
NonZeroFusionTests::getTestCaseName);
|
||||
NonZeroHorizontalFusionTests::getTestCaseName);
|
||||
|
||||
} // namespace NonZeroFusionTestsInstantiation
|
||||
} // namespace NonZeroHorizontalFusionTestsInstantiation
|
Loading…
Reference in New Issue
Block a user