[TF FE] Move to TopK-11 operation and update downgrading TopK transformation (#16590)
* [TF FE] Move to TopK-11 operation Signed-off-by: Kazantsev, Roman <roman.kazantsev@intel.com> * Update downgrading transformation --------- Signed-off-by: Kazantsev, Roman <roman.kazantsev@intel.com>
This commit is contained in:
@@ -18,11 +18,13 @@ ov::pass::ConvertTopK11ToTopK3::ConvertTopK11ToTopK3() {
|
||||
|
||||
const matcher_pass_callback callback = [=](pattern::Matcher& m) {
|
||||
const auto topk_v11 = std::dynamic_pointer_cast<opset11::TopK>(m.get_match_root());
|
||||
if (!topk_v11 || topk_v11->get_stable() || transformation_callback(topk_v11)) {
|
||||
if (!topk_v11 || transformation_callback(topk_v11)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// downgrade only if the stable sort is NOT required
|
||||
// downgrade even if stable attribute is True
|
||||
// this is needed to provide backward-compatibility
|
||||
// and operation working in the plugins that have not yet added stable mode
|
||||
|
||||
const auto topk_v3 = std::make_shared<opset3::TopK>(topk_v11->input_value(0),
|
||||
topk_v11->input_value(1),
|
||||
|
||||
@@ -47,18 +47,34 @@ TEST_F(TransformationTestsF, ConvertTopK11ToTopK3) {
|
||||
}
|
||||
}
|
||||
|
||||
TEST_F(TransformationTestsF, ConvertTopK11ToTopK3_fail) {
|
||||
const auto input = std::make_shared<ov::opset11::Parameter>(ov::element::i32, ov::Shape{2, 3, 4});
|
||||
const auto k = std::make_shared<ov::opset11::Parameter>(ov::element::i8, ov::Shape{});
|
||||
const auto topk = std::make_shared<ov::opset11::TopK>(input,
|
||||
k,
|
||||
-2,
|
||||
ov::op::TopKMode::MAX,
|
||||
ov::op::TopKSortType::SORT_VALUES,
|
||||
ov::element::i64,
|
||||
true); // stable sort on
|
||||
topk->set_friendly_name("topk11");
|
||||
TEST_F(TransformationTestsF, ConvertTopK11ToTopK3StableMode) {
|
||||
{
|
||||
const auto input = std::make_shared<ov::opset11::Parameter>(ov::element::i32, ov::Shape{2, 3, 4});
|
||||
const auto k = std::make_shared<ov::opset11::Parameter>(ov::element::i8, ov::Shape{});
|
||||
const auto topk = std::make_shared<ov::opset11::TopK>(input,
|
||||
k,
|
||||
-2,
|
||||
ov::op::TopKMode::MAX,
|
||||
ov::op::TopKSortType::SORT_VALUES,
|
||||
ov::element::i64,
|
||||
true);
|
||||
topk->set_friendly_name("topk11");
|
||||
|
||||
function = std::make_shared<ov::Model>(topk->outputs(), ov::ParameterVector{input, k});
|
||||
manager.register_pass<ov::pass::ConvertTopK11ToTopK3>();
|
||||
function = std::make_shared<ov::Model>(topk->outputs(), ov::ParameterVector{input, k});
|
||||
manager.register_pass<ov::pass::ConvertTopK11ToTopK3>();
|
||||
}
|
||||
|
||||
{
|
||||
const auto input = std::make_shared<ov::opset3::Parameter>(ov::element::i32, ov::Shape{2, 3, 4});
|
||||
const auto k = std::make_shared<ov::opset3::Parameter>(ov::element::i8, ov::Shape{});
|
||||
const auto topk = std::make_shared<ov::opset3::TopK>(input,
|
||||
k,
|
||||
-2,
|
||||
ov::op::TopKMode::MAX,
|
||||
ov::op::TopKSortType::SORT_VALUES,
|
||||
ov::element::i64);
|
||||
topk->set_friendly_name("topk11");
|
||||
|
||||
function_ref = std::make_shared<ov::Model>(topk->outputs(), ov::ParameterVector{input, k});
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user