Fixed postprocessor for top K (#8780)

* Fixed postprocessor for top K

* Changed fix
This commit is contained in:
Ilya Churaev
2021-11-25 07:34:31 +03:00
committed by GitHub
parent 295c65e1fb
commit 7fcdff592c
2 changed files with 75 additions and 1 deletions

View File

@@ -4,6 +4,8 @@
#include <string>
#include "ngraph_reader_tests.hpp"
#include "openvino/runtime/core.hpp"
TEST_F(NGraphReaderTests, DISABLED_ReadTopKNetwork) {
std::string model = R"V0G0N(
<net name="Network" version="10">
@@ -248,3 +250,75 @@ TEST_F(NGraphReaderTests, DISABLED_ReadTopKNetwork) {
data[0] = 5;
});
}
TEST_F(NGraphReaderTests, ReadTopKV10Network) {
std::string model = R"V0G0N(
<net name="Network" version="10">
<layers>
<layer name="in1" type="Parameter" id="0" version="opset1">
<data element_type="f32" shape="1,3,22,22"/>
<output>
<port id="0" precision="FP32">
<dim>1</dim>
<dim>3</dim>
<dim>22</dim>
<dim>22</dim>
</port>
</output>
</layer>
<layer id="4" name="1345813459_const" type="Const" version="opset1">
<data element_type="i64" offset="0" shape="" size="8"/>
<output>
<port id="1" precision="I64" />
</output>
</layer>
<layer name="topk" id="1" type="TopK" version="opset3">
<data axis="2" index_element_type="i64" mode="max" sort="value"/>
<input>
<port id="1">
<dim>1</dim>
<dim>3</dim>
<dim>22</dim>
<dim>22</dim>
</port>
<port id="2"/>
</input>
<output>
<port id="3" precision="FP32">
<dim>1</dim>
<dim>3</dim>
<dim>5</dim>
<dim>22</dim>
</port>
<port id="4" precision="I64">
<dim>1</dim>
<dim>3</dim>
<dim>5</dim>
<dim>22</dim>
</port>
</output>
</layer>
<layer id="2" name="res" type="Result" version="opset1">
<input>
<port id="0" precision="I64">
<dim>1</dim>
<dim>3</dim>
<dim>5</dim>
<dim>22</dim>
</port>
</input>
</layer>
</layers>
<edges>
<edge from-layer="0" from-port="0" to-layer="1" to-port="1"/>
<edge from-layer="4" from-port="1" to-layer="1" to-port="2"/>
<edge from-layer="1" from-port="4" to-layer="2" to-port="0"/>
</edges>
</net>
)V0G0N";
ov::runtime::Core core;
ov::runtime::Tensor t(ov::element::i64, {1});
t.data<int64_t>()[0] = 5;
core.read_model(model, t);
}

View File

@@ -385,7 +385,7 @@ void PostStepsList::add_convert_impl(const element::Type& type) {
if (t == element::Type{}) {
t = ctxt.target_element_type();
}
if (t == node.get_node()->get_element_type()) {
if (t == node.get_element_type()) {
return std::make_tuple(node, false);
}
OPENVINO_ASSERT(