diff --git a/ngraph/frontend/onnx_common/include/onnx_common/utils.hpp b/ngraph/frontend/onnx_common/include/onnx_common/utils.hpp index bdf0aba4468..c9d28bede13 100644 --- a/ngraph/frontend/onnx_common/include/onnx_common/utils.hpp +++ b/ngraph/frontend/onnx_common/include/onnx_common/utils.hpp @@ -1,12 +1,42 @@ // Copyright (C) 2018-2021 Intel Corporation // SPDX-License-Identifier: Apache-2.0 // +#include "ngraph/type/element_type.hpp" + +namespace ONNX_NAMESPACE +{ + enum TensorProto_DataType; +} namespace ngraph { namespace onnx_common { + /// \brief Retuns size of an ONNX data type in bytes. + /// + /// \param onnx_type Number assigned to an ONNX data type in the TensorProto_DataType enum. + /// size_t get_onnx_data_size(int32_t onnx_type); + /// \brief Retuns a nGraph data type corresponding to an ONNX type. + /// + /// \param onnx_type An element of TensorProto_DataType enum which determines an ONNX type. + /// + element::Type_t onnx_to_ng_data_type(const ONNX_NAMESPACE::TensorProto_DataType& onnx_type); + + /// \brief Retuns an ONNX data type corresponding to a nGraph data type. + /// + /// \param ng_type An element of element::Type_t enum class which determines a nGraph data + /// type. + /// + ONNX_NAMESPACE::TensorProto_DataType ng_to_onnx_data_type(const element::Type_t& ng_type); + + /// \brief Retuns true if a nGraph data type is mapped to an ONNX data type. + /// + /// \param ng_type An element of element::Type_t enum class which determines a nGraph data + /// type. + /// + bool is_supported_ng_type(const element::Type_t& ng_type); + } // namespace onnx_editor } // namespace ngraph diff --git a/ngraph/frontend/onnx_common/src/utils.cpp b/ngraph/frontend/onnx_common/src/utils.cpp index 998f5f4daa8..55e1a4e9288 100644 --- a/ngraph/frontend/onnx_common/src/utils.cpp +++ b/ngraph/frontend/onnx_common/src/utils.cpp @@ -2,6 +2,7 @@ // SPDX-License-Identifier: Apache-2.0 // +#include #include #include "ngraph/except.hpp" @@ -39,5 +40,53 @@ namespace ngraph static_cast(onnx_type))); #endif } + namespace + { + using namespace ONNX_NAMESPACE; + const std::map NG_2_ONNX_TYPES = { + {element::Type_t::bf16, TensorProto_DataType::TensorProto_DataType_BFLOAT16}, + {element::Type_t::f16, TensorProto_DataType::TensorProto_DataType_FLOAT16}, + {element::Type_t::f32, TensorProto_DataType::TensorProto_DataType_FLOAT}, + {element::Type_t::f64, TensorProto_DataType::TensorProto_DataType_DOUBLE}, + {element::Type_t::i8, TensorProto_DataType::TensorProto_DataType_INT8}, + {element::Type_t::i16, TensorProto_DataType::TensorProto_DataType_INT16}, + {element::Type_t::i32, TensorProto_DataType::TensorProto_DataType_INT32}, + {element::Type_t::i64, TensorProto_DataType::TensorProto_DataType_INT64}, + {element::Type_t::u8, TensorProto_DataType::TensorProto_DataType_UINT8}, + {element::Type_t::u16, TensorProto_DataType::TensorProto_DataType_UINT16}, + {element::Type_t::u32, TensorProto_DataType::TensorProto_DataType_UINT32}, + {element::Type_t::u64, TensorProto_DataType::TensorProto_DataType_UINT64}, + {element::Type_t::boolean, TensorProto_DataType::TensorProto_DataType_BOOL}}; + } + + element::Type_t onnx_to_ng_data_type(const TensorProto_DataType& onnx_type) + { + const auto result = std::find_if( + NG_2_ONNX_TYPES.begin(), + NG_2_ONNX_TYPES.end(), + [&onnx_type]( + const std::pair& pair) { + return pair.second == onnx_type; + }); + if (result == std::end(NG_2_ONNX_TYPES)) + { + throw ngraph_error( + "unsupported element type: " + + ONNX_NAMESPACE::TensorProto_DataType_Name( + static_cast(onnx_type))); + } + return result->first; + } + + TensorProto_DataType ng_to_onnx_data_type(const element::Type_t& ng_type) + { + return NG_2_ONNX_TYPES.at(ng_type); + } + + bool is_supported_ng_type(const element::Type_t& ng_type) + { + return NG_2_ONNX_TYPES.count(ng_type) > 0; + } + } // namespace onnx_editor } // namespace ngraph diff --git a/ngraph/frontend/onnx_editor/src/editor.cpp b/ngraph/frontend/onnx_editor/src/editor.cpp index d4b24300bea..ba435b5be4a 100644 --- a/ngraph/frontend/onnx_editor/src/editor.cpp +++ b/ngraph/frontend/onnx_editor/src/editor.cpp @@ -19,21 +19,6 @@ namespace { using namespace ONNX_NAMESPACE; - const std::map NG_2_ONNX_TYPES = { - {element::Type_t::bf16, TensorProto_DataType::TensorProto_DataType_BFLOAT16}, - {element::Type_t::f16, TensorProto_DataType::TensorProto_DataType_FLOAT16}, - {element::Type_t::f32, TensorProto_DataType::TensorProto_DataType_FLOAT}, - {element::Type_t::f64, TensorProto_DataType::TensorProto_DataType_DOUBLE}, - {element::Type_t::i8, TensorProto_DataType::TensorProto_DataType_INT8}, - {element::Type_t::i16, TensorProto_DataType::TensorProto_DataType_INT16}, - {element::Type_t::i32, TensorProto_DataType::TensorProto_DataType_INT32}, - {element::Type_t::i64, TensorProto_DataType::TensorProto_DataType_INT64}, - {element::Type_t::u8, TensorProto_DataType::TensorProto_DataType_UINT8}, - {element::Type_t::u16, TensorProto_DataType::TensorProto_DataType_UINT16}, - {element::Type_t::u32, TensorProto_DataType::TensorProto_DataType_UINT32}, - {element::Type_t::u64, TensorProto_DataType::TensorProto_DataType_UINT64}, - }; - ValueInfoProto* find_graph_input(GraphProto& graph, const std::string& name) { for (int i = 0; i < graph.input_size(); ++i) @@ -80,16 +65,17 @@ namespace } auto* tensor_type = type_proto->mutable_tensor_type(); - if (NG_2_ONNX_TYPES.count(elem_type) == 0) + + if (onnx_common::is_supported_ng_type(elem_type)) + { + tensor_type->set_elem_type(onnx_common::ng_to_onnx_data_type(elem_type)); + } + else { throw ngraph_error("The input type for input '" + onnx_input.name() + "' cannot be set to: " + element::Type(elem_type).get_type_name() + ". This type is not allowed in ONNX."); } - else - { - tensor_type->set_elem_type(NG_2_ONNX_TYPES.at(elem_type)); - } } void add_dim_to_onnx_shape(const Dimension& dim, ONNX_NAMESPACE::TensorShapeProto& onnx_shape) @@ -160,7 +146,7 @@ namespace ValueInfoProto* input) { const auto elem_type = values->get_element_type(); - if (NG_2_ONNX_TYPES.count(elem_type) == 0) + if (!onnx_common::is_supported_ng_type(elem_type)) { throw ngraph_error("Initializer '" + name + "' type cannot be set to: " + element::Type(elem_type).get_type_name() + @@ -170,7 +156,7 @@ namespace initializer.Clear(); initializer.set_name(name); - initializer.set_data_type(NG_2_ONNX_TYPES.at(values->get_element_type())); + initializer.set_data_type(onnx_common::ng_to_onnx_data_type(values->get_element_type())); for (const auto& dim : values->get_shape()) { diff --git a/ngraph/frontend/onnx_import/src/op/constant_fill.cpp b/ngraph/frontend/onnx_import/src/op/constant_fill.cpp new file mode 100644 index 00000000000..14a9c1c6ef6 --- /dev/null +++ b/ngraph/frontend/onnx_import/src/op/constant_fill.cpp @@ -0,0 +1,66 @@ +// Copyright (C) 2018-2021 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include // onnx types + +#include "default_opset.hpp" +#include "exceptions.hpp" +#include "ngraph/op/broadcast.hpp" +#include "ngraph/op/concat.hpp" +#include "ngraph/op/constant.hpp" +#include "onnx_common/utils.hpp" + +namespace ngraph +{ + namespace onnx_import + { + namespace op + { + namespace set_1 + { + OutputVector constant_fill(const Node& node) + { + Output target_shape; + const auto fill_value = node.get_attribute_value("value", 0.f); + const auto dtype = node.get_attribute_value( + "dtype", static_cast(ONNX_NAMESPACE::TensorProto_DataType_FLOAT)); + const auto ng_type = onnx_common::onnx_to_ng_data_type( + static_cast(dtype)); + const auto const_val_to_fill = + default_opset::Constant::create(ng_type, {}, {fill_value}); + const auto input_as_shape = + node.get_attribute_value("input_as_shape", 1); + if (input_as_shape == 1) // use the first input as target shape + { + CHECK_VALID_NODE( + node, + node.get_ng_inputs().size() > 0, + "The input which determines output shape was not provided"); + target_shape = node.get_ng_inputs().at(0); + if (node.has_attribute("extra_shape")) + { + const auto extra_shape = + node.get_attribute_value>("extra_shape"); + const auto extra_shape_const = default_opset::Constant::create( + target_shape.get_element_type(), {extra_shape.size()}, extra_shape); + target_shape = std::make_shared( + OutputVector{target_shape, extra_shape_const}, 0); + } + } + else // use shape attribute as target shape + { + const auto shape = node.get_attribute_value>("shape"); + target_shape = + default_opset::Constant::create(ng_type, {shape.size()}, shape); + } + + return {std::make_shared(const_val_to_fill, + target_shape)}; + } + + } // namespace set_1 + } // namespace op + } // namespace onnx_import + +} // namespace ngraph diff --git a/ngraph/frontend/onnx_import/src/op/constant_fill.hpp b/ngraph/frontend/onnx_import/src/op/constant_fill.hpp new file mode 100644 index 00000000000..347dc98cd86 --- /dev/null +++ b/ngraph/frontend/onnx_import/src/op/constant_fill.hpp @@ -0,0 +1,28 @@ +// Copyright (C) 2018-2021 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#include + +#include "ngraph/node.hpp" +#include "onnx_import/core/node.hpp" + +namespace ngraph +{ + namespace onnx_import + { + namespace op + { + namespace set_1 + { + // ConstantFill is a deprecated experimental operator removed in ONNX 1.4 + OutputVector constant_fill(const Node& node); + } // namespace set_1 + + } // namespace op + + } // namespace onnx_import + +} // namespace ngraph diff --git a/ngraph/frontend/onnx_import/src/ops_bridge.cpp b/ngraph/frontend/onnx_import/src/ops_bridge.cpp index e2107c8ade3..419f1f99552 100644 --- a/ngraph/frontend/onnx_import/src/ops_bridge.cpp +++ b/ngraph/frontend/onnx_import/src/ops_bridge.cpp @@ -29,6 +29,7 @@ #include "op/clip.hpp" #include "op/concat.hpp" #include "op/constant.hpp" +#include "op/constant_fill.hpp" #include "op/constant_of_shape.hpp" #include "op/conv.hpp" // #include "op/conv_integer.hpp" @@ -327,6 +328,7 @@ namespace ngraph REGISTER_OPERATOR("ConvTranspose", 1, conv_transpose); REGISTER_OPERATOR("Cos", 1, cos); REGISTER_OPERATOR("Cosh", 1, cosh); + REGISTER_OPERATOR("ConstantFill", 1, constant_fill); REGISTER_OPERATOR("CumSum", 1, cum_sum); REGISTER_OPERATOR("DepthToSpace", 1, depth_to_space); REGISTER_OPERATOR("DequantizeLinear", 1, dequantize_linear); diff --git a/ngraph/test/models/onnx/constant_fill_extra_shape.prototxt b/ngraph/test/models/onnx/constant_fill_extra_shape.prototxt new file mode 100644 index 00000000000..c88ab27073e --- /dev/null +++ b/ngraph/test/models/onnx/constant_fill_extra_shape.prototxt @@ -0,0 +1,59 @@ +ir_version: 7 +producer_name: "backend-test" +graph { + node { + input: "target_shape" + output: "output" + op_type: "ConstantFill" + attribute { + name: "input_as_shape" + i: 1 + type: INT + } + attribute { + name: "value" + i: 3 + type: INT + } + attribute { + name: "extra_shape" + ints: 2 + ints: 1 + type: INTS + } + } + input { + name: "target_shape" + type { + tensor_type { + elem_type: 2 + shape { + dim { + dim_value: 3 + } + } + } + } + } + initializer { + dims: 3 + data_type: 7 + int64_data: 3 + int64_data: 1 + int64_data: 2 + name: "target_shape" + } + output { + name: "output" + type { + tensor_type { + elem_type: 1 + shape { + } + } + } + } +} +opset_import { + version: 13 +} diff --git a/ngraph/test/models/onnx/constant_fill_input_as_shape_default_value.prototxt b/ngraph/test/models/onnx/constant_fill_input_as_shape_default_value.prototxt new file mode 100644 index 00000000000..30fa0dfa08a --- /dev/null +++ b/ngraph/test/models/onnx/constant_fill_input_as_shape_default_value.prototxt @@ -0,0 +1,48 @@ +ir_version: 7 +producer_name: "backend-test" +graph { + node { + input: "target_shape" + output: "output" + op_type: "ConstantFill" + attribute { + name: "input_as_shape" + i: 1 + type: INT + } + } + input { + name: "target_shape" + type { + tensor_type { + elem_type: 7 + shape { + dim { + dim_value: 3 + } + } + } + } + } + initializer { + dims: 3 + data_type: 7 + int64_data: 1 + int64_data: 2 + int64_data: 3 + name: "target_shape" + } + output { + name: "output" + type { + tensor_type { + elem_type: 1 + shape { + } + } + } + } +} +opset_import { + version: 13 +} diff --git a/ngraph/test/models/onnx/constant_fill_input_as_shape_u8_type.prototxt b/ngraph/test/models/onnx/constant_fill_input_as_shape_u8_type.prototxt new file mode 100644 index 00000000000..e0312efefa5 --- /dev/null +++ b/ngraph/test/models/onnx/constant_fill_input_as_shape_u8_type.prototxt @@ -0,0 +1,58 @@ +ir_version: 7 +producer_name: "backend-test" +graph { + node { + input: "target_shape" + output: "output" + op_type: "ConstantFill" + attribute { + name: "input_as_shape" + i: 1 + type: INT + } + attribute { + name: "dtype" + i: 2 + type: INT + } + attribute { + name: "value" + i: 3 + type: INT + } + } + input { + name: "target_shape" + type { + tensor_type { + elem_type: 2 + shape { + dim { + dim_value: 3 + } + } + } + } + } + initializer { + dims: 3 + data_type: 7 + int64_data: 3 + int64_data: 1 + int64_data: 2 + name: "target_shape" + } + output { + name: "output" + type { + tensor_type { + elem_type: 9 + shape { + } + } + } + } +} +opset_import { + version: 13 +} diff --git a/ngraph/test/models/onnx/constant_fill_shape_attribute.prototxt b/ngraph/test/models/onnx/constant_fill_shape_attribute.prototxt new file mode 100644 index 00000000000..806f01ffd89 --- /dev/null +++ b/ngraph/test/models/onnx/constant_fill_shape_attribute.prototxt @@ -0,0 +1,44 @@ +ir_version: 7 +producer_name: "backend-test" +graph { + node { + input: "target_shape" + output: "output" + op_type: "ConstantFill" + attribute { + name: "input_as_shape" + i: 0 + type: INT + } + attribute { + name: "dtype" + i: 6 + type: INT + } + attribute { + name: "value" + i: 5 + type: INT + } + attribute { + name: "shape" + ints: 2 + ints: 3 + ints: 4 + type: INTS + } + } + output { + name: "output" + type { + tensor_type { + elem_type: 6 + shape { + } + } + } + } +} +opset_import { + version: 13 +} diff --git a/ngraph/test/onnx/onnx_import.in.cpp b/ngraph/test/onnx/onnx_import.in.cpp index 3c49b1243f2..a8b0d7c90b4 100644 --- a/ngraph/test/onnx/onnx_import.in.cpp +++ b/ngraph/test/onnx/onnx_import.in.cpp @@ -4258,3 +4258,43 @@ NGRAPH_TEST(${BACKEND_NAME}, onnx_negativelog_likelihood_loss) test_case.add_expected_output(Shape{}, {-0.531306922435760498}); test_case.run(); } + +NGRAPH_TEST(${BACKEND_NAME}, onnx_constant_fill_input_as_shape_default_value) +{ + auto function = onnx_import::import_onnx_model( + file_util::path_join(SERIALIZED_ZOO, "onnx/constant_fill_input_as_shape_default_value.prototxt")); + + auto test_case = test::TestCase(function); + test_case.add_expected_output(Shape{1, 2, 3}, {0.f, 0.f, 0.f, 0.f, 0.f, 0.f}); + test_case.run(); +} + +NGRAPH_TEST(${BACKEND_NAME}, onnx_constant_fill_input_as_shape_u8_type) +{ + auto function = onnx_import::import_onnx_model( + file_util::path_join(SERIALIZED_ZOO, "onnx/constant_fill_input_as_shape_u8_type.prototxt")); + + auto test_case = test::TestCase(function); + test_case.add_expected_output(Shape{3, 1, 2}, {3, 3, 3, 3, 3, 3}); + test_case.run(); +} + +NGRAPH_TEST(${BACKEND_NAME}, onnx_constant_fill_extra_shape) +{ + auto function = onnx_import::import_onnx_model( + file_util::path_join(SERIALIZED_ZOO, "onnx/constant_fill_extra_shape.prototxt")); + + auto test_case = test::TestCase(function); + test_case.add_expected_output(Shape{3, 1, 2, 2, 1}, std::vector(12, 3.0f)); + test_case.run(); +} + +NGRAPH_TEST(${BACKEND_NAME}, onnx_constant_fill_shape_attribute) +{ + auto function = onnx_import::import_onnx_model( + file_util::path_join(SERIALIZED_ZOO, "onnx/constant_fill_shape_attribute.prototxt")); + + auto test_case = test::TestCase(function); + test_case.add_expected_output(Shape{2, 3, 4}, std::vector(24, 5)); + test_case.run(); +}