Update TopKIE operation and transform to support dynamic shapes (#526)
* Update TopKIE operation and transform to support dynamic shapes * Fix TopKIE shape infer * Updated TopKIE infer function * Removed index_element_type; changed swtich with as_string<> method * Fixed ieFuncTests * Fixed convert_topk transformation * Updated convert_topk transformations * ngraph::copy_runtime_info(topk, new_ops);
This commit is contained in:
parent
c6e03d73d8
commit
d3923f2ce0
@ -1638,37 +1638,8 @@ CNNLayer::Ptr NodeConverter<ngraph::op::v1::TopK>::createLayer(const std::shared
|
||||
auto castedLayer = ngraph::as_type_ptr<ngraph::op::v1::TopK>(layer);
|
||||
if (castedLayer == nullptr) THROW_IE_EXCEPTION << "Cannot get " << params.type << " layer " << params.name;
|
||||
|
||||
auto mode = castedLayer->get_mode();
|
||||
std::string str_mode;
|
||||
switch (mode) {
|
||||
case ngraph::op::v1::TopK::Mode::MIN:
|
||||
str_mode = "min";
|
||||
break;
|
||||
case ngraph::op::v1::TopK::Mode::MAX:
|
||||
str_mode = "max";
|
||||
break;
|
||||
default:
|
||||
THROW_IE_EXCEPTION << "Unsupported TopK mode";
|
||||
}
|
||||
|
||||
auto sort = castedLayer->get_sort_type();
|
||||
std::string str_sort;
|
||||
switch (sort) {
|
||||
case ngraph::op::v1::TopK::SortType::NONE:
|
||||
str_sort = "none";
|
||||
break;
|
||||
case ngraph::op::v1::TopK::SortType::SORT_VALUES:
|
||||
str_sort = "value";
|
||||
break;
|
||||
case ngraph::op::v1::TopK::SortType::SORT_INDICES:
|
||||
str_sort = "index";
|
||||
break;
|
||||
default:
|
||||
THROW_IE_EXCEPTION << "Unsupported TopK sort type";
|
||||
}
|
||||
|
||||
res->params["mode"] = str_mode;
|
||||
res->params["sort"] = str_sort;
|
||||
res->params["mode"] = ngraph::as_string<ngraph::op::v1::TopK::Mode>(castedLayer->get_mode());;
|
||||
res->params["sort"] = ngraph::as_string<ngraph::op::v1::TopK::SortType>(castedLayer->get_sort_type());
|
||||
res->params["axis"] = asString(castedLayer->get_axis());
|
||||
|
||||
return res;
|
||||
@ -1682,8 +1653,8 @@ CNNLayer::Ptr NodeConverter<ngraph::op::TopKIE>::createLayer(const std::shared_p
|
||||
auto castedLayer = ngraph::as_type_ptr<ngraph::op::TopKIE>(layer);
|
||||
if (castedLayer == nullptr) THROW_IE_EXCEPTION << "Cannot get " << params.type << " layer " << params.name;
|
||||
|
||||
res->params["mode"] = castedLayer->get_mode();
|
||||
res->params["sort"] = castedLayer->get_sort_type();
|
||||
res->params["mode"] = ngraph::as_string<ngraph::op::v1::TopK::Mode>(castedLayer->get_mode());;
|
||||
res->params["sort"] = ngraph::as_string<ngraph::op::v1::TopK::SortType>(castedLayer->get_sort_type());
|
||||
res->params["axis"] = asString(castedLayer->get_axis());
|
||||
|
||||
return res;
|
||||
|
@ -10,6 +10,7 @@
|
||||
#include <ie_api.h>
|
||||
|
||||
#include "ngraph/op/op.hpp"
|
||||
#include "ngraph/op/topk.hpp"
|
||||
|
||||
namespace ngraph {
|
||||
namespace op {
|
||||
@ -19,24 +20,26 @@ public:
|
||||
static constexpr NodeTypeInfo type_info{"TopKIE", 1};
|
||||
const NodeTypeInfo& get_type_info() const override { return type_info; }
|
||||
|
||||
TopKIE(const Output<Node> &data,
|
||||
const Output<Node> &k,
|
||||
TopKIE(const Output<Node>& data,
|
||||
const Output<Node>& k,
|
||||
const int64_t axis,
|
||||
const std::string& mode,
|
||||
const std::string& sort,
|
||||
const Shape& output_shape);
|
||||
const ngraph::op::TopKMode mode,
|
||||
const ngraph::op::TopKSortType sort);
|
||||
|
||||
void validate_and_infer_types() override;
|
||||
|
||||
std::shared_ptr<Node> copy_with_new_args(const NodeVector& new_args) const override;
|
||||
int64_t get_axis();
|
||||
std::string get_mode();
|
||||
std::string get_sort_type();
|
||||
Shape get_output_shape();
|
||||
std::shared_ptr<Node> clone_with_new_inputs(const OutputVector& new_args) const override;
|
||||
|
||||
int64_t axis;
|
||||
std::string mode, sort_type;
|
||||
Shape output_shape;
|
||||
int64_t get_axis() { return m_axis;}
|
||||
|
||||
ngraph::op::TopKMode get_mode() { return m_mode; }
|
||||
|
||||
ngraph::op::TopKSortType get_sort_type() { return m_sort_type; }
|
||||
|
||||
private:
|
||||
int64_t m_axis;
|
||||
ngraph::op::TopKMode m_mode;
|
||||
ngraph::op::TopKSortType m_sort_type;
|
||||
};
|
||||
|
||||
} // namespace op
|
||||
|
@ -6,43 +6,51 @@
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <ngraph/opsets/opset1.hpp>
|
||||
|
||||
using namespace std;
|
||||
using namespace ngraph;
|
||||
|
||||
constexpr NodeTypeInfo op::TopKIE::type_info;
|
||||
|
||||
op::TopKIE::TopKIE(const Output<Node>& data, const Output<Node>& k, const int64_t axis, const std::string& mode, const std::string& sort,
|
||||
const Shape& output_shape)
|
||||
: Op({data, k}), axis(axis), mode(mode), sort_type(sort), output_shape(output_shape) {
|
||||
|
||||
op::TopKIE::TopKIE(const ngraph::Output<ngraph::Node> &data, const ngraph::Output<ngraph::Node> &k, const int64_t axis, const ngraph::op::TopKMode mode,
|
||||
const ngraph::op::TopKSortType sort)
|
||||
: Op({data, k}), m_axis(axis), m_mode(mode), m_sort_type(sort) {
|
||||
constructor_validate_and_infer_types();
|
||||
}
|
||||
|
||||
std::shared_ptr<Node> op::TopKIE::copy_with_new_args(const NodeVector& new_args) const {
|
||||
if (new_args.size() != 2) {
|
||||
throw ngraph_error("Incorrect number of new arguments");
|
||||
}
|
||||
|
||||
return make_shared<TopKIE>(new_args.at(0), new_args.at(1), axis, mode, sort_type, output_shape);
|
||||
std::shared_ptr<Node> op::TopKIE::clone_with_new_inputs(const ngraph::OutputVector &new_args) const {
|
||||
check_new_args_count(this, new_args);
|
||||
return make_shared<TopKIE>(new_args.at(0), new_args.at(1), m_axis, m_mode, m_sort_type);
|
||||
}
|
||||
|
||||
void op::TopKIE::validate_and_infer_types() {
|
||||
set_output_type(0, get_input_element_type(0), output_shape);
|
||||
set_output_type(1, element::i32, output_shape);
|
||||
}
|
||||
const auto& input_partial_shape = get_input_partial_shape(0);
|
||||
const auto input_rank = input_partial_shape.rank();
|
||||
|
||||
int64_t op::TopKIE::get_axis() {
|
||||
return axis;
|
||||
}
|
||||
NODE_VALIDATION_CHECK(this,
|
||||
input_rank.is_dynamic() || input_rank.get_length() > 0,
|
||||
"Input rank must be greater than 0.");
|
||||
|
||||
std::string op::TopKIE::get_mode() {
|
||||
return mode;
|
||||
}
|
||||
const auto& k_partial_shape = get_input_partial_shape(1);
|
||||
NODE_VALIDATION_CHECK(
|
||||
this, k_partial_shape.rank().compatible(1), "The 'K' input must be a 1D tensor.");
|
||||
|
||||
std::string op::TopKIE::get_sort_type() {
|
||||
return sort_type;
|
||||
}
|
||||
// Construct v1::TopK operation to calculate output shapes
|
||||
std::shared_ptr<Node> topk;
|
||||
if (auto k_const = std::dynamic_pointer_cast<opset1::Constant>(input_value(1).get_node_shared_ptr())) {
|
||||
const auto k = k_const->cast_vector<int64_t>();
|
||||
topk = std::make_shared<opset1::TopK>(input_value(0),
|
||||
opset1::Constant::create(element::i64, Shape{}, k),
|
||||
m_axis, m_mode, m_sort_type);
|
||||
} else {
|
||||
topk = std::make_shared<opset1::TopK>(input_value(0),
|
||||
std::make_shared<opset1::Squeeze>(input_value(1), opset1::Constant::create(element::i64, Shape{1}, {0})),
|
||||
m_axis, m_mode, m_sort_type);
|
||||
}
|
||||
|
||||
Shape op::TopKIE::get_output_shape() {
|
||||
return output_shape;
|
||||
}
|
||||
set_output_size(2);
|
||||
set_output_type(0, get_input_element_type(0), topk->get_output_partial_shape(0));
|
||||
set_output_type(1, element::i32, topk->get_output_partial_shape(1));
|
||||
}
|
@ -14,52 +14,47 @@
|
||||
#include <ngraph/rt_info.hpp>
|
||||
|
||||
void ngraph::pass::ConvertTopKToTopKIE::convert_topk_to_topk_ie() {
|
||||
auto input_0 = std::make_shared<pattern::op::Label>(element::f32, Shape{1, 1, 1, 1});
|
||||
auto k = std::make_shared<pattern::op::Label>(element::i64, Shape{});
|
||||
auto topk = std::make_shared<ngraph::opset1::TopK>(input_0, k, 0, "min", "none");
|
||||
auto topk = std::make_shared<pattern::op::Label>(element::f32, Shape{1}, pattern::has_class<opset1::TopK>());
|
||||
|
||||
ngraph::graph_rewrite_callback callback = [](pattern::Matcher &m) {
|
||||
auto topk = std::dynamic_pointer_cast<ngraph::opset1::TopK>(m.get_match_root());
|
||||
if (!topk) {
|
||||
auto topk = std::dynamic_pointer_cast<opset1::TopK>(m.get_match_root());
|
||||
if (!topk || topk->input(1).get_partial_shape().rank().is_dynamic()) {
|
||||
return false;
|
||||
}
|
||||
if (topk->input(1).get_shape().size() == 1) {
|
||||
if (topk->input(1).get_partial_shape().rank().get_length() == 1) {
|
||||
return false;
|
||||
}
|
||||
auto unsqueezed_k = std::make_shared<ngraph::opset1::Unsqueeze>(topk->input(1).get_source_output().get_node_shared_ptr(),
|
||||
opset1::Constant::create(element::i64, Shape{1}, {0}));
|
||||
|
||||
std::string mode;
|
||||
switch (topk->get_mode()) {
|
||||
case ngraph::opset1::TopK::Mode::MAX:
|
||||
mode = "max";
|
||||
break;
|
||||
case ngraph::opset1::TopK::Mode::MIN:
|
||||
mode = "min";
|
||||
break;
|
||||
default:
|
||||
return false;
|
||||
}
|
||||
std::string sort_type;
|
||||
switch (topk->get_sort_type()) {
|
||||
case ngraph::opset1::TopK::SortType::NONE:
|
||||
sort_type = "none";
|
||||
break;
|
||||
case ngraph::opset1::TopK::SortType::SORT_INDICES:
|
||||
sort_type = "index";
|
||||
break;
|
||||
case ngraph::opset1::TopK::SortType::SORT_VALUES:
|
||||
sort_type = "value";
|
||||
break;
|
||||
default:
|
||||
return false;
|
||||
// WA: if we replace TopK second input with Unsqueeze operation we will get dynamic shape until first CF pass
|
||||
// but due to not all legacy operations support dynamic input shapes and dynamic shape can break pipeline we
|
||||
// need to unsqueeze constant manually.
|
||||
Output<Node> unsqueezed_k;
|
||||
NodeVector new_ops;
|
||||
if (auto k_const = std::dynamic_pointer_cast<opset1::Constant>(topk->input_value(1).get_node_shared_ptr())) {
|
||||
auto k_value = k_const->cast_vector<int64_t>();
|
||||
unsqueezed_k = opset1::Constant::create(element::i64, Shape{1}, k_value);
|
||||
} else {
|
||||
unsqueezed_k = std::make_shared<opset1::Unsqueeze>(topk->input_value(1), opset1::Constant::create(element::i64, Shape{1}, {0}));
|
||||
new_ops.push_back(unsqueezed_k.get_node_shared_ptr());
|
||||
}
|
||||
|
||||
auto new_topk = std::make_shared<ngraph::op::TopKIE>(topk->input(0).get_source_output(), unsqueezed_k, topk->get_axis(), mode,
|
||||
sort_type, topk->output(0).get_shape());
|
||||
new_topk->set_friendly_name(topk->get_friendly_name());
|
||||
ngraph::copy_runtime_info(topk, {unsqueezed_k, new_topk});
|
||||
ngraph::replace_node(topk, new_topk);
|
||||
auto topk_ie = std::make_shared<ngraph::op::TopKIE>(topk->input_value(0), unsqueezed_k, topk->get_axis(), topk->get_mode(),
|
||||
topk->get_sort_type());
|
||||
new_ops.push_back(topk_ie);
|
||||
|
||||
Output<Node> index_output;
|
||||
// insert Convert if index element type not equal to i32
|
||||
if (topk->get_index_element_type() == element::i32) {
|
||||
index_output = topk_ie->output(1);
|
||||
} else {
|
||||
index_output = std::make_shared<opset1::Convert>(topk_ie->output(1), topk->get_index_element_type());
|
||||
new_ops.push_back(index_output.get_node_shared_ptr());
|
||||
}
|
||||
|
||||
topk_ie->set_friendly_name(topk->get_friendly_name());
|
||||
ngraph::copy_runtime_info(topk, new_ops);
|
||||
topk->output(0).replace(topk_ie->output(0));
|
||||
topk->output(1).replace(index_output);
|
||||
return true;
|
||||
};
|
||||
|
||||
|
@ -13,18 +13,10 @@
|
||||
#include <ngraph/rt_info.hpp>
|
||||
|
||||
void ngraph::pass::ConvertTopK3::convert_topk3() {
|
||||
auto input = std::make_shared<pattern::op::Label>(element::i64, Shape{1, 1, 1, 1});
|
||||
auto k = ngraph::opset3::Constant::create(element::i64, Shape{}, {10});
|
||||
auto topk = std::make_shared<ngraph::opset3::TopK>(input, k, 0, "min", "value", element::i64);
|
||||
// this is a temporary workaround to avoid bug that TopK-3 does not have clone_with_new_inputs so the TopK-3 clone
|
||||
// generates TopK-1 operation
|
||||
auto topk_v1 = std::make_shared<ngraph::opset1::TopK>(input, k, 0, "min", "value", element::i64);
|
||||
auto topk = std::make_shared<pattern::op::Label>(element::f32, Shape{}, pattern::has_class<opset3::TopK>());
|
||||
|
||||
ngraph::graph_rewrite_callback callback = [](pattern::Matcher& m) {
|
||||
std::shared_ptr<ngraph::op::v1::TopK> topk = std::dynamic_pointer_cast<ngraph::opset3::TopK> (m.get_match_root());
|
||||
if (!topk) {
|
||||
topk = std::dynamic_pointer_cast<ngraph::opset1::TopK> (m.get_match_root());
|
||||
}
|
||||
auto topk = std::dynamic_pointer_cast<ngraph::opset3::TopK> (m.get_match_root());
|
||||
if (!topk) {
|
||||
return false;
|
||||
}
|
||||
@ -51,6 +43,4 @@ void ngraph::pass::ConvertTopK3::convert_topk3() {
|
||||
|
||||
auto m = std::make_shared<ngraph::pattern::Matcher>(topk, "ConvertTopK3");
|
||||
this->add_matcher(m, callback, PassProperty::CHANGE_DYNAMIC_STATE);
|
||||
auto m2 = std::make_shared<ngraph::pattern::Matcher>(topk_v1, "ConvertTopK3");
|
||||
this->add_matcher(m2, callback, PassProperty::CHANGE_DYNAMIC_STATE);
|
||||
}
|
||||
|
@ -0,0 +1,164 @@
|
||||
// Copyright (C) 2020 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#include <gtest/gtest.h>
|
||||
|
||||
#include <string>
|
||||
#include <memory>
|
||||
#include <queue>
|
||||
|
||||
#include <ngraph/function.hpp>
|
||||
#include <ngraph/opsets/opset1.hpp>
|
||||
#include <transformations/convert_opset1_to_legacy/convert_topk_to_topk_ie.hpp>
|
||||
#include <transformations/init_node_info.hpp>
|
||||
#include <transformations/utils/utils.hpp>
|
||||
#include <ngraph_ops/topk_ie.hpp>
|
||||
#include <ngraph/pass/constant_folding.hpp>
|
||||
|
||||
#include "ngraph_test_utils.hpp"
|
||||
|
||||
using namespace testing;
|
||||
|
||||
TEST(TransformationTests, ConvertTopKToTopKIEStatic) {
|
||||
std::shared_ptr<ngraph::Function> f(nullptr), f_ref(nullptr);
|
||||
{
|
||||
auto input = std::make_shared<ngraph::opset1::Parameter>(ngraph::element::f32, ngraph::Shape{15, 20, 3});
|
||||
auto k = ngraph::opset1::Constant::create(ngraph::element::i64, ngraph::Shape{}, {10});
|
||||
auto topk = std::make_shared<ngraph::opset1::TopK>(input, k, 1, "min", "value", ngraph::element::i32);
|
||||
// due to the 'compare_functions' limitation we will check only one output
|
||||
f = std::make_shared<ngraph::Function>(ngraph::OutputVector{topk->output(0)}, ngraph::ParameterVector{input});
|
||||
|
||||
ngraph::pass::InitNodeInfo().run_on_function(f);
|
||||
ngraph::pass::ConvertTopKToTopKIE().run_on_function(f);
|
||||
ASSERT_NO_THROW(check_rt_info(f));
|
||||
ngraph::pass::ConstantFolding().run_on_function(f);
|
||||
}
|
||||
|
||||
{
|
||||
auto input = std::make_shared<ngraph::opset1::Parameter>(ngraph::element::f32, ngraph::Shape{15, 20, 3});
|
||||
auto k = ngraph::opset1::Constant::create(ngraph::element::i64, ngraph::Shape{1}, {10});
|
||||
//auto unsqueezed_k = std::make_shared<ngraph::opset1::Unsqueeze>(k, ngraph::opset1::Constant::create(ngraph::element::i64, ngraph::Shape{1}, {0}));
|
||||
auto topk = std::make_shared<ngraph::op::TopKIE>(input, k, 1, ngraph::op::TopKMode::MIN,
|
||||
ngraph::op::TopKSortType::SORT_VALUES);
|
||||
// due to the 'compare_functions' limitation we will check only one output
|
||||
f_ref = std::make_shared<ngraph::Function>(ngraph::OutputVector{topk->output(0)}, ngraph::ParameterVector{input});
|
||||
}
|
||||
|
||||
auto res = compare_functions(f, f_ref);
|
||||
ASSERT_TRUE(res.first) << res.second;
|
||||
}
|
||||
|
||||
TEST(TransformationTests, ConvertTopKToTopKIEDynamic1) {
|
||||
std::shared_ptr<ngraph::Function> f(nullptr), f_ref(nullptr);
|
||||
{
|
||||
auto input = std::make_shared<ngraph::opset1::Parameter>(ngraph::element::f32, ngraph::PartialShape{DYN, 20, 3});
|
||||
auto k = ngraph::opset1::Constant::create(ngraph::element::i64, ngraph::Shape{}, {10});
|
||||
auto topk = std::make_shared<ngraph::opset1::TopK>(input, k, 1, "min", "value", ngraph::element::i32);
|
||||
// due to the 'compare_functions' limitation we will check only one output
|
||||
f = std::make_shared<ngraph::Function>(ngraph::OutputVector{topk->output(0)}, ngraph::ParameterVector{input});
|
||||
|
||||
ngraph::pass::InitNodeInfo().run_on_function(f);
|
||||
ngraph::pass::ConvertTopKToTopKIE().run_on_function(f);
|
||||
ASSERT_NO_THROW(check_rt_info(f));
|
||||
ngraph::pass::ConstantFolding().run_on_function(f);
|
||||
}
|
||||
|
||||
{
|
||||
auto input = std::make_shared<ngraph::opset1::Parameter>(ngraph::element::f32, ngraph::PartialShape{DYN, 20, 3});
|
||||
auto k = ngraph::opset1::Constant::create(ngraph::element::i64, ngraph::Shape{1}, {10});
|
||||
//auto unsqueezed_k = std::make_shared<ngraph::opset1::Unsqueeze>(k, ngraph::opset1::Constant::create(ngraph::element::i64, ngraph::Shape{1}, {0}));
|
||||
auto topk = std::make_shared<ngraph::op::TopKIE>(input, k, 1, ngraph::op::TopKMode::MIN,
|
||||
ngraph::op::TopKSortType::SORT_VALUES);
|
||||
// due to the 'compare_functions' limitation we will check only one output
|
||||
f_ref = std::make_shared<ngraph::Function>(ngraph::OutputVector{topk->output(0)}, ngraph::ParameterVector{input});
|
||||
}
|
||||
|
||||
auto res = compare_functions(f, f_ref);
|
||||
ASSERT_TRUE(res.first) << res.second;
|
||||
}
|
||||
|
||||
TEST(TransformationTests, ConvertTopKToTopKIEDynamic2) {
|
||||
std::shared_ptr<ngraph::Function> f(nullptr), f_ref(nullptr);
|
||||
{
|
||||
auto input = std::make_shared<ngraph::opset1::Parameter>(ngraph::element::f32, ngraph::PartialShape{1, DYN, 3});
|
||||
auto k = ngraph::opset1::Constant::create(ngraph::element::i64, ngraph::Shape{}, {10});
|
||||
auto topk = std::make_shared<ngraph::opset1::TopK>(input, k, 1, "min", "value", ngraph::element::i32);
|
||||
// due to the 'compare_functions' limitation we will check only one output
|
||||
f = std::make_shared<ngraph::Function>(ngraph::OutputVector{topk->output(0)}, ngraph::ParameterVector{input});
|
||||
|
||||
ngraph::pass::InitNodeInfo().run_on_function(f);
|
||||
ngraph::pass::ConvertTopKToTopKIE().run_on_function(f);
|
||||
ASSERT_NO_THROW(check_rt_info(f));
|
||||
ngraph::pass::ConstantFolding().run_on_function(f);
|
||||
}
|
||||
|
||||
{
|
||||
auto input = std::make_shared<ngraph::opset1::Parameter>(ngraph::element::f32, ngraph::PartialShape{1, DYN, 3});
|
||||
auto k = ngraph::opset1::Constant::create(ngraph::element::i64, ngraph::Shape{1}, {10});
|
||||
//auto unsqueezed_k = std::make_shared<ngraph::opset1::Unsqueeze>(k, ngraph::opset1::Constant::create(ngraph::element::i64, ngraph::Shape{1}, {0}));
|
||||
auto topk = std::make_shared<ngraph::op::TopKIE>(input, k, 1, ngraph::op::TopKMode::MIN,
|
||||
ngraph::op::TopKSortType::SORT_VALUES);
|
||||
// due to the 'compare_functions' limitation we will check only one output
|
||||
f_ref = std::make_shared<ngraph::Function>(ngraph::OutputVector{topk->output(0)}, ngraph::ParameterVector{input});
|
||||
}
|
||||
|
||||
auto res = compare_functions(f, f_ref);
|
||||
ASSERT_TRUE(res.first) << res.second;
|
||||
}
|
||||
|
||||
TEST(TransformationTests, ConvertTopKToTopKIEDynamic3) {
|
||||
std::shared_ptr<ngraph::Function> f(nullptr), f_ref(nullptr);
|
||||
{
|
||||
auto input = std::make_shared<ngraph::opset1::Parameter>(ngraph::element::f32, ngraph::PartialShape{1, 20, DYN});
|
||||
auto k = ngraph::opset1::Constant::create(ngraph::element::i64, ngraph::Shape{}, {10});
|
||||
auto topk = std::make_shared<ngraph::opset1::TopK>(input, k, 1, "min", "value", ngraph::element::i32);
|
||||
// due to the 'compare_functions' limitation we will check only one output
|
||||
f = std::make_shared<ngraph::Function>(ngraph::OutputVector{topk->output(0)}, ngraph::ParameterVector{input});
|
||||
|
||||
ngraph::pass::InitNodeInfo().run_on_function(f);
|
||||
ngraph::pass::ConvertTopKToTopKIE().run_on_function(f);
|
||||
ASSERT_NO_THROW(check_rt_info(f));
|
||||
ngraph::pass::ConstantFolding().run_on_function(f);
|
||||
}
|
||||
|
||||
{
|
||||
auto input = std::make_shared<ngraph::opset1::Parameter>(ngraph::element::f32, ngraph::PartialShape{1, 20, DYN});
|
||||
auto k = ngraph::opset1::Constant::create(ngraph::element::i64, ngraph::Shape{1}, {10});
|
||||
//auto unsqueezed_k = std::make_shared<ngraph::opset1::Unsqueeze>(k, ngraph::opset1::Constant::create(ngraph::element::i64, ngraph::Shape{1}, {0}));
|
||||
auto topk = std::make_shared<ngraph::op::TopKIE>(input, k, 1, ngraph::op::TopKMode::MIN,
|
||||
ngraph::op::TopKSortType::SORT_VALUES);
|
||||
// due to the 'compare_functions' limitation we will check only one output
|
||||
f_ref = std::make_shared<ngraph::Function>(ngraph::OutputVector{topk->output(0)}, ngraph::ParameterVector{input});
|
||||
}
|
||||
|
||||
auto res = compare_functions(f, f_ref);
|
||||
ASSERT_TRUE(res.first) << res.second;
|
||||
}
|
||||
|
||||
TEST(TransformationTests, ConvertTopKToTopKIENegative) {
|
||||
std::shared_ptr<ngraph::Function> f(nullptr), f_ref(nullptr);
|
||||
{
|
||||
auto input = std::make_shared<ngraph::opset1::Parameter>(ngraph::element::f32, ngraph::Shape{15, 20, 3});
|
||||
auto k = std::make_shared<ngraph::opset1::Parameter>(ngraph::element::f32, ngraph::PartialShape::dynamic());
|
||||
auto topk = std::make_shared<ngraph::opset1::TopK>(input, k, 1, "min", "value", ngraph::element::i32);
|
||||
// due to the 'compare_functions' limitation we will check only one output
|
||||
f = std::make_shared<ngraph::Function>(ngraph::OutputVector{topk->output(0)}, ngraph::ParameterVector{input, k});
|
||||
|
||||
ngraph::pass::InitNodeInfo().run_on_function(f);
|
||||
ngraph::pass::ConvertTopKToTopKIE().run_on_function(f);
|
||||
ASSERT_NO_THROW(check_rt_info(f));
|
||||
ngraph::pass::ConstantFolding().run_on_function(f);
|
||||
}
|
||||
|
||||
{
|
||||
auto input = std::make_shared<ngraph::opset1::Parameter>(ngraph::element::f32, ngraph::Shape{15, 20, 3});
|
||||
auto k = std::make_shared<ngraph::opset1::Parameter>(ngraph::element::f32, ngraph::PartialShape::dynamic());
|
||||
auto topk = std::make_shared<ngraph::opset1::TopK>(input, k, 1, "min", "value", ngraph::element::i32);
|
||||
// due to the 'compare_functions' limitation we will check only one output
|
||||
f_ref = std::make_shared<ngraph::Function>(ngraph::OutputVector{topk->output(0)}, ngraph::ParameterVector{input, k});
|
||||
}
|
||||
|
||||
auto res = compare_functions(f, f_ref);
|
||||
ASSERT_TRUE(res.first) << res.second;
|
||||
}
|
Loading…
Reference in New Issue
Block a user