Add tests for ArgMin/ArgMax with float inputs (#1429)
This commit is contained in:
parent
7827490340
commit
2ac35247ea
55
ngraph/test/models/onnx/argmax_float.prototxt
Normal file
55
ngraph/test/models/onnx/argmax_float.prototxt
Normal file
@ -0,0 +1,55 @@
|
|||||||
|
ir_version: 3
|
||||||
|
producer_name: "nGraph ONNX Importer"
|
||||||
|
graph {
|
||||||
|
node {
|
||||||
|
input: "data"
|
||||||
|
output: "reduced"
|
||||||
|
name: "node1"
|
||||||
|
op_type: "ArgMax"
|
||||||
|
attribute {
|
||||||
|
name: "keepdims"
|
||||||
|
i: 1
|
||||||
|
type: INT
|
||||||
|
}
|
||||||
|
attribute {
|
||||||
|
name: "axis"
|
||||||
|
i: 0
|
||||||
|
type: INT
|
||||||
|
}
|
||||||
|
doc_string: "ArgMax"
|
||||||
|
domain: ""
|
||||||
|
}
|
||||||
|
name: "test"
|
||||||
|
input {
|
||||||
|
name: "data"
|
||||||
|
type {
|
||||||
|
tensor_type {
|
||||||
|
elem_type: 1
|
||||||
|
shape {
|
||||||
|
dim {
|
||||||
|
dim_value: 4
|
||||||
|
}
|
||||||
|
dim {
|
||||||
|
dim_value: 3
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
output {
|
||||||
|
name: "reduced"
|
||||||
|
type {
|
||||||
|
tensor_type {
|
||||||
|
elem_type: 7
|
||||||
|
shape {
|
||||||
|
dim {
|
||||||
|
dim_value: 3
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
opset_import {
|
||||||
|
version: 7
|
||||||
|
}
|
55
ngraph/test/models/onnx/argmin_float.prototxt
Normal file
55
ngraph/test/models/onnx/argmin_float.prototxt
Normal file
@ -0,0 +1,55 @@
|
|||||||
|
ir_version: 3
|
||||||
|
producer_name: "nGraph ONNX Importer"
|
||||||
|
graph {
|
||||||
|
node {
|
||||||
|
input: "data"
|
||||||
|
output: "reduced"
|
||||||
|
name: "node1"
|
||||||
|
op_type: "ArgMin"
|
||||||
|
attribute {
|
||||||
|
name: "keepdims"
|
||||||
|
i: 0
|
||||||
|
type: INT
|
||||||
|
}
|
||||||
|
attribute {
|
||||||
|
name: "axis"
|
||||||
|
i: 1
|
||||||
|
type: INT
|
||||||
|
}
|
||||||
|
doc_string: "ArgMin"
|
||||||
|
domain: ""
|
||||||
|
}
|
||||||
|
name: "test"
|
||||||
|
input {
|
||||||
|
name: "data"
|
||||||
|
type {
|
||||||
|
tensor_type {
|
||||||
|
elem_type: 1
|
||||||
|
shape {
|
||||||
|
dim {
|
||||||
|
dim_value: 4
|
||||||
|
}
|
||||||
|
dim {
|
||||||
|
dim_value: 3
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
output {
|
||||||
|
name: "reduced"
|
||||||
|
type {
|
||||||
|
tensor_type {
|
||||||
|
elem_type: 7
|
||||||
|
shape {
|
||||||
|
dim {
|
||||||
|
dim_value: 4
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
opset_import {
|
||||||
|
version: 7
|
||||||
|
}
|
@ -1626,6 +1626,28 @@ NGRAPH_TEST(${BACKEND_NAME}, onnx_model_argmin_int32)
|
|||||||
test_case.run();
|
test_case.run();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
NGRAPH_TEST(${BACKEND_NAME}, onnx_model_argmax_float)
|
||||||
|
{
|
||||||
|
auto function = onnx_import::import_onnx_model(
|
||||||
|
file_util::path_join(SERIALIZED_ZOO, "onnx/argmax_float.prototxt"));
|
||||||
|
|
||||||
|
auto test_case = test::TestCase<TestEngine>(function);
|
||||||
|
test_case.add_input<float>({4, 0.1, 2, 3, -3, 1, -0.9, 0, 1, 2, 3, 0});
|
||||||
|
test_case.add_expected_output<std::int64_t>({0, 3, 0});
|
||||||
|
test_case.run();
|
||||||
|
}
|
||||||
|
|
||||||
|
NGRAPH_TEST(${BACKEND_NAME}, onnx_model_argmin_float)
|
||||||
|
{
|
||||||
|
auto function = onnx_import::import_onnx_model(
|
||||||
|
file_util::path_join(SERIALIZED_ZOO, "onnx/argmin_float.prototxt"));
|
||||||
|
|
||||||
|
auto test_case = test::TestCase<TestEngine>(function);
|
||||||
|
test_case.add_input<float>({4, 0.1, 2, 3, -3, 1, -0.9, 0, 1, 2, 3, 0});
|
||||||
|
test_case.add_expected_output<std::int64_t>({1, 1, 0, 2});
|
||||||
|
test_case.run();
|
||||||
|
}
|
||||||
|
|
||||||
NGRAPH_TEST(${BACKEND_NAME}, onnx_model_top_k)
|
NGRAPH_TEST(${BACKEND_NAME}, onnx_model_top_k)
|
||||||
{
|
{
|
||||||
auto function =
|
auto function =
|
||||||
|
@ -154,6 +154,10 @@ onnx_model_argmax_int32
|
|||||||
onnx_model_argmin_int32
|
onnx_model_argmin_int32
|
||||||
arg_max_dyn_shape
|
arg_max_dyn_shape
|
||||||
|
|
||||||
|
# Result mismatch
|
||||||
|
onnx_model_argmax_float
|
||||||
|
onnx_model_argmin_float
|
||||||
|
|
||||||
# Constant has zero dimension that is not allowable
|
# Constant has zero dimension that is not allowable
|
||||||
onnx_dyn_shapes_transpose
|
onnx_dyn_shapes_transpose
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user