[FrontEnd]enable pdpd ops layer_norm, gelu, fill_any_like, cumsum, tanh, matmul_v2 for BERT (#7511)

* enable layer_norm, gelu, fill_any_like, cumsum, tanh, matmul_v2

* change default namespace

* apply review comments
This commit is contained in:
Luo Cheng 2021-09-21 20:06:42 +08:00 committed by GitHub
parent 5ad2400468
commit f4fa513325
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
15 changed files with 519 additions and 0 deletions

View File

@ -0,0 +1,16 @@
// Copyright (C) 2018-2021 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include "ngraph/opsets/opset8.hpp"
namespace ngraph {
namespace frontend {
namespace pdpd {
namespace op {
namespace default_opset = ngraph::opset8;
} // namespace op
} // namespace pdpd
} // namespace frontend
} // namespace ngraph

View File

@ -0,0 +1,37 @@
// Copyright (C) 2018-2021 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include <node_context.hpp>
#include "default_opset.hpp"
namespace ngraph {
namespace frontend {
namespace pdpd {
namespace op {
NamedOutputs cumsum(const NodeContext& node) {
const auto x = node.get_ng_input("X");
const auto axis = node.get_attribute<int32_t>("axis", -1);
const auto flatten = node.get_attribute<bool>("flatten", false);
const auto reverse = node.get_attribute<bool>("reverse", false);
const auto exclusive = node.get_attribute<bool>("exclusive", false);
std::shared_ptr<ngraph::Node> input = x.get_node_shared_ptr();
if (flatten) {
// convert to 1-d tensor
input = std::make_shared<default_opset::Reshape>(x,
default_opset::Constant::create(element::i64, {1}, {-1}),
false);
}
const auto axis_node = default_opset::Constant::create(element::i64, {}, {axis});
return node.default_single_output_mapping(
{std::make_shared<default_opset::CumSum>(input, axis_node, exclusive, reverse)},
{"Out"});
}
} // namespace op
} // namespace pdpd
} // namespace frontend
} // namespace ngraph

View File

@ -0,0 +1,37 @@
// Copyright (C) 2018-2021 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include <node_context.hpp>
#include "default_opset.hpp"
namespace ngraph {
namespace frontend {
namespace pdpd {
namespace op {
NamedOutputs fill_any_like(const NodeContext& node) {
const auto x = node.get_ng_input("X");
auto dtype = node.get_attribute<ngraph::element::Type>("dtype", element::undefined);
const auto value = node.get_attribute<float>("value");
if (dtype == element::undefined) {
// when type does not define, use the input type
dtype = x.get_element_type();
}
const auto supported_type = {element::i32, element::i64, element::f16, element::f32, element::f64};
const bool valid_type =
std::any_of(supported_type.begin(), supported_type.end(), [dtype](const element::Type& type) {
return dtype == type;
});
PDPD_ASSERT(valid_type, "fill_any_like only supports i32, i64, f16, f32, f64");
const auto value_node = default_opset::Constant::create(dtype, {1}, {value});
const auto shape_node = std::make_shared<default_opset::ShapeOf>(x);
return node.default_single_output_mapping({std::make_shared<default_opset::Broadcast>(value_node, shape_node)},
{"Out"});
}
} // namespace op
} // namespace pdpd
} // namespace frontend
} // namespace ngraph

View File

@ -0,0 +1,23 @@
// Copyright (C) 2018-2021 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include <node_context.hpp>
#include "default_opset.hpp"
namespace ngraph {
namespace frontend {
namespace pdpd {
namespace op {
NamedOutputs gelu(const NodeContext& node) {
const auto data = node.get_ng_input("X");
const auto approximate = node.get_attribute<bool>("approximate", false);
const auto mode = approximate ? ngraph::op::GeluApproximationMode::TANH : ngraph::op::GeluApproximationMode::ERF;
return node.default_single_output_mapping({std::make_shared<default_opset::Gelu>(data, mode)}, {"Out"});
}
} // namespace op
} // namespace pdpd
} // namespace frontend
} // namespace ngraph

View File

@ -0,0 +1,59 @@
// Copyright (C) 2018-2021 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include <node_context.hpp>
#include "default_opset.hpp"
namespace ngraph {
namespace frontend {
namespace pdpd {
namespace op {
NamedOutputs layer_norm(const NodeContext& node) {
using namespace default_opset;
const auto data = node.get_ng_input("X");
const auto epsilon = node.get_attribute<float>("epsilon", 1e-05);
const auto begin_norm_axis = node.get_attribute<int32_t>("begin_norm_axis", 1);
// The limitation from:
// https://github.com/PaddlePaddle/Paddle/blob/cec36ea6ff16fda90c1a004c6e043cd9b2096a2a/paddle/fluid/operators/layer_norm_op.cc#L176
PDPD_ASSERT(begin_norm_axis > 0, "begin_norm_axis should be greater than 0");
// shape of input
const auto shape_of_node = std::make_shared<ShapeOf>(data);
// dims of input, reduce to scalar
const auto dims_node = std::make_shared<ReduceMin>(std::make_shared<ShapeOf>(shape_of_node),
Constant::create(element::i64, {1}, {0}),
false);
// get axis list to do the computation: [begin_norm_axis: dims)
const auto axis = std::make_shared<Range>(Constant::create(element::i64, {}, {begin_norm_axis}),
dims_node,
Constant::create(element::i64, {}, {1}),
element::i64);
// 'Scale' and 'Bias' are in plain, shoule get the real shape. The shape: shape_of_node[begin_norm_axis:-1]
const auto scale_bias_shape = std::make_shared<StridedSlice>(shape_of_node,
Constant::create(element::i64, {1}, {begin_norm_axis}),
Constant::create(element::i64, {1}, {0}),
std::vector<int64_t>{0},
std::vector<int64_t>{1});
const auto mvn = std::make_shared<MVN>(data, axis, true, epsilon, ngraph::op::MVNEpsMode::INSIDE_SQRT);
std::shared_ptr<ngraph::Node> result = mvn;
if (node.has_ng_input("Scale")) {
const auto s = node.get_ng_input("Scale");
const auto reshaped_s = std::make_shared<Reshape>(s, scale_bias_shape, false);
result = std::make_shared<Multiply>(mvn, reshaped_s);
}
if (node.has_ng_input("Bias")) {
const auto b = node.get_ng_input("Bias");
const auto reshaped_b = std::make_shared<Reshape>(b, scale_bias_shape, false);
result = std::make_shared<Add>(result, reshaped_b);
}
return node.default_single_output_mapping({result}, {"Y"});
}
} // namespace op
} // namespace pdpd
} // namespace frontend
} // namespace ngraph

View File

@ -0,0 +1,24 @@
// Copyright (C) 2018-2021 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include <node_context.hpp>
#include "default_opset.hpp"
namespace ngraph {
namespace frontend {
namespace pdpd {
namespace op {
NamedOutputs matmul_v2(const NodeContext& node) {
const auto x = node.get_ng_input("X");
const auto y = node.get_ng_input("Y");
const auto transpose_a = node.get_attribute<bool>("trans_x", false);
const auto transpose_b = node.get_attribute<bool>("trans_y", false);
const auto mm = std::make_shared<default_opset::MatMul>(x, y, transpose_a, transpose_b);
return node.default_single_output_mapping({mm}, {"Out"});
}
} // namespace op
} // namespace pdpd
} // namespace frontend
} // namespace ngraph

View File

@ -0,0 +1,22 @@
// Copyright (C) 2018-2021 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include <node_context.hpp>
#include "default_opset.hpp"
namespace ngraph {
namespace frontend {
namespace pdpd {
namespace op {
NamedOutputs tanh(const NodeContext& node) {
const auto x = node.get_ng_input("X");
return node.default_single_output_mapping({std::make_shared<default_opset::Tanh>(x)}, {"Out"});
}
} // namespace op
} // namespace pdpd
} // namespace frontend
} // namespace ngraph

View File

@ -17,6 +17,7 @@ OP_CONVERTER(clip);
OP_CONVERTER(concat);
OP_CONVERTER(conv2d);
OP_CONVERTER(conv2d_transpose);
OP_CONVERTER(cumsum);
OP_CONVERTER(deformable_conv);
OP_CONVERTER(dropout);
OP_CONVERTER(elementwise_add);
@ -29,15 +30,19 @@ OP_CONVERTER(elementwise_mul);
OP_CONVERTER(elementwise_pow);
OP_CONVERTER(elementwise_sub);
OP_CONVERTER(expand_v2);
OP_CONVERTER(fill_any_like);
OP_CONVERTER(fill_constant_batch_size_like);
OP_CONVERTER(fill_constant);
OP_CONVERTER(flatten_contiguous_range);
OP_CONVERTER(gelu);
OP_CONVERTER(hard_sigmoid);
OP_CONVERTER(hard_swish);
OP_CONVERTER(layer_norm);
OP_CONVERTER(leaky_relu);
OP_CONVERTER(log);
OP_CONVERTER(logical_not);
OP_CONVERTER(matmul);
OP_CONVERTER(matmul_v2);
OP_CONVERTER(mul);
OP_CONVERTER(matrix_nms);
OP_CONVERTER(multiclass_nms);
@ -57,6 +62,7 @@ OP_CONVERTER(softmax);
OP_CONVERTER(sigmoid);
OP_CONVERTER(split);
OP_CONVERTER(squeeze);
OP_CONVERTER(tanh);
OP_CONVERTER(transpose2);
OP_CONVERTER(unsqueeze);
OP_CONVERTER(yolo_box);
@ -80,6 +86,7 @@ std::map<std::string, CreatorFunction> get_supported_ops() {
{"concat", op::concat},
{"conv2d", op::conv2d},
{"conv2d_transpose", op::conv2d_transpose},
{"cumsum", op::cumsum},
{"deformable_conv", op::deformable_conv},
{"deformable_conv_v1", op::deformable_conv},
{"depthwise_conv2d", op::conv2d},
@ -94,16 +101,20 @@ std::map<std::string, CreatorFunction> get_supported_ops() {
{"elementwise_sub", op::elementwise_sub},
{"equal", op::elementwise_equal},
{"expand_v2", op::expand_v2},
{"fill_any_like", op::fill_any_like},
{"fill_constant_batch_size_like", op::fill_constant_batch_size_like},
{"fill_constant", op::fill_constant},
{"flatten_contiguous_range", op::flatten_contiguous_range},
{"gelu", op::gelu},
{"greater_equal", op::elementwise_greater_equal},
{"hard_sigmoid", op::hard_sigmoid},
{"hard_swish", op::hard_swish},
{"layer_norm", op::layer_norm},
{"leaky_relu", op::leaky_relu},
{"log", op::log},
{"logical_not", op::logical_not},
{"matmul", op::matmul},
{"matmul_v2", op::matmul_v2},
{"max_pool2d_with_index", op::pool2d},
{"mul", op::mul},
{"matrix_nms", op::matrix_nms},
@ -126,6 +137,7 @@ std::map<std::string, CreatorFunction> get_supported_ops() {
{"split", op::split},
{"squeeze2", op::squeeze},
{"sync_batch_norm", op::batch_norm},
{"tanh", op::tanh},
{"transpose2", op::transpose2},
{"unsqueeze2", op::unsqueeze},
{"yolo_box", op::yolo_box}};

View File

@ -63,6 +63,11 @@ static const std::vector<std::string> models{std::string("argmax"),
std::string("conv2d_transpose_strides_padding"),
std::string("conv2d_transpose_VALID_padding"),
std::string("conv2d_VALID_padding"),
std::string("cumsum"),
std::string("cumsum_i32"),
std::string("cumsum_i64"),
std::string("cumsum_f32"),
std::string("cumsum_f64"),
std::string("depthwise_conv2d_convolution"),
std::string("depthwise_conv2d_transpose_convolution"),
std::string("dropout"),
@ -78,6 +83,12 @@ static const std::vector<std::string> models{std::string("argmax"),
std::string("expand_v2"),
std::string("expand_v2_tensor"),
std::string("expand_v2_tensor_list"),
std::string("fill_any_like"),
std::string("fill_any_like_f16"),
std::string("fill_any_like_f32"),
std::string("fill_any_like_f64"),
std::string("fill_any_like_i32"),
std::string("fill_any_like_i64"),
std::string("fill_constant"),
std::string("fill_constant_batch_size_like"),
std::string("fill_constant_int32"),
@ -86,18 +97,31 @@ static const std::vector<std::string> models{std::string("argmax"),
std::string("fill_constant_shape_tensor"),
std::string("fill_constant_shape_tensor_list"),
std::string("flatten_contiguous_range_test1"),
std::string("gelu_erf"),
std::string("gelu_tanh"),
// greater_equal_big_int64(failure due to CPU inference),
std::string("greater_equal_float32"),
std::string("greater_equal_int32"),
std::string("greater_equal_int64"),
std::string("hard_sigmoid"),
std::string("hard_swish"),
std::string("layer_norm"),
std::string("layer_norm_noall"),
std::string("layer_norm_noscale"),
std::string("layer_norm_noshift"),
std::string("leaky_relu"),
std::string("log"),
std::string("logical_not"),
std::string("matmul_xt"),
std::string("matmul_xt_yt"),
std::string("matmul_yt"),
std::string("matmul_v2_1dx1d"),
std::string("matmul_v2_1dx2d"),
std::string("matmul_v2_2dx1d"),
std::string("matmul_v2_ndxmd"),
std::string("matmul_v2_xt"),
std::string("matmul_v2_xt_yt"),
std::string("matmul_v2_yt"),
std::string("maxAdaptivePool2D_test1"),
std::string("maxPool_test1"),
std::string("maxPool_test10"),
@ -163,6 +187,7 @@ static const std::vector<std::string> models{std::string("argmax"),
std::string("split_test_list_tensor"),
std::string("squeeze"),
std::string("squeeze_null_axes"),
std::string("tanh"),
std::string("unsqueeze"),
std::string("yolo_box_clip_box"),
std::string("yolo_box_default"),

View File

@ -0,0 +1,43 @@
#
# cumsum paddle model generator
#
import numpy as np
from save_model import saveModel
import paddle as pdpd
import sys
data_type = 'float32'
def cumsum(name:str, x, axis, dtype=None):
pdpd.enable_static()
with pdpd.static.program_guard(pdpd.static.Program(), pdpd.static.Program()):
data = pdpd.static.data(name='x', shape=x.shape, dtype = data_type)
out = pdpd.cumsum(data, axis, dtype=dtype)
out = pdpd.cast(out, np.float32)
cpu = pdpd.static.cpu_places(1)
exe = pdpd.static.Executor(cpu[0])
# startup program will call initializer to initialize the parameters.
exe.run(pdpd.static.default_startup_program())
outs = exe.run(
feed={'x': x},
fetch_list=[out])
saveModel(name, exe, feedkeys=['x'], fetchlist=[out], inputs=[x], outputs=[outs[0]], target_dir=sys.argv[1])
return outs[0]
def main():
x = np.linspace(1, 12, 12, dtype=data_type)
x = np.reshape(x, (3, 4))
cumsum("cumsum", x, axis=None)
cumsum("cumsum_f32", x, axis=-1, dtype='float32')
cumsum("cumsum_f64", x, axis=0, dtype='float64')
cumsum("cumsum_i32", x, axis=0, dtype='int32')
cumsum("cumsum_i64", x, axis=0, dtype='int64')
if __name__ == "__main__":
main()

View File

@ -0,0 +1,43 @@
#
# fill_any_like paddle model generator
#
import numpy as np
from save_model import saveModel
import paddle as pdpd
import sys
data_type = 'float32'
def fill_any_like(name:str, x, value, dtype=None):
pdpd.enable_static()
with pdpd.static.program_guard(pdpd.static.Program(), pdpd.static.Program()):
data = pdpd.static.data(name='x', shape=x.shape, dtype = data_type)
out = pdpd.full_like(data, value, dtype=dtype)
out = pdpd.cast(out, np.float32)
cpu = pdpd.static.cpu_places(1)
exe = pdpd.static.Executor(cpu[0])
# startup program will call initializer to initialize the parameters.
exe.run(pdpd.static.default_startup_program())
outs = exe.run(
feed={'x': x},
fetch_list=[out])
saveModel(name, exe, feedkeys=['x'], fetchlist=[out], inputs=[x], outputs=[outs[0]], target_dir=sys.argv[1])
return outs[0]
def main():
x = np.random.rand(8, 24, 32).astype(data_type)
fill_any_like("fill_any_like", x, 1.2)
fill_any_like("fill_any_like_f16", x, 1.0, dtype='float16')
fill_any_like("fill_any_like_f32", x, 1.2, dtype='float32')
fill_any_like("fill_any_like_f64", x, 1.2, dtype='float64')
fill_any_like("fill_any_like_i32", x, 2, dtype='int32')
fill_any_like("fill_any_like_i64", x, 10, dtype='int64')
if __name__ == "__main__":
main()

View File

@ -0,0 +1,38 @@
#
# gelu paddle model generator
#
import numpy as np
from save_model import saveModel
import paddle as pdpd
import sys
data_type = 'float32'
def gelu(name:str, x, approximate=False):
pdpd.enable_static()
with pdpd.static.program_guard(pdpd.static.Program(), pdpd.static.Program()):
data = pdpd.static.data(name='x', shape=x.shape, dtype = data_type)
out = pdpd.fluid.layers.gelu(data, approximate=approximate)
cpu = pdpd.static.cpu_places(1)
exe = pdpd.static.Executor(cpu[0])
# startup program will call initializer to initialize the parameters.
exe.run(pdpd.static.default_startup_program())
outs = exe.run(
feed={'x': x},
fetch_list=[out])
saveModel(name, exe, feedkeys=['x'], fetchlist=[out], inputs=[x], outputs=[outs[0]], target_dir=sys.argv[1])
return outs[0]
def main():
x = np.random.rand(8, 24, 32).astype(data_type)
gelu("gelu_erf", x)
gelu("gelu_tanh", x, True)
if __name__ == "__main__":
main()

View File

@ -0,0 +1,44 @@
#
# layer_norm paddle model generator
#
import numpy as np
from paddle.fluid import param_attr
from save_model import saveModel
import paddle as pdpd
import sys
data_type = 'float32'
def layer_norm(name:str, x, begin_norm_axis, scale=True, shift=True, param_attr=None, bias_attr=None):
pdpd.enable_static()
with pdpd.static.program_guard(pdpd.static.Program(), pdpd.static.Program()):
data = pdpd.static.data(name='x', shape=x.shape, dtype = data_type)
out = pdpd.static.nn.layer_norm(input=data, scale=scale, shift=shift,\
begin_norm_axis=begin_norm_axis, param_attr=param_attr, bias_attr=bias_attr)
cpu = pdpd.static.cpu_places(1)
exe = pdpd.static.Executor(cpu[0])
# startup program will call initializer to initialize the parameters.
exe.run(pdpd.static.default_startup_program())
outs = exe.run(
feed={'x': x},
fetch_list=[out])
saveModel(name, exe, feedkeys=['x'], fetchlist=[out], inputs=[x], outputs=[outs[0]], target_dir=sys.argv[1])
return outs[0]
def main():
x = np.random.rand(8, 24, 32).astype(data_type)
random_data = np.random.rand(24 * 32).astype(data_type)
attr = pdpd.ParamAttr(
initializer=pdpd.fluid.initializer.NumpyArrayInitializer(random_data))
layer_norm("layer_norm", x, begin_norm_axis=1, param_attr=attr, bias_attr=attr)
layer_norm("layer_norm_noscale", x, scale=False, begin_norm_axis=2)
layer_norm("layer_norm_noshift", x, shift=False, begin_norm_axis=1)
layer_norm("layer_norm_noall", x, scale=False, shift=False, begin_norm_axis=1)
if __name__ == "__main__":
main()

View File

@ -0,0 +1,59 @@
import numpy as np
from save_model import saveModel
import sys
def matmul(name, x1, x2, x_transpose=False, y_transpose=False):
import paddle as pdpd
pdpd.enable_static()
with pdpd.static.program_guard(pdpd.static.Program(), pdpd.static.Program()):
node_x1 = pdpd.static.data(name='x1', shape=x1.shape, dtype=x1.dtype)
node_x2 = pdpd.static.data(name='x2', shape=x2.shape, dtype=x2.dtype)
result = pdpd.matmul(node_x1, node_x2, x_transpose, y_transpose)
#result = pdpd.static.nn.batch_norm(mul_node, use_global_stats=True)
cpu = pdpd.static.cpu_places(1)
exe = pdpd.static.Executor(cpu[0])
# startup program will call initializer to initialize the parameters.
exe.run(pdpd.static.default_startup_program())
outs = exe.run(
feed={'x1': x1, 'x2': x2},
fetch_list=[result])
saveModel(name, exe, feedkeys=['x1', 'x2'], fetchlist=[result], inputs=[x1, x2], outputs=[outs[0]], target_dir=sys.argv[1])
return outs[0]
if __name__ == "__main__":
input_2x5 = np.array([[1, 2, 3, 4, 5],
[6, 7, 8, 9, 10]]).astype(np.float32)
input_5x3 = np.array([[1, 2, 3],
[4, 5, 6],
[7, 8, 9],
[10, 11, 12],
[13, 14, 15]]).astype(np.float32)
input_5x2 = np.array([[1, 2],
[4, 5],
[7, 8],
[10, 11],
[13, 14]]).astype(np.float32)
input_2x3 = np.array([[1, 2, 3],
[4, 5, 6]]).astype(np.float32)
input_1d = np.array([2, 3]).astype(np.float32)
input_nd = np.random.rand(2, 1, 10, 3).astype(np.float32)
input_md = np.random.rand(3, 3, 4).astype(np.float32)
matmul("matmul_v2_1dx1d", input_1d, input_1d)
matmul("matmul_v2_1dx2d", input_1d, input_2x3)
matmul("matmul_v2_2dx1d", input_5x2, input_1d)
matmul("matmul_v2_ndxmd", input_nd, input_md)
matmul("matmul_v2_xt", input_2x5, input_2x3, x_transpose=True, y_transpose=False)
matmul("matmul_v2_yt", input_2x3, input_5x3, x_transpose=False, y_transpose=True)
matmul("matmul_v2_xt_yt", input_2x5, input_5x2, x_transpose=True, y_transpose=True)

View File

@ -0,0 +1,37 @@
#
# tanh paddle model generator
#
import numpy as np
from save_model import saveModel
import paddle as pdpd
import sys
data_type = 'float32'
def tanh(name:str, x):
pdpd.enable_static()
with pdpd.static.program_guard(pdpd.static.Program(), pdpd.static.Program()):
data = pdpd.static.data(name='x', shape=x.shape, dtype = data_type)
out = pdpd.tanh(data)
cpu = pdpd.static.cpu_places(1)
exe = pdpd.static.Executor(cpu[0])
# startup program will call initializer to initialize the parameters.
exe.run(pdpd.static.default_startup_program())
outs = exe.run(
feed={'x': x},
fetch_list=[out])
saveModel(name, exe, feedkeys=['x'], fetchlist=[out], inputs=[x], outputs=[outs[0]], target_dir=sys.argv[1])
return outs[0]
def main():
x = np.random.rand(8, 24, 32).astype(data_type)
tanh("tanh", x)
if __name__ == "__main__":
main()