Cecilia/bert/op convert (#7462)

* [paddlepaddle] embedding test case generator.

* [FrontEnd][PaddlePaddle] add op convert embedding for lookup_table_v2.

* code clean.
This commit is contained in:
cecilia peng 2021-09-22 12:40:11 +08:00 committed by GitHub
parent 1aa6db4aaf
commit 8ce8697830
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 234 additions and 0 deletions

View File

@ -0,0 +1,55 @@
// Copyright (C) 2018-2021 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include <ngraph/opsets/opset8.hpp>
#include <node_context.hpp>
namespace ngraph {
namespace frontend {
namespace pdpd {
namespace op {
using namespace opset8;
using namespace element;
NamedOutputs embedding(const NodeContext& node) {
auto data_ids = node.get_ng_input("Ids");
auto data_w = node.get_ng_input("W");
auto padding_idx = node.get_attribute<int64_t>("padding_idx");
const auto const_axis0 = Constant::create<int32_t>(i64, {1}, {0});
std::shared_ptr<Node> node_embedding;
if (padding_idx < 0) // no mask
{
node_embedding = std::make_shared<Gather>(data_w, data_ids, const_axis0);
} else { // mask embedding
auto node_shape_of_w = std::make_shared<ShapeOf>(data_w);
auto node_vocab_size = std::make_shared<Gather>(node_shape_of_w,
Constant::create<int64_t>(i64, {1}, {0}),
const_axis0); // vocab_size
auto node_stop = std::make_shared<Squeeze>(node_vocab_size);
auto node_range = std::make_shared<Range>(Constant::create<int64_t>(i64, {}, {0}),
node_stop,
Constant::create<int64_t>(i64, {}, {1}),
i64);
auto node_equal = std::make_shared<Equal>(node_range, Constant::create(i64, {1}, {padding_idx}));
auto node_mask = std::make_shared<Unsqueeze>(node_equal, Constant::create<int64_t>(i64, {1}, {1}));
data_w = std::make_shared<Select>(node_mask,
Constant::create<float>(f32, {1}, {0}),
data_w,
ov::op::AutoBroadcastType::NUMPY); // masked W
node_embedding = std::make_shared<Gather>(data_w, data_ids, const_axis0);
}
return node.default_single_output_mapping({node_embedding}, {"Out"});
}
} // namespace op
} // namespace pdpd
} // namespace frontend
} // namespace ngraph

View File

@ -29,6 +29,7 @@ OP_CONVERTER(elementwise_min);
OP_CONVERTER(elementwise_mul);
OP_CONVERTER(elementwise_pow);
OP_CONVERTER(elementwise_sub);
OP_CONVERTER(embedding);
OP_CONVERTER(expand_v2);
OP_CONVERTER(fill_any_like);
OP_CONVERTER(fill_constant_batch_size_like);
@ -113,6 +114,7 @@ std::map<std::string, CreatorFunction> get_supported_ops() {
{"leaky_relu", op::leaky_relu},
{"log", op::log},
{"logical_not", op::logical_not},
{"lookup_table_v2", op::embedding},
{"matmul", op::matmul},
{"matmul_v2", op::matmul_v2},
{"max_pool2d_with_index", op::pool2d},

View File

@ -79,6 +79,13 @@ static const std::vector<std::string> models{std::string("argmax"),
std::string("elementwise_mul1"),
std::string("elementwise_pow1"),
std::string("elementwise_sub1"),
std::string("embedding_0"),
std::string("embedding_sparse"),
std::string("embedding_none_weight"),
std::string("embedding_paddings"),
std::string("embedding_paddings_neg1"),
std::string("embedding_tensorIds"),
std::string("embedding_tensorIds_paddings"),
std::string("equal"),
std::string("expand_v2"),
std::string("expand_v2_tensor"),

View File

@ -0,0 +1,170 @@
#
# paddle model generator
# for lookup_table_v2
# https://www.paddlepaddle.org.cn/documentation/docs/zh/api/paddle/nn/Embedding_cn.html#embedding
# equal to "gather"
#
import numpy as np
import sys
from save_model import saveModel
def ngraph_embedding(ids, vocab_embeddings, vocab_size, embedding_dim, padding_idx, sparse):
"""
decomposing embedding with ngraph ops.
"""
import ngraph as ng
from ngraph import opset8 as opset
from openvino.inference_engine import IECore
if vocab_embeddings is None:
#
vocab_embeddings = np.zeros((vocab_size, embedding_dim)).astype("float32")
node_ids = ng.parameter(shape=ids.shape, name='ids', dtype=ids.dtype)
node_w = ng.parameter(shape=vocab_embeddings.shape, name='w', dtype=vocab_embeddings.dtype)
if padding_idx == -1:
padding_idx += vocab_size
if padding_idx is not None:
'''
mask W
'''
masked_embeddings = np.ones(vocab_embeddings.shape, dtype='int64')
masked_embeddings[padding_idx,:] = 0 # mask
node_mask = ng.constant(masked_embeddings, name='mask', dtype=vocab_embeddings.dtype)
node_masked_w = ng.multiply(node_w, node_mask)
node_axis = ng.constant([0], name='const0', dtype=np.int64)
node_gather = opset.gather(data=node_masked_w if padding_idx else node_w, indices=node_ids, axis=node_axis, batch_dims=0)
graph = ng.result(node_gather, name='y')
parameters = [node_ids, node_w]
inputs_dict = {'ids': ids, "w": vocab_embeddings}
#
function = ng.Function(graph, parameters, "embedding")
ie_network = ng.function_to_cnn(function)
ie = IECore()
executable_network = ie.load_network(ie_network, 'CPU')
output = executable_network.infer(inputs_dict)
return output
def embedding(name : str, ids, vocab_size, embedding_dim, padding_idx=None, sparse=False, vocab_embeddings=None, compare=False):
"""
padding_idx (int|long|None)
"""
import paddle
paddle.enable_static()
with paddle.static.program_guard(paddle.static.Program(), paddle.static.Program()):
node_ids = paddle.static.data(name = 'Ids', shape = ids.shape, dtype = ids.dtype)
pretrained_attr = paddle.ParamAttr(name='W',
initializer=paddle.nn.initializer.Assign(vocab_embeddings),
trainable=False) if vocab_embeddings is not None else None
node_embedding = paddle.nn.Embedding(num_embeddings=vocab_size, embedding_dim=embedding_dim, padding_idx=padding_idx, sparse=sparse, weight_attr=pretrained_attr, name=name)
node_out = node_embedding(node_ids)
cpu = paddle.static.cpu_places(1)
exe = paddle.static.Executor(cpu[0])
# startup program will call initializer to initialize the parameters.
exe.run(paddle.static.default_startup_program())
input_dict = {'Ids': ids}
output_vars_list = [node_out]
infer_results = exe.run(
feed=input_dict,
fetch_list=output_vars_list )
saveModel(name, exe, feedkeys=list(input_dict.keys()), fetchlist=output_vars_list, inputs=list(input_dict.values()), outputs=infer_results, target_dir=sys.argv[1])
#
outputs = dict()
for i in range(len(infer_results)):
outputs[output_vars_list[i].name] = infer_results[i]
#
if compare:
ng_result = ngraph_embedding(ids, vocab_embeddings, vocab_size, embedding_dim, padding_idx, sparse)
ng_result = list(ng_result.values())[0]
pdpd_result = list(outputs.values())[0]
match = np.all(np.isclose(
pdpd_result, ng_result, rtol=1e-4, atol=1e-5))
prefix_color = '\n\033[92m' if match else '\n\033[91m'
print(prefix_color +
'TestCase {} Result {} '.format(name, match) + '\033[0m\n')
if not match:
np.set_printoptions(precision=2)
np.set_printoptions(suppress=True)
print(prefix_color +
'pdpd_result: {}'.format(pdpd_result) + '\033[0m\n')
print(prefix_color +
'ng_result: {}'.format(ng_result) + '\033[0m\n')
raise ValueError(name + ': OV result does not match PDPD!')
return outputs
if __name__ == "__main__":
import paddle.compat as cpt
vocab_size = 17
embedding_dim = 31
table = np.random.random((vocab_size, embedding_dim)).astype("float32")
#
ids = np.random.randint(0, vocab_size, 4).astype("int32")
embedding("embedding_0", ids, vocab_size, embedding_dim, vocab_embeddings=table, compare=False)
#
ids = np.random.randint(0, vocab_size, 4).astype("int32")
embedding("embedding_sparse", ids, vocab_size, embedding_dim, sparse=True, vocab_embeddings=table, compare=False)
# # compare fail
ids = np.random.randint(0, vocab_size, 4).astype("int32")
embedding("embedding_none_weight", ids, vocab_size, embedding_dim, compare=False)
#
ids = np.random.randint(0, vocab_size, 4).astype("int32")
ids = np.squeeze(ids)
padding_idx = np.random.choice(ids, 1)[0]
# print('padding_idx {}, ids {}'.format(padding_idx, ids))
outputs = embedding("embedding_paddings", ids, vocab_size, embedding_dim, padding_idx=int(padding_idx), vocab_embeddings=table, compare=False)
# print('outputs {}'.format(outputs))
# corner case
ids = np.random.randint(0, vocab_size, 4).astype("int32")
pick = np.random.choice(4, 1)[0] # pick randomly to be max vacab_size -1
ids[pick] = vocab_size-1
padding_idx = -1
# print('padding_idx {}, ids {}'.format(padding_idx, ids))
outputs = embedding("embedding_paddings_neg1", ids, vocab_size, embedding_dim, padding_idx=int(padding_idx), vocab_embeddings=table, compare=False)
# print('outputs {}'.format(outputs))
#
ids = np.random.randint(low=0, high=vocab_size, size=(2, 4, 5)).astype("int32")
embedding("embedding_tensorIds", ids, vocab_size, embedding_dim, vocab_embeddings=table, compare=False)
#
ids = np.random.randint(low=0, high=vocab_size, size=(2, 4, 5)).astype("int32")
flatten_idx = ids.flatten()
padding_idx = np.random.choice(flatten_idx, 1)[0]
# print('padding_idx {}'.format(padding_idx))
outputs = embedding("embedding_tensorIds_paddings", ids, vocab_size, embedding_dim, padding_idx=cpt.long_type(padding_idx), vocab_embeddings=table, compare=False)
# print('outputs {}'.format(outputs))