[FrontEnd]enable pdpd ops layer_norm, gelu, fill_any_like, cumsum, tanh, matmul_v2 for BERT (#7511)
* enable layer_norm, gelu, fill_any_like, cumsum, tanh, matmul_v2 * change default namespace * apply review comments
This commit is contained in:
parent
5ad2400468
commit
f4fa513325
16
ngraph/frontend/paddlepaddle/src/default_opset.hpp
Normal file
16
ngraph/frontend/paddlepaddle/src/default_opset.hpp
Normal file
@ -0,0 +1,16 @@
|
||||
// Copyright (C) 2018-2021 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#include "ngraph/opsets/opset8.hpp"
|
||||
|
||||
namespace ngraph {
|
||||
namespace frontend {
|
||||
namespace pdpd {
|
||||
namespace op {
|
||||
namespace default_opset = ngraph::opset8;
|
||||
|
||||
} // namespace op
|
||||
} // namespace pdpd
|
||||
} // namespace frontend
|
||||
} // namespace ngraph
|
37
ngraph/frontend/paddlepaddle/src/op/cumsum.cpp
Normal file
37
ngraph/frontend/paddlepaddle/src/op/cumsum.cpp
Normal file
@ -0,0 +1,37 @@
|
||||
// Copyright (C) 2018-2021 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#include <node_context.hpp>
|
||||
|
||||
#include "default_opset.hpp"
|
||||
|
||||
namespace ngraph {
|
||||
namespace frontend {
|
||||
namespace pdpd {
|
||||
namespace op {
|
||||
NamedOutputs cumsum(const NodeContext& node) {
|
||||
const auto x = node.get_ng_input("X");
|
||||
const auto axis = node.get_attribute<int32_t>("axis", -1);
|
||||
const auto flatten = node.get_attribute<bool>("flatten", false);
|
||||
const auto reverse = node.get_attribute<bool>("reverse", false);
|
||||
const auto exclusive = node.get_attribute<bool>("exclusive", false);
|
||||
|
||||
std::shared_ptr<ngraph::Node> input = x.get_node_shared_ptr();
|
||||
if (flatten) {
|
||||
// convert to 1-d tensor
|
||||
input = std::make_shared<default_opset::Reshape>(x,
|
||||
default_opset::Constant::create(element::i64, {1}, {-1}),
|
||||
false);
|
||||
}
|
||||
|
||||
const auto axis_node = default_opset::Constant::create(element::i64, {}, {axis});
|
||||
return node.default_single_output_mapping(
|
||||
{std::make_shared<default_opset::CumSum>(input, axis_node, exclusive, reverse)},
|
||||
{"Out"});
|
||||
}
|
||||
|
||||
} // namespace op
|
||||
} // namespace pdpd
|
||||
} // namespace frontend
|
||||
} // namespace ngraph
|
37
ngraph/frontend/paddlepaddle/src/op/fill_any_like.cpp
Normal file
37
ngraph/frontend/paddlepaddle/src/op/fill_any_like.cpp
Normal file
@ -0,0 +1,37 @@
|
||||
// Copyright (C) 2018-2021 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#include <node_context.hpp>
|
||||
|
||||
#include "default_opset.hpp"
|
||||
|
||||
namespace ngraph {
|
||||
namespace frontend {
|
||||
namespace pdpd {
|
||||
namespace op {
|
||||
NamedOutputs fill_any_like(const NodeContext& node) {
|
||||
const auto x = node.get_ng_input("X");
|
||||
auto dtype = node.get_attribute<ngraph::element::Type>("dtype", element::undefined);
|
||||
const auto value = node.get_attribute<float>("value");
|
||||
if (dtype == element::undefined) {
|
||||
// when type does not define, use the input type
|
||||
dtype = x.get_element_type();
|
||||
}
|
||||
const auto supported_type = {element::i32, element::i64, element::f16, element::f32, element::f64};
|
||||
const bool valid_type =
|
||||
std::any_of(supported_type.begin(), supported_type.end(), [dtype](const element::Type& type) {
|
||||
return dtype == type;
|
||||
});
|
||||
PDPD_ASSERT(valid_type, "fill_any_like only supports i32, i64, f16, f32, f64");
|
||||
const auto value_node = default_opset::Constant::create(dtype, {1}, {value});
|
||||
const auto shape_node = std::make_shared<default_opset::ShapeOf>(x);
|
||||
|
||||
return node.default_single_output_mapping({std::make_shared<default_opset::Broadcast>(value_node, shape_node)},
|
||||
{"Out"});
|
||||
}
|
||||
|
||||
} // namespace op
|
||||
} // namespace pdpd
|
||||
} // namespace frontend
|
||||
} // namespace ngraph
|
23
ngraph/frontend/paddlepaddle/src/op/gelu.cpp
Normal file
23
ngraph/frontend/paddlepaddle/src/op/gelu.cpp
Normal file
@ -0,0 +1,23 @@
|
||||
// Copyright (C) 2018-2021 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#include <node_context.hpp>
|
||||
|
||||
#include "default_opset.hpp"
|
||||
|
||||
namespace ngraph {
|
||||
namespace frontend {
|
||||
namespace pdpd {
|
||||
namespace op {
|
||||
NamedOutputs gelu(const NodeContext& node) {
|
||||
const auto data = node.get_ng_input("X");
|
||||
const auto approximate = node.get_attribute<bool>("approximate", false);
|
||||
const auto mode = approximate ? ngraph::op::GeluApproximationMode::TANH : ngraph::op::GeluApproximationMode::ERF;
|
||||
|
||||
return node.default_single_output_mapping({std::make_shared<default_opset::Gelu>(data, mode)}, {"Out"});
|
||||
}
|
||||
} // namespace op
|
||||
} // namespace pdpd
|
||||
} // namespace frontend
|
||||
} // namespace ngraph
|
59
ngraph/frontend/paddlepaddle/src/op/layer_norm.cpp
Normal file
59
ngraph/frontend/paddlepaddle/src/op/layer_norm.cpp
Normal file
@ -0,0 +1,59 @@
|
||||
// Copyright (C) 2018-2021 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#include <node_context.hpp>
|
||||
|
||||
#include "default_opset.hpp"
|
||||
|
||||
namespace ngraph {
|
||||
namespace frontend {
|
||||
namespace pdpd {
|
||||
namespace op {
|
||||
NamedOutputs layer_norm(const NodeContext& node) {
|
||||
using namespace default_opset;
|
||||
const auto data = node.get_ng_input("X");
|
||||
const auto epsilon = node.get_attribute<float>("epsilon", 1e-05);
|
||||
const auto begin_norm_axis = node.get_attribute<int32_t>("begin_norm_axis", 1);
|
||||
// The limitation from:
|
||||
// https://github.com/PaddlePaddle/Paddle/blob/cec36ea6ff16fda90c1a004c6e043cd9b2096a2a/paddle/fluid/operators/layer_norm_op.cc#L176
|
||||
PDPD_ASSERT(begin_norm_axis > 0, "begin_norm_axis should be greater than 0");
|
||||
|
||||
// shape of input
|
||||
const auto shape_of_node = std::make_shared<ShapeOf>(data);
|
||||
// dims of input, reduce to scalar
|
||||
const auto dims_node = std::make_shared<ReduceMin>(std::make_shared<ShapeOf>(shape_of_node),
|
||||
Constant::create(element::i64, {1}, {0}),
|
||||
false);
|
||||
// get axis list to do the computation: [begin_norm_axis: dims)
|
||||
const auto axis = std::make_shared<Range>(Constant::create(element::i64, {}, {begin_norm_axis}),
|
||||
dims_node,
|
||||
Constant::create(element::i64, {}, {1}),
|
||||
element::i64);
|
||||
// 'Scale' and 'Bias' are in plain, shoule get the real shape. The shape: shape_of_node[begin_norm_axis:-1]
|
||||
const auto scale_bias_shape = std::make_shared<StridedSlice>(shape_of_node,
|
||||
Constant::create(element::i64, {1}, {begin_norm_axis}),
|
||||
Constant::create(element::i64, {1}, {0}),
|
||||
std::vector<int64_t>{0},
|
||||
std::vector<int64_t>{1});
|
||||
|
||||
const auto mvn = std::make_shared<MVN>(data, axis, true, epsilon, ngraph::op::MVNEpsMode::INSIDE_SQRT);
|
||||
std::shared_ptr<ngraph::Node> result = mvn;
|
||||
if (node.has_ng_input("Scale")) {
|
||||
const auto s = node.get_ng_input("Scale");
|
||||
const auto reshaped_s = std::make_shared<Reshape>(s, scale_bias_shape, false);
|
||||
result = std::make_shared<Multiply>(mvn, reshaped_s);
|
||||
}
|
||||
|
||||
if (node.has_ng_input("Bias")) {
|
||||
const auto b = node.get_ng_input("Bias");
|
||||
const auto reshaped_b = std::make_shared<Reshape>(b, scale_bias_shape, false);
|
||||
result = std::make_shared<Add>(result, reshaped_b);
|
||||
}
|
||||
|
||||
return node.default_single_output_mapping({result}, {"Y"});
|
||||
}
|
||||
} // namespace op
|
||||
} // namespace pdpd
|
||||
} // namespace frontend
|
||||
} // namespace ngraph
|
24
ngraph/frontend/paddlepaddle/src/op/matmul_v2.cpp
Normal file
24
ngraph/frontend/paddlepaddle/src/op/matmul_v2.cpp
Normal file
@ -0,0 +1,24 @@
|
||||
// Copyright (C) 2018-2021 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
#include <node_context.hpp>
|
||||
|
||||
#include "default_opset.hpp"
|
||||
|
||||
namespace ngraph {
|
||||
namespace frontend {
|
||||
namespace pdpd {
|
||||
namespace op {
|
||||
NamedOutputs matmul_v2(const NodeContext& node) {
|
||||
const auto x = node.get_ng_input("X");
|
||||
const auto y = node.get_ng_input("Y");
|
||||
const auto transpose_a = node.get_attribute<bool>("trans_x", false);
|
||||
const auto transpose_b = node.get_attribute<bool>("trans_y", false);
|
||||
const auto mm = std::make_shared<default_opset::MatMul>(x, y, transpose_a, transpose_b);
|
||||
return node.default_single_output_mapping({mm}, {"Out"});
|
||||
}
|
||||
|
||||
} // namespace op
|
||||
} // namespace pdpd
|
||||
} // namespace frontend
|
||||
} // namespace ngraph
|
22
ngraph/frontend/paddlepaddle/src/op/tanh.cpp
Normal file
22
ngraph/frontend/paddlepaddle/src/op/tanh.cpp
Normal file
@ -0,0 +1,22 @@
|
||||
// Copyright (C) 2018-2021 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#include <node_context.hpp>
|
||||
|
||||
#include "default_opset.hpp"
|
||||
|
||||
namespace ngraph {
|
||||
namespace frontend {
|
||||
namespace pdpd {
|
||||
namespace op {
|
||||
NamedOutputs tanh(const NodeContext& node) {
|
||||
const auto x = node.get_ng_input("X");
|
||||
|
||||
return node.default_single_output_mapping({std::make_shared<default_opset::Tanh>(x)}, {"Out"});
|
||||
}
|
||||
|
||||
} // namespace op
|
||||
} // namespace pdpd
|
||||
} // namespace frontend
|
||||
} // namespace ngraph
|
@ -17,6 +17,7 @@ OP_CONVERTER(clip);
|
||||
OP_CONVERTER(concat);
|
||||
OP_CONVERTER(conv2d);
|
||||
OP_CONVERTER(conv2d_transpose);
|
||||
OP_CONVERTER(cumsum);
|
||||
OP_CONVERTER(deformable_conv);
|
||||
OP_CONVERTER(dropout);
|
||||
OP_CONVERTER(elementwise_add);
|
||||
@ -29,15 +30,19 @@ OP_CONVERTER(elementwise_mul);
|
||||
OP_CONVERTER(elementwise_pow);
|
||||
OP_CONVERTER(elementwise_sub);
|
||||
OP_CONVERTER(expand_v2);
|
||||
OP_CONVERTER(fill_any_like);
|
||||
OP_CONVERTER(fill_constant_batch_size_like);
|
||||
OP_CONVERTER(fill_constant);
|
||||
OP_CONVERTER(flatten_contiguous_range);
|
||||
OP_CONVERTER(gelu);
|
||||
OP_CONVERTER(hard_sigmoid);
|
||||
OP_CONVERTER(hard_swish);
|
||||
OP_CONVERTER(layer_norm);
|
||||
OP_CONVERTER(leaky_relu);
|
||||
OP_CONVERTER(log);
|
||||
OP_CONVERTER(logical_not);
|
||||
OP_CONVERTER(matmul);
|
||||
OP_CONVERTER(matmul_v2);
|
||||
OP_CONVERTER(mul);
|
||||
OP_CONVERTER(matrix_nms);
|
||||
OP_CONVERTER(multiclass_nms);
|
||||
@ -57,6 +62,7 @@ OP_CONVERTER(softmax);
|
||||
OP_CONVERTER(sigmoid);
|
||||
OP_CONVERTER(split);
|
||||
OP_CONVERTER(squeeze);
|
||||
OP_CONVERTER(tanh);
|
||||
OP_CONVERTER(transpose2);
|
||||
OP_CONVERTER(unsqueeze);
|
||||
OP_CONVERTER(yolo_box);
|
||||
@ -80,6 +86,7 @@ std::map<std::string, CreatorFunction> get_supported_ops() {
|
||||
{"concat", op::concat},
|
||||
{"conv2d", op::conv2d},
|
||||
{"conv2d_transpose", op::conv2d_transpose},
|
||||
{"cumsum", op::cumsum},
|
||||
{"deformable_conv", op::deformable_conv},
|
||||
{"deformable_conv_v1", op::deformable_conv},
|
||||
{"depthwise_conv2d", op::conv2d},
|
||||
@ -94,16 +101,20 @@ std::map<std::string, CreatorFunction> get_supported_ops() {
|
||||
{"elementwise_sub", op::elementwise_sub},
|
||||
{"equal", op::elementwise_equal},
|
||||
{"expand_v2", op::expand_v2},
|
||||
{"fill_any_like", op::fill_any_like},
|
||||
{"fill_constant_batch_size_like", op::fill_constant_batch_size_like},
|
||||
{"fill_constant", op::fill_constant},
|
||||
{"flatten_contiguous_range", op::flatten_contiguous_range},
|
||||
{"gelu", op::gelu},
|
||||
{"greater_equal", op::elementwise_greater_equal},
|
||||
{"hard_sigmoid", op::hard_sigmoid},
|
||||
{"hard_swish", op::hard_swish},
|
||||
{"layer_norm", op::layer_norm},
|
||||
{"leaky_relu", op::leaky_relu},
|
||||
{"log", op::log},
|
||||
{"logical_not", op::logical_not},
|
||||
{"matmul", op::matmul},
|
||||
{"matmul_v2", op::matmul_v2},
|
||||
{"max_pool2d_with_index", op::pool2d},
|
||||
{"mul", op::mul},
|
||||
{"matrix_nms", op::matrix_nms},
|
||||
@ -126,6 +137,7 @@ std::map<std::string, CreatorFunction> get_supported_ops() {
|
||||
{"split", op::split},
|
||||
{"squeeze2", op::squeeze},
|
||||
{"sync_batch_norm", op::batch_norm},
|
||||
{"tanh", op::tanh},
|
||||
{"transpose2", op::transpose2},
|
||||
{"unsqueeze2", op::unsqueeze},
|
||||
{"yolo_box", op::yolo_box}};
|
||||
|
@ -63,6 +63,11 @@ static const std::vector<std::string> models{std::string("argmax"),
|
||||
std::string("conv2d_transpose_strides_padding"),
|
||||
std::string("conv2d_transpose_VALID_padding"),
|
||||
std::string("conv2d_VALID_padding"),
|
||||
std::string("cumsum"),
|
||||
std::string("cumsum_i32"),
|
||||
std::string("cumsum_i64"),
|
||||
std::string("cumsum_f32"),
|
||||
std::string("cumsum_f64"),
|
||||
std::string("depthwise_conv2d_convolution"),
|
||||
std::string("depthwise_conv2d_transpose_convolution"),
|
||||
std::string("dropout"),
|
||||
@ -78,6 +83,12 @@ static const std::vector<std::string> models{std::string("argmax"),
|
||||
std::string("expand_v2"),
|
||||
std::string("expand_v2_tensor"),
|
||||
std::string("expand_v2_tensor_list"),
|
||||
std::string("fill_any_like"),
|
||||
std::string("fill_any_like_f16"),
|
||||
std::string("fill_any_like_f32"),
|
||||
std::string("fill_any_like_f64"),
|
||||
std::string("fill_any_like_i32"),
|
||||
std::string("fill_any_like_i64"),
|
||||
std::string("fill_constant"),
|
||||
std::string("fill_constant_batch_size_like"),
|
||||
std::string("fill_constant_int32"),
|
||||
@ -86,18 +97,31 @@ static const std::vector<std::string> models{std::string("argmax"),
|
||||
std::string("fill_constant_shape_tensor"),
|
||||
std::string("fill_constant_shape_tensor_list"),
|
||||
std::string("flatten_contiguous_range_test1"),
|
||||
std::string("gelu_erf"),
|
||||
std::string("gelu_tanh"),
|
||||
// greater_equal_big_int64(failure due to CPU inference),
|
||||
std::string("greater_equal_float32"),
|
||||
std::string("greater_equal_int32"),
|
||||
std::string("greater_equal_int64"),
|
||||
std::string("hard_sigmoid"),
|
||||
std::string("hard_swish"),
|
||||
std::string("layer_norm"),
|
||||
std::string("layer_norm_noall"),
|
||||
std::string("layer_norm_noscale"),
|
||||
std::string("layer_norm_noshift"),
|
||||
std::string("leaky_relu"),
|
||||
std::string("log"),
|
||||
std::string("logical_not"),
|
||||
std::string("matmul_xt"),
|
||||
std::string("matmul_xt_yt"),
|
||||
std::string("matmul_yt"),
|
||||
std::string("matmul_v2_1dx1d"),
|
||||
std::string("matmul_v2_1dx2d"),
|
||||
std::string("matmul_v2_2dx1d"),
|
||||
std::string("matmul_v2_ndxmd"),
|
||||
std::string("matmul_v2_xt"),
|
||||
std::string("matmul_v2_xt_yt"),
|
||||
std::string("matmul_v2_yt"),
|
||||
std::string("maxAdaptivePool2D_test1"),
|
||||
std::string("maxPool_test1"),
|
||||
std::string("maxPool_test10"),
|
||||
@ -163,6 +187,7 @@ static const std::vector<std::string> models{std::string("argmax"),
|
||||
std::string("split_test_list_tensor"),
|
||||
std::string("squeeze"),
|
||||
std::string("squeeze_null_axes"),
|
||||
std::string("tanh"),
|
||||
std::string("unsqueeze"),
|
||||
std::string("yolo_box_clip_box"),
|
||||
std::string("yolo_box_default"),
|
||||
|
@ -0,0 +1,43 @@
|
||||
#
|
||||
# cumsum paddle model generator
|
||||
#
|
||||
import numpy as np
|
||||
from save_model import saveModel
|
||||
import paddle as pdpd
|
||||
import sys
|
||||
|
||||
data_type = 'float32'
|
||||
|
||||
def cumsum(name:str, x, axis, dtype=None):
|
||||
pdpd.enable_static()
|
||||
|
||||
with pdpd.static.program_guard(pdpd.static.Program(), pdpd.static.Program()):
|
||||
data = pdpd.static.data(name='x', shape=x.shape, dtype = data_type)
|
||||
out = pdpd.cumsum(data, axis, dtype=dtype)
|
||||
out = pdpd.cast(out, np.float32)
|
||||
|
||||
cpu = pdpd.static.cpu_places(1)
|
||||
exe = pdpd.static.Executor(cpu[0])
|
||||
# startup program will call initializer to initialize the parameters.
|
||||
exe.run(pdpd.static.default_startup_program())
|
||||
|
||||
outs = exe.run(
|
||||
feed={'x': x},
|
||||
fetch_list=[out])
|
||||
|
||||
saveModel(name, exe, feedkeys=['x'], fetchlist=[out], inputs=[x], outputs=[outs[0]], target_dir=sys.argv[1])
|
||||
|
||||
return outs[0]
|
||||
|
||||
def main():
|
||||
x = np.linspace(1, 12, 12, dtype=data_type)
|
||||
x = np.reshape(x, (3, 4))
|
||||
|
||||
cumsum("cumsum", x, axis=None)
|
||||
cumsum("cumsum_f32", x, axis=-1, dtype='float32')
|
||||
cumsum("cumsum_f64", x, axis=0, dtype='float64')
|
||||
cumsum("cumsum_i32", x, axis=0, dtype='int32')
|
||||
cumsum("cumsum_i64", x, axis=0, dtype='int64')
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
@ -0,0 +1,43 @@
|
||||
#
|
||||
# fill_any_like paddle model generator
|
||||
#
|
||||
import numpy as np
|
||||
from save_model import saveModel
|
||||
import paddle as pdpd
|
||||
import sys
|
||||
|
||||
data_type = 'float32'
|
||||
|
||||
def fill_any_like(name:str, x, value, dtype=None):
|
||||
pdpd.enable_static()
|
||||
|
||||
with pdpd.static.program_guard(pdpd.static.Program(), pdpd.static.Program()):
|
||||
data = pdpd.static.data(name='x', shape=x.shape, dtype = data_type)
|
||||
out = pdpd.full_like(data, value, dtype=dtype)
|
||||
out = pdpd.cast(out, np.float32)
|
||||
|
||||
cpu = pdpd.static.cpu_places(1)
|
||||
exe = pdpd.static.Executor(cpu[0])
|
||||
# startup program will call initializer to initialize the parameters.
|
||||
exe.run(pdpd.static.default_startup_program())
|
||||
|
||||
outs = exe.run(
|
||||
feed={'x': x},
|
||||
fetch_list=[out])
|
||||
|
||||
saveModel(name, exe, feedkeys=['x'], fetchlist=[out], inputs=[x], outputs=[outs[0]], target_dir=sys.argv[1])
|
||||
|
||||
return outs[0]
|
||||
|
||||
def main():
|
||||
x = np.random.rand(8, 24, 32).astype(data_type)
|
||||
|
||||
fill_any_like("fill_any_like", x, 1.2)
|
||||
fill_any_like("fill_any_like_f16", x, 1.0, dtype='float16')
|
||||
fill_any_like("fill_any_like_f32", x, 1.2, dtype='float32')
|
||||
fill_any_like("fill_any_like_f64", x, 1.2, dtype='float64')
|
||||
fill_any_like("fill_any_like_i32", x, 2, dtype='int32')
|
||||
fill_any_like("fill_any_like_i64", x, 10, dtype='int64')
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
@ -0,0 +1,38 @@
|
||||
#
|
||||
# gelu paddle model generator
|
||||
#
|
||||
import numpy as np
|
||||
from save_model import saveModel
|
||||
import paddle as pdpd
|
||||
import sys
|
||||
|
||||
data_type = 'float32'
|
||||
|
||||
def gelu(name:str, x, approximate=False):
|
||||
pdpd.enable_static()
|
||||
|
||||
with pdpd.static.program_guard(pdpd.static.Program(), pdpd.static.Program()):
|
||||
data = pdpd.static.data(name='x', shape=x.shape, dtype = data_type)
|
||||
out = pdpd.fluid.layers.gelu(data, approximate=approximate)
|
||||
|
||||
cpu = pdpd.static.cpu_places(1)
|
||||
exe = pdpd.static.Executor(cpu[0])
|
||||
# startup program will call initializer to initialize the parameters.
|
||||
exe.run(pdpd.static.default_startup_program())
|
||||
|
||||
outs = exe.run(
|
||||
feed={'x': x},
|
||||
fetch_list=[out])
|
||||
|
||||
saveModel(name, exe, feedkeys=['x'], fetchlist=[out], inputs=[x], outputs=[outs[0]], target_dir=sys.argv[1])
|
||||
|
||||
return outs[0]
|
||||
|
||||
def main():
|
||||
x = np.random.rand(8, 24, 32).astype(data_type)
|
||||
|
||||
gelu("gelu_erf", x)
|
||||
gelu("gelu_tanh", x, True)
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
@ -0,0 +1,44 @@
|
||||
#
|
||||
# layer_norm paddle model generator
|
||||
#
|
||||
import numpy as np
|
||||
from paddle.fluid import param_attr
|
||||
from save_model import saveModel
|
||||
import paddle as pdpd
|
||||
import sys
|
||||
|
||||
data_type = 'float32'
|
||||
|
||||
def layer_norm(name:str, x, begin_norm_axis, scale=True, shift=True, param_attr=None, bias_attr=None):
|
||||
pdpd.enable_static()
|
||||
|
||||
with pdpd.static.program_guard(pdpd.static.Program(), pdpd.static.Program()):
|
||||
data = pdpd.static.data(name='x', shape=x.shape, dtype = data_type)
|
||||
out = pdpd.static.nn.layer_norm(input=data, scale=scale, shift=shift,\
|
||||
begin_norm_axis=begin_norm_axis, param_attr=param_attr, bias_attr=bias_attr)
|
||||
|
||||
cpu = pdpd.static.cpu_places(1)
|
||||
exe = pdpd.static.Executor(cpu[0])
|
||||
# startup program will call initializer to initialize the parameters.
|
||||
exe.run(pdpd.static.default_startup_program())
|
||||
|
||||
outs = exe.run(
|
||||
feed={'x': x},
|
||||
fetch_list=[out])
|
||||
|
||||
saveModel(name, exe, feedkeys=['x'], fetchlist=[out], inputs=[x], outputs=[outs[0]], target_dir=sys.argv[1])
|
||||
|
||||
return outs[0]
|
||||
|
||||
def main():
|
||||
x = np.random.rand(8, 24, 32).astype(data_type)
|
||||
random_data = np.random.rand(24 * 32).astype(data_type)
|
||||
attr = pdpd.ParamAttr(
|
||||
initializer=pdpd.fluid.initializer.NumpyArrayInitializer(random_data))
|
||||
layer_norm("layer_norm", x, begin_norm_axis=1, param_attr=attr, bias_attr=attr)
|
||||
layer_norm("layer_norm_noscale", x, scale=False, begin_norm_axis=2)
|
||||
layer_norm("layer_norm_noshift", x, shift=False, begin_norm_axis=1)
|
||||
layer_norm("layer_norm_noall", x, scale=False, shift=False, begin_norm_axis=1)
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
@ -0,0 +1,59 @@
|
||||
import numpy as np
|
||||
from save_model import saveModel
|
||||
import sys
|
||||
|
||||
def matmul(name, x1, x2, x_transpose=False, y_transpose=False):
|
||||
import paddle as pdpd
|
||||
|
||||
pdpd.enable_static()
|
||||
with pdpd.static.program_guard(pdpd.static.Program(), pdpd.static.Program()):
|
||||
node_x1 = pdpd.static.data(name='x1', shape=x1.shape, dtype=x1.dtype)
|
||||
node_x2 = pdpd.static.data(name='x2', shape=x2.shape, dtype=x2.dtype)
|
||||
result = pdpd.matmul(node_x1, node_x2, x_transpose, y_transpose)
|
||||
#result = pdpd.static.nn.batch_norm(mul_node, use_global_stats=True)
|
||||
|
||||
cpu = pdpd.static.cpu_places(1)
|
||||
exe = pdpd.static.Executor(cpu[0])
|
||||
# startup program will call initializer to initialize the parameters.
|
||||
exe.run(pdpd.static.default_startup_program())
|
||||
|
||||
outs = exe.run(
|
||||
feed={'x1': x1, 'x2': x2},
|
||||
fetch_list=[result])
|
||||
saveModel(name, exe, feedkeys=['x1', 'x2'], fetchlist=[result], inputs=[x1, x2], outputs=[outs[0]], target_dir=sys.argv[1])
|
||||
|
||||
return outs[0]
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
input_2x5 = np.array([[1, 2, 3, 4, 5],
|
||||
[6, 7, 8, 9, 10]]).astype(np.float32)
|
||||
|
||||
input_5x3 = np.array([[1, 2, 3],
|
||||
[4, 5, 6],
|
||||
[7, 8, 9],
|
||||
[10, 11, 12],
|
||||
[13, 14, 15]]).astype(np.float32)
|
||||
|
||||
input_5x2 = np.array([[1, 2],
|
||||
[4, 5],
|
||||
[7, 8],
|
||||
[10, 11],
|
||||
[13, 14]]).astype(np.float32)
|
||||
|
||||
input_2x3 = np.array([[1, 2, 3],
|
||||
[4, 5, 6]]).astype(np.float32)
|
||||
|
||||
input_1d = np.array([2, 3]).astype(np.float32)
|
||||
|
||||
input_nd = np.random.rand(2, 1, 10, 3).astype(np.float32)
|
||||
input_md = np.random.rand(3, 3, 4).astype(np.float32)
|
||||
|
||||
matmul("matmul_v2_1dx1d", input_1d, input_1d)
|
||||
matmul("matmul_v2_1dx2d", input_1d, input_2x3)
|
||||
matmul("matmul_v2_2dx1d", input_5x2, input_1d)
|
||||
matmul("matmul_v2_ndxmd", input_nd, input_md)
|
||||
|
||||
matmul("matmul_v2_xt", input_2x5, input_2x3, x_transpose=True, y_transpose=False)
|
||||
matmul("matmul_v2_yt", input_2x3, input_5x3, x_transpose=False, y_transpose=True)
|
||||
matmul("matmul_v2_xt_yt", input_2x5, input_5x2, x_transpose=True, y_transpose=True)
|
@ -0,0 +1,37 @@
|
||||
#
|
||||
# tanh paddle model generator
|
||||
#
|
||||
import numpy as np
|
||||
from save_model import saveModel
|
||||
import paddle as pdpd
|
||||
import sys
|
||||
|
||||
data_type = 'float32'
|
||||
|
||||
def tanh(name:str, x):
|
||||
pdpd.enable_static()
|
||||
|
||||
with pdpd.static.program_guard(pdpd.static.Program(), pdpd.static.Program()):
|
||||
data = pdpd.static.data(name='x', shape=x.shape, dtype = data_type)
|
||||
out = pdpd.tanh(data)
|
||||
|
||||
cpu = pdpd.static.cpu_places(1)
|
||||
exe = pdpd.static.Executor(cpu[0])
|
||||
# startup program will call initializer to initialize the parameters.
|
||||
exe.run(pdpd.static.default_startup_program())
|
||||
|
||||
outs = exe.run(
|
||||
feed={'x': x},
|
||||
fetch_list=[out])
|
||||
|
||||
saveModel(name, exe, feedkeys=['x'], fetchlist=[out], inputs=[x], outputs=[outs[0]], target_dir=sys.argv[1])
|
||||
|
||||
return outs[0]
|
||||
|
||||
def main():
|
||||
x = np.random.rand(8, 24, 32).astype(data_type)
|
||||
|
||||
tanh("tanh", x)
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
Loading…
Reference in New Issue
Block a user