TopK v11 -> v3 downgrade transformation (#16339)
This commit is contained in:
@@ -0,0 +1,23 @@
|
||||
// Copyright (C) 2018-2023 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <openvino/pass/graph_rewrite.hpp>
|
||||
#include <transformations_visibility.hpp>
|
||||
|
||||
namespace ov {
|
||||
namespace pass {
|
||||
/**
|
||||
* @ingroup ie_transformation_common_api
|
||||
* @brief Converts TopK version 11 to TopK version 3 if TopK 11 stable attribute is set to false
|
||||
*/
|
||||
class TRANSFORMATIONS_API ConvertTopK11ToTopK3 : public MatcherPass {
|
||||
public:
|
||||
OPENVINO_RTTI("ConvertTopK11ToTopK3", "0");
|
||||
ConvertTopK11ToTopK3();
|
||||
};
|
||||
|
||||
} // namespace pass
|
||||
} // namespace ov
|
||||
@@ -92,6 +92,7 @@
|
||||
#include "transformations/op_conversions/convert_softmax_upgrade.hpp"
|
||||
#include "transformations/op_conversions/convert_space_to_depth.hpp"
|
||||
#include "transformations/op_conversions/convert_subtract.hpp"
|
||||
#include "transformations/op_conversions/convert_topk11_downgrade.hpp"
|
||||
#include "transformations/op_conversions/convert_xor_to_logical_xor.hpp"
|
||||
#include "transformations/op_conversions/detection_output_downgrade.hpp"
|
||||
#include "transformations/op_conversions/detection_output_upgrade.hpp"
|
||||
@@ -209,6 +210,7 @@ bool ov::pass::CommonOptimizations::run_on_model(const std::shared_ptr<ov::Model
|
||||
REGISTER_PASS(manager, ConvertROIAlign9To3)
|
||||
REGISTER_PASS(manager, ConvertMulticlassNms8ToMulticlassNms9)
|
||||
REGISTER_PASS(manager, ConvertXorToLogicalXor)
|
||||
REGISTER_PASS(manager, ConvertTopK11ToTopK3)
|
||||
|
||||
auto fq_fusions = manager.register_pass<GraphRewrite>();
|
||||
ADD_MATCHER(fq_fusions, FakeQuantizeMulFusion)
|
||||
|
||||
@@ -0,0 +1,43 @@
|
||||
// Copyright (C) 2018-2023 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#include "transformations/op_conversions/convert_topk11_downgrade.hpp"
|
||||
|
||||
#include <ngraph/pattern/op/wrap_type.hpp>
|
||||
#include <ngraph/rt_info.hpp>
|
||||
#include <openvino/opsets/opset11.hpp>
|
||||
#include <openvino/opsets/opset3.hpp>
|
||||
|
||||
#include "itt.hpp"
|
||||
|
||||
ov::pass::ConvertTopK11ToTopK3::ConvertTopK11ToTopK3() {
|
||||
MATCHER_SCOPE(ConvertTopK11ToTopK3);
|
||||
|
||||
const auto topk_v11_pattern = pattern::wrap_type<opset11::TopK>();
|
||||
|
||||
const matcher_pass_callback callback = [=](pattern::Matcher& m) {
|
||||
const auto topk_v11 = std::dynamic_pointer_cast<opset11::TopK>(m.get_match_root());
|
||||
if (!topk_v11 || topk_v11->get_stable() || transformation_callback(topk_v11)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// downgrade only if the stable sort is NOT required
|
||||
|
||||
const auto topk_v3 = std::make_shared<opset3::TopK>(topk_v11->input_value(0),
|
||||
topk_v11->input_value(1),
|
||||
topk_v11->get_axis(),
|
||||
topk_v11->get_mode(),
|
||||
topk_v11->get_sort_type(),
|
||||
topk_v11->get_index_element_type());
|
||||
|
||||
topk_v3->set_friendly_name(topk_v11->get_friendly_name());
|
||||
copy_runtime_info(topk_v11, topk_v3);
|
||||
replace_node(topk_v11, topk_v3);
|
||||
|
||||
return true;
|
||||
};
|
||||
|
||||
auto m = std::make_shared<pattern::Matcher>(topk_v11_pattern, matcher_name);
|
||||
register_matcher(m, callback);
|
||||
}
|
||||
@@ -0,0 +1,64 @@
|
||||
// Copyright (C) 2018-2023 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#include <gtest/gtest.h>
|
||||
|
||||
#include <memory>
|
||||
#include <openvino/opsets/opset11.hpp>
|
||||
#include <openvino/opsets/opset3.hpp>
|
||||
#include <openvino/pass/manager.hpp>
|
||||
#include <transformations/op_conversions/convert_topk11_downgrade.hpp>
|
||||
#include <transformations/utils/utils.hpp>
|
||||
|
||||
#include "common_test_utils/ngraph_test_utils.hpp"
|
||||
|
||||
using namespace testing;
|
||||
|
||||
TEST_F(TransformationTestsF, ConvertTopK11ToTopK3) {
|
||||
{
|
||||
const auto input = std::make_shared<ov::opset11::Parameter>(ov::element::i32, ov::Shape{2, 3, 4});
|
||||
const auto k = std::make_shared<ov::opset11::Parameter>(ov::element::i8, ov::Shape{});
|
||||
const auto topk = std::make_shared<ov::opset11::TopK>(input,
|
||||
k,
|
||||
-2,
|
||||
ov::op::TopKMode::MAX,
|
||||
ov::op::TopKSortType::SORT_VALUES,
|
||||
ov::element::i64,
|
||||
false);
|
||||
topk->set_friendly_name("topk11");
|
||||
|
||||
function = std::make_shared<ov::Model>(topk->outputs(), ov::ParameterVector{input, k});
|
||||
manager.register_pass<ov::pass::ConvertTopK11ToTopK3>();
|
||||
}
|
||||
|
||||
{
|
||||
const auto input = std::make_shared<ov::opset3::Parameter>(ov::element::i32, ov::Shape{2, 3, 4});
|
||||
const auto k = std::make_shared<ov::opset3::Parameter>(ov::element::i8, ov::Shape{});
|
||||
const auto topk = std::make_shared<ov::opset3::TopK>(input,
|
||||
k,
|
||||
-2,
|
||||
ov::op::TopKMode::MAX,
|
||||
ov::op::TopKSortType::SORT_VALUES,
|
||||
ov::element::i64);
|
||||
topk->set_friendly_name("topk11");
|
||||
|
||||
function_ref = std::make_shared<ov::Model>(topk->outputs(), ov::ParameterVector{input, k});
|
||||
}
|
||||
}
|
||||
|
||||
TEST_F(TransformationTestsF, ConvertTopK11ToTopK3_fail) {
|
||||
const auto input = std::make_shared<ov::opset11::Parameter>(ov::element::i32, ov::Shape{2, 3, 4});
|
||||
const auto k = std::make_shared<ov::opset11::Parameter>(ov::element::i8, ov::Shape{});
|
||||
const auto topk = std::make_shared<ov::opset11::TopK>(input,
|
||||
k,
|
||||
-2,
|
||||
ov::op::TopKMode::MAX,
|
||||
ov::op::TopKSortType::SORT_VALUES,
|
||||
ov::element::i64,
|
||||
true); // stable sort on
|
||||
topk->set_friendly_name("topk11");
|
||||
|
||||
function = std::make_shared<ov::Model>(topk->outputs(), ov::ParameterVector{input, k});
|
||||
manager.register_pass<ov::pass::ConvertTopK11ToTopK3>();
|
||||
}
|
||||
Reference in New Issue
Block a user