[ONNX FE] Support for ONNX ATen (embedding bag) (#8802)

This commit is contained in:
Katarzyna Mitrus 2021-11-26 08:30:45 +01:00 committed by GitHub
parent 60fb05bb6a
commit bf504c00d3
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
13 changed files with 1070 additions and 0 deletions

View File

@ -0,0 +1,93 @@
// Copyright (C) 2018-2021 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include "op/aten.hpp"
#include "default_opset.hpp"
#include "exceptions.hpp"
#include "onnx_import/core/node.hpp"
#include "onnx_import/core/null_node.hpp"
#include "openvino/opsets/opset8.hpp"
namespace ngraph {
namespace onnx_import {
namespace op {
namespace set_1 {
OutputVector aten(const Node& node) {
OutputVector inputs{node.get_ng_inputs()};
const auto operator_name = node.get_attribute_value<std::string>("operator", "");
CHECK_VALID_NODE(node,
operator_name == "embedding_bag",
"Only `embedding_bag` is supported as ATen `operator` attribute. Got: ",
operator_name);
const auto mode = node.get_attribute_value<int64_t>("mode");
CHECK_VALID_NODE(node,
mode == 0,
"Unsupported mode, only `0` (sum) is supported as ATen embedding_bag `mode` attribute. Got: ",
mode);
CHECK_VALID_NODE(node, inputs.size() >= 2, "Minimum 2 inputs are required. Got: ", inputs.size());
const bool is_packed_two_inputs =
inputs.size() == 2 || (inputs.size() == 3 && ngraph::op::is_null(inputs[2])) ||
(inputs.size() == 4 && ngraph::op::is_null(inputs[2]) && ngraph::op::is_null(inputs[3]));
const bool is_packed_three_inputs =
inputs.size() == 4 && ngraph::op::is_null(inputs[2]) && !ngraph::op::is_null(inputs[3]);
const bool is_offsets_three_inputs = inputs.size() == 3 && !ngraph::op::is_null(inputs[2]);
Output<ov::Node> embedding_bag;
if (is_packed_two_inputs) {
embedding_bag = std::make_shared<default_opset::EmbeddingBagPackedSum>(inputs[0], inputs[1]);
} else if (is_packed_three_inputs) {
embedding_bag = std::make_shared<default_opset::EmbeddingBagPackedSum>(inputs[0], inputs[1], inputs[3]);
} else if (is_offsets_three_inputs) {
embedding_bag = std::make_shared<default_opset::EmbeddingBagOffsetsSum>(inputs[0], inputs[1], inputs[2]);
} else if (inputs.size() >= 4) {
// Need to expand embedding table with zeros (default values for empty bags)
const auto& emb_tbl_in = inputs[0];
const auto& indices_in = inputs[1];
const auto& offsets_in = inputs[2];
const auto& per_sample_weights_in = inputs[3];
const auto data_type = emb_tbl_in.get_element_type();
const auto ind_type = indices_in.get_element_type();
const auto zero_const = std::make_shared<default_opset::Constant>(ind_type, Shape{}, 0);
// Shape aligned node, filled with zeros
const auto zero_of_data_type_const = std::make_shared<default_opset::Constant>(data_type, Shape{1}, 0);
const auto weights_shape_node = std::make_shared<default_opset::ShapeOf>(emb_tbl_in, ind_type);
const auto weights_last_dim_idx = std::make_shared<default_opset::Constant>(element::i32, Shape{1}, -1);
const auto weights_last_dim =
std::make_shared<ov::opset8::Gather>(weights_shape_node, weights_last_dim_idx, zero_const);
const auto zero_col_node =
std::make_shared<default_opset::Broadcast>(zero_of_data_type_const, weights_last_dim);
const auto default_embeddings_node = std::make_shared<default_opset::Unsqueeze>(zero_col_node, zero_const);
// Expanded embedding table weights
const auto weights_concat =
std::make_shared<default_opset::Concat>(OutputVector{emb_tbl_in, default_embeddings_node}, 0);
// Index in embedding table to fill empty bags
const auto weights_first_dim = std::make_shared<default_opset::Squeeze>(
std::make_shared<default_opset::Gather>(weights_shape_node, zero_const, zero_const));
embedding_bag = std::make_shared<default_opset::EmbeddingBagOffsetsSum>(weights_concat,
indices_in,
offsets_in,
weights_first_dim, // default index
per_sample_weights_in);
} else {
OPENVINO_UNREACHABLE("Unsupported inputs configuration for ATen `embedding_bag` operation.");
}
// Enable import onnx Node with duplicated outputs
return OutputVector(node.get_outputs_size(), embedding_bag);
}
} // namespace set_1
} // namespace op
} // namespace onnx_import
} // namespace ngraph

View File

@ -0,0 +1,21 @@
// Copyright (C) 2018-2021 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#pragma once
#include "default_opset.hpp"
#include "ngraph/node.hpp"
#include "onnx_import/core/node.hpp"
namespace ngraph {
namespace onnx_import {
namespace op {
namespace set_1 {
OutputVector aten(const Node& node);
} // namespace set_1
} // namespace op
} // namespace onnx_import
} // namespace ngraph

View File

@ -24,6 +24,7 @@
#include "op/asinh.hpp"
#include "op/atan.hpp"
#include "op/atanh.hpp"
#include "op/aten.hpp"
#include "op/average_pool.hpp"
#include "op/batch_norm.hpp"
#include "op/bitshift.hpp"
@ -293,6 +294,7 @@ OperatorsBridge::OperatorsBridge() {
REGISTER_OPERATOR("Asin", 1, asin);
REGISTER_OPERATOR("Asinh", 1, asinh);
REGISTER_OPERATOR("Atan", 1, atan);
REGISTER_OPERATOR("ATen", 1, aten);
REGISTER_OPERATOR("Atanh", 1, atanh);
REGISTER_OPERATOR("AveragePool", 1, average_pool);
REGISTER_OPERATOR("BatchNormalization", 1, batch_norm);

View File

@ -0,0 +1,96 @@
ir_version: 3
producer_name: "onnx_import_test"
graph {
node {
input: "emb_tbl"
input: "indices"
input: "offsets"
output: "result_0"
output: "result_1"
output: "result_2"
output: "result_3"
op_type: "ATen"
attribute {
name: "mode"
i: 0
type: INT
}
attribute {
name: "operator"
s: "embedding_bag"
type: STRING
}
attribute {
name: "scale_grad_by_freq"
i: 0
type: INT
}
attribute {
name: "sparse"
i: 1
type: INT
}
}
name: "test_aten_model"
input {
name: "emb_tbl"
type {
tensor_type {
elem_type: 1
shape {
dim {
dim_value: 5
}
dim {
dim_value: 2
}
}
}
}
}
input {
name: "indices"
type {
tensor_type {
elem_type: 6
shape {
dim {
dim_value: 4
}
}
}
}
}
input {
name: "offsets"
type {
tensor_type {
elem_type: 6
shape {
dim {
dim_value: 3
}
}
}
}
}
output {
name: "result_0"
type {
tensor_type {
elem_type: 1
shape {
dim {
dim_value: 3
}
dim {
dim_value: 2
}
}
}
}
}
}
opset_import {
version: 1
}

View File

@ -0,0 +1,93 @@
ir_version: 3
producer_name: "onnx_import_test"
graph {
node {
input: "emb_tbl"
input: "indices"
input: "offsets"
output: "result_0"
op_type: "ATen"
attribute {
name: "mode"
i: 0
type: INT
}
attribute {
name: "operator"
s: "embedding_bag"
type: STRING
}
attribute {
name: "scale_grad_by_freq"
i: 0
type: INT
}
attribute {
name: "sparse"
i: 1
type: INT
}
}
name: "test_aten_model"
input {
name: "emb_tbl"
type {
tensor_type {
elem_type: 1
shape {
dim {
dim_value: 5
}
dim {
dim_value: 2
}
}
}
}
}
input {
name: "indices"
type {
tensor_type {
elem_type: 6
shape {
dim {
dim_value: 4
}
}
}
}
}
input {
name: "offsets"
type {
tensor_type {
elem_type: 6
shape {
dim {
dim_value: 3
}
}
}
}
}
output {
name: "result_0"
type {
tensor_type {
elem_type: 1
shape {
dim {
dim_value: 3
}
dim {
dim_value: 2
}
}
}
}
}
}
opset_import {
version: 1
}

View File

@ -0,0 +1,107 @@
ir_version: 3
producer_name: "onnx_import_test"
graph {
node {
input: "emb_tbl"
input: "indices"
input: "offsets"
input: "per_sample_weights"
output: "result_0"
op_type: "ATen"
attribute {
name: "mode"
i: 0
type: INT
}
attribute {
name: "operator"
s: "embedding_bag"
type: STRING
}
attribute {
name: "scale_grad_by_freq"
i: 0
type: INT
}
attribute {
name: "sparse"
i: 1
type: INT
}
}
name: "test_aten_model"
input {
name: "emb_tbl"
type {
tensor_type {
elem_type: 1
shape {
dim {
dim_value: 5
}
dim {
dim_value: 2
}
}
}
}
}
input {
name: "indices"
type {
tensor_type {
elem_type: 6
shape {
dim {
dim_value: 4
}
}
}
}
}
input {
name: "offsets"
type {
tensor_type {
elem_type: 6
shape {
dim {
dim_value: 3
}
}
}
}
}
input {
name: "per_sample_weights"
type {
tensor_type {
elem_type: 1
shape {
dim {
dim_value: 4
}
}
}
}
}
output {
name: "result_0"
type {
tensor_type {
elem_type: 1
shape {
dim {
dim_value: 3
}
dim {
dim_value: 2
}
}
}
}
}
}
opset_import {
version: 1
}

View File

@ -0,0 +1,82 @@
ir_version: 3
producer_name: "onnx_import_test"
graph {
node {
input: "emb_tbl"
input: "indices"
output: "result_0"
op_type: "ATen"
attribute {
name: "mode"
i: 0
type: INT
}
attribute {
name: "operator"
s: "embedding_bag"
type: STRING
}
attribute {
name: "scale_grad_by_freq"
i: 0
type: INT
}
attribute {
name: "sparse"
i: 1
type: INT
}
}
name: "test_aten_model"
input {
name: "emb_tbl"
type {
tensor_type {
elem_type: 1
shape {
dim {
dim_value: 5
}
dim {
dim_value: 2
}
}
}
}
}
input {
name: "indices"
type {
tensor_type {
elem_type: 6
shape {
dim {
dim_value: 3
}
dim {
dim_value: 2
}
}
}
}
}
output {
name: "result_0"
type {
tensor_type {
elem_type: 1
shape {
dim {
dim_value: 3
}
dim {
dim_value: 2
}
}
}
}
}
}
opset_import {
version: 1
}

View File

@ -0,0 +1,83 @@
ir_version: 3
producer_name: "onnx_import_test"
graph {
node {
input: "emb_tbl"
input: "indices"
input: ""
output: "result_0"
op_type: "ATen"
attribute {
name: "mode"
i: 0
type: INT
}
attribute {
name: "operator"
s: "embedding_bag"
type: STRING
}
attribute {
name: "scale_grad_by_freq"
i: 0
type: INT
}
attribute {
name: "sparse"
i: 1
type: INT
}
}
name: "test_aten_model"
input {
name: "emb_tbl"
type {
tensor_type {
elem_type: 1
shape {
dim {
dim_value: 5
}
dim {
dim_value: 2
}
}
}
}
}
input {
name: "indices"
type {
tensor_type {
elem_type: 6
shape {
dim {
dim_value: 3
}
dim {
dim_value: 2
}
}
}
}
}
output {
name: "result_0"
type {
tensor_type {
elem_type: 1
shape {
dim {
dim_value: 3
}
dim {
dim_value: 2
}
}
}
}
}
}
opset_import {
version: 1
}

View File

@ -0,0 +1,100 @@
ir_version: 3
producer_name: "onnx_import_test"
graph {
node {
input: "emb_tbl"
input: "indices"
input: ""
input: "per_sample_weights"
output: "result_0"
op_type: "ATen"
attribute {
name: "mode"
i: 0
type: INT
}
attribute {
name: "operator"
s: "embedding_bag"
type: STRING
}
attribute {
name: "scale_grad_by_freq"
i: 0
type: INT
}
attribute {
name: "sparse"
i: 1
type: INT
}
}
name: "test_aten_model"
input {
name: "emb_tbl"
type {
tensor_type {
elem_type: 1
shape {
dim {
dim_value: 5
}
dim {
dim_value: 2
}
}
}
}
}
input {
name: "indices"
type {
tensor_type {
elem_type: 6
shape {
dim {
dim_value: 3
}
dim {
dim_value: 2
}
}
}
}
}
input {
name: "per_sample_weights"
type {
tensor_type {
elem_type: 1
shape {
dim {
dim_value: 3
}
dim {
dim_value: 2
}
}
}
}
}
output {
name: "result_0"
type {
tensor_type {
elem_type: 1
shape {
dim {
dim_value: 3
}
dim {
dim_value: 2
}
}
}
}
}
}
opset_import {
version: 1
}

View File

@ -0,0 +1,84 @@
ir_version: 3
producer_name: "onnx_import_test"
graph {
node {
input: "emb_tbl"
input: "indices"
input: ""
input: ""
output: "result_0"
op_type: "ATen"
attribute {
name: "mode"
i: 0
type: INT
}
attribute {
name: "operator"
s: "embedding_bag"
type: STRING
}
attribute {
name: "scale_grad_by_freq"
i: 0
type: INT
}
attribute {
name: "sparse"
i: 1
type: INT
}
}
name: "test_aten_model"
input {
name: "emb_tbl"
type {
tensor_type {
elem_type: 1
shape {
dim {
dim_value: 5
}
dim {
dim_value: 2
}
}
}
}
}
input {
name: "indices"
type {
tensor_type {
elem_type: 6
shape {
dim {
dim_value: 3
}
dim {
dim_value: 2
}
}
}
}
}
output {
name: "result_0"
type {
tensor_type {
elem_type: 1
shape {
dim {
dim_value: 3
}
dim {
dim_value: 2
}
}
}
}
}
}
opset_import {
version: 1
}

View File

@ -0,0 +1,93 @@
ir_version: 3
producer_name: "onnx_import_test"
graph {
node {
input: "emb_tbl"
input: "indices"
input: "offsets"
output: "result_0"
op_type: "ATen"
attribute {
name: "mode"
i: 1
type: INT
}
attribute {
name: "operator"
s: "embedding_bag"
type: STRING
}
attribute {
name: "scale_grad_by_freq"
i: 0
type: INT
}
attribute {
name: "sparse"
i: 1
type: INT
}
}
name: "test_aten_model"
input {
name: "emb_tbl"
type {
tensor_type {
elem_type: 1
shape {
dim {
dim_value: 5
}
dim {
dim_value: 2
}
}
}
}
}
input {
name: "indices"
type {
tensor_type {
elem_type: 6
shape {
dim {
dim_value: 4
}
}
}
}
}
input {
name: "offsets"
type {
tensor_type {
elem_type: 6
shape {
dim {
dim_value: 3
}
}
}
}
}
output {
name: "result_0"
type {
tensor_type {
elem_type: 1
shape {
dim {
dim_value: 3
}
dim {
dim_value: 2
}
}
}
}
}
}
opset_import {
version: 1
}

View File

@ -0,0 +1,93 @@
ir_version: 3
producer_name: "onnx_import_test"
graph {
node {
input: "emb_tbl"
input: "indices"
input: "offsets"
output: "result_0"
op_type: "ATen"
attribute {
name: "mode"
i: 1
type: INT
}
attribute {
name: "operator"
s: "test_unsupported_operator"
type: STRING
}
attribute {
name: "scale_grad_by_freq"
i: 0
type: INT
}
attribute {
name: "sparse"
i: 1
type: INT
}
}
name: "test_aten_model"
input {
name: "emb_tbl"
type {
tensor_type {
elem_type: 1
shape {
dim {
dim_value: 5
}
dim {
dim_value: 2
}
}
}
}
}
input {
name: "indices"
type {
tensor_type {
elem_type: 6
shape {
dim {
dim_value: 4
}
}
}
}
}
input {
name: "offsets"
type {
tensor_type {
elem_type: 6
shape {
dim {
dim_value: 3
}
}
}
}
}
output {
name: "result_0"
type {
tensor_type {
elem_type: 1
shape {
dim {
dim_value: 3
}
dim {
dim_value: 2
}
}
}
}
}
}
opset_import {
version: 1
}

View File

@ -4419,3 +4419,126 @@ NGRAPH_TEST(${BACKEND_NAME}, onnx_model_random_normal_like) {
test_case.add_expected_output<float>(Shape{2, 2}, {13.459274, 41.75028, -19.311913, 131.79282});
test_case.run();
}
NGRAPH_TEST(${BACKEND_NAME}, onnx_aten_embedding_bag_packed_sum_2in) {
const auto function =
onnx_import::import_onnx_model(file_util::path_join(SERIALIZED_ZOO, "onnx/aten_embedding_sum_packed_2in.onnx"));
auto test_case = test::TestCase<TestEngine>(function);
test_case.add_input<float>(Shape{5, 2}, {-0.2, -0.6, -0.1, -0.4, -1.9, -1.8, -1., 1.5, 0.8, -0.7});
test_case.add_input<int32_t>(Shape{3, 2}, {0, 2, 1, 2, 3, 4}); // indices
test_case.add_expected_output<float>(Shape{3, 2}, {-2.1, -2.4, -2., -2.2, -0.19999999, 0.8});
test_case.run();
}
NGRAPH_TEST(${BACKEND_NAME}, onnx_aten_embedding_bag_packed_sum_3in_offsets_none) {
const auto function = onnx_import::import_onnx_model(
file_util::path_join(SERIALIZED_ZOO, "onnx/aten_embedding_sum_packed_3in_offset_none.onnx"));
auto test_case = test::TestCase<TestEngine>(function);
test_case.add_input<float>(Shape{5, 2}, {-0.2, -0.6, -0.1, -0.4, -1.9, -1.8, -1., 1.5, 0.8, -0.7});
test_case.add_input<int32_t>(Shape{3, 2}, {0, 2, 1, 2, 3, 4}); // indices
test_case.add_expected_output<float>(Shape{3, 2}, {-2.1, -2.4, -2., -2.2, -0.19999999, 0.8});
test_case.run();
}
NGRAPH_TEST(${BACKEND_NAME}, onnx_aten_embedding_bag_packed_sum_4in_per_sample_weights) {
const auto function = onnx_import::import_onnx_model(
file_util::path_join(SERIALIZED_ZOO, "onnx/aten_embedding_sum_packed_4in_per_sample_weights.onnx"));
auto test_case = test::TestCase<TestEngine>(function);
test_case.add_input<float>(Shape{5, 2}, {-0.2, -0.6, -0.1, -0.4, -1.9, -1.8, -1., 1.5, 0.8, -0.7});
test_case.add_input<int32_t>(Shape{3, 2}, {0, 2, 1, 2, 3, 4}); // indices
test_case.add_input<float>(Shape{3, 2}, {0.5, 0.5, 0.5, 0.5, 0.5, 0.5}); // per_sample_weights
test_case.add_expected_output<float>(Shape{3, 2}, {-1.05, -1.2, -1., -1.1, -0.09999999, 0.4});
test_case.run();
}
NGRAPH_TEST(${BACKEND_NAME}, onnx_aten_embedding_bag_packed_sum_4in_two_none) {
const auto function = onnx_import::import_onnx_model(
file_util::path_join(SERIALIZED_ZOO, "onnx/aten_embedding_sum_packed_4in_two_none.onnx"));
auto test_case = test::TestCase<TestEngine>(function);
test_case.add_input<float>(Shape{5, 2}, {-0.2, -0.6, -0.1, -0.4, -1.9, -1.8, -1., 1.5, 0.8, -0.7});
test_case.add_input<int32_t>(Shape{3, 2}, {0, 2, 1, 2, 3, 4}); // indices
test_case.add_expected_output<float>(Shape{3, 2}, {-2.1, -2.4, -2., -2.2, -0.19999999, 0.8});
test_case.run();
}
NGRAPH_TEST(${BACKEND_NAME}, onnx_aten_embedding_bag_offsets_sum_3in) {
const auto function =
onnx_import::import_onnx_model(file_util::path_join(SERIALIZED_ZOO, "onnx/aten_embedding_sum_offset_3in.onnx"));
auto test_case = test::TestCase<TestEngine>(function);
test_case.add_input<float>(Shape{5, 2}, {-0.2, -0.6, -0.1, -0.4, -1.9, -1.8, -1., 1.5, 0.8, -0.7});
test_case.add_input<int32_t>(Shape{4}, {0, 2, 3, 4}); // indices
test_case.add_input<int32_t>(Shape{3}, {0, 2, 2}); // offsets
test_case.add_expected_output<float>(Shape{3, 2}, {-2.1, -2.4, 0, 0, -0.2, 0.8});
test_case.run();
}
NGRAPH_TEST(${BACKEND_NAME}, onnx_aten_embedding_bag_offsets_sum_4in) {
const auto function =
onnx_import::import_onnx_model(file_util::path_join(SERIALIZED_ZOO, "onnx/aten_embedding_sum_offset_4in.onnx"));
auto test_case = test::TestCase<TestEngine>(function);
test_case.add_input<float>(Shape{5, 2}, {-0.2, -0.6, -0.1, -0.4, -1.9, -1.8, -1., 1.5, 0.8, -0.7});
test_case.add_input<int32_t>(Shape{4}, {0, 2, 3, 4}); // indices
test_case.add_input<int32_t>(Shape{3}, {0, 2, 2}); // offsets
test_case.add_input<float>(Shape{4}, {0.5, 0.5, 0.5, 0.5}); // per_sample_weights
test_case.add_expected_output<float>(Shape{3, 2}, {-1.05, -1.2, 0., 0., -0.09999999, 0.4});
test_case.run();
}
NGRAPH_TEST(${BACKEND_NAME}, onnx_aten_embedding_bag_many_node_outputs) {
const auto function = onnx_import::import_onnx_model(
file_util::path_join(SERIALIZED_ZOO, "onnx/aten_embedding_sum_many_outputs.onnx"));
// 4 outputs in onnx Node (1 connected and 3 not connected)
EXPECT_EQ(function->outputs().size(), 1);
EXPECT_EQ(function->get_results().size(), 1);
auto test_case = test::TestCase<TestEngine>(function);
test_case.add_input<float>(Shape{5, 2}, {-0.2, -0.6, -0.1, -0.4, -1.9, -1.8, -1., 1.5, 0.8, -0.7});
test_case.add_input<int32_t>(Shape{4}, {0, 2, 3, 4}); // indices
test_case.add_input<int32_t>(Shape{3}, {0, 2, 2}); // offsets
test_case.add_expected_output<float>(Shape{3, 2}, {-2.1, -2.4, 0, 0, -0.2, 0.8});
test_case.run();
}
NGRAPH_TEST(${BACKEND_NAME}, onnx_aten_unsupported_embedding_mode) {
try {
const auto function = onnx_import::import_onnx_model(
file_util::path_join(SERIALIZED_ZOO, "onnx/aten_unsupported_embedding_mode.onnx"));
FAIL() << "Expected exception was not thrown.";
} catch (const ngraph::ngraph_error& e) {
EXPECT_HAS_SUBSTRING(
e.what(),
std::string(
"Unsupported mode, only `0` (sum) is supported as ATen embedding_bag `mode` attribute. Got: 1"));
} catch (...) {
FAIL() << "Other exception than expected was thrown.";
}
}
NGRAPH_TEST(${BACKEND_NAME}, onnx_aten_unsupported_operator) {
try {
const auto function =
onnx_import::import_onnx_model(file_util::path_join(SERIALIZED_ZOO, "onnx/aten_unsupported_operator.onnx"));
FAIL() << "Expected exception was not thrown.";
} catch (const ngraph::ngraph_error& e) {
EXPECT_HAS_SUBSTRING(
e.what(),
std::string(
"Only `embedding_bag` is supported as ATen `operator` attribute. Got: test_unsupported_operator"));
} catch (...) {
FAIL() << "Other exception than expected was thrown.";
}
}