[ONNX] Add support for ONNX ConstantFill op (#5203)

This commit is contained in:
Mateusz Bencer 2021-04-14 16:11:21 +02:00 committed by GitHub
parent 891cf56255
commit 7ac7215924
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
11 changed files with 432 additions and 22 deletions

View File

@ -1,12 +1,42 @@
// Copyright (C) 2018-2021 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include "ngraph/type/element_type.hpp"
namespace ONNX_NAMESPACE
{
enum TensorProto_DataType;
}
namespace ngraph
{
namespace onnx_common
{
/// \brief Retuns size of an ONNX data type in bytes.
///
/// \param onnx_type Number assigned to an ONNX data type in the TensorProto_DataType enum.
///
size_t get_onnx_data_size(int32_t onnx_type);
/// \brief Retuns a nGraph data type corresponding to an ONNX type.
///
/// \param onnx_type An element of TensorProto_DataType enum which determines an ONNX type.
///
element::Type_t onnx_to_ng_data_type(const ONNX_NAMESPACE::TensorProto_DataType& onnx_type);
/// \brief Retuns an ONNX data type corresponding to a nGraph data type.
///
/// \param ng_type An element of element::Type_t enum class which determines a nGraph data
/// type.
///
ONNX_NAMESPACE::TensorProto_DataType ng_to_onnx_data_type(const element::Type_t& ng_type);
/// \brief Retuns true if a nGraph data type is mapped to an ONNX data type.
///
/// \param ng_type An element of element::Type_t enum class which determines a nGraph data
/// type.
///
bool is_supported_ng_type(const element::Type_t& ng_type);
} // namespace onnx_editor
} // namespace ngraph

View File

@ -2,6 +2,7 @@
// SPDX-License-Identifier: Apache-2.0
//
#include <algorithm>
#include <onnx/onnx_pb.h>
#include "ngraph/except.hpp"
@ -39,5 +40,53 @@ namespace ngraph
static_cast<ONNX_NAMESPACE::TensorProto_DataType>(onnx_type)));
#endif
}
namespace
{
using namespace ONNX_NAMESPACE;
const std::map<element::Type_t, TensorProto_DataType> NG_2_ONNX_TYPES = {
{element::Type_t::bf16, TensorProto_DataType::TensorProto_DataType_BFLOAT16},
{element::Type_t::f16, TensorProto_DataType::TensorProto_DataType_FLOAT16},
{element::Type_t::f32, TensorProto_DataType::TensorProto_DataType_FLOAT},
{element::Type_t::f64, TensorProto_DataType::TensorProto_DataType_DOUBLE},
{element::Type_t::i8, TensorProto_DataType::TensorProto_DataType_INT8},
{element::Type_t::i16, TensorProto_DataType::TensorProto_DataType_INT16},
{element::Type_t::i32, TensorProto_DataType::TensorProto_DataType_INT32},
{element::Type_t::i64, TensorProto_DataType::TensorProto_DataType_INT64},
{element::Type_t::u8, TensorProto_DataType::TensorProto_DataType_UINT8},
{element::Type_t::u16, TensorProto_DataType::TensorProto_DataType_UINT16},
{element::Type_t::u32, TensorProto_DataType::TensorProto_DataType_UINT32},
{element::Type_t::u64, TensorProto_DataType::TensorProto_DataType_UINT64},
{element::Type_t::boolean, TensorProto_DataType::TensorProto_DataType_BOOL}};
}
element::Type_t onnx_to_ng_data_type(const TensorProto_DataType& onnx_type)
{
const auto result = std::find_if(
NG_2_ONNX_TYPES.begin(),
NG_2_ONNX_TYPES.end(),
[&onnx_type](
const std::pair<element::Type_t, ONNX_NAMESPACE::TensorProto_DataType>& pair) {
return pair.second == onnx_type;
});
if (result == std::end(NG_2_ONNX_TYPES))
{
throw ngraph_error(
"unsupported element type: " +
ONNX_NAMESPACE::TensorProto_DataType_Name(
static_cast<ONNX_NAMESPACE::TensorProto_DataType>(onnx_type)));
}
return result->first;
}
TensorProto_DataType ng_to_onnx_data_type(const element::Type_t& ng_type)
{
return NG_2_ONNX_TYPES.at(ng_type);
}
bool is_supported_ng_type(const element::Type_t& ng_type)
{
return NG_2_ONNX_TYPES.count(ng_type) > 0;
}
} // namespace onnx_editor
} // namespace ngraph

View File

@ -19,21 +19,6 @@ namespace
{
using namespace ONNX_NAMESPACE;
const std::map<element::Type_t, TensorProto_DataType> NG_2_ONNX_TYPES = {
{element::Type_t::bf16, TensorProto_DataType::TensorProto_DataType_BFLOAT16},
{element::Type_t::f16, TensorProto_DataType::TensorProto_DataType_FLOAT16},
{element::Type_t::f32, TensorProto_DataType::TensorProto_DataType_FLOAT},
{element::Type_t::f64, TensorProto_DataType::TensorProto_DataType_DOUBLE},
{element::Type_t::i8, TensorProto_DataType::TensorProto_DataType_INT8},
{element::Type_t::i16, TensorProto_DataType::TensorProto_DataType_INT16},
{element::Type_t::i32, TensorProto_DataType::TensorProto_DataType_INT32},
{element::Type_t::i64, TensorProto_DataType::TensorProto_DataType_INT64},
{element::Type_t::u8, TensorProto_DataType::TensorProto_DataType_UINT8},
{element::Type_t::u16, TensorProto_DataType::TensorProto_DataType_UINT16},
{element::Type_t::u32, TensorProto_DataType::TensorProto_DataType_UINT32},
{element::Type_t::u64, TensorProto_DataType::TensorProto_DataType_UINT64},
};
ValueInfoProto* find_graph_input(GraphProto& graph, const std::string& name)
{
for (int i = 0; i < graph.input_size(); ++i)
@ -80,16 +65,17 @@ namespace
}
auto* tensor_type = type_proto->mutable_tensor_type();
if (NG_2_ONNX_TYPES.count(elem_type) == 0)
if (onnx_common::is_supported_ng_type(elem_type))
{
tensor_type->set_elem_type(onnx_common::ng_to_onnx_data_type(elem_type));
}
else
{
throw ngraph_error("The input type for input '" + onnx_input.name() +
"' cannot be set to: " + element::Type(elem_type).get_type_name() +
". This type is not allowed in ONNX.");
}
else
{
tensor_type->set_elem_type(NG_2_ONNX_TYPES.at(elem_type));
}
}
void add_dim_to_onnx_shape(const Dimension& dim, ONNX_NAMESPACE::TensorShapeProto& onnx_shape)
@ -160,7 +146,7 @@ namespace
ValueInfoProto* input)
{
const auto elem_type = values->get_element_type();
if (NG_2_ONNX_TYPES.count(elem_type) == 0)
if (!onnx_common::is_supported_ng_type(elem_type))
{
throw ngraph_error("Initializer '" + name + "' type cannot be set to: " +
element::Type(elem_type).get_type_name() +
@ -170,7 +156,7 @@ namespace
initializer.Clear();
initializer.set_name(name);
initializer.set_data_type(NG_2_ONNX_TYPES.at(values->get_element_type()));
initializer.set_data_type(onnx_common::ng_to_onnx_data_type(values->get_element_type()));
for (const auto& dim : values->get_shape())
{

View File

@ -0,0 +1,66 @@
// Copyright (C) 2018-2021 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include <onnx/onnx_pb.h> // onnx types
#include "default_opset.hpp"
#include "exceptions.hpp"
#include "ngraph/op/broadcast.hpp"
#include "ngraph/op/concat.hpp"
#include "ngraph/op/constant.hpp"
#include "onnx_common/utils.hpp"
namespace ngraph
{
namespace onnx_import
{
namespace op
{
namespace set_1
{
OutputVector constant_fill(const Node& node)
{
Output<ngraph::Node> target_shape;
const auto fill_value = node.get_attribute_value<float>("value", 0.f);
const auto dtype = node.get_attribute_value<int64_t>(
"dtype", static_cast<int64_t>(ONNX_NAMESPACE::TensorProto_DataType_FLOAT));
const auto ng_type = onnx_common::onnx_to_ng_data_type(
static_cast<ONNX_NAMESPACE::TensorProto_DataType>(dtype));
const auto const_val_to_fill =
default_opset::Constant::create(ng_type, {}, {fill_value});
const auto input_as_shape =
node.get_attribute_value<int64_t>("input_as_shape", 1);
if (input_as_shape == 1) // use the first input as target shape
{
CHECK_VALID_NODE(
node,
node.get_ng_inputs().size() > 0,
"The input which determines output shape was not provided");
target_shape = node.get_ng_inputs().at(0);
if (node.has_attribute("extra_shape"))
{
const auto extra_shape =
node.get_attribute_value<std::vector<int64_t>>("extra_shape");
const auto extra_shape_const = default_opset::Constant::create(
target_shape.get_element_type(), {extra_shape.size()}, extra_shape);
target_shape = std::make_shared<default_opset::Concat>(
OutputVector{target_shape, extra_shape_const}, 0);
}
}
else // use shape attribute as target shape
{
const auto shape = node.get_attribute_value<std::vector<int64_t>>("shape");
target_shape =
default_opset::Constant::create(ng_type, {shape.size()}, shape);
}
return {std::make_shared<default_opset::Broadcast>(const_val_to_fill,
target_shape)};
}
} // namespace set_1
} // namespace op
} // namespace onnx_import
} // namespace ngraph

View File

@ -0,0 +1,28 @@
// Copyright (C) 2018-2021 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#pragma once
#include <memory>
#include "ngraph/node.hpp"
#include "onnx_import/core/node.hpp"
namespace ngraph
{
namespace onnx_import
{
namespace op
{
namespace set_1
{
// ConstantFill is a deprecated experimental operator removed in ONNX 1.4
OutputVector constant_fill(const Node& node);
} // namespace set_1
} // namespace op
} // namespace onnx_import
} // namespace ngraph

View File

@ -29,6 +29,7 @@
#include "op/clip.hpp"
#include "op/concat.hpp"
#include "op/constant.hpp"
#include "op/constant_fill.hpp"
#include "op/constant_of_shape.hpp"
#include "op/conv.hpp"
// #include "op/conv_integer.hpp"
@ -327,6 +328,7 @@ namespace ngraph
REGISTER_OPERATOR("ConvTranspose", 1, conv_transpose);
REGISTER_OPERATOR("Cos", 1, cos);
REGISTER_OPERATOR("Cosh", 1, cosh);
REGISTER_OPERATOR("ConstantFill", 1, constant_fill);
REGISTER_OPERATOR("CumSum", 1, cum_sum);
REGISTER_OPERATOR("DepthToSpace", 1, depth_to_space);
REGISTER_OPERATOR("DequantizeLinear", 1, dequantize_linear);

View File

@ -0,0 +1,59 @@
ir_version: 7
producer_name: "backend-test"
graph {
node {
input: "target_shape"
output: "output"
op_type: "ConstantFill"
attribute {
name: "input_as_shape"
i: 1
type: INT
}
attribute {
name: "value"
i: 3
type: INT
}
attribute {
name: "extra_shape"
ints: 2
ints: 1
type: INTS
}
}
input {
name: "target_shape"
type {
tensor_type {
elem_type: 2
shape {
dim {
dim_value: 3
}
}
}
}
}
initializer {
dims: 3
data_type: 7
int64_data: 3
int64_data: 1
int64_data: 2
name: "target_shape"
}
output {
name: "output"
type {
tensor_type {
elem_type: 1
shape {
}
}
}
}
}
opset_import {
version: 13
}

View File

@ -0,0 +1,48 @@
ir_version: 7
producer_name: "backend-test"
graph {
node {
input: "target_shape"
output: "output"
op_type: "ConstantFill"
attribute {
name: "input_as_shape"
i: 1
type: INT
}
}
input {
name: "target_shape"
type {
tensor_type {
elem_type: 7
shape {
dim {
dim_value: 3
}
}
}
}
}
initializer {
dims: 3
data_type: 7
int64_data: 1
int64_data: 2
int64_data: 3
name: "target_shape"
}
output {
name: "output"
type {
tensor_type {
elem_type: 1
shape {
}
}
}
}
}
opset_import {
version: 13
}

View File

@ -0,0 +1,58 @@
ir_version: 7
producer_name: "backend-test"
graph {
node {
input: "target_shape"
output: "output"
op_type: "ConstantFill"
attribute {
name: "input_as_shape"
i: 1
type: INT
}
attribute {
name: "dtype"
i: 2
type: INT
}
attribute {
name: "value"
i: 3
type: INT
}
}
input {
name: "target_shape"
type {
tensor_type {
elem_type: 2
shape {
dim {
dim_value: 3
}
}
}
}
}
initializer {
dims: 3
data_type: 7
int64_data: 3
int64_data: 1
int64_data: 2
name: "target_shape"
}
output {
name: "output"
type {
tensor_type {
elem_type: 9
shape {
}
}
}
}
}
opset_import {
version: 13
}

View File

@ -0,0 +1,44 @@
ir_version: 7
producer_name: "backend-test"
graph {
node {
input: "target_shape"
output: "output"
op_type: "ConstantFill"
attribute {
name: "input_as_shape"
i: 0
type: INT
}
attribute {
name: "dtype"
i: 6
type: INT
}
attribute {
name: "value"
i: 5
type: INT
}
attribute {
name: "shape"
ints: 2
ints: 3
ints: 4
type: INTS
}
}
output {
name: "output"
type {
tensor_type {
elem_type: 6
shape {
}
}
}
}
}
opset_import {
version: 13
}

View File

@ -4258,3 +4258,43 @@ NGRAPH_TEST(${BACKEND_NAME}, onnx_negativelog_likelihood_loss)
test_case.add_expected_output<float>(Shape{}, {-0.531306922435760498});
test_case.run();
}
NGRAPH_TEST(${BACKEND_NAME}, onnx_constant_fill_input_as_shape_default_value)
{
auto function = onnx_import::import_onnx_model(
file_util::path_join(SERIALIZED_ZOO, "onnx/constant_fill_input_as_shape_default_value.prototxt"));
auto test_case = test::TestCase<TestEngine>(function);
test_case.add_expected_output<float>(Shape{1, 2, 3}, {0.f, 0.f, 0.f, 0.f, 0.f, 0.f});
test_case.run();
}
NGRAPH_TEST(${BACKEND_NAME}, onnx_constant_fill_input_as_shape_u8_type)
{
auto function = onnx_import::import_onnx_model(
file_util::path_join(SERIALIZED_ZOO, "onnx/constant_fill_input_as_shape_u8_type.prototxt"));
auto test_case = test::TestCase<TestEngine>(function);
test_case.add_expected_output<uint8_t>(Shape{3, 1, 2}, {3, 3, 3, 3, 3, 3});
test_case.run();
}
NGRAPH_TEST(${BACKEND_NAME}, onnx_constant_fill_extra_shape)
{
auto function = onnx_import::import_onnx_model(
file_util::path_join(SERIALIZED_ZOO, "onnx/constant_fill_extra_shape.prototxt"));
auto test_case = test::TestCase<TestEngine>(function);
test_case.add_expected_output<float>(Shape{3, 1, 2, 2, 1}, std::vector<float>(12, 3.0f));
test_case.run();
}
NGRAPH_TEST(${BACKEND_NAME}, onnx_constant_fill_shape_attribute)
{
auto function = onnx_import::import_onnx_model(
file_util::path_join(SERIALIZED_ZOO, "onnx/constant_fill_shape_attribute.prototxt"));
auto test_case = test::TestCase<TestEngine>(function);
test_case.add_expected_output<int32_t>(Shape{2, 3, 4}, std::vector<int32_t>(24, 5));
test_case.run();
}