ScatterElementsUpdate downgrade transformation (#18306)

This commit is contained in:
Tomasz Dołbniak 2023-06-30 13:18:24 +02:00 committed by GitHub
parent f405ee2b9d
commit 60d5d57ece
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 143 additions and 1 deletions

View File

@ -0,0 +1,23 @@
// Copyright (C) 2018-2023 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#pragma once
#include <openvino/pass/graph_rewrite.hpp>
#include <transformations_visibility.hpp>
namespace ov {
namespace pass {
/**
* @ingroup ie_transformation_common_api
* @brief Converts Pad v12 to Pad v1
*/
class TRANSFORMATIONS_API ConvertScatterElementsUpdate12ToScatterElementsUpdate3 : public MatcherPass {
public:
OPENVINO_RTTI("ConvertScatterElementsUpdate12ToScatterElementsUpdate3", "0");
ConvertScatterElementsUpdate12ToScatterElementsUpdate3();
};
} // namespace pass
} // namespace ov

View File

@ -89,7 +89,7 @@
#include "transformations/op_conversions/convert_reduce_to_pooling.hpp"
#include "transformations/op_conversions/convert_roi_align_v3_to_v9.hpp"
#include "transformations/op_conversions/convert_roi_align_v9_to_v3.hpp"
#include "transformations/op_conversions/convert_scatter_elements_to_scatter.hpp"
#include "transformations/op_conversions/convert_scatter_elements_update12_downgrade.hpp"
#include "transformations/op_conversions/convert_softmax_downgrade.hpp"
#include "transformations/op_conversions/convert_softmax_upgrade.hpp"
#include "transformations/op_conversions/convert_space_to_depth.hpp"
@ -215,6 +215,7 @@ bool ov::pass::CommonOptimizations::run_on_model(const std::shared_ptr<ov::Model
REGISTER_PASS(manager, ConvertTopK11ToTopK3)
REGISTER_PASS(manager, ConvertInterpolate11ToInterpolate4)
REGISTER_PASS(manager, ConvertPad12ToPad1)
REGISTER_PASS(manager, ConvertScatterElementsUpdate12ToScatterElementsUpdate3)
auto fq_fusions = manager.register_pass<GraphRewrite>();
ADD_MATCHER(fq_fusions, FakeQuantizeMulFusion)

View File

@ -0,0 +1,40 @@
// Copyright (C) 2018-2023 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include "transformations/op_conversions/convert_scatter_elements_update12_downgrade.hpp"
#include <ngraph/pattern/op/wrap_type.hpp>
#include <ngraph/rt_info.hpp>
#include <openvino/op/scatter_elements_update.hpp>
#include "itt.hpp"
ov::pass::ConvertScatterElementsUpdate12ToScatterElementsUpdate3::
ConvertScatterElementsUpdate12ToScatterElementsUpdate3() {
MATCHER_SCOPE(ConvertScatterElementsUpdate12ToScatterElementsUpdate3);
const auto seu_v12_pattern = pattern::wrap_type<ov::op::v12::ScatterElementsUpdate>();
const matcher_pass_callback callback = [=](pattern::Matcher& m) {
const auto seu_v12 = std::dynamic_pointer_cast<ov::op::v12::ScatterElementsUpdate>(m.get_match_root());
if (!seu_v12 || transformation_callback(seu_v12) ||
seu_v12->get_reduction() != ov::op::v12::ScatterElementsUpdate::Reduction::NONE) {
return false;
}
const auto seu_v3 = std::make_shared<ov::op::v3::ScatterElementsUpdate>(seu_v12->input_value(0),
seu_v12->input_value(1),
seu_v12->input_value(2),
seu_v12->input_value(3));
seu_v3->set_friendly_name(seu_v12->get_friendly_name());
copy_runtime_info(seu_v12, seu_v3);
replace_node(seu_v12, seu_v3);
return true;
};
auto m = std::make_shared<pattern::Matcher>(seu_v12_pattern, matcher_name);
register_matcher(m, callback);
}

View File

@ -0,0 +1,78 @@
// Copyright (C) 2018-2023 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include <gtest/gtest.h>
#include <memory>
#include <openvino/opsets/opset12.hpp>
#include <openvino/opsets/opset3.hpp>
#include <openvino/pass/manager.hpp>
#include <transformations/op_conversions/convert_scatter_elements_update12_downgrade.hpp>
#include <transformations/utils/utils.hpp>
#include "common_test_utils/ngraph_test_utils.hpp"
using namespace testing;
namespace {
using Reduction = ov::opset12::ScatterElementsUpdate::Reduction;
std::shared_ptr<ov::Model> create_v12_model(const Reduction reduction_type, const bool use_init_value) {
const auto input = std::make_shared<ov::opset12::Parameter>(ov::element::f32, ov::Shape{1, 3, 100, 100});
const auto indices = std::make_shared<ov::opset12::Parameter>(ov::element::i32, ov::Shape{1, 1, 5, 5});
const auto updates = std::make_shared<ov::opset12::Parameter>(ov::element::f32, ov::Shape{1, 1, 5, 5});
const auto axis = std::make_shared<ov::opset12::Parameter>(ov::element::i64, ov::Shape{});
const auto seu = std::make_shared<ov::opset12::ScatterElementsUpdate>(input,
indices,
updates,
axis,
reduction_type,
use_init_value);
seu->set_friendly_name("seu12");
return std::make_shared<ov::Model>(seu->outputs(), ov::ParameterVector{input, indices, updates, axis});
}
std::shared_ptr<ov::Model> create_v3_model() {
const auto input = std::make_shared<ov::opset3::Parameter>(ov::element::f32, ov::Shape{1, 3, 100, 100});
const auto indices = std::make_shared<ov::opset3::Parameter>(ov::element::i32, ov::Shape{1, 1, 5, 5});
const auto updates = std::make_shared<ov::opset3::Parameter>(ov::element::f32, ov::Shape{1, 1, 5, 5});
const auto axis = std::make_shared<ov::opset3::Parameter>(ov::element::i64, ov::Shape{});
const auto seu = std::make_shared<ov::opset3::ScatterElementsUpdate>(input, indices, updates, axis);
seu->set_friendly_name("seu3");
return std::make_shared<ov::Model>(seu->outputs(), ov::ParameterVector{input, indices, updates, axis});
}
} // namespace
TEST_F(TransformationTestsF, ConvertScatterElementsUpdate12ToScatterElementsUpdate3_no_reduction_use_init_value) {
manager.register_pass<ov::pass::ConvertScatterElementsUpdate12ToScatterElementsUpdate3>();
function = create_v12_model(Reduction::NONE, true);
function_ref = create_v3_model();
comparator.enable(FunctionsComparator::CmpValues::CONST_VALUES);
comparator.enable(FunctionsComparator::CmpValues::ATTRIBUTES);
}
TEST_F(TransformationTestsF, ConvertScatterElementsUpdate12ToScatterElementsUpdate3_no_reduction) {
manager.register_pass<ov::pass::ConvertScatterElementsUpdate12ToScatterElementsUpdate3>();
function = create_v12_model(Reduction::NONE, false);
function_ref = create_v3_model();
comparator.enable(FunctionsComparator::CmpValues::CONST_VALUES);
comparator.enable(FunctionsComparator::CmpValues::ATTRIBUTES);
}
TEST_F(TransformationTestsF, ConvertScatterElementsUpdate12ToScatterElementsUpdate3_reduction_use_init_value) {
manager.register_pass<ov::pass::ConvertScatterElementsUpdate12ToScatterElementsUpdate3>();
function = create_v12_model(Reduction::MEAN, true);
}
TEST_F(TransformationTestsF, ConvertScatterElementsUpdate12ToScatterElementsUpdate3_reduction) {
manager.register_pass<ov::pass::ConvertScatterElementsUpdate12ToScatterElementsUpdate3>();
function = create_v12_model(Reduction::PROD, false);
}