[ONNX] Add type conversion for Pow op inputs (#2589)
Co-authored-by: mitruska <katarzyna.mitrus@intel.com>
This commit is contained in:
parent
c2394508c1
commit
8002b16eb2
@ -16,11 +16,7 @@
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <memory>
|
||||
|
||||
#include "ngraph/node.hpp"
|
||||
#include "onnx_import/core/node.hpp"
|
||||
#include "onnx_import/default_opset.hpp"
|
||||
|
||||
namespace ngraph
|
||||
{
|
||||
@ -30,11 +26,7 @@ namespace ngraph
|
||||
{
|
||||
namespace set_1
|
||||
{
|
||||
inline OutputVector pow(const Node& node)
|
||||
{
|
||||
return {std::make_shared<default_opset::Power>(node.get_ng_inputs().at(0),
|
||||
node.get_ng_inputs().at(1))};
|
||||
}
|
||||
OutputVector pow(const Node& node);
|
||||
|
||||
} // namespace set_1
|
||||
|
||||
|
67
ngraph/frontend/onnx_import/src/op/pow.cpp
Normal file
67
ngraph/frontend/onnx_import/src/op/pow.cpp
Normal file
@ -0,0 +1,67 @@
|
||||
//*****************************************************************************
|
||||
// Copyright 2020 Intel Corporation
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
//*****************************************************************************
|
||||
|
||||
#include <memory>
|
||||
|
||||
#include "ngraph/node.hpp"
|
||||
#include "onnx_import/default_opset.hpp"
|
||||
#include "onnx_import/op/pow.hpp"
|
||||
|
||||
namespace ngraph
|
||||
{
|
||||
namespace onnx_import
|
||||
{
|
||||
namespace op
|
||||
{
|
||||
namespace set_1
|
||||
{
|
||||
OutputVector pow(const Node& node)
|
||||
{
|
||||
auto inputs = node.get_ng_inputs();
|
||||
NGRAPH_CHECK(inputs.size() == 2,
|
||||
"Power operation requires 2 inputs. Got: ",
|
||||
inputs.size());
|
||||
|
||||
auto base = inputs[0];
|
||||
auto exponent = inputs[1];
|
||||
auto base_type = inputs[0].get_element_type();
|
||||
auto exponent_type = inputs[1].get_element_type();
|
||||
if (exponent_type != base_type)
|
||||
{
|
||||
if (exponent_type.is_integral() ||
|
||||
(base_type.is_real() &&
|
||||
base_type.bitwidth() >= exponent_type.bitwidth()))
|
||||
{
|
||||
exponent =
|
||||
std::make_shared<default_opset::Convert>(exponent, base_type);
|
||||
}
|
||||
else
|
||||
{
|
||||
base = std::make_shared<default_opset::Convert>(base, exponent_type);
|
||||
auto power = std::make_shared<default_opset::Power>(base, exponent);
|
||||
return {std::make_shared<default_opset::Convert>(power, base_type)};
|
||||
}
|
||||
}
|
||||
return {std::make_shared<default_opset::Power>(base, exponent)};
|
||||
}
|
||||
|
||||
} // namespace set_1
|
||||
|
||||
} // namespace op
|
||||
|
||||
} // namespace onnx_import
|
||||
|
||||
} // namespace ngraph
|
@ -154,10 +154,6 @@ xfail_issue_38715 = xfail_test(reason="RuntimeError: While validating ONNX node
|
||||
xfail_issue_38717 = xfail_test(reason="RuntimeError: nGraph does not support the following ONNX operations:"
|
||||
"GreaterOrEqual")
|
||||
xfail_issue_38719 = xfail_test(reason="nGraph does not support the following ONNX operations: GatherND")
|
||||
xfail_issue_38721 = xfail_test(reason="RuntimeError: While validating ONNX node '<Node(Pow): z>': "
|
||||
"While validating node 'v1::Power Power_<number>"
|
||||
"(x[0]:f32{3}, y[0]:i64{3}) -> (dynamic?)' with friendly_name "
|
||||
"'Power_<number>': Argument element types are inconsistent.")
|
||||
xfail_issue_38722 = xfail_test(reason="RuntimeError: While validating ONNX nodes MatMulInteger"
|
||||
"and QLinearMatMul"
|
||||
"Input0 scale and input0 zero point shape must be same and 1")
|
||||
|
@ -68,7 +68,6 @@ from tests import (BACKEND_NAME,
|
||||
xfail_issue_33589,
|
||||
xfail_issue_38719,
|
||||
xfail_issue_33535,
|
||||
xfail_issue_38721,
|
||||
xfail_issue_38722,
|
||||
xfail_issue_38723,
|
||||
xfail_issue_38724,
|
||||
@ -189,7 +188,11 @@ tests_expected_to_fail = [
|
||||
"OnnxBackendPyTorchConvertedModelTest.test_Embedding_sparse_cpu",
|
||||
"OnnxBackendNodeModelTest.test_constantofshape_int_shape_zero_cpu",
|
||||
"OnnxBackendNodeModelTest.test_max_int64_cpu",
|
||||
"OnnxBackendNodeModelTest.test_pow_types_float32_int64_cpu",
|
||||
"OnnxBackendNodeModelTest.test_pow_types_float_cpu",
|
||||
"OnnxBackendNodeModelTest.test_pow_types_int64_float32_cpu",
|
||||
"OnnxBackendNodeModelTest.test_pow_types_int64_int64_cpu",
|
||||
"OnnxBackendNodeModelTest.test_pow_types_int_cpu",
|
||||
"OnnxBackendNodeModelTest.test_min_int64_cpu",
|
||||
"OnnxBackendNodeModelTest.test_gather_negative_indices_cpu",
|
||||
"OnnxBackendNodeModelTest.test_scatternd_cpu"),
|
||||
@ -248,7 +251,8 @@ tests_expected_to_fail = [
|
||||
"OnnxBackendNodeModelTest.test_min_uint32_cpu"),
|
||||
(xfail_issue_36478,
|
||||
"OnnxBackendNodeModelTest.test_max_uint64_cpu",
|
||||
"OnnxBackendNodeModelTest.test_min_uint64_cpu"),
|
||||
"OnnxBackendNodeModelTest.test_min_uint64_cpu",
|
||||
"OnnxBackendNodeModelTest.test_pow_types_float32_uint64_cpu"),
|
||||
(xfail_issue_36437,
|
||||
"OnnxBackendNodeModelTest.test_argmax_default_axis_example_cpu",
|
||||
"OnnxBackendNodeModelTest.test_argmax_default_axis_random_cpu",
|
||||
@ -273,7 +277,8 @@ tests_expected_to_fail = [
|
||||
"OnnxBackendNodeModelTest.test_argmin_negative_axis_keepdims_random_select_last_index_cpu",
|
||||
"OnnxBackendNodeModelTest.test_argmin_negative_axis_keepdims_example_select_last_index_cpu",
|
||||
"OnnxBackendNodeModelTest.test_argmin_keepdims_example_select_last_index_cpu",
|
||||
"OnnxBackendNodeModelTest.test_argmin_keepdims_random_select_last_index_cpu"),
|
||||
"OnnxBackendNodeModelTest.test_argmin_keepdims_random_select_last_index_cpu",
|
||||
"OnnxBackendNodeModelTest.test_pow_types_float32_uint32_cpu"),
|
||||
(xfail_issue_38088,
|
||||
"OnnxBackendPyTorchConvertedModelTest.test_GLU_cpu"),
|
||||
(xfail_issue_38089,
|
||||
@ -598,15 +603,6 @@ tests_expected_to_fail = [
|
||||
"OnnxBackendNodeModelTest.test_dynamicquantizelinear_min_adjusted_cpu",
|
||||
"OnnxBackendNodeModelTest.test_dynamicquantizelinear_cpu",
|
||||
"OnnxBackendNodeModelTest.test_dynamicquantizelinear_max_adjusted_cpu"),
|
||||
(xfail_issue_38721,
|
||||
"OnnxBackendNodeModelTest.test_pow_types_int_cpu",
|
||||
"OnnxBackendNodeModelTest.test_pow_types_int64_float32_cpu",
|
||||
"OnnxBackendNodeModelTest.test_pow_types_int32_float32_cpu",
|
||||
"OnnxBackendNodeModelTest.test_pow_types_float_cpu",
|
||||
"OnnxBackendNodeModelTest.test_pow_types_float32_uint64_cpu",
|
||||
"OnnxBackendNodeModelTest.test_pow_types_float32_uint32_cpu",
|
||||
"OnnxBackendNodeModelTest.test_pow_types_float32_int64_cpu",
|
||||
"OnnxBackendNodeModelTest.test_pow_types_float32_int32_cpu"),
|
||||
(xfail_issue_38722,
|
||||
"OnnxBackendNodeModelTest.test_matmulinteger_cpu",
|
||||
"OnnxBackendNodeModelTest.test_qlinearmatmul_2D_cpu",
|
||||
|
60
ngraph/test/models/onnx/pow_float32_float32.prototxt
Normal file
60
ngraph/test/models/onnx/pow_float32_float32.prototxt
Normal file
@ -0,0 +1,60 @@
|
||||
ir_version: 7
|
||||
producer_name: "onnx-importer-test"
|
||||
graph {
|
||||
node {
|
||||
input: "X"
|
||||
input: "N"
|
||||
output: "Y"
|
||||
op_type: "Pow"
|
||||
}
|
||||
name: "test-model-lstm"
|
||||
input {
|
||||
name: "X"
|
||||
type {
|
||||
tensor_type {
|
||||
elem_type: 1
|
||||
shape {
|
||||
dim {
|
||||
dim_value: 1
|
||||
}
|
||||
dim {
|
||||
dim_value: 4
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
input {
|
||||
name: "N"
|
||||
type {
|
||||
tensor_type {
|
||||
elem_type: 1
|
||||
shape {
|
||||
dim {
|
||||
dim_value: 1
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
output {
|
||||
name: "Y"
|
||||
type {
|
||||
tensor_type {
|
||||
elem_type: 1
|
||||
shape {
|
||||
dim {
|
||||
dim_value: 1
|
||||
}
|
||||
dim {
|
||||
dim_value: 4
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
opset_import {
|
||||
domain: ""
|
||||
version: 12
|
||||
}
|
59
ngraph/test/models/onnx/pow_float32_int32.prototxt
Normal file
59
ngraph/test/models/onnx/pow_float32_int32.prototxt
Normal file
@ -0,0 +1,59 @@
|
||||
producer_name: "onnx-importer-test"
|
||||
graph {
|
||||
node {
|
||||
input: "X"
|
||||
input: "N"
|
||||
output: "Y"
|
||||
op_type: "Pow"
|
||||
}
|
||||
name: "test-model-lstm"
|
||||
input {
|
||||
name: "X"
|
||||
type {
|
||||
tensor_type {
|
||||
elem_type: 1
|
||||
shape {
|
||||
dim {
|
||||
dim_value: 1
|
||||
}
|
||||
dim {
|
||||
dim_value: 4
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
input {
|
||||
name: "N"
|
||||
type {
|
||||
tensor_type {
|
||||
elem_type: 6
|
||||
shape {
|
||||
dim {
|
||||
dim_value: 1
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
output {
|
||||
name: "Y"
|
||||
type {
|
||||
tensor_type {
|
||||
elem_type: 1
|
||||
shape {
|
||||
dim {
|
||||
dim_value: 1
|
||||
}
|
||||
dim {
|
||||
dim_value: 4
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
opset_import {
|
||||
domain: ""
|
||||
version: 12
|
||||
}
|
60
ngraph/test/models/onnx/pow_int32_float32.prototxt
Normal file
60
ngraph/test/models/onnx/pow_int32_float32.prototxt
Normal file
@ -0,0 +1,60 @@
|
||||
ir_version: 7
|
||||
producer_name: "onnx-importer-test"
|
||||
graph {
|
||||
node {
|
||||
input: "X"
|
||||
input: "N"
|
||||
output: "Y"
|
||||
op_type: "Pow"
|
||||
}
|
||||
name: "test-model-lstm"
|
||||
input {
|
||||
name: "X"
|
||||
type {
|
||||
tensor_type {
|
||||
elem_type: 6
|
||||
shape {
|
||||
dim {
|
||||
dim_value: 1
|
||||
}
|
||||
dim {
|
||||
dim_value: 4
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
input {
|
||||
name: "N"
|
||||
type {
|
||||
tensor_type {
|
||||
elem_type: 1
|
||||
shape {
|
||||
dim {
|
||||
dim_value: 1
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
output {
|
||||
name: "Y"
|
||||
type {
|
||||
tensor_type {
|
||||
elem_type: 6
|
||||
shape {
|
||||
dim {
|
||||
dim_value: 1
|
||||
}
|
||||
dim {
|
||||
dim_value: 4
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
opset_import {
|
||||
domain: ""
|
||||
version: 12
|
||||
}
|
@ -2300,6 +2300,48 @@ NGRAPH_TEST(${BACKEND_NAME}, onnx_model_pad_constant)
|
||||
test_case.run();
|
||||
}
|
||||
|
||||
NGRAPH_TEST(${BACKEND_NAME}, onnx_model_pow_float32_float32)
|
||||
{
|
||||
const auto function = onnx_import::import_onnx_model(
|
||||
file_util::path_join(SERIALIZED_ZOO, "onnx/pow_float32_float32.prototxt"));
|
||||
auto test_case = test::TestCase<TestEngine>(function);
|
||||
|
||||
test_case.add_input<float>({1.f, 2.f, 3.f, 4.f}); // base
|
||||
test_case.add_input<float>({3.5f}); // exponent
|
||||
|
||||
test_case.add_expected_output<float>(Shape{1, 4}, {1.f, 11.313708f, 46.765373f, 128.f});
|
||||
|
||||
test_case.run();
|
||||
}
|
||||
|
||||
NGRAPH_TEST(${BACKEND_NAME}, onnx_model_pow_float32_int32)
|
||||
{
|
||||
const auto function = onnx_import::import_onnx_model(
|
||||
file_util::path_join(SERIALIZED_ZOO, "onnx/pow_float32_int32.prototxt"));
|
||||
auto test_case = test::TestCase<TestEngine>(function);
|
||||
|
||||
test_case.add_input<float>({1.f, 2.f, 3.f, 4.f}); // base
|
||||
test_case.add_input<int>({3}); // exponent
|
||||
|
||||
test_case.add_expected_output<float>(Shape{1, 4}, {1.f, 8.f, 27.f, 64.f});
|
||||
|
||||
test_case.run();
|
||||
}
|
||||
|
||||
NGRAPH_TEST(${BACKEND_NAME}, onnx_model_pow_int32_float32)
|
||||
{
|
||||
const auto function = onnx_import::import_onnx_model(
|
||||
file_util::path_join(SERIALIZED_ZOO, "onnx/pow_int32_float32.prototxt"));
|
||||
auto test_case = test::TestCase<TestEngine>(function);
|
||||
|
||||
test_case.add_input<int>({1, 2, 3, 4}); // base
|
||||
test_case.add_input<float>({3.5f}); // exponent
|
||||
|
||||
test_case.add_expected_output<int>(Shape{1, 4}, {1, 11, 46, 128});
|
||||
|
||||
test_case.run();
|
||||
}
|
||||
|
||||
NGRAPH_TEST(${BACKEND_NAME}, onnx_model_reciprocal)
|
||||
{
|
||||
const auto function = onnx_import::import_onnx_model(
|
||||
|
Loading…
Reference in New Issue
Block a user