Paddle FasterRCNN Ops Conversion: greater_than, less_than, gather, floor (#9657)

* Paddle FasterRCNN Ops Conversion: greater_than, less_than, gather, floor

* Apply suggestions from code review

* fix 'gather' testcase failure issue on CI

* implement 'axis' input for 'Gather' Op conversion with testcase comment;use common function for all elementwise Ops
This commit is contained in:
Bo Liu 2022-03-29 16:20:37 +08:00 committed by GitHub
parent 8f88889876
commit 02c60c76ab
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
12 changed files with 408 additions and 30 deletions

View File

@ -133,12 +133,22 @@ static const std::vector<std::string> models{std::string("argmax"),
std::string("fill_constant_shape_tensor"), std::string("fill_constant_shape_tensor"),
std::string("fill_constant_shape_tensor_list"), std::string("fill_constant_shape_tensor_list"),
std::string("flatten_contiguous_range_test1"), std::string("flatten_contiguous_range_test1"),
std::string("floor_float32"),
std::string("gather_multi_dimension"),
std::string("gather_one_dimension"),
std::string("gather_one_dimension2"),
// gather_axis_input
// (CVS-82724: not support Axis as input),
std::string("gelu_erf"), std::string("gelu_erf"),
std::string("gelu_tanh"), std::string("gelu_tanh"),
// greater_equal_big_int64(failure due to CPU inference), // greater_equal_big_int64(failure due to CPU inference),
std::string("greater_equal_big_int64"),
std::string("greater_equal_float32"), std::string("greater_equal_float32"),
std::string("greater_equal_int32"), std::string("greater_equal_int32"),
std::string("greater_equal_int64"), std::string("greater_equal_int64"),
std::string("greater_than_float32"),
std::string("greater_than_int32"),
std::string("greater_than_int64"),
std::string("hard_sigmoid"), std::string("hard_sigmoid"),
std::string("hard_swish"), std::string("hard_swish"),
std::string("layer_norm"), std::string("layer_norm"),
@ -146,6 +156,9 @@ static const std::vector<std::string> models{std::string("argmax"),
std::string("layer_norm_noscale"), std::string("layer_norm_noscale"),
std::string("layer_norm_noshift"), std::string("layer_norm_noshift"),
std::string("leaky_relu"), std::string("leaky_relu"),
std::string("less_than_float32"),
std::string("less_than_int32"),
std::string("less_than_int64"),
std::string("linear_downsample_false_0"), std::string("linear_downsample_false_0"),
std::string("linear_downsample_false_1"), std::string("linear_downsample_false_1"),
std::string("linear_downsample_true_0"), std::string("linear_downsample_true_0"),

View File

@ -0,0 +1,40 @@
#
# floor paddle model generator
#
import numpy as np
from save_model import saveModel
import paddle
import sys
def floor(name: str, x):
paddle.enable_static()
with paddle.static.program_guard(paddle.static.Program(), paddle.static.Program()):
data = paddle.static.data(name='x', shape=x.shape, dtype=x.dtype)
out = paddle.floor(data)
cpu = paddle.static.cpu_places(1)
exe = paddle.static.Executor(cpu[0])
# startup program will call initializer to initialize the parameters.
exe.run(paddle.static.default_startup_program())
outs = exe.run(
feed={'x': x},
fetch_list=[out])
saveModel(name, exe, feedkeys=['x'], fetchlist=[out], inputs=[
x], outputs=[outs[0]], target_dir=sys.argv[1])
return outs[0]
def main():
data_type = 'float32'
x = np.array([-0.4, -0.2, 2.1, 0.3]).astype(data_type)
floor("floor_float32", x)
if __name__ == "__main__":
main()

View File

@ -0,0 +1,100 @@
#
# gather paddle model generator
#
import numpy as np
from save_model import saveModel
import paddle
import sys
def gather(name: str, x, y, z):
paddle.enable_static()
with paddle.static.program_guard(paddle.static.Program(), paddle.static.Program()):
data = paddle.static.data(name='x', shape=x.shape, dtype=x.dtype)
index = paddle.static.data(name='index', shape=y.shape, dtype=y.dtype)
if (z == None):
out = paddle.gather(data, index)
else:
axis = paddle.static.data(
name='axis', shape=z.shape, dtype=z.dtype)
out = paddle.gather(data, index, axis)
if x.dtype == "int64":
out = paddle.cast(out, "float32")
cpu = paddle.static.cpu_places(1)
exe = paddle.static.Executor(cpu[0])
# startup program will call initializer to initialize the parameters.
exe.run(paddle.static.default_startup_program())
if (z == None):
outs = exe.run(
feed={'x': x, 'index': y},
fetch_list=[out])
saveModel(name, exe, feedkeys=['x', 'index'], fetchlist=[out], inputs=[
x, y], outputs=[outs[0]], target_dir=sys.argv[1])
else:
outs = exe.run(
feed={'x': x, 'index': y, 'axis': z},
fetch_list=[out])
saveModel(name, exe, feedkeys=['x', 'index', 'axis'], fetchlist=[
out], inputs=[x, y, z], outputs=[outs[0]], target_dir=sys.argv[1])
return outs[0]
def main():
# For multi-dimension input
x_shape = (10, 20)
x_type = "float32"
index = [1, 3, 5]
index_type = "int32"
xnp = np.random.random(x_shape).astype(x_type)
index_np = np.array(index).astype(index_type)
axis_np = None
gather("gather_multi_dimension", xnp, index_np, axis_np)
# For one_dimension input
x_shape = (100)
x_type = "int64"
index = [1, 3, 5]
index_type = "int64"
xnp = np.random.random(x_shape).astype(x_type)
index_np = np.array(index).astype(index_type)
axis_np = None
gather("gather_one_dimension", xnp, index_np, axis_np)
# For one_dimension input2
x_shape = (100)
x_type = "int64"
index = [1, 1, 1, 1, 1, 1, 1, 1, 1, 1]
index_type = "int64"
xnp = np.random.random(x_shape).astype(x_type)
index_np = np.array(index).astype(index_type)
axis_np = None
gather("gather_one_dimension2", xnp, index_np, axis_np)
# For axis as input
x_shape = (6, 88, 3)
x_type = "float32"
index = [1, 3, 5]
index_type = "int32"
axis = [0]
axis_type = "int32"
xnp = np.random.random(x_shape).astype(x_type)
axis_np = np.array(axis).astype(axis_type)
index_np = np.array(index).astype(index_type)
gather("gather_axis_input", xnp, index_np, axis_np)
if __name__ == "__main__":
main()

View File

@ -0,0 +1,59 @@
#
# greater_than paddle model generator
#
import numpy as np
from save_model import saveModel
import paddle as pdpd
import sys
def greater_than(name: str, x, y, data_type, cast_to_fp32=False):
pdpd.enable_static()
with pdpd.static.program_guard(pdpd.static.Program(), pdpd.static.Program()):
node_x = pdpd.static.data(
name='input_x', shape=x.shape, dtype=data_type)
node_y = pdpd.static.data(
name='input_y', shape=y.shape, dtype=data_type)
out = pdpd.fluid.layers.greater_than(
x=node_x, y=node_y, name='greater_than')
# FuzzyTest framework doesn't support boolean so cast to fp32/int32
if cast_to_fp32:
data_type = "float32"
out = pdpd.cast(out, data_type)
cpu = pdpd.static.cpu_places(1)
exe = pdpd.static.Executor(cpu[0])
# startup program will call initializer to initialize the parameters.
exe.run(pdpd.static.default_startup_program())
outs = exe.run(
feed={'input_x': x, 'input_y': y},
fetch_list=[out])
saveModel(name, exe, feedkeys=['input_x', 'input_y'], fetchlist=[out],
inputs=[x, y], outputs=[outs[0]], target_dir=sys.argv[1])
return outs[0]
def main():
test_cases = [
"float32",
"int32",
"int64"
]
for test in test_cases:
x = np.array([0, 1, 2, 3]).astype(test)
y = np.array([1, 0, 2, 4]).astype(test)
if ((test == "float64") or (test == "int64")):
greater_than("greater_than_" + test, x, y, test, True)
else:
greater_than("greater_than_" + test, x, y, test, False)
if __name__ == "__main__":
main()

View File

@ -0,0 +1,58 @@
#
# less_than paddle model generator
#
import numpy as np
from save_model import saveModel
import paddle as pdpd
import sys
def less_than(name: str, x, y, data_type, cast_to_fp32=False):
pdpd.enable_static()
with pdpd.static.program_guard(pdpd.static.Program(), pdpd.static.Program()):
node_x = pdpd.static.data(
name='input_x', shape=x.shape, dtype=data_type)
node_y = pdpd.static.data(
name='input_y', shape=y.shape, dtype=data_type)
out = pdpd.fluid.layers.less_than(x=node_x, y=node_y, name='less_than')
# FuzzyTest framework doesn't support boolean so cast to fp32/int32
if cast_to_fp32:
data_type = "float32"
out = pdpd.cast(out, data_type)
cpu = pdpd.static.cpu_places(1)
exe = pdpd.static.Executor(cpu[0])
# startup program will call initializer to initialize the parameters.
exe.run(pdpd.static.default_startup_program())
outs = exe.run(
feed={'input_x': x, 'input_y': y},
fetch_list=[out])
saveModel(name, exe, feedkeys=['input_x', 'input_y'], fetchlist=[out],
inputs=[x, y], outputs=[outs[0]], target_dir=sys.argv[1])
return outs[0]
def main():
test_cases = [
"float32",
"int32",
"int64"
]
for test in test_cases:
x = np.array([0, 1, 2, 3]).astype(test)
y = np.array([1, 0, 2, 4]).astype(test)
if ((test == "float64") or (test == "int64")):
less_than("less_than_" + test, x, y, test, True)
else:
less_than("less_than_" + test, x, y, test, False)
if __name__ == "__main__":
main()

View File

@ -2,41 +2,12 @@
// SPDX-License-Identifier: Apache-2.0 // SPDX-License-Identifier: Apache-2.0
// //
#include <map> #include "elementwise_ops.hpp"
#include "default_opset.hpp"
#include "openvino/frontend/paddle/node_context.hpp"
namespace ov { namespace ov {
namespace frontend { namespace frontend {
namespace paddle { namespace paddle {
namespace op { namespace op {
template <typename T>
NamedOutputs elementwise_ops(const NodeContext& node) {
auto x = node.get_input("X");
auto y = node.get_input("Y");
auto axis = node.get_attribute<int>("axis");
PADDLE_OP_CHECK(node, x.get_partial_shape().rank().is_static(), "elementwise_ops: X rank must be static!");
PADDLE_OP_CHECK(node, y.get_partial_shape().rank().is_static(), "elementwise_ops: Y rank must be static!");
int64_t x_rank = x.get_partial_shape().rank().get_length();
int64_t y_rank = y.get_partial_shape().rank().get_length();
if ((axis == -1) || (axis == x_rank - 1) || (x_rank == y_rank)) {
return node.default_single_output_mapping({std::make_shared<T>(x, y)}, {"Out"});
} else {
std::vector<int64_t> indices;
for (int64_t i = 0; i < axis; i++)
indices.push_back(i);
for (int64_t i = y_rank + axis; i < x_rank; i++)
indices.push_back(i);
auto indices_node = default_opset::Constant::create(ov::element::i64, ov::Shape{indices.size()}, indices);
auto y_node = std::make_shared<default_opset::Unsqueeze>(y, indices_node);
return node.default_single_output_mapping({std::make_shared<T>(x, y_node)}, {"Out"});
}
}
// //
NamedOutputs elementwise_add(const NodeContext& node_context) { NamedOutputs elementwise_add(const NodeContext& node_context) {

View File

@ -0,0 +1,45 @@
// Copyright (C) 2018-2022 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include <map>
#include "default_opset.hpp"
#include "openvino/frontend/paddle/node_context.hpp"
namespace ov {
namespace frontend {
namespace paddle {
namespace op {
template <typename T>
NamedOutputs elementwise_ops(const NodeContext& node) {
auto x = node.get_input("X");
auto y = node.get_input("Y");
auto axis = node.get_attribute<int>("axis");
PADDLE_OP_CHECK(node, x.get_partial_shape().rank().is_static(), "elementwise_ops: X rank must be static!");
PADDLE_OP_CHECK(node, y.get_partial_shape().rank().is_static(), "elementwise_ops: Y rank must be static!");
int64_t x_rank = x.get_partial_shape().rank().get_length();
int64_t y_rank = y.get_partial_shape().rank().get_length();
if ((axis == -1) || (axis == x_rank - 1) || (x_rank == y_rank)) {
return node.default_single_output_mapping({std::make_shared<T>(x, y)}, {"Out"});
} else {
std::vector<int64_t> indices;
for (int64_t i = 0; i < axis; i++)
indices.push_back(i);
for (int64_t i = y_rank + axis; i < x_rank; i++)
indices.push_back(i);
auto indices_node = default_opset::Constant::create(ov::element::i64, ov::Shape{indices.size()}, indices);
auto y_node = std::make_shared<default_opset::Unsqueeze>(y, indices_node);
return node.default_single_output_mapping({std::make_shared<T>(x, y_node)}, {"Out"});
}
}
} // namespace op
} // namespace paddle
} // namespace frontend
} // namespace ov

View File

@ -0,0 +1,19 @@
// Copyright (C) 2018-2022 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include "default_opset.hpp"
#include "openvino/frontend/paddle/node_context.hpp"
namespace ov {
namespace frontend {
namespace paddle {
namespace op {
NamedOutputs floor(const NodeContext& node) {
const auto data_node = node.get_input("X");
return node.default_single_output_mapping({std::make_shared<default_opset::Floor>(data_node)}, {"Out"});
}
} // namespace op
} // namespace paddle
} // namespace frontend
} // namespace ov

View File

@ -0,0 +1,31 @@
// Copyright (C) 2018-2022 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include "default_opset.hpp"
#include "openvino/frontend/paddle/node_context.hpp"
namespace ov {
namespace frontend {
namespace paddle {
namespace op {
NamedOutputs gather(const NodeContext& node) {
const auto data_node = node.get_input("X");
const auto index_node = node.get_input("Index");
Output<Node> axis_node;
if (node.has_input("Axis")) {
axis_node = node.get_input("Axis");
} else {
const auto axis_value = node.get_attribute<int>("axis", 0);
axis_node = default_opset::Constant::create(element::i32, Shape{}, {axis_value});
}
return node.default_single_output_mapping(
{std::make_shared<default_opset::Gather>(data_node, index_node, axis_node)},
{"Out"});
}
} // namespace op
} // namespace paddle
} // namespace frontend
} // namespace ov

View File

@ -0,0 +1,17 @@
// Copyright (C) 2018-2022 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include "elementwise_ops.hpp"
namespace ov {
namespace frontend {
namespace paddle {
namespace op {
NamedOutputs greater_than(const NodeContext& node) {
return elementwise_ops<default_opset::Greater>(node);
}
} // namespace op
} // namespace paddle
} // namespace frontend
} // namespace ov

View File

@ -0,0 +1,17 @@
// Copyright (C) 2018-2022 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include "elementwise_ops.hpp"
namespace ov {
namespace frontend {
namespace paddle {
namespace op {
NamedOutputs less_than(const NodeContext& node) {
return elementwise_ops<default_opset::Less>(node);
}
} // namespace op
} // namespace paddle
} // namespace frontend
} // namespace ov

View File

@ -37,11 +37,15 @@ OP_CONVERTER(fill_any_like);
OP_CONVERTER(fill_constant_batch_size_like); OP_CONVERTER(fill_constant_batch_size_like);
OP_CONVERTER(fill_constant); OP_CONVERTER(fill_constant);
OP_CONVERTER(flatten_contiguous_range); OP_CONVERTER(flatten_contiguous_range);
OP_CONVERTER(floor);
OP_CONVERTER(gather);
OP_CONVERTER(gelu); OP_CONVERTER(gelu);
OP_CONVERTER(greater_than);
OP_CONVERTER(hard_sigmoid); OP_CONVERTER(hard_sigmoid);
OP_CONVERTER(hard_swish); OP_CONVERTER(hard_swish);
OP_CONVERTER(layer_norm); OP_CONVERTER(layer_norm);
OP_CONVERTER(leaky_relu); OP_CONVERTER(leaky_relu);
OP_CONVERTER(less_than);
OP_CONVERTER(linear_interp_v2); OP_CONVERTER(linear_interp_v2);
OP_CONVERTER(log); OP_CONVERTER(log);
OP_CONVERTER(logical_and); OP_CONVERTER(logical_and);
@ -115,12 +119,16 @@ std::map<std::string, CreatorFunction> get_supported_ops() {
{"fill_constant_batch_size_like", op::fill_constant_batch_size_like}, {"fill_constant_batch_size_like", op::fill_constant_batch_size_like},
{"fill_constant", op::fill_constant}, {"fill_constant", op::fill_constant},
{"flatten_contiguous_range", op::flatten_contiguous_range}, {"flatten_contiguous_range", op::flatten_contiguous_range},
{"floor", op::floor},
{"gather", op::gather},
{"gelu", op::gelu}, {"gelu", op::gelu},
{"greater_equal", op::elementwise_greater_equal}, {"greater_equal", op::elementwise_greater_equal},
{"greater_than", op::greater_than},
{"hard_sigmoid", op::hard_sigmoid}, {"hard_sigmoid", op::hard_sigmoid},
{"hard_swish", op::hard_swish}, {"hard_swish", op::hard_swish},
{"layer_norm", op::layer_norm}, {"layer_norm", op::layer_norm},
{"leaky_relu", op::leaky_relu}, {"leaky_relu", op::leaky_relu},
{"less_than", op::less_than},
{"linear_interp_v2", op::linear_interp_v2}, {"linear_interp_v2", op::linear_interp_v2},
{"log", op::log}, {"log", op::log},
{"logical_and", op::logical_and}, {"logical_and", op::logical_and},