TopK v11 -> v3 downgrade transformation (#16339)

This commit is contained in:
Tomasz Dołbniak
2023-03-17 13:40:56 +01:00
committed by GitHub
parent 249d57f37e
commit a99a5057e2
4 changed files with 132 additions and 0 deletions

View File

@@ -0,0 +1,23 @@
// Copyright (C) 2018-2023 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#pragma once
#include <openvino/pass/graph_rewrite.hpp>
#include <transformations_visibility.hpp>
namespace ov {
namespace pass {
/**
* @ingroup ie_transformation_common_api
* @brief Converts TopK version 11 to TopK version 3 if TopK 11 stable attribute is set to false
*/
class TRANSFORMATIONS_API ConvertTopK11ToTopK3 : public MatcherPass {
public:
OPENVINO_RTTI("ConvertTopK11ToTopK3", "0");
ConvertTopK11ToTopK3();
};
} // namespace pass
} // namespace ov

View File

@@ -92,6 +92,7 @@
#include "transformations/op_conversions/convert_softmax_upgrade.hpp"
#include "transformations/op_conversions/convert_space_to_depth.hpp"
#include "transformations/op_conversions/convert_subtract.hpp"
#include "transformations/op_conversions/convert_topk11_downgrade.hpp"
#include "transformations/op_conversions/convert_xor_to_logical_xor.hpp"
#include "transformations/op_conversions/detection_output_downgrade.hpp"
#include "transformations/op_conversions/detection_output_upgrade.hpp"
@@ -209,6 +210,7 @@ bool ov::pass::CommonOptimizations::run_on_model(const std::shared_ptr<ov::Model
REGISTER_PASS(manager, ConvertROIAlign9To3)
REGISTER_PASS(manager, ConvertMulticlassNms8ToMulticlassNms9)
REGISTER_PASS(manager, ConvertXorToLogicalXor)
REGISTER_PASS(manager, ConvertTopK11ToTopK3)
auto fq_fusions = manager.register_pass<GraphRewrite>();
ADD_MATCHER(fq_fusions, FakeQuantizeMulFusion)

View File

@@ -0,0 +1,43 @@
// Copyright (C) 2018-2023 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include "transformations/op_conversions/convert_topk11_downgrade.hpp"
#include <ngraph/pattern/op/wrap_type.hpp>
#include <ngraph/rt_info.hpp>
#include <openvino/opsets/opset11.hpp>
#include <openvino/opsets/opset3.hpp>
#include "itt.hpp"
ov::pass::ConvertTopK11ToTopK3::ConvertTopK11ToTopK3() {
MATCHER_SCOPE(ConvertTopK11ToTopK3);
const auto topk_v11_pattern = pattern::wrap_type<opset11::TopK>();
const matcher_pass_callback callback = [=](pattern::Matcher& m) {
const auto topk_v11 = std::dynamic_pointer_cast<opset11::TopK>(m.get_match_root());
if (!topk_v11 || topk_v11->get_stable() || transformation_callback(topk_v11)) {
return false;
}
// downgrade only if the stable sort is NOT required
const auto topk_v3 = std::make_shared<opset3::TopK>(topk_v11->input_value(0),
topk_v11->input_value(1),
topk_v11->get_axis(),
topk_v11->get_mode(),
topk_v11->get_sort_type(),
topk_v11->get_index_element_type());
topk_v3->set_friendly_name(topk_v11->get_friendly_name());
copy_runtime_info(topk_v11, topk_v3);
replace_node(topk_v11, topk_v3);
return true;
};
auto m = std::make_shared<pattern::Matcher>(topk_v11_pattern, matcher_name);
register_matcher(m, callback);
}

View File

@@ -0,0 +1,64 @@
// Copyright (C) 2018-2023 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include <gtest/gtest.h>
#include <memory>
#include <openvino/opsets/opset11.hpp>
#include <openvino/opsets/opset3.hpp>
#include <openvino/pass/manager.hpp>
#include <transformations/op_conversions/convert_topk11_downgrade.hpp>
#include <transformations/utils/utils.hpp>
#include "common_test_utils/ngraph_test_utils.hpp"
using namespace testing;
TEST_F(TransformationTestsF, ConvertTopK11ToTopK3) {
{
const auto input = std::make_shared<ov::opset11::Parameter>(ov::element::i32, ov::Shape{2, 3, 4});
const auto k = std::make_shared<ov::opset11::Parameter>(ov::element::i8, ov::Shape{});
const auto topk = std::make_shared<ov::opset11::TopK>(input,
k,
-2,
ov::op::TopKMode::MAX,
ov::op::TopKSortType::SORT_VALUES,
ov::element::i64,
false);
topk->set_friendly_name("topk11");
function = std::make_shared<ov::Model>(topk->outputs(), ov::ParameterVector{input, k});
manager.register_pass<ov::pass::ConvertTopK11ToTopK3>();
}
{
const auto input = std::make_shared<ov::opset3::Parameter>(ov::element::i32, ov::Shape{2, 3, 4});
const auto k = std::make_shared<ov::opset3::Parameter>(ov::element::i8, ov::Shape{});
const auto topk = std::make_shared<ov::opset3::TopK>(input,
k,
-2,
ov::op::TopKMode::MAX,
ov::op::TopKSortType::SORT_VALUES,
ov::element::i64);
topk->set_friendly_name("topk11");
function_ref = std::make_shared<ov::Model>(topk->outputs(), ov::ParameterVector{input, k});
}
}
TEST_F(TransformationTestsF, ConvertTopK11ToTopK3_fail) {
const auto input = std::make_shared<ov::opset11::Parameter>(ov::element::i32, ov::Shape{2, 3, 4});
const auto k = std::make_shared<ov::opset11::Parameter>(ov::element::i8, ov::Shape{});
const auto topk = std::make_shared<ov::opset11::TopK>(input,
k,
-2,
ov::op::TopKMode::MAX,
ov::op::TopKSortType::SORT_VALUES,
ov::element::i64,
true); // stable sort on
topk->set_friendly_name("topk11");
function = std::make_shared<ov::Model>(topk->outputs(), ov::ParameterVector{input, k});
manager.register_pass<ov::pass::ConvertTopK11ToTopK3>();
}