[TF FE] Move to TopK-11 operation and update downgrading TopK transformation (#16590)
* [TF FE] Move to TopK-11 operation Signed-off-by: Kazantsev, Roman <roman.kazantsev@intel.com> * Update downgrading transformation --------- Signed-off-by: Kazantsev, Roman <roman.kazantsev@intel.com>
This commit is contained in:
parent
c33a3f87f0
commit
9a5a8f6abc
@ -18,11 +18,13 @@ ov::pass::ConvertTopK11ToTopK3::ConvertTopK11ToTopK3() {
|
|||||||
|
|
||||||
const matcher_pass_callback callback = [=](pattern::Matcher& m) {
|
const matcher_pass_callback callback = [=](pattern::Matcher& m) {
|
||||||
const auto topk_v11 = std::dynamic_pointer_cast<opset11::TopK>(m.get_match_root());
|
const auto topk_v11 = std::dynamic_pointer_cast<opset11::TopK>(m.get_match_root());
|
||||||
if (!topk_v11 || topk_v11->get_stable() || transformation_callback(topk_v11)) {
|
if (!topk_v11 || transformation_callback(topk_v11)) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
// downgrade only if the stable sort is NOT required
|
// downgrade even if stable attribute is True
|
||||||
|
// this is needed to provide backward-compatibility
|
||||||
|
// and operation working in the plugins that have not yet added stable mode
|
||||||
|
|
||||||
const auto topk_v3 = std::make_shared<opset3::TopK>(topk_v11->input_value(0),
|
const auto topk_v3 = std::make_shared<opset3::TopK>(topk_v11->input_value(0),
|
||||||
topk_v11->input_value(1),
|
topk_v11->input_value(1),
|
||||||
|
@ -47,18 +47,34 @@ TEST_F(TransformationTestsF, ConvertTopK11ToTopK3) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(TransformationTestsF, ConvertTopK11ToTopK3_fail) {
|
TEST_F(TransformationTestsF, ConvertTopK11ToTopK3StableMode) {
|
||||||
const auto input = std::make_shared<ov::opset11::Parameter>(ov::element::i32, ov::Shape{2, 3, 4});
|
{
|
||||||
const auto k = std::make_shared<ov::opset11::Parameter>(ov::element::i8, ov::Shape{});
|
const auto input = std::make_shared<ov::opset11::Parameter>(ov::element::i32, ov::Shape{2, 3, 4});
|
||||||
const auto topk = std::make_shared<ov::opset11::TopK>(input,
|
const auto k = std::make_shared<ov::opset11::Parameter>(ov::element::i8, ov::Shape{});
|
||||||
k,
|
const auto topk = std::make_shared<ov::opset11::TopK>(input,
|
||||||
-2,
|
k,
|
||||||
ov::op::TopKMode::MAX,
|
-2,
|
||||||
ov::op::TopKSortType::SORT_VALUES,
|
ov::op::TopKMode::MAX,
|
||||||
ov::element::i64,
|
ov::op::TopKSortType::SORT_VALUES,
|
||||||
true); // stable sort on
|
ov::element::i64,
|
||||||
topk->set_friendly_name("topk11");
|
true);
|
||||||
|
topk->set_friendly_name("topk11");
|
||||||
|
|
||||||
function = std::make_shared<ov::Model>(topk->outputs(), ov::ParameterVector{input, k});
|
function = std::make_shared<ov::Model>(topk->outputs(), ov::ParameterVector{input, k});
|
||||||
manager.register_pass<ov::pass::ConvertTopK11ToTopK3>();
|
manager.register_pass<ov::pass::ConvertTopK11ToTopK3>();
|
||||||
|
}
|
||||||
|
|
||||||
|
{
|
||||||
|
const auto input = std::make_shared<ov::opset3::Parameter>(ov::element::i32, ov::Shape{2, 3, 4});
|
||||||
|
const auto k = std::make_shared<ov::opset3::Parameter>(ov::element::i8, ov::Shape{});
|
||||||
|
const auto topk = std::make_shared<ov::opset3::TopK>(input,
|
||||||
|
k,
|
||||||
|
-2,
|
||||||
|
ov::op::TopKMode::MAX,
|
||||||
|
ov::op::TopKSortType::SORT_VALUES,
|
||||||
|
ov::element::i64);
|
||||||
|
topk->set_friendly_name("topk11");
|
||||||
|
|
||||||
|
function_ref = std::make_shared<ov::Model>(topk->outputs(), ov::ParameterVector{input, k});
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
@ -3,10 +3,10 @@
|
|||||||
//
|
//
|
||||||
|
|
||||||
#include "common_op_table.hpp"
|
#include "common_op_table.hpp"
|
||||||
#include "openvino/opsets/opset8.hpp"
|
#include "openvino/opsets/opset11.hpp"
|
||||||
|
|
||||||
using namespace std;
|
using namespace std;
|
||||||
using namespace ov::opset8;
|
using namespace ov::opset11;
|
||||||
|
|
||||||
namespace ov {
|
namespace ov {
|
||||||
namespace frontend {
|
namespace frontend {
|
||||||
@ -23,10 +23,12 @@ OutputVector translate_top_k_base_op(const NodeContext& node, const ov::Output<o
|
|||||||
-1,
|
-1,
|
||||||
ov::op::v1::TopK::Mode::MAX,
|
ov::op::v1::TopK::Mode::MAX,
|
||||||
sorted ? TopK::SortType::SORT_VALUES : TopK::SortType::SORT_INDICES,
|
sorted ? TopK::SortType::SORT_VALUES : TopK::SortType::SORT_INDICES,
|
||||||
ov::element::i32);
|
ov::element::i32,
|
||||||
|
true);
|
||||||
set_node_name(node.get_name(), top_k);
|
set_node_name(node.get_name(), top_k);
|
||||||
return top_k->outputs();
|
return top_k->outputs();
|
||||||
}
|
}
|
||||||
|
|
||||||
OutputVector translate_top_k_op(const NodeContext& node) {
|
OutputVector translate_top_k_op(const NodeContext& node) {
|
||||||
// retrieve k attribute
|
// retrieve k attribute
|
||||||
auto k = node.get_attribute<int64_t>("k");
|
auto k = node.get_attribute<int64_t>("k");
|
||||||
@ -39,7 +41,6 @@ OutputVector translate_top_k_v2_op(const NodeContext& node) {
|
|||||||
auto k_input = node.get_input(1);
|
auto k_input = node.get_input(1);
|
||||||
return translate_top_k_base_op(node, k_input, 1);
|
return translate_top_k_base_op(node, k_input, 1);
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace op
|
} // namespace op
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
} // namespace frontend
|
} // namespace frontend
|
||||||
|
@ -36,13 +36,11 @@ class TestTopKV2(CommonTFLayerTest):
|
|||||||
dict(input_shape=[10], input_type=tf.float32, k=5, sorted=True, is_first_output=True, is_second_output=False),
|
dict(input_shape=[10], input_type=tf.float32, k=5, sorted=True, is_first_output=True, is_second_output=False),
|
||||||
dict(input_shape=[2, 3, 10], input_type=tf.int32, k=10, sorted=True, is_first_output=True,
|
dict(input_shape=[2, 3, 10], input_type=tf.int32, k=10, sorted=True, is_first_output=True,
|
||||||
is_second_output=False),
|
is_second_output=False),
|
||||||
# Currently, OpenVINO TopK supports only TensorFlow TopK with sorted=True and the first output
|
# Expect stable mode support by the CPU plugin. See 101503
|
||||||
# For other cases, we need to introduce new version of TopK in OpenVINO opset due to multiple misalignments
|
|
||||||
# described in 88024
|
|
||||||
pytest.param(dict(input_shape=[4, 12], input_type=tf.float32, k=10, sorted=True, is_first_output=True,
|
pytest.param(dict(input_shape=[4, 12], input_type=tf.float32, k=10, sorted=True, is_first_output=True,
|
||||||
is_second_output=True), marks=pytest.mark.xfail(reason="88024")),
|
is_second_output=True), marks=pytest.mark.xfail(reason="101503")),
|
||||||
pytest.param(dict(input_shape=[5, 10], input_type=tf.int32, k=8, sorted=False, is_first_output=True,
|
pytest.param(dict(input_shape=[5, 10], input_type=tf.int32, k=8, sorted=False, is_first_output=True,
|
||||||
is_second_output=True), marks=pytest.mark.xfail(reason="88024")),
|
is_second_output=True), marks=pytest.mark.xfail(reason="101503")),
|
||||||
]
|
]
|
||||||
|
|
||||||
@pytest.mark.parametrize("params", test_basic)
|
@pytest.mark.parametrize("params", test_basic)
|
||||||
|
Loading…
Reference in New Issue
Block a user