[ONNX] Add type conversion for Pow op inputs (#2589)

Co-authored-by: mitruska <katarzyna.mitrus@intel.com>
This commit is contained in:
Mateusz Tabaka 2020-10-20 11:19:03 +02:00 committed by GitHub
parent c2394508c1
commit 8002b16eb2
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 297 additions and 25 deletions

View File

@ -16,11 +16,7 @@
#pragma once
#include <memory>
#include "ngraph/node.hpp"
#include "onnx_import/core/node.hpp"
#include "onnx_import/default_opset.hpp"
namespace ngraph
{
@ -30,11 +26,7 @@ namespace ngraph
{
namespace set_1
{
inline OutputVector pow(const Node& node)
{
return {std::make_shared<default_opset::Power>(node.get_ng_inputs().at(0),
node.get_ng_inputs().at(1))};
}
OutputVector pow(const Node& node);
} // namespace set_1

View File

@ -0,0 +1,67 @@
//*****************************************************************************
// Copyright 2020 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#include <memory>
#include "ngraph/node.hpp"
#include "onnx_import/default_opset.hpp"
#include "onnx_import/op/pow.hpp"
namespace ngraph
{
namespace onnx_import
{
namespace op
{
namespace set_1
{
OutputVector pow(const Node& node)
{
auto inputs = node.get_ng_inputs();
NGRAPH_CHECK(inputs.size() == 2,
"Power operation requires 2 inputs. Got: ",
inputs.size());
auto base = inputs[0];
auto exponent = inputs[1];
auto base_type = inputs[0].get_element_type();
auto exponent_type = inputs[1].get_element_type();
if (exponent_type != base_type)
{
if (exponent_type.is_integral() ||
(base_type.is_real() &&
base_type.bitwidth() >= exponent_type.bitwidth()))
{
exponent =
std::make_shared<default_opset::Convert>(exponent, base_type);
}
else
{
base = std::make_shared<default_opset::Convert>(base, exponent_type);
auto power = std::make_shared<default_opset::Power>(base, exponent);
return {std::make_shared<default_opset::Convert>(power, base_type)};
}
}
return {std::make_shared<default_opset::Power>(base, exponent)};
}
} // namespace set_1
} // namespace op
} // namespace onnx_import
} // namespace ngraph

View File

@ -154,10 +154,6 @@ xfail_issue_38715 = xfail_test(reason="RuntimeError: While validating ONNX node
xfail_issue_38717 = xfail_test(reason="RuntimeError: nGraph does not support the following ONNX operations:"
"GreaterOrEqual")
xfail_issue_38719 = xfail_test(reason="nGraph does not support the following ONNX operations: GatherND")
xfail_issue_38721 = xfail_test(reason="RuntimeError: While validating ONNX node '<Node(Pow): z>': "
"While validating node 'v1::Power Power_<number>"
"(x[0]:f32{3}, y[0]:i64{3}) -> (dynamic?)' with friendly_name "
"'Power_<number>': Argument element types are inconsistent.")
xfail_issue_38722 = xfail_test(reason="RuntimeError: While validating ONNX nodes MatMulInteger"
"and QLinearMatMul"
"Input0 scale and input0 zero point shape must be same and 1")

View File

@ -68,7 +68,6 @@ from tests import (BACKEND_NAME,
xfail_issue_33589,
xfail_issue_38719,
xfail_issue_33535,
xfail_issue_38721,
xfail_issue_38722,
xfail_issue_38723,
xfail_issue_38724,
@ -189,7 +188,11 @@ tests_expected_to_fail = [
"OnnxBackendPyTorchConvertedModelTest.test_Embedding_sparse_cpu",
"OnnxBackendNodeModelTest.test_constantofshape_int_shape_zero_cpu",
"OnnxBackendNodeModelTest.test_max_int64_cpu",
"OnnxBackendNodeModelTest.test_pow_types_float32_int64_cpu",
"OnnxBackendNodeModelTest.test_pow_types_float_cpu",
"OnnxBackendNodeModelTest.test_pow_types_int64_float32_cpu",
"OnnxBackendNodeModelTest.test_pow_types_int64_int64_cpu",
"OnnxBackendNodeModelTest.test_pow_types_int_cpu",
"OnnxBackendNodeModelTest.test_min_int64_cpu",
"OnnxBackendNodeModelTest.test_gather_negative_indices_cpu",
"OnnxBackendNodeModelTest.test_scatternd_cpu"),
@ -248,7 +251,8 @@ tests_expected_to_fail = [
"OnnxBackendNodeModelTest.test_min_uint32_cpu"),
(xfail_issue_36478,
"OnnxBackendNodeModelTest.test_max_uint64_cpu",
"OnnxBackendNodeModelTest.test_min_uint64_cpu"),
"OnnxBackendNodeModelTest.test_min_uint64_cpu",
"OnnxBackendNodeModelTest.test_pow_types_float32_uint64_cpu"),
(xfail_issue_36437,
"OnnxBackendNodeModelTest.test_argmax_default_axis_example_cpu",
"OnnxBackendNodeModelTest.test_argmax_default_axis_random_cpu",
@ -273,7 +277,8 @@ tests_expected_to_fail = [
"OnnxBackendNodeModelTest.test_argmin_negative_axis_keepdims_random_select_last_index_cpu",
"OnnxBackendNodeModelTest.test_argmin_negative_axis_keepdims_example_select_last_index_cpu",
"OnnxBackendNodeModelTest.test_argmin_keepdims_example_select_last_index_cpu",
"OnnxBackendNodeModelTest.test_argmin_keepdims_random_select_last_index_cpu"),
"OnnxBackendNodeModelTest.test_argmin_keepdims_random_select_last_index_cpu",
"OnnxBackendNodeModelTest.test_pow_types_float32_uint32_cpu"),
(xfail_issue_38088,
"OnnxBackendPyTorchConvertedModelTest.test_GLU_cpu"),
(xfail_issue_38089,
@ -598,15 +603,6 @@ tests_expected_to_fail = [
"OnnxBackendNodeModelTest.test_dynamicquantizelinear_min_adjusted_cpu",
"OnnxBackendNodeModelTest.test_dynamicquantizelinear_cpu",
"OnnxBackendNodeModelTest.test_dynamicquantizelinear_max_adjusted_cpu"),
(xfail_issue_38721,
"OnnxBackendNodeModelTest.test_pow_types_int_cpu",
"OnnxBackendNodeModelTest.test_pow_types_int64_float32_cpu",
"OnnxBackendNodeModelTest.test_pow_types_int32_float32_cpu",
"OnnxBackendNodeModelTest.test_pow_types_float_cpu",
"OnnxBackendNodeModelTest.test_pow_types_float32_uint64_cpu",
"OnnxBackendNodeModelTest.test_pow_types_float32_uint32_cpu",
"OnnxBackendNodeModelTest.test_pow_types_float32_int64_cpu",
"OnnxBackendNodeModelTest.test_pow_types_float32_int32_cpu"),
(xfail_issue_38722,
"OnnxBackendNodeModelTest.test_matmulinteger_cpu",
"OnnxBackendNodeModelTest.test_qlinearmatmul_2D_cpu",

View File

@ -0,0 +1,60 @@
ir_version: 7
producer_name: "onnx-importer-test"
graph {
node {
input: "X"
input: "N"
output: "Y"
op_type: "Pow"
}
name: "test-model-lstm"
input {
name: "X"
type {
tensor_type {
elem_type: 1
shape {
dim {
dim_value: 1
}
dim {
dim_value: 4
}
}
}
}
}
input {
name: "N"
type {
tensor_type {
elem_type: 1
shape {
dim {
dim_value: 1
}
}
}
}
}
output {
name: "Y"
type {
tensor_type {
elem_type: 1
shape {
dim {
dim_value: 1
}
dim {
dim_value: 4
}
}
}
}
}
}
opset_import {
domain: ""
version: 12
}

View File

@ -0,0 +1,59 @@
producer_name: "onnx-importer-test"
graph {
node {
input: "X"
input: "N"
output: "Y"
op_type: "Pow"
}
name: "test-model-lstm"
input {
name: "X"
type {
tensor_type {
elem_type: 1
shape {
dim {
dim_value: 1
}
dim {
dim_value: 4
}
}
}
}
}
input {
name: "N"
type {
tensor_type {
elem_type: 6
shape {
dim {
dim_value: 1
}
}
}
}
}
output {
name: "Y"
type {
tensor_type {
elem_type: 1
shape {
dim {
dim_value: 1
}
dim {
dim_value: 4
}
}
}
}
}
}
opset_import {
domain: ""
version: 12
}

View File

@ -0,0 +1,60 @@
ir_version: 7
producer_name: "onnx-importer-test"
graph {
node {
input: "X"
input: "N"
output: "Y"
op_type: "Pow"
}
name: "test-model-lstm"
input {
name: "X"
type {
tensor_type {
elem_type: 6
shape {
dim {
dim_value: 1
}
dim {
dim_value: 4
}
}
}
}
}
input {
name: "N"
type {
tensor_type {
elem_type: 1
shape {
dim {
dim_value: 1
}
}
}
}
}
output {
name: "Y"
type {
tensor_type {
elem_type: 6
shape {
dim {
dim_value: 1
}
dim {
dim_value: 4
}
}
}
}
}
}
opset_import {
domain: ""
version: 12
}

View File

@ -2300,6 +2300,48 @@ NGRAPH_TEST(${BACKEND_NAME}, onnx_model_pad_constant)
test_case.run();
}
NGRAPH_TEST(${BACKEND_NAME}, onnx_model_pow_float32_float32)
{
const auto function = onnx_import::import_onnx_model(
file_util::path_join(SERIALIZED_ZOO, "onnx/pow_float32_float32.prototxt"));
auto test_case = test::TestCase<TestEngine>(function);
test_case.add_input<float>({1.f, 2.f, 3.f, 4.f}); // base
test_case.add_input<float>({3.5f}); // exponent
test_case.add_expected_output<float>(Shape{1, 4}, {1.f, 11.313708f, 46.765373f, 128.f});
test_case.run();
}
NGRAPH_TEST(${BACKEND_NAME}, onnx_model_pow_float32_int32)
{
const auto function = onnx_import::import_onnx_model(
file_util::path_join(SERIALIZED_ZOO, "onnx/pow_float32_int32.prototxt"));
auto test_case = test::TestCase<TestEngine>(function);
test_case.add_input<float>({1.f, 2.f, 3.f, 4.f}); // base
test_case.add_input<int>({3}); // exponent
test_case.add_expected_output<float>(Shape{1, 4}, {1.f, 8.f, 27.f, 64.f});
test_case.run();
}
NGRAPH_TEST(${BACKEND_NAME}, onnx_model_pow_int32_float32)
{
const auto function = onnx_import::import_onnx_model(
file_util::path_join(SERIALIZED_ZOO, "onnx/pow_int32_float32.prototxt"));
auto test_case = test::TestCase<TestEngine>(function);
test_case.add_input<int>({1, 2, 3, 4}); // base
test_case.add_input<float>({3.5f}); // exponent
test_case.add_expected_output<int>(Shape{1, 4}, {1, 11, 46, 128});
test_case.run();
}
NGRAPH_TEST(${BACKEND_NAME}, onnx_model_reciprocal)
{
const auto function = onnx_import::import_onnx_model(