Add tests for ArgMin/ArgMax with float inputs (#1429)
This commit is contained in:
parent
7827490340
commit
2ac35247ea
55
ngraph/test/models/onnx/argmax_float.prototxt
Normal file
55
ngraph/test/models/onnx/argmax_float.prototxt
Normal file
@ -0,0 +1,55 @@
|
||||
ir_version: 3
|
||||
producer_name: "nGraph ONNX Importer"
|
||||
graph {
|
||||
node {
|
||||
input: "data"
|
||||
output: "reduced"
|
||||
name: "node1"
|
||||
op_type: "ArgMax"
|
||||
attribute {
|
||||
name: "keepdims"
|
||||
i: 1
|
||||
type: INT
|
||||
}
|
||||
attribute {
|
||||
name: "axis"
|
||||
i: 0
|
||||
type: INT
|
||||
}
|
||||
doc_string: "ArgMax"
|
||||
domain: ""
|
||||
}
|
||||
name: "test"
|
||||
input {
|
||||
name: "data"
|
||||
type {
|
||||
tensor_type {
|
||||
elem_type: 1
|
||||
shape {
|
||||
dim {
|
||||
dim_value: 4
|
||||
}
|
||||
dim {
|
||||
dim_value: 3
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
output {
|
||||
name: "reduced"
|
||||
type {
|
||||
tensor_type {
|
||||
elem_type: 7
|
||||
shape {
|
||||
dim {
|
||||
dim_value: 3
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
opset_import {
|
||||
version: 7
|
||||
}
|
55
ngraph/test/models/onnx/argmin_float.prototxt
Normal file
55
ngraph/test/models/onnx/argmin_float.prototxt
Normal file
@ -0,0 +1,55 @@
|
||||
ir_version: 3
|
||||
producer_name: "nGraph ONNX Importer"
|
||||
graph {
|
||||
node {
|
||||
input: "data"
|
||||
output: "reduced"
|
||||
name: "node1"
|
||||
op_type: "ArgMin"
|
||||
attribute {
|
||||
name: "keepdims"
|
||||
i: 0
|
||||
type: INT
|
||||
}
|
||||
attribute {
|
||||
name: "axis"
|
||||
i: 1
|
||||
type: INT
|
||||
}
|
||||
doc_string: "ArgMin"
|
||||
domain: ""
|
||||
}
|
||||
name: "test"
|
||||
input {
|
||||
name: "data"
|
||||
type {
|
||||
tensor_type {
|
||||
elem_type: 1
|
||||
shape {
|
||||
dim {
|
||||
dim_value: 4
|
||||
}
|
||||
dim {
|
||||
dim_value: 3
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
output {
|
||||
name: "reduced"
|
||||
type {
|
||||
tensor_type {
|
||||
elem_type: 7
|
||||
shape {
|
||||
dim {
|
||||
dim_value: 4
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
opset_import {
|
||||
version: 7
|
||||
}
|
@ -1626,6 +1626,28 @@ NGRAPH_TEST(${BACKEND_NAME}, onnx_model_argmin_int32)
|
||||
test_case.run();
|
||||
}
|
||||
|
||||
NGRAPH_TEST(${BACKEND_NAME}, onnx_model_argmax_float)
|
||||
{
|
||||
auto function = onnx_import::import_onnx_model(
|
||||
file_util::path_join(SERIALIZED_ZOO, "onnx/argmax_float.prototxt"));
|
||||
|
||||
auto test_case = test::TestCase<TestEngine>(function);
|
||||
test_case.add_input<float>({4, 0.1, 2, 3, -3, 1, -0.9, 0, 1, 2, 3, 0});
|
||||
test_case.add_expected_output<std::int64_t>({0, 3, 0});
|
||||
test_case.run();
|
||||
}
|
||||
|
||||
NGRAPH_TEST(${BACKEND_NAME}, onnx_model_argmin_float)
|
||||
{
|
||||
auto function = onnx_import::import_onnx_model(
|
||||
file_util::path_join(SERIALIZED_ZOO, "onnx/argmin_float.prototxt"));
|
||||
|
||||
auto test_case = test::TestCase<TestEngine>(function);
|
||||
test_case.add_input<float>({4, 0.1, 2, 3, -3, 1, -0.9, 0, 1, 2, 3, 0});
|
||||
test_case.add_expected_output<std::int64_t>({1, 1, 0, 2});
|
||||
test_case.run();
|
||||
}
|
||||
|
||||
NGRAPH_TEST(${BACKEND_NAME}, onnx_model_top_k)
|
||||
{
|
||||
auto function =
|
||||
|
@ -154,6 +154,10 @@ onnx_model_argmax_int32
|
||||
onnx_model_argmin_int32
|
||||
arg_max_dyn_shape
|
||||
|
||||
# Result mismatch
|
||||
onnx_model_argmax_float
|
||||
onnx_model_argmin_float
|
||||
|
||||
# Constant has zero dimension that is not allowable
|
||||
onnx_dyn_shapes_transpose
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user