Fixed postprocessor for top K (#8780)
* Fixed postprocessor for top K * Changed fix
This commit is contained in:
@@ -4,6 +4,8 @@
|
||||
|
||||
#include <string>
|
||||
#include "ngraph_reader_tests.hpp"
|
||||
#include "openvino/runtime/core.hpp"
|
||||
|
||||
TEST_F(NGraphReaderTests, DISABLED_ReadTopKNetwork) {
|
||||
std::string model = R"V0G0N(
|
||||
<net name="Network" version="10">
|
||||
@@ -248,3 +250,75 @@ TEST_F(NGraphReaderTests, DISABLED_ReadTopKNetwork) {
|
||||
data[0] = 5;
|
||||
});
|
||||
}
|
||||
|
||||
TEST_F(NGraphReaderTests, ReadTopKV10Network) {
|
||||
std::string model = R"V0G0N(
|
||||
<net name="Network" version="10">
|
||||
<layers>
|
||||
<layer name="in1" type="Parameter" id="0" version="opset1">
|
||||
<data element_type="f32" shape="1,3,22,22"/>
|
||||
<output>
|
||||
<port id="0" precision="FP32">
|
||||
<dim>1</dim>
|
||||
<dim>3</dim>
|
||||
<dim>22</dim>
|
||||
<dim>22</dim>
|
||||
</port>
|
||||
</output>
|
||||
</layer>
|
||||
<layer id="4" name="1345813459_const" type="Const" version="opset1">
|
||||
<data element_type="i64" offset="0" shape="" size="8"/>
|
||||
<output>
|
||||
<port id="1" precision="I64" />
|
||||
</output>
|
||||
</layer>
|
||||
<layer name="topk" id="1" type="TopK" version="opset3">
|
||||
<data axis="2" index_element_type="i64" mode="max" sort="value"/>
|
||||
<input>
|
||||
<port id="1">
|
||||
<dim>1</dim>
|
||||
<dim>3</dim>
|
||||
<dim>22</dim>
|
||||
<dim>22</dim>
|
||||
</port>
|
||||
<port id="2"/>
|
||||
</input>
|
||||
<output>
|
||||
<port id="3" precision="FP32">
|
||||
<dim>1</dim>
|
||||
<dim>3</dim>
|
||||
<dim>5</dim>
|
||||
<dim>22</dim>
|
||||
</port>
|
||||
<port id="4" precision="I64">
|
||||
<dim>1</dim>
|
||||
<dim>3</dim>
|
||||
<dim>5</dim>
|
||||
<dim>22</dim>
|
||||
</port>
|
||||
</output>
|
||||
</layer>
|
||||
<layer id="2" name="res" type="Result" version="opset1">
|
||||
<input>
|
||||
<port id="0" precision="I64">
|
||||
<dim>1</dim>
|
||||
<dim>3</dim>
|
||||
<dim>5</dim>
|
||||
<dim>22</dim>
|
||||
</port>
|
||||
</input>
|
||||
</layer>
|
||||
</layers>
|
||||
<edges>
|
||||
<edge from-layer="0" from-port="0" to-layer="1" to-port="1"/>
|
||||
<edge from-layer="4" from-port="1" to-layer="1" to-port="2"/>
|
||||
<edge from-layer="1" from-port="4" to-layer="2" to-port="0"/>
|
||||
</edges>
|
||||
</net>
|
||||
)V0G0N";
|
||||
|
||||
ov::runtime::Core core;
|
||||
ov::runtime::Tensor t(ov::element::i64, {1});
|
||||
t.data<int64_t>()[0] = 5;
|
||||
core.read_model(model, t);
|
||||
}
|
||||
|
||||
@@ -385,7 +385,7 @@ void PostStepsList::add_convert_impl(const element::Type& type) {
|
||||
if (t == element::Type{}) {
|
||||
t = ctxt.target_element_type();
|
||||
}
|
||||
if (t == node.get_node()->get_element_type()) {
|
||||
if (t == node.get_element_type()) {
|
||||
return std::make_tuple(node, false);
|
||||
}
|
||||
OPENVINO_ASSERT(
|
||||
|
||||
Reference in New Issue
Block a user