Add tests for ArgMin/ArgMax with float inputs (#1429)

This commit is contained in:
Mateusz Tabaka 2020-07-27 12:40:27 +02:00 committed by GitHub
parent 7827490340
commit 2ac35247ea
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 136 additions and 0 deletions

View File

@ -0,0 +1,55 @@
ir_version: 3
producer_name: "nGraph ONNX Importer"
graph {
node {
input: "data"
output: "reduced"
name: "node1"
op_type: "ArgMax"
attribute {
name: "keepdims"
i: 1
type: INT
}
attribute {
name: "axis"
i: 0
type: INT
}
doc_string: "ArgMax"
domain: ""
}
name: "test"
input {
name: "data"
type {
tensor_type {
elem_type: 1
shape {
dim {
dim_value: 4
}
dim {
dim_value: 3
}
}
}
}
}
output {
name: "reduced"
type {
tensor_type {
elem_type: 7
shape {
dim {
dim_value: 3
}
}
}
}
}
}
opset_import {
version: 7
}

View File

@ -0,0 +1,55 @@
ir_version: 3
producer_name: "nGraph ONNX Importer"
graph {
node {
input: "data"
output: "reduced"
name: "node1"
op_type: "ArgMin"
attribute {
name: "keepdims"
i: 0
type: INT
}
attribute {
name: "axis"
i: 1
type: INT
}
doc_string: "ArgMin"
domain: ""
}
name: "test"
input {
name: "data"
type {
tensor_type {
elem_type: 1
shape {
dim {
dim_value: 4
}
dim {
dim_value: 3
}
}
}
}
}
output {
name: "reduced"
type {
tensor_type {
elem_type: 7
shape {
dim {
dim_value: 4
}
}
}
}
}
}
opset_import {
version: 7
}

View File

@ -1626,6 +1626,28 @@ NGRAPH_TEST(${BACKEND_NAME}, onnx_model_argmin_int32)
test_case.run();
}
NGRAPH_TEST(${BACKEND_NAME}, onnx_model_argmax_float)
{
auto function = onnx_import::import_onnx_model(
file_util::path_join(SERIALIZED_ZOO, "onnx/argmax_float.prototxt"));
auto test_case = test::TestCase<TestEngine>(function);
test_case.add_input<float>({4, 0.1, 2, 3, -3, 1, -0.9, 0, 1, 2, 3, 0});
test_case.add_expected_output<std::int64_t>({0, 3, 0});
test_case.run();
}
NGRAPH_TEST(${BACKEND_NAME}, onnx_model_argmin_float)
{
auto function = onnx_import::import_onnx_model(
file_util::path_join(SERIALIZED_ZOO, "onnx/argmin_float.prototxt"));
auto test_case = test::TestCase<TestEngine>(function);
test_case.add_input<float>({4, 0.1, 2, 3, -3, 1, -0.9, 0, 1, 2, 3, 0});
test_case.add_expected_output<std::int64_t>({1, 1, 0, 2});
test_case.run();
}
NGRAPH_TEST(${BACKEND_NAME}, onnx_model_top_k)
{
auto function =

View File

@ -154,6 +154,10 @@ onnx_model_argmax_int32
onnx_model_argmin_int32
arg_max_dyn_shape
# Result mismatch
onnx_model_argmax_float
onnx_model_argmin_float
# Constant has zero dimension that is not allowable
onnx_dyn_shapes_transpose