[ONNX] Add support for ONNX ConstantFill op (#5203)
This commit is contained in:
parent
891cf56255
commit
7ac7215924
@ -1,12 +1,42 @@
|
||||
// Copyright (C) 2018-2021 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
#include "ngraph/type/element_type.hpp"
|
||||
|
||||
namespace ONNX_NAMESPACE
|
||||
{
|
||||
enum TensorProto_DataType;
|
||||
}
|
||||
|
||||
namespace ngraph
|
||||
{
|
||||
namespace onnx_common
|
||||
{
|
||||
/// \brief Retuns size of an ONNX data type in bytes.
|
||||
///
|
||||
/// \param onnx_type Number assigned to an ONNX data type in the TensorProto_DataType enum.
|
||||
///
|
||||
size_t get_onnx_data_size(int32_t onnx_type);
|
||||
|
||||
/// \brief Retuns a nGraph data type corresponding to an ONNX type.
|
||||
///
|
||||
/// \param onnx_type An element of TensorProto_DataType enum which determines an ONNX type.
|
||||
///
|
||||
element::Type_t onnx_to_ng_data_type(const ONNX_NAMESPACE::TensorProto_DataType& onnx_type);
|
||||
|
||||
/// \brief Retuns an ONNX data type corresponding to a nGraph data type.
|
||||
///
|
||||
/// \param ng_type An element of element::Type_t enum class which determines a nGraph data
|
||||
/// type.
|
||||
///
|
||||
ONNX_NAMESPACE::TensorProto_DataType ng_to_onnx_data_type(const element::Type_t& ng_type);
|
||||
|
||||
/// \brief Retuns true if a nGraph data type is mapped to an ONNX data type.
|
||||
///
|
||||
/// \param ng_type An element of element::Type_t enum class which determines a nGraph data
|
||||
/// type.
|
||||
///
|
||||
bool is_supported_ng_type(const element::Type_t& ng_type);
|
||||
|
||||
} // namespace onnx_editor
|
||||
} // namespace ngraph
|
||||
|
@ -2,6 +2,7 @@
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#include <algorithm>
|
||||
#include <onnx/onnx_pb.h>
|
||||
|
||||
#include "ngraph/except.hpp"
|
||||
@ -39,5 +40,53 @@ namespace ngraph
|
||||
static_cast<ONNX_NAMESPACE::TensorProto_DataType>(onnx_type)));
|
||||
#endif
|
||||
}
|
||||
namespace
|
||||
{
|
||||
using namespace ONNX_NAMESPACE;
|
||||
const std::map<element::Type_t, TensorProto_DataType> NG_2_ONNX_TYPES = {
|
||||
{element::Type_t::bf16, TensorProto_DataType::TensorProto_DataType_BFLOAT16},
|
||||
{element::Type_t::f16, TensorProto_DataType::TensorProto_DataType_FLOAT16},
|
||||
{element::Type_t::f32, TensorProto_DataType::TensorProto_DataType_FLOAT},
|
||||
{element::Type_t::f64, TensorProto_DataType::TensorProto_DataType_DOUBLE},
|
||||
{element::Type_t::i8, TensorProto_DataType::TensorProto_DataType_INT8},
|
||||
{element::Type_t::i16, TensorProto_DataType::TensorProto_DataType_INT16},
|
||||
{element::Type_t::i32, TensorProto_DataType::TensorProto_DataType_INT32},
|
||||
{element::Type_t::i64, TensorProto_DataType::TensorProto_DataType_INT64},
|
||||
{element::Type_t::u8, TensorProto_DataType::TensorProto_DataType_UINT8},
|
||||
{element::Type_t::u16, TensorProto_DataType::TensorProto_DataType_UINT16},
|
||||
{element::Type_t::u32, TensorProto_DataType::TensorProto_DataType_UINT32},
|
||||
{element::Type_t::u64, TensorProto_DataType::TensorProto_DataType_UINT64},
|
||||
{element::Type_t::boolean, TensorProto_DataType::TensorProto_DataType_BOOL}};
|
||||
}
|
||||
|
||||
element::Type_t onnx_to_ng_data_type(const TensorProto_DataType& onnx_type)
|
||||
{
|
||||
const auto result = std::find_if(
|
||||
NG_2_ONNX_TYPES.begin(),
|
||||
NG_2_ONNX_TYPES.end(),
|
||||
[&onnx_type](
|
||||
const std::pair<element::Type_t, ONNX_NAMESPACE::TensorProto_DataType>& pair) {
|
||||
return pair.second == onnx_type;
|
||||
});
|
||||
if (result == std::end(NG_2_ONNX_TYPES))
|
||||
{
|
||||
throw ngraph_error(
|
||||
"unsupported element type: " +
|
||||
ONNX_NAMESPACE::TensorProto_DataType_Name(
|
||||
static_cast<ONNX_NAMESPACE::TensorProto_DataType>(onnx_type)));
|
||||
}
|
||||
return result->first;
|
||||
}
|
||||
|
||||
TensorProto_DataType ng_to_onnx_data_type(const element::Type_t& ng_type)
|
||||
{
|
||||
return NG_2_ONNX_TYPES.at(ng_type);
|
||||
}
|
||||
|
||||
bool is_supported_ng_type(const element::Type_t& ng_type)
|
||||
{
|
||||
return NG_2_ONNX_TYPES.count(ng_type) > 0;
|
||||
}
|
||||
|
||||
} // namespace onnx_editor
|
||||
} // namespace ngraph
|
||||
|
@ -19,21 +19,6 @@ namespace
|
||||
{
|
||||
using namespace ONNX_NAMESPACE;
|
||||
|
||||
const std::map<element::Type_t, TensorProto_DataType> NG_2_ONNX_TYPES = {
|
||||
{element::Type_t::bf16, TensorProto_DataType::TensorProto_DataType_BFLOAT16},
|
||||
{element::Type_t::f16, TensorProto_DataType::TensorProto_DataType_FLOAT16},
|
||||
{element::Type_t::f32, TensorProto_DataType::TensorProto_DataType_FLOAT},
|
||||
{element::Type_t::f64, TensorProto_DataType::TensorProto_DataType_DOUBLE},
|
||||
{element::Type_t::i8, TensorProto_DataType::TensorProto_DataType_INT8},
|
||||
{element::Type_t::i16, TensorProto_DataType::TensorProto_DataType_INT16},
|
||||
{element::Type_t::i32, TensorProto_DataType::TensorProto_DataType_INT32},
|
||||
{element::Type_t::i64, TensorProto_DataType::TensorProto_DataType_INT64},
|
||||
{element::Type_t::u8, TensorProto_DataType::TensorProto_DataType_UINT8},
|
||||
{element::Type_t::u16, TensorProto_DataType::TensorProto_DataType_UINT16},
|
||||
{element::Type_t::u32, TensorProto_DataType::TensorProto_DataType_UINT32},
|
||||
{element::Type_t::u64, TensorProto_DataType::TensorProto_DataType_UINT64},
|
||||
};
|
||||
|
||||
ValueInfoProto* find_graph_input(GraphProto& graph, const std::string& name)
|
||||
{
|
||||
for (int i = 0; i < graph.input_size(); ++i)
|
||||
@ -80,16 +65,17 @@ namespace
|
||||
}
|
||||
|
||||
auto* tensor_type = type_proto->mutable_tensor_type();
|
||||
if (NG_2_ONNX_TYPES.count(elem_type) == 0)
|
||||
|
||||
if (onnx_common::is_supported_ng_type(elem_type))
|
||||
{
|
||||
tensor_type->set_elem_type(onnx_common::ng_to_onnx_data_type(elem_type));
|
||||
}
|
||||
else
|
||||
{
|
||||
throw ngraph_error("The input type for input '" + onnx_input.name() +
|
||||
"' cannot be set to: " + element::Type(elem_type).get_type_name() +
|
||||
". This type is not allowed in ONNX.");
|
||||
}
|
||||
else
|
||||
{
|
||||
tensor_type->set_elem_type(NG_2_ONNX_TYPES.at(elem_type));
|
||||
}
|
||||
}
|
||||
|
||||
void add_dim_to_onnx_shape(const Dimension& dim, ONNX_NAMESPACE::TensorShapeProto& onnx_shape)
|
||||
@ -160,7 +146,7 @@ namespace
|
||||
ValueInfoProto* input)
|
||||
{
|
||||
const auto elem_type = values->get_element_type();
|
||||
if (NG_2_ONNX_TYPES.count(elem_type) == 0)
|
||||
if (!onnx_common::is_supported_ng_type(elem_type))
|
||||
{
|
||||
throw ngraph_error("Initializer '" + name + "' type cannot be set to: " +
|
||||
element::Type(elem_type).get_type_name() +
|
||||
@ -170,7 +156,7 @@ namespace
|
||||
initializer.Clear();
|
||||
|
||||
initializer.set_name(name);
|
||||
initializer.set_data_type(NG_2_ONNX_TYPES.at(values->get_element_type()));
|
||||
initializer.set_data_type(onnx_common::ng_to_onnx_data_type(values->get_element_type()));
|
||||
|
||||
for (const auto& dim : values->get_shape())
|
||||
{
|
||||
|
66
ngraph/frontend/onnx_import/src/op/constant_fill.cpp
Normal file
66
ngraph/frontend/onnx_import/src/op/constant_fill.cpp
Normal file
@ -0,0 +1,66 @@
|
||||
// Copyright (C) 2018-2021 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#include <onnx/onnx_pb.h> // onnx types
|
||||
|
||||
#include "default_opset.hpp"
|
||||
#include "exceptions.hpp"
|
||||
#include "ngraph/op/broadcast.hpp"
|
||||
#include "ngraph/op/concat.hpp"
|
||||
#include "ngraph/op/constant.hpp"
|
||||
#include "onnx_common/utils.hpp"
|
||||
|
||||
namespace ngraph
|
||||
{
|
||||
namespace onnx_import
|
||||
{
|
||||
namespace op
|
||||
{
|
||||
namespace set_1
|
||||
{
|
||||
OutputVector constant_fill(const Node& node)
|
||||
{
|
||||
Output<ngraph::Node> target_shape;
|
||||
const auto fill_value = node.get_attribute_value<float>("value", 0.f);
|
||||
const auto dtype = node.get_attribute_value<int64_t>(
|
||||
"dtype", static_cast<int64_t>(ONNX_NAMESPACE::TensorProto_DataType_FLOAT));
|
||||
const auto ng_type = onnx_common::onnx_to_ng_data_type(
|
||||
static_cast<ONNX_NAMESPACE::TensorProto_DataType>(dtype));
|
||||
const auto const_val_to_fill =
|
||||
default_opset::Constant::create(ng_type, {}, {fill_value});
|
||||
const auto input_as_shape =
|
||||
node.get_attribute_value<int64_t>("input_as_shape", 1);
|
||||
if (input_as_shape == 1) // use the first input as target shape
|
||||
{
|
||||
CHECK_VALID_NODE(
|
||||
node,
|
||||
node.get_ng_inputs().size() > 0,
|
||||
"The input which determines output shape was not provided");
|
||||
target_shape = node.get_ng_inputs().at(0);
|
||||
if (node.has_attribute("extra_shape"))
|
||||
{
|
||||
const auto extra_shape =
|
||||
node.get_attribute_value<std::vector<int64_t>>("extra_shape");
|
||||
const auto extra_shape_const = default_opset::Constant::create(
|
||||
target_shape.get_element_type(), {extra_shape.size()}, extra_shape);
|
||||
target_shape = std::make_shared<default_opset::Concat>(
|
||||
OutputVector{target_shape, extra_shape_const}, 0);
|
||||
}
|
||||
}
|
||||
else // use shape attribute as target shape
|
||||
{
|
||||
const auto shape = node.get_attribute_value<std::vector<int64_t>>("shape");
|
||||
target_shape =
|
||||
default_opset::Constant::create(ng_type, {shape.size()}, shape);
|
||||
}
|
||||
|
||||
return {std::make_shared<default_opset::Broadcast>(const_val_to_fill,
|
||||
target_shape)};
|
||||
}
|
||||
|
||||
} // namespace set_1
|
||||
} // namespace op
|
||||
} // namespace onnx_import
|
||||
|
||||
} // namespace ngraph
|
28
ngraph/frontend/onnx_import/src/op/constant_fill.hpp
Normal file
28
ngraph/frontend/onnx_import/src/op/constant_fill.hpp
Normal file
@ -0,0 +1,28 @@
|
||||
// Copyright (C) 2018-2021 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <memory>
|
||||
|
||||
#include "ngraph/node.hpp"
|
||||
#include "onnx_import/core/node.hpp"
|
||||
|
||||
namespace ngraph
|
||||
{
|
||||
namespace onnx_import
|
||||
{
|
||||
namespace op
|
||||
{
|
||||
namespace set_1
|
||||
{
|
||||
// ConstantFill is a deprecated experimental operator removed in ONNX 1.4
|
||||
OutputVector constant_fill(const Node& node);
|
||||
} // namespace set_1
|
||||
|
||||
} // namespace op
|
||||
|
||||
} // namespace onnx_import
|
||||
|
||||
} // namespace ngraph
|
@ -29,6 +29,7 @@
|
||||
#include "op/clip.hpp"
|
||||
#include "op/concat.hpp"
|
||||
#include "op/constant.hpp"
|
||||
#include "op/constant_fill.hpp"
|
||||
#include "op/constant_of_shape.hpp"
|
||||
#include "op/conv.hpp"
|
||||
// #include "op/conv_integer.hpp"
|
||||
@ -327,6 +328,7 @@ namespace ngraph
|
||||
REGISTER_OPERATOR("ConvTranspose", 1, conv_transpose);
|
||||
REGISTER_OPERATOR("Cos", 1, cos);
|
||||
REGISTER_OPERATOR("Cosh", 1, cosh);
|
||||
REGISTER_OPERATOR("ConstantFill", 1, constant_fill);
|
||||
REGISTER_OPERATOR("CumSum", 1, cum_sum);
|
||||
REGISTER_OPERATOR("DepthToSpace", 1, depth_to_space);
|
||||
REGISTER_OPERATOR("DequantizeLinear", 1, dequantize_linear);
|
||||
|
59
ngraph/test/models/onnx/constant_fill_extra_shape.prototxt
Normal file
59
ngraph/test/models/onnx/constant_fill_extra_shape.prototxt
Normal file
@ -0,0 +1,59 @@
|
||||
ir_version: 7
|
||||
producer_name: "backend-test"
|
||||
graph {
|
||||
node {
|
||||
input: "target_shape"
|
||||
output: "output"
|
||||
op_type: "ConstantFill"
|
||||
attribute {
|
||||
name: "input_as_shape"
|
||||
i: 1
|
||||
type: INT
|
||||
}
|
||||
attribute {
|
||||
name: "value"
|
||||
i: 3
|
||||
type: INT
|
||||
}
|
||||
attribute {
|
||||
name: "extra_shape"
|
||||
ints: 2
|
||||
ints: 1
|
||||
type: INTS
|
||||
}
|
||||
}
|
||||
input {
|
||||
name: "target_shape"
|
||||
type {
|
||||
tensor_type {
|
||||
elem_type: 2
|
||||
shape {
|
||||
dim {
|
||||
dim_value: 3
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
initializer {
|
||||
dims: 3
|
||||
data_type: 7
|
||||
int64_data: 3
|
||||
int64_data: 1
|
||||
int64_data: 2
|
||||
name: "target_shape"
|
||||
}
|
||||
output {
|
||||
name: "output"
|
||||
type {
|
||||
tensor_type {
|
||||
elem_type: 1
|
||||
shape {
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
opset_import {
|
||||
version: 13
|
||||
}
|
@ -0,0 +1,48 @@
|
||||
ir_version: 7
|
||||
producer_name: "backend-test"
|
||||
graph {
|
||||
node {
|
||||
input: "target_shape"
|
||||
output: "output"
|
||||
op_type: "ConstantFill"
|
||||
attribute {
|
||||
name: "input_as_shape"
|
||||
i: 1
|
||||
type: INT
|
||||
}
|
||||
}
|
||||
input {
|
||||
name: "target_shape"
|
||||
type {
|
||||
tensor_type {
|
||||
elem_type: 7
|
||||
shape {
|
||||
dim {
|
||||
dim_value: 3
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
initializer {
|
||||
dims: 3
|
||||
data_type: 7
|
||||
int64_data: 1
|
||||
int64_data: 2
|
||||
int64_data: 3
|
||||
name: "target_shape"
|
||||
}
|
||||
output {
|
||||
name: "output"
|
||||
type {
|
||||
tensor_type {
|
||||
elem_type: 1
|
||||
shape {
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
opset_import {
|
||||
version: 13
|
||||
}
|
@ -0,0 +1,58 @@
|
||||
ir_version: 7
|
||||
producer_name: "backend-test"
|
||||
graph {
|
||||
node {
|
||||
input: "target_shape"
|
||||
output: "output"
|
||||
op_type: "ConstantFill"
|
||||
attribute {
|
||||
name: "input_as_shape"
|
||||
i: 1
|
||||
type: INT
|
||||
}
|
||||
attribute {
|
||||
name: "dtype"
|
||||
i: 2
|
||||
type: INT
|
||||
}
|
||||
attribute {
|
||||
name: "value"
|
||||
i: 3
|
||||
type: INT
|
||||
}
|
||||
}
|
||||
input {
|
||||
name: "target_shape"
|
||||
type {
|
||||
tensor_type {
|
||||
elem_type: 2
|
||||
shape {
|
||||
dim {
|
||||
dim_value: 3
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
initializer {
|
||||
dims: 3
|
||||
data_type: 7
|
||||
int64_data: 3
|
||||
int64_data: 1
|
||||
int64_data: 2
|
||||
name: "target_shape"
|
||||
}
|
||||
output {
|
||||
name: "output"
|
||||
type {
|
||||
tensor_type {
|
||||
elem_type: 9
|
||||
shape {
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
opset_import {
|
||||
version: 13
|
||||
}
|
@ -0,0 +1,44 @@
|
||||
ir_version: 7
|
||||
producer_name: "backend-test"
|
||||
graph {
|
||||
node {
|
||||
input: "target_shape"
|
||||
output: "output"
|
||||
op_type: "ConstantFill"
|
||||
attribute {
|
||||
name: "input_as_shape"
|
||||
i: 0
|
||||
type: INT
|
||||
}
|
||||
attribute {
|
||||
name: "dtype"
|
||||
i: 6
|
||||
type: INT
|
||||
}
|
||||
attribute {
|
||||
name: "value"
|
||||
i: 5
|
||||
type: INT
|
||||
}
|
||||
attribute {
|
||||
name: "shape"
|
||||
ints: 2
|
||||
ints: 3
|
||||
ints: 4
|
||||
type: INTS
|
||||
}
|
||||
}
|
||||
output {
|
||||
name: "output"
|
||||
type {
|
||||
tensor_type {
|
||||
elem_type: 6
|
||||
shape {
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
opset_import {
|
||||
version: 13
|
||||
}
|
@ -4258,3 +4258,43 @@ NGRAPH_TEST(${BACKEND_NAME}, onnx_negativelog_likelihood_loss)
|
||||
test_case.add_expected_output<float>(Shape{}, {-0.531306922435760498});
|
||||
test_case.run();
|
||||
}
|
||||
|
||||
NGRAPH_TEST(${BACKEND_NAME}, onnx_constant_fill_input_as_shape_default_value)
|
||||
{
|
||||
auto function = onnx_import::import_onnx_model(
|
||||
file_util::path_join(SERIALIZED_ZOO, "onnx/constant_fill_input_as_shape_default_value.prototxt"));
|
||||
|
||||
auto test_case = test::TestCase<TestEngine>(function);
|
||||
test_case.add_expected_output<float>(Shape{1, 2, 3}, {0.f, 0.f, 0.f, 0.f, 0.f, 0.f});
|
||||
test_case.run();
|
||||
}
|
||||
|
||||
NGRAPH_TEST(${BACKEND_NAME}, onnx_constant_fill_input_as_shape_u8_type)
|
||||
{
|
||||
auto function = onnx_import::import_onnx_model(
|
||||
file_util::path_join(SERIALIZED_ZOO, "onnx/constant_fill_input_as_shape_u8_type.prototxt"));
|
||||
|
||||
auto test_case = test::TestCase<TestEngine>(function);
|
||||
test_case.add_expected_output<uint8_t>(Shape{3, 1, 2}, {3, 3, 3, 3, 3, 3});
|
||||
test_case.run();
|
||||
}
|
||||
|
||||
NGRAPH_TEST(${BACKEND_NAME}, onnx_constant_fill_extra_shape)
|
||||
{
|
||||
auto function = onnx_import::import_onnx_model(
|
||||
file_util::path_join(SERIALIZED_ZOO, "onnx/constant_fill_extra_shape.prototxt"));
|
||||
|
||||
auto test_case = test::TestCase<TestEngine>(function);
|
||||
test_case.add_expected_output<float>(Shape{3, 1, 2, 2, 1}, std::vector<float>(12, 3.0f));
|
||||
test_case.run();
|
||||
}
|
||||
|
||||
NGRAPH_TEST(${BACKEND_NAME}, onnx_constant_fill_shape_attribute)
|
||||
{
|
||||
auto function = onnx_import::import_onnx_model(
|
||||
file_util::path_join(SERIALIZED_ZOO, "onnx/constant_fill_shape_attribute.prototxt"));
|
||||
|
||||
auto test_case = test::TestCase<TestEngine>(function);
|
||||
test_case.add_expected_output<int32_t>(Shape{2, 3, 4}, std::vector<int32_t>(24, 5));
|
||||
test_case.run();
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user