ScatterElementsUpdate downgrade transformation (#18306)
This commit is contained in:
parent
f405ee2b9d
commit
60d5d57ece
@ -0,0 +1,23 @@
|
||||
// Copyright (C) 2018-2023 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <openvino/pass/graph_rewrite.hpp>
|
||||
#include <transformations_visibility.hpp>
|
||||
|
||||
namespace ov {
|
||||
namespace pass {
|
||||
/**
|
||||
* @ingroup ie_transformation_common_api
|
||||
* @brief Converts Pad v12 to Pad v1
|
||||
*/
|
||||
class TRANSFORMATIONS_API ConvertScatterElementsUpdate12ToScatterElementsUpdate3 : public MatcherPass {
|
||||
public:
|
||||
OPENVINO_RTTI("ConvertScatterElementsUpdate12ToScatterElementsUpdate3", "0");
|
||||
ConvertScatterElementsUpdate12ToScatterElementsUpdate3();
|
||||
};
|
||||
|
||||
} // namespace pass
|
||||
} // namespace ov
|
@ -89,7 +89,7 @@
|
||||
#include "transformations/op_conversions/convert_reduce_to_pooling.hpp"
|
||||
#include "transformations/op_conversions/convert_roi_align_v3_to_v9.hpp"
|
||||
#include "transformations/op_conversions/convert_roi_align_v9_to_v3.hpp"
|
||||
#include "transformations/op_conversions/convert_scatter_elements_to_scatter.hpp"
|
||||
#include "transformations/op_conversions/convert_scatter_elements_update12_downgrade.hpp"
|
||||
#include "transformations/op_conversions/convert_softmax_downgrade.hpp"
|
||||
#include "transformations/op_conversions/convert_softmax_upgrade.hpp"
|
||||
#include "transformations/op_conversions/convert_space_to_depth.hpp"
|
||||
@ -215,6 +215,7 @@ bool ov::pass::CommonOptimizations::run_on_model(const std::shared_ptr<ov::Model
|
||||
REGISTER_PASS(manager, ConvertTopK11ToTopK3)
|
||||
REGISTER_PASS(manager, ConvertInterpolate11ToInterpolate4)
|
||||
REGISTER_PASS(manager, ConvertPad12ToPad1)
|
||||
REGISTER_PASS(manager, ConvertScatterElementsUpdate12ToScatterElementsUpdate3)
|
||||
|
||||
auto fq_fusions = manager.register_pass<GraphRewrite>();
|
||||
ADD_MATCHER(fq_fusions, FakeQuantizeMulFusion)
|
||||
|
@ -0,0 +1,40 @@
|
||||
// Copyright (C) 2018-2023 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#include "transformations/op_conversions/convert_scatter_elements_update12_downgrade.hpp"
|
||||
|
||||
#include <ngraph/pattern/op/wrap_type.hpp>
|
||||
#include <ngraph/rt_info.hpp>
|
||||
#include <openvino/op/scatter_elements_update.hpp>
|
||||
|
||||
#include "itt.hpp"
|
||||
|
||||
ov::pass::ConvertScatterElementsUpdate12ToScatterElementsUpdate3::
|
||||
ConvertScatterElementsUpdate12ToScatterElementsUpdate3() {
|
||||
MATCHER_SCOPE(ConvertScatterElementsUpdate12ToScatterElementsUpdate3);
|
||||
|
||||
const auto seu_v12_pattern = pattern::wrap_type<ov::op::v12::ScatterElementsUpdate>();
|
||||
|
||||
const matcher_pass_callback callback = [=](pattern::Matcher& m) {
|
||||
const auto seu_v12 = std::dynamic_pointer_cast<ov::op::v12::ScatterElementsUpdate>(m.get_match_root());
|
||||
if (!seu_v12 || transformation_callback(seu_v12) ||
|
||||
seu_v12->get_reduction() != ov::op::v12::ScatterElementsUpdate::Reduction::NONE) {
|
||||
return false;
|
||||
}
|
||||
|
||||
const auto seu_v3 = std::make_shared<ov::op::v3::ScatterElementsUpdate>(seu_v12->input_value(0),
|
||||
seu_v12->input_value(1),
|
||||
seu_v12->input_value(2),
|
||||
seu_v12->input_value(3));
|
||||
|
||||
seu_v3->set_friendly_name(seu_v12->get_friendly_name());
|
||||
copy_runtime_info(seu_v12, seu_v3);
|
||||
replace_node(seu_v12, seu_v3);
|
||||
|
||||
return true;
|
||||
};
|
||||
|
||||
auto m = std::make_shared<pattern::Matcher>(seu_v12_pattern, matcher_name);
|
||||
register_matcher(m, callback);
|
||||
}
|
@ -0,0 +1,78 @@
|
||||
// Copyright (C) 2018-2023 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#include <gtest/gtest.h>
|
||||
|
||||
#include <memory>
|
||||
#include <openvino/opsets/opset12.hpp>
|
||||
#include <openvino/opsets/opset3.hpp>
|
||||
#include <openvino/pass/manager.hpp>
|
||||
#include <transformations/op_conversions/convert_scatter_elements_update12_downgrade.hpp>
|
||||
#include <transformations/utils/utils.hpp>
|
||||
|
||||
#include "common_test_utils/ngraph_test_utils.hpp"
|
||||
|
||||
using namespace testing;
|
||||
|
||||
namespace {
|
||||
using Reduction = ov::opset12::ScatterElementsUpdate::Reduction;
|
||||
|
||||
std::shared_ptr<ov::Model> create_v12_model(const Reduction reduction_type, const bool use_init_value) {
|
||||
const auto input = std::make_shared<ov::opset12::Parameter>(ov::element::f32, ov::Shape{1, 3, 100, 100});
|
||||
const auto indices = std::make_shared<ov::opset12::Parameter>(ov::element::i32, ov::Shape{1, 1, 5, 5});
|
||||
const auto updates = std::make_shared<ov::opset12::Parameter>(ov::element::f32, ov::Shape{1, 1, 5, 5});
|
||||
const auto axis = std::make_shared<ov::opset12::Parameter>(ov::element::i64, ov::Shape{});
|
||||
|
||||
const auto seu = std::make_shared<ov::opset12::ScatterElementsUpdate>(input,
|
||||
indices,
|
||||
updates,
|
||||
axis,
|
||||
reduction_type,
|
||||
use_init_value);
|
||||
|
||||
seu->set_friendly_name("seu12");
|
||||
|
||||
return std::make_shared<ov::Model>(seu->outputs(), ov::ParameterVector{input, indices, updates, axis});
|
||||
}
|
||||
|
||||
std::shared_ptr<ov::Model> create_v3_model() {
|
||||
const auto input = std::make_shared<ov::opset3::Parameter>(ov::element::f32, ov::Shape{1, 3, 100, 100});
|
||||
const auto indices = std::make_shared<ov::opset3::Parameter>(ov::element::i32, ov::Shape{1, 1, 5, 5});
|
||||
const auto updates = std::make_shared<ov::opset3::Parameter>(ov::element::f32, ov::Shape{1, 1, 5, 5});
|
||||
const auto axis = std::make_shared<ov::opset3::Parameter>(ov::element::i64, ov::Shape{});
|
||||
|
||||
const auto seu = std::make_shared<ov::opset3::ScatterElementsUpdate>(input, indices, updates, axis);
|
||||
|
||||
seu->set_friendly_name("seu3");
|
||||
|
||||
return std::make_shared<ov::Model>(seu->outputs(), ov::ParameterVector{input, indices, updates, axis});
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
TEST_F(TransformationTestsF, ConvertScatterElementsUpdate12ToScatterElementsUpdate3_no_reduction_use_init_value) {
|
||||
manager.register_pass<ov::pass::ConvertScatterElementsUpdate12ToScatterElementsUpdate3>();
|
||||
function = create_v12_model(Reduction::NONE, true);
|
||||
function_ref = create_v3_model();
|
||||
comparator.enable(FunctionsComparator::CmpValues::CONST_VALUES);
|
||||
comparator.enable(FunctionsComparator::CmpValues::ATTRIBUTES);
|
||||
}
|
||||
|
||||
TEST_F(TransformationTestsF, ConvertScatterElementsUpdate12ToScatterElementsUpdate3_no_reduction) {
|
||||
manager.register_pass<ov::pass::ConvertScatterElementsUpdate12ToScatterElementsUpdate3>();
|
||||
function = create_v12_model(Reduction::NONE, false);
|
||||
function_ref = create_v3_model();
|
||||
comparator.enable(FunctionsComparator::CmpValues::CONST_VALUES);
|
||||
comparator.enable(FunctionsComparator::CmpValues::ATTRIBUTES);
|
||||
}
|
||||
|
||||
TEST_F(TransformationTestsF, ConvertScatterElementsUpdate12ToScatterElementsUpdate3_reduction_use_init_value) {
|
||||
manager.register_pass<ov::pass::ConvertScatterElementsUpdate12ToScatterElementsUpdate3>();
|
||||
function = create_v12_model(Reduction::MEAN, true);
|
||||
}
|
||||
|
||||
TEST_F(TransformationTestsF, ConvertScatterElementsUpdate12ToScatterElementsUpdate3_reduction) {
|
||||
manager.register_pass<ov::pass::ConvertScatterElementsUpdate12ToScatterElementsUpdate3>();
|
||||
function = create_v12_model(Reduction::PROD, false);
|
||||
}
|
Loading…
Reference in New Issue
Block a user