[ONNX] Add support for ONNX ConstantFill op (#5203)
This commit is contained in:
parent
891cf56255
commit
7ac7215924
@ -1,12 +1,42 @@
|
|||||||
// Copyright (C) 2018-2021 Intel Corporation
|
// Copyright (C) 2018-2021 Intel Corporation
|
||||||
// SPDX-License-Identifier: Apache-2.0
|
// SPDX-License-Identifier: Apache-2.0
|
||||||
//
|
//
|
||||||
|
#include "ngraph/type/element_type.hpp"
|
||||||
|
|
||||||
|
namespace ONNX_NAMESPACE
|
||||||
|
{
|
||||||
|
enum TensorProto_DataType;
|
||||||
|
}
|
||||||
|
|
||||||
namespace ngraph
|
namespace ngraph
|
||||||
{
|
{
|
||||||
namespace onnx_common
|
namespace onnx_common
|
||||||
{
|
{
|
||||||
|
/// \brief Retuns size of an ONNX data type in bytes.
|
||||||
|
///
|
||||||
|
/// \param onnx_type Number assigned to an ONNX data type in the TensorProto_DataType enum.
|
||||||
|
///
|
||||||
size_t get_onnx_data_size(int32_t onnx_type);
|
size_t get_onnx_data_size(int32_t onnx_type);
|
||||||
|
|
||||||
|
/// \brief Retuns a nGraph data type corresponding to an ONNX type.
|
||||||
|
///
|
||||||
|
/// \param onnx_type An element of TensorProto_DataType enum which determines an ONNX type.
|
||||||
|
///
|
||||||
|
element::Type_t onnx_to_ng_data_type(const ONNX_NAMESPACE::TensorProto_DataType& onnx_type);
|
||||||
|
|
||||||
|
/// \brief Retuns an ONNX data type corresponding to a nGraph data type.
|
||||||
|
///
|
||||||
|
/// \param ng_type An element of element::Type_t enum class which determines a nGraph data
|
||||||
|
/// type.
|
||||||
|
///
|
||||||
|
ONNX_NAMESPACE::TensorProto_DataType ng_to_onnx_data_type(const element::Type_t& ng_type);
|
||||||
|
|
||||||
|
/// \brief Retuns true if a nGraph data type is mapped to an ONNX data type.
|
||||||
|
///
|
||||||
|
/// \param ng_type An element of element::Type_t enum class which determines a nGraph data
|
||||||
|
/// type.
|
||||||
|
///
|
||||||
|
bool is_supported_ng_type(const element::Type_t& ng_type);
|
||||||
|
|
||||||
} // namespace onnx_editor
|
} // namespace onnx_editor
|
||||||
} // namespace ngraph
|
} // namespace ngraph
|
||||||
|
@ -2,6 +2,7 @@
|
|||||||
// SPDX-License-Identifier: Apache-2.0
|
// SPDX-License-Identifier: Apache-2.0
|
||||||
//
|
//
|
||||||
|
|
||||||
|
#include <algorithm>
|
||||||
#include <onnx/onnx_pb.h>
|
#include <onnx/onnx_pb.h>
|
||||||
|
|
||||||
#include "ngraph/except.hpp"
|
#include "ngraph/except.hpp"
|
||||||
@ -39,5 +40,53 @@ namespace ngraph
|
|||||||
static_cast<ONNX_NAMESPACE::TensorProto_DataType>(onnx_type)));
|
static_cast<ONNX_NAMESPACE::TensorProto_DataType>(onnx_type)));
|
||||||
#endif
|
#endif
|
||||||
}
|
}
|
||||||
|
namespace
|
||||||
|
{
|
||||||
|
using namespace ONNX_NAMESPACE;
|
||||||
|
const std::map<element::Type_t, TensorProto_DataType> NG_2_ONNX_TYPES = {
|
||||||
|
{element::Type_t::bf16, TensorProto_DataType::TensorProto_DataType_BFLOAT16},
|
||||||
|
{element::Type_t::f16, TensorProto_DataType::TensorProto_DataType_FLOAT16},
|
||||||
|
{element::Type_t::f32, TensorProto_DataType::TensorProto_DataType_FLOAT},
|
||||||
|
{element::Type_t::f64, TensorProto_DataType::TensorProto_DataType_DOUBLE},
|
||||||
|
{element::Type_t::i8, TensorProto_DataType::TensorProto_DataType_INT8},
|
||||||
|
{element::Type_t::i16, TensorProto_DataType::TensorProto_DataType_INT16},
|
||||||
|
{element::Type_t::i32, TensorProto_DataType::TensorProto_DataType_INT32},
|
||||||
|
{element::Type_t::i64, TensorProto_DataType::TensorProto_DataType_INT64},
|
||||||
|
{element::Type_t::u8, TensorProto_DataType::TensorProto_DataType_UINT8},
|
||||||
|
{element::Type_t::u16, TensorProto_DataType::TensorProto_DataType_UINT16},
|
||||||
|
{element::Type_t::u32, TensorProto_DataType::TensorProto_DataType_UINT32},
|
||||||
|
{element::Type_t::u64, TensorProto_DataType::TensorProto_DataType_UINT64},
|
||||||
|
{element::Type_t::boolean, TensorProto_DataType::TensorProto_DataType_BOOL}};
|
||||||
|
}
|
||||||
|
|
||||||
|
element::Type_t onnx_to_ng_data_type(const TensorProto_DataType& onnx_type)
|
||||||
|
{
|
||||||
|
const auto result = std::find_if(
|
||||||
|
NG_2_ONNX_TYPES.begin(),
|
||||||
|
NG_2_ONNX_TYPES.end(),
|
||||||
|
[&onnx_type](
|
||||||
|
const std::pair<element::Type_t, ONNX_NAMESPACE::TensorProto_DataType>& pair) {
|
||||||
|
return pair.second == onnx_type;
|
||||||
|
});
|
||||||
|
if (result == std::end(NG_2_ONNX_TYPES))
|
||||||
|
{
|
||||||
|
throw ngraph_error(
|
||||||
|
"unsupported element type: " +
|
||||||
|
ONNX_NAMESPACE::TensorProto_DataType_Name(
|
||||||
|
static_cast<ONNX_NAMESPACE::TensorProto_DataType>(onnx_type)));
|
||||||
|
}
|
||||||
|
return result->first;
|
||||||
|
}
|
||||||
|
|
||||||
|
TensorProto_DataType ng_to_onnx_data_type(const element::Type_t& ng_type)
|
||||||
|
{
|
||||||
|
return NG_2_ONNX_TYPES.at(ng_type);
|
||||||
|
}
|
||||||
|
|
||||||
|
bool is_supported_ng_type(const element::Type_t& ng_type)
|
||||||
|
{
|
||||||
|
return NG_2_ONNX_TYPES.count(ng_type) > 0;
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace onnx_editor
|
} // namespace onnx_editor
|
||||||
} // namespace ngraph
|
} // namespace ngraph
|
||||||
|
@ -19,21 +19,6 @@ namespace
|
|||||||
{
|
{
|
||||||
using namespace ONNX_NAMESPACE;
|
using namespace ONNX_NAMESPACE;
|
||||||
|
|
||||||
const std::map<element::Type_t, TensorProto_DataType> NG_2_ONNX_TYPES = {
|
|
||||||
{element::Type_t::bf16, TensorProto_DataType::TensorProto_DataType_BFLOAT16},
|
|
||||||
{element::Type_t::f16, TensorProto_DataType::TensorProto_DataType_FLOAT16},
|
|
||||||
{element::Type_t::f32, TensorProto_DataType::TensorProto_DataType_FLOAT},
|
|
||||||
{element::Type_t::f64, TensorProto_DataType::TensorProto_DataType_DOUBLE},
|
|
||||||
{element::Type_t::i8, TensorProto_DataType::TensorProto_DataType_INT8},
|
|
||||||
{element::Type_t::i16, TensorProto_DataType::TensorProto_DataType_INT16},
|
|
||||||
{element::Type_t::i32, TensorProto_DataType::TensorProto_DataType_INT32},
|
|
||||||
{element::Type_t::i64, TensorProto_DataType::TensorProto_DataType_INT64},
|
|
||||||
{element::Type_t::u8, TensorProto_DataType::TensorProto_DataType_UINT8},
|
|
||||||
{element::Type_t::u16, TensorProto_DataType::TensorProto_DataType_UINT16},
|
|
||||||
{element::Type_t::u32, TensorProto_DataType::TensorProto_DataType_UINT32},
|
|
||||||
{element::Type_t::u64, TensorProto_DataType::TensorProto_DataType_UINT64},
|
|
||||||
};
|
|
||||||
|
|
||||||
ValueInfoProto* find_graph_input(GraphProto& graph, const std::string& name)
|
ValueInfoProto* find_graph_input(GraphProto& graph, const std::string& name)
|
||||||
{
|
{
|
||||||
for (int i = 0; i < graph.input_size(); ++i)
|
for (int i = 0; i < graph.input_size(); ++i)
|
||||||
@ -80,16 +65,17 @@ namespace
|
|||||||
}
|
}
|
||||||
|
|
||||||
auto* tensor_type = type_proto->mutable_tensor_type();
|
auto* tensor_type = type_proto->mutable_tensor_type();
|
||||||
if (NG_2_ONNX_TYPES.count(elem_type) == 0)
|
|
||||||
|
if (onnx_common::is_supported_ng_type(elem_type))
|
||||||
|
{
|
||||||
|
tensor_type->set_elem_type(onnx_common::ng_to_onnx_data_type(elem_type));
|
||||||
|
}
|
||||||
|
else
|
||||||
{
|
{
|
||||||
throw ngraph_error("The input type for input '" + onnx_input.name() +
|
throw ngraph_error("The input type for input '" + onnx_input.name() +
|
||||||
"' cannot be set to: " + element::Type(elem_type).get_type_name() +
|
"' cannot be set to: " + element::Type(elem_type).get_type_name() +
|
||||||
". This type is not allowed in ONNX.");
|
". This type is not allowed in ONNX.");
|
||||||
}
|
}
|
||||||
else
|
|
||||||
{
|
|
||||||
tensor_type->set_elem_type(NG_2_ONNX_TYPES.at(elem_type));
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
void add_dim_to_onnx_shape(const Dimension& dim, ONNX_NAMESPACE::TensorShapeProto& onnx_shape)
|
void add_dim_to_onnx_shape(const Dimension& dim, ONNX_NAMESPACE::TensorShapeProto& onnx_shape)
|
||||||
@ -160,7 +146,7 @@ namespace
|
|||||||
ValueInfoProto* input)
|
ValueInfoProto* input)
|
||||||
{
|
{
|
||||||
const auto elem_type = values->get_element_type();
|
const auto elem_type = values->get_element_type();
|
||||||
if (NG_2_ONNX_TYPES.count(elem_type) == 0)
|
if (!onnx_common::is_supported_ng_type(elem_type))
|
||||||
{
|
{
|
||||||
throw ngraph_error("Initializer '" + name + "' type cannot be set to: " +
|
throw ngraph_error("Initializer '" + name + "' type cannot be set to: " +
|
||||||
element::Type(elem_type).get_type_name() +
|
element::Type(elem_type).get_type_name() +
|
||||||
@ -170,7 +156,7 @@ namespace
|
|||||||
initializer.Clear();
|
initializer.Clear();
|
||||||
|
|
||||||
initializer.set_name(name);
|
initializer.set_name(name);
|
||||||
initializer.set_data_type(NG_2_ONNX_TYPES.at(values->get_element_type()));
|
initializer.set_data_type(onnx_common::ng_to_onnx_data_type(values->get_element_type()));
|
||||||
|
|
||||||
for (const auto& dim : values->get_shape())
|
for (const auto& dim : values->get_shape())
|
||||||
{
|
{
|
||||||
|
66
ngraph/frontend/onnx_import/src/op/constant_fill.cpp
Normal file
66
ngraph/frontend/onnx_import/src/op/constant_fill.cpp
Normal file
@ -0,0 +1,66 @@
|
|||||||
|
// Copyright (C) 2018-2021 Intel Corporation
|
||||||
|
// SPDX-License-Identifier: Apache-2.0
|
||||||
|
//
|
||||||
|
|
||||||
|
#include <onnx/onnx_pb.h> // onnx types
|
||||||
|
|
||||||
|
#include "default_opset.hpp"
|
||||||
|
#include "exceptions.hpp"
|
||||||
|
#include "ngraph/op/broadcast.hpp"
|
||||||
|
#include "ngraph/op/concat.hpp"
|
||||||
|
#include "ngraph/op/constant.hpp"
|
||||||
|
#include "onnx_common/utils.hpp"
|
||||||
|
|
||||||
|
namespace ngraph
|
||||||
|
{
|
||||||
|
namespace onnx_import
|
||||||
|
{
|
||||||
|
namespace op
|
||||||
|
{
|
||||||
|
namespace set_1
|
||||||
|
{
|
||||||
|
OutputVector constant_fill(const Node& node)
|
||||||
|
{
|
||||||
|
Output<ngraph::Node> target_shape;
|
||||||
|
const auto fill_value = node.get_attribute_value<float>("value", 0.f);
|
||||||
|
const auto dtype = node.get_attribute_value<int64_t>(
|
||||||
|
"dtype", static_cast<int64_t>(ONNX_NAMESPACE::TensorProto_DataType_FLOAT));
|
||||||
|
const auto ng_type = onnx_common::onnx_to_ng_data_type(
|
||||||
|
static_cast<ONNX_NAMESPACE::TensorProto_DataType>(dtype));
|
||||||
|
const auto const_val_to_fill =
|
||||||
|
default_opset::Constant::create(ng_type, {}, {fill_value});
|
||||||
|
const auto input_as_shape =
|
||||||
|
node.get_attribute_value<int64_t>("input_as_shape", 1);
|
||||||
|
if (input_as_shape == 1) // use the first input as target shape
|
||||||
|
{
|
||||||
|
CHECK_VALID_NODE(
|
||||||
|
node,
|
||||||
|
node.get_ng_inputs().size() > 0,
|
||||||
|
"The input which determines output shape was not provided");
|
||||||
|
target_shape = node.get_ng_inputs().at(0);
|
||||||
|
if (node.has_attribute("extra_shape"))
|
||||||
|
{
|
||||||
|
const auto extra_shape =
|
||||||
|
node.get_attribute_value<std::vector<int64_t>>("extra_shape");
|
||||||
|
const auto extra_shape_const = default_opset::Constant::create(
|
||||||
|
target_shape.get_element_type(), {extra_shape.size()}, extra_shape);
|
||||||
|
target_shape = std::make_shared<default_opset::Concat>(
|
||||||
|
OutputVector{target_shape, extra_shape_const}, 0);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
else // use shape attribute as target shape
|
||||||
|
{
|
||||||
|
const auto shape = node.get_attribute_value<std::vector<int64_t>>("shape");
|
||||||
|
target_shape =
|
||||||
|
default_opset::Constant::create(ng_type, {shape.size()}, shape);
|
||||||
|
}
|
||||||
|
|
||||||
|
return {std::make_shared<default_opset::Broadcast>(const_val_to_fill,
|
||||||
|
target_shape)};
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace set_1
|
||||||
|
} // namespace op
|
||||||
|
} // namespace onnx_import
|
||||||
|
|
||||||
|
} // namespace ngraph
|
28
ngraph/frontend/onnx_import/src/op/constant_fill.hpp
Normal file
28
ngraph/frontend/onnx_import/src/op/constant_fill.hpp
Normal file
@ -0,0 +1,28 @@
|
|||||||
|
// Copyright (C) 2018-2021 Intel Corporation
|
||||||
|
// SPDX-License-Identifier: Apache-2.0
|
||||||
|
//
|
||||||
|
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include <memory>
|
||||||
|
|
||||||
|
#include "ngraph/node.hpp"
|
||||||
|
#include "onnx_import/core/node.hpp"
|
||||||
|
|
||||||
|
namespace ngraph
|
||||||
|
{
|
||||||
|
namespace onnx_import
|
||||||
|
{
|
||||||
|
namespace op
|
||||||
|
{
|
||||||
|
namespace set_1
|
||||||
|
{
|
||||||
|
// ConstantFill is a deprecated experimental operator removed in ONNX 1.4
|
||||||
|
OutputVector constant_fill(const Node& node);
|
||||||
|
} // namespace set_1
|
||||||
|
|
||||||
|
} // namespace op
|
||||||
|
|
||||||
|
} // namespace onnx_import
|
||||||
|
|
||||||
|
} // namespace ngraph
|
@ -29,6 +29,7 @@
|
|||||||
#include "op/clip.hpp"
|
#include "op/clip.hpp"
|
||||||
#include "op/concat.hpp"
|
#include "op/concat.hpp"
|
||||||
#include "op/constant.hpp"
|
#include "op/constant.hpp"
|
||||||
|
#include "op/constant_fill.hpp"
|
||||||
#include "op/constant_of_shape.hpp"
|
#include "op/constant_of_shape.hpp"
|
||||||
#include "op/conv.hpp"
|
#include "op/conv.hpp"
|
||||||
// #include "op/conv_integer.hpp"
|
// #include "op/conv_integer.hpp"
|
||||||
@ -327,6 +328,7 @@ namespace ngraph
|
|||||||
REGISTER_OPERATOR("ConvTranspose", 1, conv_transpose);
|
REGISTER_OPERATOR("ConvTranspose", 1, conv_transpose);
|
||||||
REGISTER_OPERATOR("Cos", 1, cos);
|
REGISTER_OPERATOR("Cos", 1, cos);
|
||||||
REGISTER_OPERATOR("Cosh", 1, cosh);
|
REGISTER_OPERATOR("Cosh", 1, cosh);
|
||||||
|
REGISTER_OPERATOR("ConstantFill", 1, constant_fill);
|
||||||
REGISTER_OPERATOR("CumSum", 1, cum_sum);
|
REGISTER_OPERATOR("CumSum", 1, cum_sum);
|
||||||
REGISTER_OPERATOR("DepthToSpace", 1, depth_to_space);
|
REGISTER_OPERATOR("DepthToSpace", 1, depth_to_space);
|
||||||
REGISTER_OPERATOR("DequantizeLinear", 1, dequantize_linear);
|
REGISTER_OPERATOR("DequantizeLinear", 1, dequantize_linear);
|
||||||
|
59
ngraph/test/models/onnx/constant_fill_extra_shape.prototxt
Normal file
59
ngraph/test/models/onnx/constant_fill_extra_shape.prototxt
Normal file
@ -0,0 +1,59 @@
|
|||||||
|
ir_version: 7
|
||||||
|
producer_name: "backend-test"
|
||||||
|
graph {
|
||||||
|
node {
|
||||||
|
input: "target_shape"
|
||||||
|
output: "output"
|
||||||
|
op_type: "ConstantFill"
|
||||||
|
attribute {
|
||||||
|
name: "input_as_shape"
|
||||||
|
i: 1
|
||||||
|
type: INT
|
||||||
|
}
|
||||||
|
attribute {
|
||||||
|
name: "value"
|
||||||
|
i: 3
|
||||||
|
type: INT
|
||||||
|
}
|
||||||
|
attribute {
|
||||||
|
name: "extra_shape"
|
||||||
|
ints: 2
|
||||||
|
ints: 1
|
||||||
|
type: INTS
|
||||||
|
}
|
||||||
|
}
|
||||||
|
input {
|
||||||
|
name: "target_shape"
|
||||||
|
type {
|
||||||
|
tensor_type {
|
||||||
|
elem_type: 2
|
||||||
|
shape {
|
||||||
|
dim {
|
||||||
|
dim_value: 3
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
initializer {
|
||||||
|
dims: 3
|
||||||
|
data_type: 7
|
||||||
|
int64_data: 3
|
||||||
|
int64_data: 1
|
||||||
|
int64_data: 2
|
||||||
|
name: "target_shape"
|
||||||
|
}
|
||||||
|
output {
|
||||||
|
name: "output"
|
||||||
|
type {
|
||||||
|
tensor_type {
|
||||||
|
elem_type: 1
|
||||||
|
shape {
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
opset_import {
|
||||||
|
version: 13
|
||||||
|
}
|
@ -0,0 +1,48 @@
|
|||||||
|
ir_version: 7
|
||||||
|
producer_name: "backend-test"
|
||||||
|
graph {
|
||||||
|
node {
|
||||||
|
input: "target_shape"
|
||||||
|
output: "output"
|
||||||
|
op_type: "ConstantFill"
|
||||||
|
attribute {
|
||||||
|
name: "input_as_shape"
|
||||||
|
i: 1
|
||||||
|
type: INT
|
||||||
|
}
|
||||||
|
}
|
||||||
|
input {
|
||||||
|
name: "target_shape"
|
||||||
|
type {
|
||||||
|
tensor_type {
|
||||||
|
elem_type: 7
|
||||||
|
shape {
|
||||||
|
dim {
|
||||||
|
dim_value: 3
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
initializer {
|
||||||
|
dims: 3
|
||||||
|
data_type: 7
|
||||||
|
int64_data: 1
|
||||||
|
int64_data: 2
|
||||||
|
int64_data: 3
|
||||||
|
name: "target_shape"
|
||||||
|
}
|
||||||
|
output {
|
||||||
|
name: "output"
|
||||||
|
type {
|
||||||
|
tensor_type {
|
||||||
|
elem_type: 1
|
||||||
|
shape {
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
opset_import {
|
||||||
|
version: 13
|
||||||
|
}
|
@ -0,0 +1,58 @@
|
|||||||
|
ir_version: 7
|
||||||
|
producer_name: "backend-test"
|
||||||
|
graph {
|
||||||
|
node {
|
||||||
|
input: "target_shape"
|
||||||
|
output: "output"
|
||||||
|
op_type: "ConstantFill"
|
||||||
|
attribute {
|
||||||
|
name: "input_as_shape"
|
||||||
|
i: 1
|
||||||
|
type: INT
|
||||||
|
}
|
||||||
|
attribute {
|
||||||
|
name: "dtype"
|
||||||
|
i: 2
|
||||||
|
type: INT
|
||||||
|
}
|
||||||
|
attribute {
|
||||||
|
name: "value"
|
||||||
|
i: 3
|
||||||
|
type: INT
|
||||||
|
}
|
||||||
|
}
|
||||||
|
input {
|
||||||
|
name: "target_shape"
|
||||||
|
type {
|
||||||
|
tensor_type {
|
||||||
|
elem_type: 2
|
||||||
|
shape {
|
||||||
|
dim {
|
||||||
|
dim_value: 3
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
initializer {
|
||||||
|
dims: 3
|
||||||
|
data_type: 7
|
||||||
|
int64_data: 3
|
||||||
|
int64_data: 1
|
||||||
|
int64_data: 2
|
||||||
|
name: "target_shape"
|
||||||
|
}
|
||||||
|
output {
|
||||||
|
name: "output"
|
||||||
|
type {
|
||||||
|
tensor_type {
|
||||||
|
elem_type: 9
|
||||||
|
shape {
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
opset_import {
|
||||||
|
version: 13
|
||||||
|
}
|
@ -0,0 +1,44 @@
|
|||||||
|
ir_version: 7
|
||||||
|
producer_name: "backend-test"
|
||||||
|
graph {
|
||||||
|
node {
|
||||||
|
input: "target_shape"
|
||||||
|
output: "output"
|
||||||
|
op_type: "ConstantFill"
|
||||||
|
attribute {
|
||||||
|
name: "input_as_shape"
|
||||||
|
i: 0
|
||||||
|
type: INT
|
||||||
|
}
|
||||||
|
attribute {
|
||||||
|
name: "dtype"
|
||||||
|
i: 6
|
||||||
|
type: INT
|
||||||
|
}
|
||||||
|
attribute {
|
||||||
|
name: "value"
|
||||||
|
i: 5
|
||||||
|
type: INT
|
||||||
|
}
|
||||||
|
attribute {
|
||||||
|
name: "shape"
|
||||||
|
ints: 2
|
||||||
|
ints: 3
|
||||||
|
ints: 4
|
||||||
|
type: INTS
|
||||||
|
}
|
||||||
|
}
|
||||||
|
output {
|
||||||
|
name: "output"
|
||||||
|
type {
|
||||||
|
tensor_type {
|
||||||
|
elem_type: 6
|
||||||
|
shape {
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
opset_import {
|
||||||
|
version: 13
|
||||||
|
}
|
@ -4258,3 +4258,43 @@ NGRAPH_TEST(${BACKEND_NAME}, onnx_negativelog_likelihood_loss)
|
|||||||
test_case.add_expected_output<float>(Shape{}, {-0.531306922435760498});
|
test_case.add_expected_output<float>(Shape{}, {-0.531306922435760498});
|
||||||
test_case.run();
|
test_case.run();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
NGRAPH_TEST(${BACKEND_NAME}, onnx_constant_fill_input_as_shape_default_value)
|
||||||
|
{
|
||||||
|
auto function = onnx_import::import_onnx_model(
|
||||||
|
file_util::path_join(SERIALIZED_ZOO, "onnx/constant_fill_input_as_shape_default_value.prototxt"));
|
||||||
|
|
||||||
|
auto test_case = test::TestCase<TestEngine>(function);
|
||||||
|
test_case.add_expected_output<float>(Shape{1, 2, 3}, {0.f, 0.f, 0.f, 0.f, 0.f, 0.f});
|
||||||
|
test_case.run();
|
||||||
|
}
|
||||||
|
|
||||||
|
NGRAPH_TEST(${BACKEND_NAME}, onnx_constant_fill_input_as_shape_u8_type)
|
||||||
|
{
|
||||||
|
auto function = onnx_import::import_onnx_model(
|
||||||
|
file_util::path_join(SERIALIZED_ZOO, "onnx/constant_fill_input_as_shape_u8_type.prototxt"));
|
||||||
|
|
||||||
|
auto test_case = test::TestCase<TestEngine>(function);
|
||||||
|
test_case.add_expected_output<uint8_t>(Shape{3, 1, 2}, {3, 3, 3, 3, 3, 3});
|
||||||
|
test_case.run();
|
||||||
|
}
|
||||||
|
|
||||||
|
NGRAPH_TEST(${BACKEND_NAME}, onnx_constant_fill_extra_shape)
|
||||||
|
{
|
||||||
|
auto function = onnx_import::import_onnx_model(
|
||||||
|
file_util::path_join(SERIALIZED_ZOO, "onnx/constant_fill_extra_shape.prototxt"));
|
||||||
|
|
||||||
|
auto test_case = test::TestCase<TestEngine>(function);
|
||||||
|
test_case.add_expected_output<float>(Shape{3, 1, 2, 2, 1}, std::vector<float>(12, 3.0f));
|
||||||
|
test_case.run();
|
||||||
|
}
|
||||||
|
|
||||||
|
NGRAPH_TEST(${BACKEND_NAME}, onnx_constant_fill_shape_attribute)
|
||||||
|
{
|
||||||
|
auto function = onnx_import::import_onnx_model(
|
||||||
|
file_util::path_join(SERIALIZED_ZOO, "onnx/constant_fill_shape_attribute.prototxt"));
|
||||||
|
|
||||||
|
auto test_case = test::TestCase<TestEngine>(function);
|
||||||
|
test_case.add_expected_output<int32_t>(Shape{2, 3, 4}, std::vector<int32_t>(24, 5));
|
||||||
|
test_case.run();
|
||||||
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user