Update TopKIE operation and transform to support dynamic shapes (#526)

* Update TopKIE operation and transform to support dynamic shapes

* Fix TopKIE shape infer

* Updated TopKIE infer function

* Removed index_element_type; changed swtich with as_string<> method

* Fixed ieFuncTests

* Fixed convert_topk transformation

* Updated convert_topk transformations

* ngraph::copy_runtime_info(topk, new_ops);
This commit is contained in:
Gleb Kazantaev 2020-05-26 01:19:38 +03:00 committed by GitHub
parent c6e03d73d8
commit d3923f2ce0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 250 additions and 119 deletions

View File

@ -1638,37 +1638,8 @@ CNNLayer::Ptr NodeConverter<ngraph::op::v1::TopK>::createLayer(const std::shared
auto castedLayer = ngraph::as_type_ptr<ngraph::op::v1::TopK>(layer);
if (castedLayer == nullptr) THROW_IE_EXCEPTION << "Cannot get " << params.type << " layer " << params.name;
auto mode = castedLayer->get_mode();
std::string str_mode;
switch (mode) {
case ngraph::op::v1::TopK::Mode::MIN:
str_mode = "min";
break;
case ngraph::op::v1::TopK::Mode::MAX:
str_mode = "max";
break;
default:
THROW_IE_EXCEPTION << "Unsupported TopK mode";
}
auto sort = castedLayer->get_sort_type();
std::string str_sort;
switch (sort) {
case ngraph::op::v1::TopK::SortType::NONE:
str_sort = "none";
break;
case ngraph::op::v1::TopK::SortType::SORT_VALUES:
str_sort = "value";
break;
case ngraph::op::v1::TopK::SortType::SORT_INDICES:
str_sort = "index";
break;
default:
THROW_IE_EXCEPTION << "Unsupported TopK sort type";
}
res->params["mode"] = str_mode;
res->params["sort"] = str_sort;
res->params["mode"] = ngraph::as_string<ngraph::op::v1::TopK::Mode>(castedLayer->get_mode());;
res->params["sort"] = ngraph::as_string<ngraph::op::v1::TopK::SortType>(castedLayer->get_sort_type());
res->params["axis"] = asString(castedLayer->get_axis());
return res;
@ -1682,8 +1653,8 @@ CNNLayer::Ptr NodeConverter<ngraph::op::TopKIE>::createLayer(const std::shared_p
auto castedLayer = ngraph::as_type_ptr<ngraph::op::TopKIE>(layer);
if (castedLayer == nullptr) THROW_IE_EXCEPTION << "Cannot get " << params.type << " layer " << params.name;
res->params["mode"] = castedLayer->get_mode();
res->params["sort"] = castedLayer->get_sort_type();
res->params["mode"] = ngraph::as_string<ngraph::op::v1::TopK::Mode>(castedLayer->get_mode());;
res->params["sort"] = ngraph::as_string<ngraph::op::v1::TopK::SortType>(castedLayer->get_sort_type());
res->params["axis"] = asString(castedLayer->get_axis());
return res;

View File

@ -10,6 +10,7 @@
#include <ie_api.h>
#include "ngraph/op/op.hpp"
#include "ngraph/op/topk.hpp"
namespace ngraph {
namespace op {
@ -19,24 +20,26 @@ public:
static constexpr NodeTypeInfo type_info{"TopKIE", 1};
const NodeTypeInfo& get_type_info() const override { return type_info; }
TopKIE(const Output<Node> &data,
const Output<Node> &k,
TopKIE(const Output<Node>& data,
const Output<Node>& k,
const int64_t axis,
const std::string& mode,
const std::string& sort,
const Shape& output_shape);
const ngraph::op::TopKMode mode,
const ngraph::op::TopKSortType sort);
void validate_and_infer_types() override;
std::shared_ptr<Node> copy_with_new_args(const NodeVector& new_args) const override;
int64_t get_axis();
std::string get_mode();
std::string get_sort_type();
Shape get_output_shape();
std::shared_ptr<Node> clone_with_new_inputs(const OutputVector& new_args) const override;
int64_t axis;
std::string mode, sort_type;
Shape output_shape;
int64_t get_axis() { return m_axis;}
ngraph::op::TopKMode get_mode() { return m_mode; }
ngraph::op::TopKSortType get_sort_type() { return m_sort_type; }
private:
int64_t m_axis;
ngraph::op::TopKMode m_mode;
ngraph::op::TopKSortType m_sort_type;
};
} // namespace op

View File

@ -6,43 +6,51 @@
#include <memory>
#include <string>
#include <ngraph/opsets/opset1.hpp>
using namespace std;
using namespace ngraph;
constexpr NodeTypeInfo op::TopKIE::type_info;
op::TopKIE::TopKIE(const Output<Node>& data, const Output<Node>& k, const int64_t axis, const std::string& mode, const std::string& sort,
const Shape& output_shape)
: Op({data, k}), axis(axis), mode(mode), sort_type(sort), output_shape(output_shape) {
op::TopKIE::TopKIE(const ngraph::Output<ngraph::Node> &data, const ngraph::Output<ngraph::Node> &k, const int64_t axis, const ngraph::op::TopKMode mode,
const ngraph::op::TopKSortType sort)
: Op({data, k}), m_axis(axis), m_mode(mode), m_sort_type(sort) {
constructor_validate_and_infer_types();
}
std::shared_ptr<Node> op::TopKIE::copy_with_new_args(const NodeVector& new_args) const {
if (new_args.size() != 2) {
throw ngraph_error("Incorrect number of new arguments");
}
return make_shared<TopKIE>(new_args.at(0), new_args.at(1), axis, mode, sort_type, output_shape);
std::shared_ptr<Node> op::TopKIE::clone_with_new_inputs(const ngraph::OutputVector &new_args) const {
check_new_args_count(this, new_args);
return make_shared<TopKIE>(new_args.at(0), new_args.at(1), m_axis, m_mode, m_sort_type);
}
void op::TopKIE::validate_and_infer_types() {
set_output_type(0, get_input_element_type(0), output_shape);
set_output_type(1, element::i32, output_shape);
}
const auto& input_partial_shape = get_input_partial_shape(0);
const auto input_rank = input_partial_shape.rank();
int64_t op::TopKIE::get_axis() {
return axis;
}
NODE_VALIDATION_CHECK(this,
input_rank.is_dynamic() || input_rank.get_length() > 0,
"Input rank must be greater than 0.");
std::string op::TopKIE::get_mode() {
return mode;
}
const auto& k_partial_shape = get_input_partial_shape(1);
NODE_VALIDATION_CHECK(
this, k_partial_shape.rank().compatible(1), "The 'K' input must be a 1D tensor.");
std::string op::TopKIE::get_sort_type() {
return sort_type;
}
// Construct v1::TopK operation to calculate output shapes
std::shared_ptr<Node> topk;
if (auto k_const = std::dynamic_pointer_cast<opset1::Constant>(input_value(1).get_node_shared_ptr())) {
const auto k = k_const->cast_vector<int64_t>();
topk = std::make_shared<opset1::TopK>(input_value(0),
opset1::Constant::create(element::i64, Shape{}, k),
m_axis, m_mode, m_sort_type);
} else {
topk = std::make_shared<opset1::TopK>(input_value(0),
std::make_shared<opset1::Squeeze>(input_value(1), opset1::Constant::create(element::i64, Shape{1}, {0})),
m_axis, m_mode, m_sort_type);
}
Shape op::TopKIE::get_output_shape() {
return output_shape;
}
set_output_size(2);
set_output_type(0, get_input_element_type(0), topk->get_output_partial_shape(0));
set_output_type(1, element::i32, topk->get_output_partial_shape(1));
}

View File

@ -14,52 +14,47 @@
#include <ngraph/rt_info.hpp>
void ngraph::pass::ConvertTopKToTopKIE::convert_topk_to_topk_ie() {
auto input_0 = std::make_shared<pattern::op::Label>(element::f32, Shape{1, 1, 1, 1});
auto k = std::make_shared<pattern::op::Label>(element::i64, Shape{});
auto topk = std::make_shared<ngraph::opset1::TopK>(input_0, k, 0, "min", "none");
auto topk = std::make_shared<pattern::op::Label>(element::f32, Shape{1}, pattern::has_class<opset1::TopK>());
ngraph::graph_rewrite_callback callback = [](pattern::Matcher &m) {
auto topk = std::dynamic_pointer_cast<ngraph::opset1::TopK>(m.get_match_root());
if (!topk) {
auto topk = std::dynamic_pointer_cast<opset1::TopK>(m.get_match_root());
if (!topk || topk->input(1).get_partial_shape().rank().is_dynamic()) {
return false;
}
if (topk->input(1).get_shape().size() == 1) {
if (topk->input(1).get_partial_shape().rank().get_length() == 1) {
return false;
}
auto unsqueezed_k = std::make_shared<ngraph::opset1::Unsqueeze>(topk->input(1).get_source_output().get_node_shared_ptr(),
opset1::Constant::create(element::i64, Shape{1}, {0}));
std::string mode;
switch (topk->get_mode()) {
case ngraph::opset1::TopK::Mode::MAX:
mode = "max";
break;
case ngraph::opset1::TopK::Mode::MIN:
mode = "min";
break;
default:
return false;
}
std::string sort_type;
switch (topk->get_sort_type()) {
case ngraph::opset1::TopK::SortType::NONE:
sort_type = "none";
break;
case ngraph::opset1::TopK::SortType::SORT_INDICES:
sort_type = "index";
break;
case ngraph::opset1::TopK::SortType::SORT_VALUES:
sort_type = "value";
break;
default:
return false;
// WA: if we replace TopK second input with Unsqueeze operation we will get dynamic shape until first CF pass
// but due to not all legacy operations support dynamic input shapes and dynamic shape can break pipeline we
// need to unsqueeze constant manually.
Output<Node> unsqueezed_k;
NodeVector new_ops;
if (auto k_const = std::dynamic_pointer_cast<opset1::Constant>(topk->input_value(1).get_node_shared_ptr())) {
auto k_value = k_const->cast_vector<int64_t>();
unsqueezed_k = opset1::Constant::create(element::i64, Shape{1}, k_value);
} else {
unsqueezed_k = std::make_shared<opset1::Unsqueeze>(topk->input_value(1), opset1::Constant::create(element::i64, Shape{1}, {0}));
new_ops.push_back(unsqueezed_k.get_node_shared_ptr());
}
auto new_topk = std::make_shared<ngraph::op::TopKIE>(topk->input(0).get_source_output(), unsqueezed_k, topk->get_axis(), mode,
sort_type, topk->output(0).get_shape());
new_topk->set_friendly_name(topk->get_friendly_name());
ngraph::copy_runtime_info(topk, {unsqueezed_k, new_topk});
ngraph::replace_node(topk, new_topk);
auto topk_ie = std::make_shared<ngraph::op::TopKIE>(topk->input_value(0), unsqueezed_k, topk->get_axis(), topk->get_mode(),
topk->get_sort_type());
new_ops.push_back(topk_ie);
Output<Node> index_output;
// insert Convert if index element type not equal to i32
if (topk->get_index_element_type() == element::i32) {
index_output = topk_ie->output(1);
} else {
index_output = std::make_shared<opset1::Convert>(topk_ie->output(1), topk->get_index_element_type());
new_ops.push_back(index_output.get_node_shared_ptr());
}
topk_ie->set_friendly_name(topk->get_friendly_name());
ngraph::copy_runtime_info(topk, new_ops);
topk->output(0).replace(topk_ie->output(0));
topk->output(1).replace(index_output);
return true;
};

View File

@ -13,18 +13,10 @@
#include <ngraph/rt_info.hpp>
void ngraph::pass::ConvertTopK3::convert_topk3() {
auto input = std::make_shared<pattern::op::Label>(element::i64, Shape{1, 1, 1, 1});
auto k = ngraph::opset3::Constant::create(element::i64, Shape{}, {10});
auto topk = std::make_shared<ngraph::opset3::TopK>(input, k, 0, "min", "value", element::i64);
// this is a temporary workaround to avoid bug that TopK-3 does not have clone_with_new_inputs so the TopK-3 clone
// generates TopK-1 operation
auto topk_v1 = std::make_shared<ngraph::opset1::TopK>(input, k, 0, "min", "value", element::i64);
auto topk = std::make_shared<pattern::op::Label>(element::f32, Shape{}, pattern::has_class<opset3::TopK>());
ngraph::graph_rewrite_callback callback = [](pattern::Matcher& m) {
std::shared_ptr<ngraph::op::v1::TopK> topk = std::dynamic_pointer_cast<ngraph::opset3::TopK> (m.get_match_root());
if (!topk) {
topk = std::dynamic_pointer_cast<ngraph::opset1::TopK> (m.get_match_root());
}
auto topk = std::dynamic_pointer_cast<ngraph::opset3::TopK> (m.get_match_root());
if (!topk) {
return false;
}
@ -51,6 +43,4 @@ void ngraph::pass::ConvertTopK3::convert_topk3() {
auto m = std::make_shared<ngraph::pattern::Matcher>(topk, "ConvertTopK3");
this->add_matcher(m, callback, PassProperty::CHANGE_DYNAMIC_STATE);
auto m2 = std::make_shared<ngraph::pattern::Matcher>(topk_v1, "ConvertTopK3");
this->add_matcher(m2, callback, PassProperty::CHANGE_DYNAMIC_STATE);
}

View File

@ -0,0 +1,164 @@
// Copyright (C) 2020 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include <gtest/gtest.h>
#include <string>
#include <memory>
#include <queue>
#include <ngraph/function.hpp>
#include <ngraph/opsets/opset1.hpp>
#include <transformations/convert_opset1_to_legacy/convert_topk_to_topk_ie.hpp>
#include <transformations/init_node_info.hpp>
#include <transformations/utils/utils.hpp>
#include <ngraph_ops/topk_ie.hpp>
#include <ngraph/pass/constant_folding.hpp>
#include "ngraph_test_utils.hpp"
using namespace testing;
TEST(TransformationTests, ConvertTopKToTopKIEStatic) {
std::shared_ptr<ngraph::Function> f(nullptr), f_ref(nullptr);
{
auto input = std::make_shared<ngraph::opset1::Parameter>(ngraph::element::f32, ngraph::Shape{15, 20, 3});
auto k = ngraph::opset1::Constant::create(ngraph::element::i64, ngraph::Shape{}, {10});
auto topk = std::make_shared<ngraph::opset1::TopK>(input, k, 1, "min", "value", ngraph::element::i32);
// due to the 'compare_functions' limitation we will check only one output
f = std::make_shared<ngraph::Function>(ngraph::OutputVector{topk->output(0)}, ngraph::ParameterVector{input});
ngraph::pass::InitNodeInfo().run_on_function(f);
ngraph::pass::ConvertTopKToTopKIE().run_on_function(f);
ASSERT_NO_THROW(check_rt_info(f));
ngraph::pass::ConstantFolding().run_on_function(f);
}
{
auto input = std::make_shared<ngraph::opset1::Parameter>(ngraph::element::f32, ngraph::Shape{15, 20, 3});
auto k = ngraph::opset1::Constant::create(ngraph::element::i64, ngraph::Shape{1}, {10});
//auto unsqueezed_k = std::make_shared<ngraph::opset1::Unsqueeze>(k, ngraph::opset1::Constant::create(ngraph::element::i64, ngraph::Shape{1}, {0}));
auto topk = std::make_shared<ngraph::op::TopKIE>(input, k, 1, ngraph::op::TopKMode::MIN,
ngraph::op::TopKSortType::SORT_VALUES);
// due to the 'compare_functions' limitation we will check only one output
f_ref = std::make_shared<ngraph::Function>(ngraph::OutputVector{topk->output(0)}, ngraph::ParameterVector{input});
}
auto res = compare_functions(f, f_ref);
ASSERT_TRUE(res.first) << res.second;
}
TEST(TransformationTests, ConvertTopKToTopKIEDynamic1) {
std::shared_ptr<ngraph::Function> f(nullptr), f_ref(nullptr);
{
auto input = std::make_shared<ngraph::opset1::Parameter>(ngraph::element::f32, ngraph::PartialShape{DYN, 20, 3});
auto k = ngraph::opset1::Constant::create(ngraph::element::i64, ngraph::Shape{}, {10});
auto topk = std::make_shared<ngraph::opset1::TopK>(input, k, 1, "min", "value", ngraph::element::i32);
// due to the 'compare_functions' limitation we will check only one output
f = std::make_shared<ngraph::Function>(ngraph::OutputVector{topk->output(0)}, ngraph::ParameterVector{input});
ngraph::pass::InitNodeInfo().run_on_function(f);
ngraph::pass::ConvertTopKToTopKIE().run_on_function(f);
ASSERT_NO_THROW(check_rt_info(f));
ngraph::pass::ConstantFolding().run_on_function(f);
}
{
auto input = std::make_shared<ngraph::opset1::Parameter>(ngraph::element::f32, ngraph::PartialShape{DYN, 20, 3});
auto k = ngraph::opset1::Constant::create(ngraph::element::i64, ngraph::Shape{1}, {10});
//auto unsqueezed_k = std::make_shared<ngraph::opset1::Unsqueeze>(k, ngraph::opset1::Constant::create(ngraph::element::i64, ngraph::Shape{1}, {0}));
auto topk = std::make_shared<ngraph::op::TopKIE>(input, k, 1, ngraph::op::TopKMode::MIN,
ngraph::op::TopKSortType::SORT_VALUES);
// due to the 'compare_functions' limitation we will check only one output
f_ref = std::make_shared<ngraph::Function>(ngraph::OutputVector{topk->output(0)}, ngraph::ParameterVector{input});
}
auto res = compare_functions(f, f_ref);
ASSERT_TRUE(res.first) << res.second;
}
TEST(TransformationTests, ConvertTopKToTopKIEDynamic2) {
std::shared_ptr<ngraph::Function> f(nullptr), f_ref(nullptr);
{
auto input = std::make_shared<ngraph::opset1::Parameter>(ngraph::element::f32, ngraph::PartialShape{1, DYN, 3});
auto k = ngraph::opset1::Constant::create(ngraph::element::i64, ngraph::Shape{}, {10});
auto topk = std::make_shared<ngraph::opset1::TopK>(input, k, 1, "min", "value", ngraph::element::i32);
// due to the 'compare_functions' limitation we will check only one output
f = std::make_shared<ngraph::Function>(ngraph::OutputVector{topk->output(0)}, ngraph::ParameterVector{input});
ngraph::pass::InitNodeInfo().run_on_function(f);
ngraph::pass::ConvertTopKToTopKIE().run_on_function(f);
ASSERT_NO_THROW(check_rt_info(f));
ngraph::pass::ConstantFolding().run_on_function(f);
}
{
auto input = std::make_shared<ngraph::opset1::Parameter>(ngraph::element::f32, ngraph::PartialShape{1, DYN, 3});
auto k = ngraph::opset1::Constant::create(ngraph::element::i64, ngraph::Shape{1}, {10});
//auto unsqueezed_k = std::make_shared<ngraph::opset1::Unsqueeze>(k, ngraph::opset1::Constant::create(ngraph::element::i64, ngraph::Shape{1}, {0}));
auto topk = std::make_shared<ngraph::op::TopKIE>(input, k, 1, ngraph::op::TopKMode::MIN,
ngraph::op::TopKSortType::SORT_VALUES);
// due to the 'compare_functions' limitation we will check only one output
f_ref = std::make_shared<ngraph::Function>(ngraph::OutputVector{topk->output(0)}, ngraph::ParameterVector{input});
}
auto res = compare_functions(f, f_ref);
ASSERT_TRUE(res.first) << res.second;
}
TEST(TransformationTests, ConvertTopKToTopKIEDynamic3) {
std::shared_ptr<ngraph::Function> f(nullptr), f_ref(nullptr);
{
auto input = std::make_shared<ngraph::opset1::Parameter>(ngraph::element::f32, ngraph::PartialShape{1, 20, DYN});
auto k = ngraph::opset1::Constant::create(ngraph::element::i64, ngraph::Shape{}, {10});
auto topk = std::make_shared<ngraph::opset1::TopK>(input, k, 1, "min", "value", ngraph::element::i32);
// due to the 'compare_functions' limitation we will check only one output
f = std::make_shared<ngraph::Function>(ngraph::OutputVector{topk->output(0)}, ngraph::ParameterVector{input});
ngraph::pass::InitNodeInfo().run_on_function(f);
ngraph::pass::ConvertTopKToTopKIE().run_on_function(f);
ASSERT_NO_THROW(check_rt_info(f));
ngraph::pass::ConstantFolding().run_on_function(f);
}
{
auto input = std::make_shared<ngraph::opset1::Parameter>(ngraph::element::f32, ngraph::PartialShape{1, 20, DYN});
auto k = ngraph::opset1::Constant::create(ngraph::element::i64, ngraph::Shape{1}, {10});
//auto unsqueezed_k = std::make_shared<ngraph::opset1::Unsqueeze>(k, ngraph::opset1::Constant::create(ngraph::element::i64, ngraph::Shape{1}, {0}));
auto topk = std::make_shared<ngraph::op::TopKIE>(input, k, 1, ngraph::op::TopKMode::MIN,
ngraph::op::TopKSortType::SORT_VALUES);
// due to the 'compare_functions' limitation we will check only one output
f_ref = std::make_shared<ngraph::Function>(ngraph::OutputVector{topk->output(0)}, ngraph::ParameterVector{input});
}
auto res = compare_functions(f, f_ref);
ASSERT_TRUE(res.first) << res.second;
}
TEST(TransformationTests, ConvertTopKToTopKIENegative) {
std::shared_ptr<ngraph::Function> f(nullptr), f_ref(nullptr);
{
auto input = std::make_shared<ngraph::opset1::Parameter>(ngraph::element::f32, ngraph::Shape{15, 20, 3});
auto k = std::make_shared<ngraph::opset1::Parameter>(ngraph::element::f32, ngraph::PartialShape::dynamic());
auto topk = std::make_shared<ngraph::opset1::TopK>(input, k, 1, "min", "value", ngraph::element::i32);
// due to the 'compare_functions' limitation we will check only one output
f = std::make_shared<ngraph::Function>(ngraph::OutputVector{topk->output(0)}, ngraph::ParameterVector{input, k});
ngraph::pass::InitNodeInfo().run_on_function(f);
ngraph::pass::ConvertTopKToTopKIE().run_on_function(f);
ASSERT_NO_THROW(check_rt_info(f));
ngraph::pass::ConstantFolding().run_on_function(f);
}
{
auto input = std::make_shared<ngraph::opset1::Parameter>(ngraph::element::f32, ngraph::Shape{15, 20, 3});
auto k = std::make_shared<ngraph::opset1::Parameter>(ngraph::element::f32, ngraph::PartialShape::dynamic());
auto topk = std::make_shared<ngraph::opset1::TopK>(input, k, 1, "min", "value", ngraph::element::i32);
// due to the 'compare_functions' limitation we will check only one output
f_ref = std::make_shared<ngraph::Function>(ngraph::OutputVector{topk->output(0)}, ngraph::ParameterVector{input, k});
}
auto res = compare_functions(f, f_ref);
ASSERT_TRUE(res.first) << res.second;
}