Paddle FasterRCNN Ops Conversion: greater_than, less_than, gather, floor (#9657)
* Paddle FasterRCNN Ops Conversion: greater_than, less_than, gather, floor * Apply suggestions from code review * fix 'gather' testcase failure issue on CI * implement 'axis' input for 'Gather' Op conversion with testcase comment;use common function for all elementwise Ops
This commit is contained in:
parent
8f88889876
commit
02c60c76ab
@ -133,12 +133,22 @@ static const std::vector<std::string> models{std::string("argmax"),
|
||||
std::string("fill_constant_shape_tensor"),
|
||||
std::string("fill_constant_shape_tensor_list"),
|
||||
std::string("flatten_contiguous_range_test1"),
|
||||
std::string("floor_float32"),
|
||||
std::string("gather_multi_dimension"),
|
||||
std::string("gather_one_dimension"),
|
||||
std::string("gather_one_dimension2"),
|
||||
// gather_axis_input
|
||||
// (CVS-82724: not support Axis as input),
|
||||
std::string("gelu_erf"),
|
||||
std::string("gelu_tanh"),
|
||||
// greater_equal_big_int64(failure due to CPU inference),
|
||||
std::string("greater_equal_big_int64"),
|
||||
std::string("greater_equal_float32"),
|
||||
std::string("greater_equal_int32"),
|
||||
std::string("greater_equal_int64"),
|
||||
std::string("greater_than_float32"),
|
||||
std::string("greater_than_int32"),
|
||||
std::string("greater_than_int64"),
|
||||
std::string("hard_sigmoid"),
|
||||
std::string("hard_swish"),
|
||||
std::string("layer_norm"),
|
||||
@ -146,6 +156,9 @@ static const std::vector<std::string> models{std::string("argmax"),
|
||||
std::string("layer_norm_noscale"),
|
||||
std::string("layer_norm_noshift"),
|
||||
std::string("leaky_relu"),
|
||||
std::string("less_than_float32"),
|
||||
std::string("less_than_int32"),
|
||||
std::string("less_than_int64"),
|
||||
std::string("linear_downsample_false_0"),
|
||||
std::string("linear_downsample_false_1"),
|
||||
std::string("linear_downsample_true_0"),
|
||||
|
@ -0,0 +1,40 @@
|
||||
#
|
||||
# floor paddle model generator
|
||||
#
|
||||
import numpy as np
|
||||
from save_model import saveModel
|
||||
import paddle
|
||||
import sys
|
||||
|
||||
|
||||
def floor(name: str, x):
|
||||
paddle.enable_static()
|
||||
|
||||
with paddle.static.program_guard(paddle.static.Program(), paddle.static.Program()):
|
||||
data = paddle.static.data(name='x', shape=x.shape, dtype=x.dtype)
|
||||
out = paddle.floor(data)
|
||||
|
||||
cpu = paddle.static.cpu_places(1)
|
||||
exe = paddle.static.Executor(cpu[0])
|
||||
# startup program will call initializer to initialize the parameters.
|
||||
exe.run(paddle.static.default_startup_program())
|
||||
|
||||
outs = exe.run(
|
||||
feed={'x': x},
|
||||
fetch_list=[out])
|
||||
|
||||
saveModel(name, exe, feedkeys=['x'], fetchlist=[out], inputs=[
|
||||
x], outputs=[outs[0]], target_dir=sys.argv[1])
|
||||
|
||||
return outs[0]
|
||||
|
||||
|
||||
def main():
|
||||
data_type = 'float32'
|
||||
x = np.array([-0.4, -0.2, 2.1, 0.3]).astype(data_type)
|
||||
|
||||
floor("floor_float32", x)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
@ -0,0 +1,100 @@
|
||||
#
|
||||
# gather paddle model generator
|
||||
#
|
||||
import numpy as np
|
||||
from save_model import saveModel
|
||||
import paddle
|
||||
import sys
|
||||
|
||||
|
||||
def gather(name: str, x, y, z):
|
||||
paddle.enable_static()
|
||||
|
||||
with paddle.static.program_guard(paddle.static.Program(), paddle.static.Program()):
|
||||
data = paddle.static.data(name='x', shape=x.shape, dtype=x.dtype)
|
||||
index = paddle.static.data(name='index', shape=y.shape, dtype=y.dtype)
|
||||
if (z == None):
|
||||
out = paddle.gather(data, index)
|
||||
else:
|
||||
axis = paddle.static.data(
|
||||
name='axis', shape=z.shape, dtype=z.dtype)
|
||||
out = paddle.gather(data, index, axis)
|
||||
if x.dtype == "int64":
|
||||
out = paddle.cast(out, "float32")
|
||||
|
||||
cpu = paddle.static.cpu_places(1)
|
||||
exe = paddle.static.Executor(cpu[0])
|
||||
# startup program will call initializer to initialize the parameters.
|
||||
exe.run(paddle.static.default_startup_program())
|
||||
if (z == None):
|
||||
outs = exe.run(
|
||||
feed={'x': x, 'index': y},
|
||||
fetch_list=[out])
|
||||
|
||||
saveModel(name, exe, feedkeys=['x', 'index'], fetchlist=[out], inputs=[
|
||||
x, y], outputs=[outs[0]], target_dir=sys.argv[1])
|
||||
else:
|
||||
outs = exe.run(
|
||||
feed={'x': x, 'index': y, 'axis': z},
|
||||
fetch_list=[out])
|
||||
|
||||
saveModel(name, exe, feedkeys=['x', 'index', 'axis'], fetchlist=[
|
||||
out], inputs=[x, y, z], outputs=[outs[0]], target_dir=sys.argv[1])
|
||||
|
||||
return outs[0]
|
||||
|
||||
|
||||
def main():
|
||||
# For multi-dimension input
|
||||
x_shape = (10, 20)
|
||||
x_type = "float32"
|
||||
index = [1, 3, 5]
|
||||
index_type = "int32"
|
||||
|
||||
xnp = np.random.random(x_shape).astype(x_type)
|
||||
index_np = np.array(index).astype(index_type)
|
||||
axis_np = None
|
||||
|
||||
gather("gather_multi_dimension", xnp, index_np, axis_np)
|
||||
|
||||
# For one_dimension input
|
||||
x_shape = (100)
|
||||
x_type = "int64"
|
||||
index = [1, 3, 5]
|
||||
index_type = "int64"
|
||||
|
||||
xnp = np.random.random(x_shape).astype(x_type)
|
||||
index_np = np.array(index).astype(index_type)
|
||||
axis_np = None
|
||||
|
||||
gather("gather_one_dimension", xnp, index_np, axis_np)
|
||||
|
||||
# For one_dimension input2
|
||||
x_shape = (100)
|
||||
x_type = "int64"
|
||||
index = [1, 1, 1, 1, 1, 1, 1, 1, 1, 1]
|
||||
index_type = "int64"
|
||||
|
||||
xnp = np.random.random(x_shape).astype(x_type)
|
||||
index_np = np.array(index).astype(index_type)
|
||||
axis_np = None
|
||||
|
||||
gather("gather_one_dimension2", xnp, index_np, axis_np)
|
||||
|
||||
# For axis as input
|
||||
x_shape = (6, 88, 3)
|
||||
x_type = "float32"
|
||||
index = [1, 3, 5]
|
||||
index_type = "int32"
|
||||
axis = [0]
|
||||
axis_type = "int32"
|
||||
|
||||
xnp = np.random.random(x_shape).astype(x_type)
|
||||
axis_np = np.array(axis).astype(axis_type)
|
||||
index_np = np.array(index).astype(index_type)
|
||||
|
||||
gather("gather_axis_input", xnp, index_np, axis_np)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
@ -0,0 +1,59 @@
|
||||
#
|
||||
# greater_than paddle model generator
|
||||
#
|
||||
import numpy as np
|
||||
from save_model import saveModel
|
||||
import paddle as pdpd
|
||||
import sys
|
||||
|
||||
|
||||
def greater_than(name: str, x, y, data_type, cast_to_fp32=False):
|
||||
pdpd.enable_static()
|
||||
|
||||
with pdpd.static.program_guard(pdpd.static.Program(), pdpd.static.Program()):
|
||||
node_x = pdpd.static.data(
|
||||
name='input_x', shape=x.shape, dtype=data_type)
|
||||
node_y = pdpd.static.data(
|
||||
name='input_y', shape=y.shape, dtype=data_type)
|
||||
out = pdpd.fluid.layers.greater_than(
|
||||
x=node_x, y=node_y, name='greater_than')
|
||||
# FuzzyTest framework doesn't support boolean so cast to fp32/int32
|
||||
|
||||
if cast_to_fp32:
|
||||
data_type = "float32"
|
||||
|
||||
out = pdpd.cast(out, data_type)
|
||||
cpu = pdpd.static.cpu_places(1)
|
||||
exe = pdpd.static.Executor(cpu[0])
|
||||
# startup program will call initializer to initialize the parameters.
|
||||
exe.run(pdpd.static.default_startup_program())
|
||||
|
||||
outs = exe.run(
|
||||
feed={'input_x': x, 'input_y': y},
|
||||
fetch_list=[out])
|
||||
|
||||
saveModel(name, exe, feedkeys=['input_x', 'input_y'], fetchlist=[out],
|
||||
inputs=[x, y], outputs=[outs[0]], target_dir=sys.argv[1])
|
||||
|
||||
return outs[0]
|
||||
|
||||
|
||||
def main():
|
||||
|
||||
test_cases = [
|
||||
"float32",
|
||||
"int32",
|
||||
"int64"
|
||||
]
|
||||
|
||||
for test in test_cases:
|
||||
x = np.array([0, 1, 2, 3]).astype(test)
|
||||
y = np.array([1, 0, 2, 4]).astype(test)
|
||||
if ((test == "float64") or (test == "int64")):
|
||||
greater_than("greater_than_" + test, x, y, test, True)
|
||||
else:
|
||||
greater_than("greater_than_" + test, x, y, test, False)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
@ -0,0 +1,58 @@
|
||||
#
|
||||
# less_than paddle model generator
|
||||
#
|
||||
import numpy as np
|
||||
from save_model import saveModel
|
||||
import paddle as pdpd
|
||||
import sys
|
||||
|
||||
|
||||
def less_than(name: str, x, y, data_type, cast_to_fp32=False):
|
||||
pdpd.enable_static()
|
||||
|
||||
with pdpd.static.program_guard(pdpd.static.Program(), pdpd.static.Program()):
|
||||
node_x = pdpd.static.data(
|
||||
name='input_x', shape=x.shape, dtype=data_type)
|
||||
node_y = pdpd.static.data(
|
||||
name='input_y', shape=y.shape, dtype=data_type)
|
||||
out = pdpd.fluid.layers.less_than(x=node_x, y=node_y, name='less_than')
|
||||
# FuzzyTest framework doesn't support boolean so cast to fp32/int32
|
||||
|
||||
if cast_to_fp32:
|
||||
data_type = "float32"
|
||||
|
||||
out = pdpd.cast(out, data_type)
|
||||
cpu = pdpd.static.cpu_places(1)
|
||||
exe = pdpd.static.Executor(cpu[0])
|
||||
# startup program will call initializer to initialize the parameters.
|
||||
exe.run(pdpd.static.default_startup_program())
|
||||
|
||||
outs = exe.run(
|
||||
feed={'input_x': x, 'input_y': y},
|
||||
fetch_list=[out])
|
||||
|
||||
saveModel(name, exe, feedkeys=['input_x', 'input_y'], fetchlist=[out],
|
||||
inputs=[x, y], outputs=[outs[0]], target_dir=sys.argv[1])
|
||||
|
||||
return outs[0]
|
||||
|
||||
|
||||
def main():
|
||||
|
||||
test_cases = [
|
||||
"float32",
|
||||
"int32",
|
||||
"int64"
|
||||
]
|
||||
|
||||
for test in test_cases:
|
||||
x = np.array([0, 1, 2, 3]).astype(test)
|
||||
y = np.array([1, 0, 2, 4]).astype(test)
|
||||
if ((test == "float64") or (test == "int64")):
|
||||
less_than("less_than_" + test, x, y, test, True)
|
||||
else:
|
||||
less_than("less_than_" + test, x, y, test, False)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
@ -2,41 +2,12 @@
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#include <map>
|
||||
|
||||
#include "default_opset.hpp"
|
||||
#include "openvino/frontend/paddle/node_context.hpp"
|
||||
#include "elementwise_ops.hpp"
|
||||
|
||||
namespace ov {
|
||||
namespace frontend {
|
||||
namespace paddle {
|
||||
namespace op {
|
||||
template <typename T>
|
||||
NamedOutputs elementwise_ops(const NodeContext& node) {
|
||||
auto x = node.get_input("X");
|
||||
auto y = node.get_input("Y");
|
||||
|
||||
auto axis = node.get_attribute<int>("axis");
|
||||
|
||||
PADDLE_OP_CHECK(node, x.get_partial_shape().rank().is_static(), "elementwise_ops: X rank must be static!");
|
||||
PADDLE_OP_CHECK(node, y.get_partial_shape().rank().is_static(), "elementwise_ops: Y rank must be static!");
|
||||
int64_t x_rank = x.get_partial_shape().rank().get_length();
|
||||
int64_t y_rank = y.get_partial_shape().rank().get_length();
|
||||
|
||||
if ((axis == -1) || (axis == x_rank - 1) || (x_rank == y_rank)) {
|
||||
return node.default_single_output_mapping({std::make_shared<T>(x, y)}, {"Out"});
|
||||
} else {
|
||||
std::vector<int64_t> indices;
|
||||
for (int64_t i = 0; i < axis; i++)
|
||||
indices.push_back(i);
|
||||
for (int64_t i = y_rank + axis; i < x_rank; i++)
|
||||
indices.push_back(i);
|
||||
|
||||
auto indices_node = default_opset::Constant::create(ov::element::i64, ov::Shape{indices.size()}, indices);
|
||||
auto y_node = std::make_shared<default_opset::Unsqueeze>(y, indices_node);
|
||||
return node.default_single_output_mapping({std::make_shared<T>(x, y_node)}, {"Out"});
|
||||
}
|
||||
}
|
||||
|
||||
//
|
||||
NamedOutputs elementwise_add(const NodeContext& node_context) {
|
||||
|
45
src/frontends/paddle/src/op/elementwise_ops.hpp
Normal file
45
src/frontends/paddle/src/op/elementwise_ops.hpp
Normal file
@ -0,0 +1,45 @@
|
||||
// Copyright (C) 2018-2022 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#include <map>
|
||||
|
||||
#include "default_opset.hpp"
|
||||
#include "openvino/frontend/paddle/node_context.hpp"
|
||||
|
||||
namespace ov {
|
||||
namespace frontend {
|
||||
namespace paddle {
|
||||
namespace op {
|
||||
|
||||
template <typename T>
|
||||
NamedOutputs elementwise_ops(const NodeContext& node) {
|
||||
auto x = node.get_input("X");
|
||||
auto y = node.get_input("Y");
|
||||
|
||||
auto axis = node.get_attribute<int>("axis");
|
||||
|
||||
PADDLE_OP_CHECK(node, x.get_partial_shape().rank().is_static(), "elementwise_ops: X rank must be static!");
|
||||
PADDLE_OP_CHECK(node, y.get_partial_shape().rank().is_static(), "elementwise_ops: Y rank must be static!");
|
||||
int64_t x_rank = x.get_partial_shape().rank().get_length();
|
||||
int64_t y_rank = y.get_partial_shape().rank().get_length();
|
||||
|
||||
if ((axis == -1) || (axis == x_rank - 1) || (x_rank == y_rank)) {
|
||||
return node.default_single_output_mapping({std::make_shared<T>(x, y)}, {"Out"});
|
||||
} else {
|
||||
std::vector<int64_t> indices;
|
||||
for (int64_t i = 0; i < axis; i++)
|
||||
indices.push_back(i);
|
||||
for (int64_t i = y_rank + axis; i < x_rank; i++)
|
||||
indices.push_back(i);
|
||||
|
||||
auto indices_node = default_opset::Constant::create(ov::element::i64, ov::Shape{indices.size()}, indices);
|
||||
auto y_node = std::make_shared<default_opset::Unsqueeze>(y, indices_node);
|
||||
return node.default_single_output_mapping({std::make_shared<T>(x, y_node)}, {"Out"});
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace op
|
||||
} // namespace paddle
|
||||
} // namespace frontend
|
||||
} // namespace ov
|
19
src/frontends/paddle/src/op/floor.cpp
Normal file
19
src/frontends/paddle/src/op/floor.cpp
Normal file
@ -0,0 +1,19 @@
|
||||
// Copyright (C) 2018-2022 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#include "default_opset.hpp"
|
||||
#include "openvino/frontend/paddle/node_context.hpp"
|
||||
|
||||
namespace ov {
|
||||
namespace frontend {
|
||||
namespace paddle {
|
||||
namespace op {
|
||||
NamedOutputs floor(const NodeContext& node) {
|
||||
const auto data_node = node.get_input("X");
|
||||
return node.default_single_output_mapping({std::make_shared<default_opset::Floor>(data_node)}, {"Out"});
|
||||
}
|
||||
} // namespace op
|
||||
} // namespace paddle
|
||||
} // namespace frontend
|
||||
} // namespace ov
|
31
src/frontends/paddle/src/op/gather.cpp
Normal file
31
src/frontends/paddle/src/op/gather.cpp
Normal file
@ -0,0 +1,31 @@
|
||||
// Copyright (C) 2018-2022 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#include "default_opset.hpp"
|
||||
#include "openvino/frontend/paddle/node_context.hpp"
|
||||
|
||||
namespace ov {
|
||||
namespace frontend {
|
||||
namespace paddle {
|
||||
namespace op {
|
||||
NamedOutputs gather(const NodeContext& node) {
|
||||
const auto data_node = node.get_input("X");
|
||||
const auto index_node = node.get_input("Index");
|
||||
Output<Node> axis_node;
|
||||
|
||||
if (node.has_input("Axis")) {
|
||||
axis_node = node.get_input("Axis");
|
||||
} else {
|
||||
const auto axis_value = node.get_attribute<int>("axis", 0);
|
||||
axis_node = default_opset::Constant::create(element::i32, Shape{}, {axis_value});
|
||||
}
|
||||
|
||||
return node.default_single_output_mapping(
|
||||
{std::make_shared<default_opset::Gather>(data_node, index_node, axis_node)},
|
||||
{"Out"});
|
||||
}
|
||||
} // namespace op
|
||||
} // namespace paddle
|
||||
} // namespace frontend
|
||||
} // namespace ov
|
17
src/frontends/paddle/src/op/greater_than.cpp
Normal file
17
src/frontends/paddle/src/op/greater_than.cpp
Normal file
@ -0,0 +1,17 @@
|
||||
// Copyright (C) 2018-2022 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#include "elementwise_ops.hpp"
|
||||
|
||||
namespace ov {
|
||||
namespace frontend {
|
||||
namespace paddle {
|
||||
namespace op {
|
||||
NamedOutputs greater_than(const NodeContext& node) {
|
||||
return elementwise_ops<default_opset::Greater>(node);
|
||||
}
|
||||
} // namespace op
|
||||
} // namespace paddle
|
||||
} // namespace frontend
|
||||
} // namespace ov
|
17
src/frontends/paddle/src/op/less_than.cpp
Normal file
17
src/frontends/paddle/src/op/less_than.cpp
Normal file
@ -0,0 +1,17 @@
|
||||
// Copyright (C) 2018-2022 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#include "elementwise_ops.hpp"
|
||||
|
||||
namespace ov {
|
||||
namespace frontend {
|
||||
namespace paddle {
|
||||
namespace op {
|
||||
NamedOutputs less_than(const NodeContext& node) {
|
||||
return elementwise_ops<default_opset::Less>(node);
|
||||
}
|
||||
} // namespace op
|
||||
} // namespace paddle
|
||||
} // namespace frontend
|
||||
} // namespace ov
|
@ -37,11 +37,15 @@ OP_CONVERTER(fill_any_like);
|
||||
OP_CONVERTER(fill_constant_batch_size_like);
|
||||
OP_CONVERTER(fill_constant);
|
||||
OP_CONVERTER(flatten_contiguous_range);
|
||||
OP_CONVERTER(floor);
|
||||
OP_CONVERTER(gather);
|
||||
OP_CONVERTER(gelu);
|
||||
OP_CONVERTER(greater_than);
|
||||
OP_CONVERTER(hard_sigmoid);
|
||||
OP_CONVERTER(hard_swish);
|
||||
OP_CONVERTER(layer_norm);
|
||||
OP_CONVERTER(leaky_relu);
|
||||
OP_CONVERTER(less_than);
|
||||
OP_CONVERTER(linear_interp_v2);
|
||||
OP_CONVERTER(log);
|
||||
OP_CONVERTER(logical_and);
|
||||
@ -115,12 +119,16 @@ std::map<std::string, CreatorFunction> get_supported_ops() {
|
||||
{"fill_constant_batch_size_like", op::fill_constant_batch_size_like},
|
||||
{"fill_constant", op::fill_constant},
|
||||
{"flatten_contiguous_range", op::flatten_contiguous_range},
|
||||
{"floor", op::floor},
|
||||
{"gather", op::gather},
|
||||
{"gelu", op::gelu},
|
||||
{"greater_equal", op::elementwise_greater_equal},
|
||||
{"greater_than", op::greater_than},
|
||||
{"hard_sigmoid", op::hard_sigmoid},
|
||||
{"hard_swish", op::hard_swish},
|
||||
{"layer_norm", op::layer_norm},
|
||||
{"leaky_relu", op::leaky_relu},
|
||||
{"less_than", op::less_than},
|
||||
{"linear_interp_v2", op::linear_interp_v2},
|
||||
{"log", op::log},
|
||||
{"logical_and", op::logical_and},
|
||||
|
Loading…
Reference in New Issue
Block a user