Paddle FasterRCNN Ops Conversion: roi_align, strided_slice, where (#10893)
* Paddle FasterRCNN Ops Conversion: roi_align, strided_slice, where * add check for 'aligned' feature of 'roi_align' op; use common function for idx_node in 'striede_slice' op * Apply suggestions from code review * use common funciton for stride_slice and slice, OP_CHECK for 'where' op conversion * Apply suggestions from code review
This commit is contained in:
parent
4057e408d8
commit
070f27a089
@ -256,6 +256,8 @@ static const std::vector<std::string> models{std::string("argmax"),
|
||||
std::string("rnn_lstm_layer_2_forward"),
|
||||
std::string("rnn_lstm_layer_1_forward_seq_len_4"),
|
||||
std::string("rnn_lstm_layer_2_bidirectional_seq_len_4"),
|
||||
std::string("roi_align_test"),
|
||||
std::string("roi_align_test2"),
|
||||
std::string("scale_bias_after_float32"),
|
||||
std::string("scale_bias_after_int32"),
|
||||
std::string("scale_bias_after_int64"),
|
||||
@ -290,6 +292,15 @@ static const std::vector<std::string> models{std::string("argmax"),
|
||||
std::string("stack_test_int32"),
|
||||
std::string("stack_test_neg_axis"),
|
||||
std::string("stack_test_none_axis"),
|
||||
std::string("strided_slice_input1_1"),
|
||||
std::string("strided_slice_input1_2"),
|
||||
std::string("strided_slice_input1_3"),
|
||||
std::string("strided_slice_input1_4"),
|
||||
std::string("strided_slice_input2_1"),
|
||||
std::string("strided_slice_input2_2"),
|
||||
std::string("strided_slice_input2_3"),
|
||||
std::string("strided_slice_input3_1"),
|
||||
std::string("strided_slice_input3_2"),
|
||||
std::string("tanh"),
|
||||
std::string("trilinear_downsample_false_0"),
|
||||
std::string("trilinear_downsample_false_1"),
|
||||
@ -300,6 +311,9 @@ static const std::vector<std::string> models{std::string("argmax"),
|
||||
std::string("trilinear_upsample_scales2"),
|
||||
std::string("trilinear_upsample_true_0"),
|
||||
std::string("unsqueeze"),
|
||||
std::string("where_1"),
|
||||
std::string("where_2"),
|
||||
std::string("where_3"),
|
||||
// Temporily disable them until root caused to secure CI stable.
|
||||
// CVS-66703 to track this.
|
||||
// std::string("yolo_box_clip_box"),
|
||||
|
@ -0,0 +1,114 @@
|
||||
# Copyright (C) 2018-2022 Intel Corporation
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
#
|
||||
# roi_align paddle model generator
|
||||
#
|
||||
import numpy as np
|
||||
from save_model import saveModel
|
||||
import paddle
|
||||
import sys
|
||||
|
||||
|
||||
def make_rois(batch_size, width, height, pooled_width, pooled_height, spatial_scale, roi_per_batch):
|
||||
rois = []
|
||||
rois_num = []
|
||||
for bno in range(batch_size):
|
||||
for i in range(roi_per_batch):
|
||||
x1 = np.random.randint(
|
||||
0, width // spatial_scale - pooled_width)
|
||||
y1 = np.random.randint(
|
||||
0, height // spatial_scale - pooled_height)
|
||||
|
||||
x2 = np.random.randint(x1 + pooled_width,
|
||||
width // spatial_scale)
|
||||
y2 = np.random.randint(
|
||||
y1 + pooled_height, height // spatial_scale)
|
||||
|
||||
roi = [x1, y1, x2, y2]
|
||||
rois.append(roi)
|
||||
rois_num.append(len(rois))
|
||||
rois = np.array(rois).astype("float32")
|
||||
rois_num = np.array(rois_num).astype("int32")
|
||||
|
||||
return rois, rois_num
|
||||
|
||||
|
||||
def roi_align(name: str, x_data, rois_data, rois_num_data, pooled_height, pooled_width, spatial_scale, sampling_ratio):
|
||||
paddle.enable_static()
|
||||
|
||||
with paddle.static.program_guard(paddle.static.Program(), paddle.static.Program()):
|
||||
x = paddle.static.data(
|
||||
name='x', shape=x_data.shape, dtype=x_data.dtype)
|
||||
rois = paddle.static.data(
|
||||
name='rois', shape=rois_data.shape, dtype=rois_data.dtype)
|
||||
rois_num = paddle.static.data(
|
||||
name='rois_num', shape=rois_num_data.shape, dtype=rois_num_data.dtype)
|
||||
# TODO: 'aligned' attribute is not supported by Paddle 2.1
|
||||
out = paddle.fluid.layers.roi_align(input=x,
|
||||
rois=rois,
|
||||
pooled_height=pooled_height,
|
||||
pooled_width=pooled_width,
|
||||
spatial_scale=spatial_scale,
|
||||
sampling_ratio=sampling_ratio,
|
||||
rois_num=rois_num)
|
||||
|
||||
cpu = paddle.static.cpu_places(1)
|
||||
exe = paddle.static.Executor(cpu[0])
|
||||
# startup program will call initializer to initialize the parameters.
|
||||
exe.run(paddle.static.default_startup_program())
|
||||
|
||||
outs = exe.run(
|
||||
feed={'x': x_data, 'rois': rois_data, 'rois_num': rois_num_data},
|
||||
fetch_list=[out])
|
||||
|
||||
saveModel(name, exe, feedkeys=['x', 'rois', 'rois_num'], fetchlist=[out], inputs=[
|
||||
x_data, rois_data, rois_num_data], outputs=[outs[0]], target_dir=sys.argv[1])
|
||||
|
||||
return outs[0]
|
||||
|
||||
|
||||
def main():
|
||||
batch_size = 1
|
||||
channels = 3
|
||||
height = 8
|
||||
width = 6
|
||||
|
||||
x_dim = (batch_size, channels, height, width)
|
||||
x = np.random.random(x_dim).astype('float32')
|
||||
|
||||
spatial_scale = 1.0 / 2.0
|
||||
pooled_height = 2
|
||||
pooled_width = 2
|
||||
sampling_ratio = -1
|
||||
|
||||
roi_per_batch = 1
|
||||
rois, rois_num = make_rois(batch_size, width, height, pooled_width,
|
||||
pooled_height, spatial_scale, roi_per_batch)
|
||||
|
||||
roi_align("roi_align_test", x, rois, rois_num, pooled_height,
|
||||
pooled_width, spatial_scale, sampling_ratio)
|
||||
|
||||
batch_size = 1
|
||||
channels = 3
|
||||
height = 8
|
||||
width = 6
|
||||
|
||||
x_dim = (batch_size, channels, height, width)
|
||||
x = np.random.random(x_dim).astype('float32')
|
||||
|
||||
spatial_scale = 1.0 / 2.0
|
||||
pooled_height = 2
|
||||
pooled_width = 2
|
||||
sampling_ratio = 2
|
||||
|
||||
roi_per_batch = 2
|
||||
rois, rois_num = make_rois(batch_size, width, height, pooled_width,
|
||||
pooled_height, spatial_scale, roi_per_batch)
|
||||
|
||||
roi_align("roi_align_test2", x, rois, rois_num, pooled_height,
|
||||
pooled_width, spatial_scale, sampling_ratio)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
@ -0,0 +1,134 @@
|
||||
# Copyright (C) 2018-2022 Intel Corporation
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
#
|
||||
# strided_slice paddle model generator
|
||||
#
|
||||
import numpy as np
|
||||
from save_model import saveModel
|
||||
import sys
|
||||
|
||||
|
||||
def strided_slice(name: str, input_data, attrs: dict):
|
||||
import paddle
|
||||
paddle.enable_static()
|
||||
|
||||
with paddle.static.program_guard(paddle.static.Program(), paddle.static.Program()):
|
||||
Input = paddle.static.data(
|
||||
name='x', shape=input_data.shape, dtype=input_data.dtype)
|
||||
|
||||
out = paddle.fluid.layers.strided_slice(Input, axes=attrs['axes'],
|
||||
starts=attrs['starts'],
|
||||
ends=attrs['ends'],
|
||||
strides=attrs['strides'])
|
||||
|
||||
cpu = paddle.static.cpu_places(1)
|
||||
exe = paddle.static.Executor(cpu[0])
|
||||
# startup program will call initializer to initialize the parameters.
|
||||
exe.run(paddle.static.default_startup_program())
|
||||
|
||||
outs = exe.run(
|
||||
feed={'x': input_data},
|
||||
fetch_list=[out])
|
||||
|
||||
# Save inputs in order of ngraph function, to facilite Fuzzy test,
|
||||
# which accepts inputs and outputs in this order as well.
|
||||
saveModel(name, exe, feedkeys=['x'], fetchlist=[out],
|
||||
inputs=[input_data], outputs=[outs[0]], target_dir=sys.argv[1])
|
||||
return outs
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
strided_slice_input1_1 = {
|
||||
'name': "strided_slice_input1_1",
|
||||
'axes': np.array([0]).astype('int32').tolist(),
|
||||
'starts': np.array([-4]).astype('int32').tolist(),
|
||||
'ends': np.array([-3]).astype('int32').tolist(),
|
||||
'strides': np.array([1]).astype('int32').tolist()
|
||||
}
|
||||
|
||||
strided_slice_input1_2 = {
|
||||
'name': "strided_slice_input1_2",
|
||||
'axes': np.array([0]).astype('int32').tolist(),
|
||||
'starts': np.array([3]).astype('int32').tolist(),
|
||||
'ends': np.array([8]).astype('int32').tolist(),
|
||||
'strides': np.array([1]).astype('int32').tolist()
|
||||
}
|
||||
|
||||
strided_slice_input1_3 = {
|
||||
'name': "strided_slice_input1_3",
|
||||
'axes': np.array([0]).astype('int32').tolist(),
|
||||
'starts': np.array([5]).astype('int32').tolist(),
|
||||
'ends': np.array([0]).astype('int32').tolist(),
|
||||
'strides': np.array([-1]).astype('int32').tolist()
|
||||
}
|
||||
|
||||
strided_slice_input1_4 = {
|
||||
'name': "strided_slice_input1_4",
|
||||
'axes': np.array([0]).astype('int32').tolist(),
|
||||
'starts': np.array([-1]).astype('int32').tolist(),
|
||||
'ends': np.array([-3]).astype('int32').tolist(),
|
||||
'strides': np.array([-1]).astype('int32').tolist()
|
||||
}
|
||||
|
||||
strided_slice_input2_1 = {
|
||||
'name': "strided_slice_input2_1",
|
||||
'axes': np.array([0, 1, 2]).astype('int32').tolist(),
|
||||
'starts': np.array([1, 0, 0]).astype('int32').tolist(),
|
||||
'ends': np.array([2, 1, 3]).astype('int32').tolist(),
|
||||
'strides': np.array([1, 1, 1]).astype('int32').tolist()
|
||||
}
|
||||
|
||||
strided_slice_input2_2 = {
|
||||
'name': "strided_slice_input2_2",
|
||||
'axes': np.array([0, 1, 2]).astype('int32').tolist(),
|
||||
'starts': np.array([1, -1, 0]).astype('int32').tolist(),
|
||||
'ends': np.array([2, -3, 3]).astype('int32').tolist(),
|
||||
'strides': np.array([1, -1, 1]).astype('int32').tolist()
|
||||
}
|
||||
|
||||
strided_slice_input2_3 = {
|
||||
'name': "strided_slice_input2_3",
|
||||
'axes': np.array([0, 1, 2]).astype('int32').tolist(),
|
||||
'starts': np.array([1, 0, 0]).astype('int32').tolist(),
|
||||
'ends': np.array([2, 2, 3]).astype('int32').tolist(),
|
||||
'strides': np.array([1, 1, 1]).astype('int32').tolist()
|
||||
}
|
||||
|
||||
strided_slice_input3_1 = {
|
||||
'name': "strided_slice_input3_1",
|
||||
'axes': np.array([1]).astype('int32').tolist(),
|
||||
'starts': np.array([1]).astype('int32').tolist(),
|
||||
'ends': np.array([2]).astype('int32').tolist(),
|
||||
'strides': np.array([1]).astype('int32').tolist()
|
||||
}
|
||||
|
||||
strided_slice_input3_2 = {
|
||||
'name': "strided_slice_input3_2",
|
||||
'axes': np.array([1]).astype('int32').tolist(),
|
||||
'starts': np.array([-1]).astype('int32').tolist(),
|
||||
'ends': np.array([-2]).astype('int32').tolist(),
|
||||
'strides': np.array([-1]).astype('int32').tolist()
|
||||
}
|
||||
|
||||
strided_slice_input1_list = [strided_slice_input1_1,
|
||||
strided_slice_input1_2, strided_slice_input1_3, strided_slice_input1_4]
|
||||
|
||||
strided_slice_input2_list = [strided_slice_input2_1,
|
||||
strided_slice_input2_2, strided_slice_input2_3]
|
||||
|
||||
strided_slice_input3_list = [
|
||||
strided_slice_input3_1, strided_slice_input3_2]
|
||||
|
||||
input1 = np.random.rand(100).astype('float32')
|
||||
for item in strided_slice_input1_list:
|
||||
pred_paddle = strided_slice(item['name'], input1, item)
|
||||
|
||||
input2 = np.random.rand(5, 5, 5).astype('int32')
|
||||
for item in strided_slice_input2_list:
|
||||
pred_paddle = strided_slice(item['name'], input2, item)
|
||||
|
||||
input3 = np.random.rand(1, 100, 1).astype('float32')
|
||||
for item in strided_slice_input3_list:
|
||||
pred_paddle = strided_slice(item['name'], input3, item)
|
@ -0,0 +1,69 @@
|
||||
# Copyright (C) 2018-2022 Intel Corporation
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
#
|
||||
# where paddle model generator
|
||||
#
|
||||
import numpy as np
|
||||
from save_model import saveModel
|
||||
import sys
|
||||
|
||||
|
||||
def where(name, test_x, test_y, test_cond):
|
||||
import paddle
|
||||
paddle.enable_static()
|
||||
main_program = paddle.static.Program()
|
||||
startup_program = paddle.static.Program()
|
||||
with paddle.static.program_guard(main_program, startup_program):
|
||||
X_Node = paddle.static.data(
|
||||
name='x', shape=test_x.shape, dtype=test_x.dtype)
|
||||
Y_Node = paddle.static.data(
|
||||
name='y', shape=test_y.shape, dtype=test_y.dtype)
|
||||
Cond_Node = paddle.static.data(
|
||||
name='cond', shape=test_cond.shape, dtype=test_cond.dtype)
|
||||
|
||||
Cond_Node_bl = paddle.fluid.layers.cast(Cond_Node, "bool")
|
||||
|
||||
out = paddle.where(Cond_Node_bl, X_Node, Y_Node)
|
||||
cpu = paddle.static.cpu_places(1)
|
||||
exe = paddle.static.Executor(cpu[0])
|
||||
# startup program will call initializer to initialize the parameters.
|
||||
exe.run(paddle.static.default_startup_program())
|
||||
|
||||
outs = exe.run(
|
||||
feed={'x': test_x, 'y': test_y, 'cond': test_cond},
|
||||
fetch_list=[out]
|
||||
)
|
||||
|
||||
saveModel(name, exe, feedkeys=['x', 'y', 'cond'], fetchlist=[out], inputs=[
|
||||
test_x, test_y, test_cond], outputs=[outs[0]], target_dir=sys.argv[1])
|
||||
|
||||
|
||||
def main():
|
||||
|
||||
test_cases = [
|
||||
{
|
||||
"name": "where_1",
|
||||
"x": np.random.uniform(-3, 5, (100)).astype("float32"),
|
||||
"y": np.random.uniform(-3, 5, (100)).astype("float32"),
|
||||
"cond": np.zeros((100)).astype("int32")
|
||||
},
|
||||
{
|
||||
"name": "where_2",
|
||||
"x": np.random.uniform(-5, 5, (60, 2)).astype("int32"),
|
||||
"y": np.random.uniform(-5, 5, (60, 2)).astype("int32"),
|
||||
"cond": np.ones((60, 2)).astype("int32")
|
||||
},
|
||||
{
|
||||
"name": "where_3",
|
||||
"x": np.random.uniform(-3, 5, (20, 2, 4)).astype("float32"),
|
||||
"y": np.random.uniform(-3, 5, (20, 2, 4)).astype("float32"),
|
||||
"cond": np.array(np.random.randint(2, size=(20, 2, 4)), dtype="int32")
|
||||
}
|
||||
]
|
||||
for test in test_cases:
|
||||
where(test['name'], test['x'], test['y'], test['cond'])
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
52
src/frontends/paddle/src/op/roi_align.cpp
Normal file
52
src/frontends/paddle/src/op/roi_align.cpp
Normal file
@ -0,0 +1,52 @@
|
||||
// Copyright (C) 2018-2022 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#include "default_opset.hpp"
|
||||
#include "openvino/frontend/paddle/node_context.hpp"
|
||||
|
||||
namespace ov {
|
||||
namespace frontend {
|
||||
namespace paddle {
|
||||
namespace op {
|
||||
NamedOutputs roi_align(const NodeContext& node) {
|
||||
const auto data_node = node.get_input("X");
|
||||
const auto roi_node = node.get_input("ROIs");
|
||||
// TODO: support 'aligned' feature #82319
|
||||
const auto aligned = node.get_attribute("aligned", false);
|
||||
PADDLE_OP_CHECK(node, !aligned, "OpenVINO not support 'aligned' feature!");
|
||||
|
||||
// TODO: support multiple batches #83232
|
||||
if (data_node.get_partial_shape().rank().is_static() && data_node.get_partial_shape()[0].is_static())
|
||||
PADDLE_OP_CHECK(node, data_node.get_partial_shape()[0] == 1, "roi_align currenty only support batch_size = 1!");
|
||||
|
||||
const auto roi_node_shape = std::make_shared<default_opset::ShapeOf>(roi_node, element::i32);
|
||||
const auto start = default_opset::Constant::create(element::i64, {1}, {0});
|
||||
const auto stop = default_opset::Constant::create(element::i64, {1}, {1});
|
||||
const auto step = default_opset::Constant::create(element::i64, {1}, {1});
|
||||
const auto roisNum = std::make_shared<default_opset::Slice>(roi_node_shape, start, stop, step);
|
||||
|
||||
const auto zero_const = std::make_shared<default_opset::Constant>(element::i32, Shape{1}, 0);
|
||||
const auto fake_roisNum_node = std::make_shared<default_opset::Broadcast>(zero_const, roisNum);
|
||||
|
||||
const auto pooled_h = node.get_attribute<int>("pooled_height", 1);
|
||||
const auto pooled_w = node.get_attribute<int>("pooled_width", 1);
|
||||
const auto spatial_scale = node.get_attribute<float>("spatial_scale", 1.0);
|
||||
auto sampling_ratio = node.get_attribute<int>("sampling_ratio", -1);
|
||||
sampling_ratio = (sampling_ratio <= 0) ? 0 : sampling_ratio;
|
||||
|
||||
// Paddle only use 'avg' interpolation mode
|
||||
return node.default_single_output_mapping({std::make_shared<default_opset::ROIAlign>(data_node,
|
||||
roi_node,
|
||||
fake_roisNum_node,
|
||||
pooled_h,
|
||||
pooled_w,
|
||||
sampling_ratio,
|
||||
spatial_scale,
|
||||
"avg")},
|
||||
{"Out"});
|
||||
}
|
||||
} // namespace op
|
||||
} // namespace paddle
|
||||
} // namespace frontend
|
||||
} // namespace ov
|
@ -1,103 +1,14 @@
|
||||
// Copyright (C) 2018-2022 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#include <limits.h>
|
||||
|
||||
#include "default_opset.hpp"
|
||||
#include "openvino/frontend/paddle/node_context.hpp"
|
||||
#include "slice_ops.hpp"
|
||||
|
||||
namespace ov {
|
||||
namespace frontend {
|
||||
namespace paddle {
|
||||
namespace op {
|
||||
using namespace default_opset;
|
||||
NamedOutputs slice(const NodeContext& node) {
|
||||
auto data = node.get_input("Input");
|
||||
auto axes = node.get_attribute<std::vector<int32_t>>("axes");
|
||||
Output<Node> start_idx_node, end_idx_node;
|
||||
if (node.has_input("StartsTensor")) {
|
||||
start_idx_node = node.get_input("StartsTensor");
|
||||
} else if (node.has_input("StartsTensorList")) {
|
||||
auto inputs = node.get_ng_inputs("StartsTensorList");
|
||||
start_idx_node = std::make_shared<Concat>(inputs, 0);
|
||||
} else {
|
||||
auto starts = node.get_attribute<std::vector<int32_t>>("starts");
|
||||
start_idx_node = Constant::create(element::i32, {starts.size()}, starts);
|
||||
}
|
||||
|
||||
if (node.has_input("EndsTensor")) {
|
||||
end_idx_node = node.get_input("EndsTensor");
|
||||
} else if (node.has_input("EndsTensorList")) {
|
||||
auto inputs = node.get_ng_inputs("EndsTensorList");
|
||||
end_idx_node = std::make_shared<Concat>(inputs, 0);
|
||||
} else {
|
||||
auto ends = node.get_attribute<std::vector<int32_t>>("ends");
|
||||
end_idx_node = Constant::create(element::i32, {ends.size()}, ends);
|
||||
}
|
||||
|
||||
// The following process is:
|
||||
// Given:
|
||||
// data = [ [1, 2, 3, 4], [5, 6, 7, 8], ] // shape is: [2, 4]
|
||||
// axes = [0]
|
||||
// starts = [1]
|
||||
// ends = [2]
|
||||
// Our process is:
|
||||
// 1. Get 'axes': [0, 1], 'starts', 'ends'
|
||||
// 2. Get data shape: [2,4] and dims: 2
|
||||
// 3. Create two tensor t1 and t2, shape is the dims from step2: 2. t1: [0, 0], t2: [INT_MAX, INT_MAX]
|
||||
// 4. Use 'ScatterNDUpdate' to update some elements in t1, the updated indexes are coming from 'axes', the contents
|
||||
// are coming from 'starts', t1: [1, 0]; apply the similar process to t2
|
||||
// 5. Call 'StrideSlice' with t1 and t2
|
||||
// Why using ScatterNDUpdate is that 'axes' may be discontinuous.
|
||||
|
||||
// the shape of input, such as [2, 4]
|
||||
auto shape_node = std::make_shared<ShapeOf>(data, element::Type_t::i32);
|
||||
// the input dim, such as [2]
|
||||
auto shape_shape_node = std::make_shared<ShapeOf>(shape_node, element::i32);
|
||||
auto const_0_node = Constant::create(element::i32, {}, {0});
|
||||
auto const_max_node = Constant::create(element::i32, {}, {INT_MAX});
|
||||
// t1: [0, 0]
|
||||
auto start_node = std::make_shared<Broadcast>(const_0_node, shape_shape_node);
|
||||
// t2: [INT_MAX, INT_MAX]
|
||||
auto end_node = std::make_shared<Broadcast>(const_max_node, shape_shape_node);
|
||||
auto axes_node = Constant::create(element::i32, {axes.size(), 1}, axes);
|
||||
// update t1
|
||||
auto fixed_start_node = std::make_shared<ScatterNDUpdate>(start_node, axes_node, start_idx_node);
|
||||
// update t2
|
||||
auto fixed_end_node = std::make_shared<ScatterNDUpdate>(end_node, axes_node, end_idx_node);
|
||||
|
||||
auto stride_slice_node = std::make_shared<StridedSlice>(data,
|
||||
fixed_start_node,
|
||||
fixed_end_node,
|
||||
std::vector<int64_t>{0},
|
||||
std::vector<int64_t>{0});
|
||||
|
||||
auto decrease_axis = node.get_attribute<std::vector<int32_t>>("decrease_axis");
|
||||
|
||||
if (decrease_axis.size() > 0) {
|
||||
// according to paddle slice_op, when all axes are decreased, output shape is [1], instead of scalar.
|
||||
// Ref: paddle/fluid/operators/slice_op.h
|
||||
PartialShape input_shape = data.get_partial_shape();
|
||||
PADDLE_OP_CHECK(node,
|
||||
input_shape.rank().is_static(),
|
||||
"input rank of slice must be static when decrease_axis is set.");
|
||||
|
||||
auto squeeze_index_node = Constant::create(element::i32, {decrease_axis.size()}, decrease_axis);
|
||||
auto decreased_node = std::make_shared<Squeeze>(stride_slice_node, squeeze_index_node);
|
||||
|
||||
auto input_rank = input_shape.rank().get_length();
|
||||
if (input_rank == decrease_axis.size()) {
|
||||
auto restore_node = std::make_shared<Reshape>(decreased_node,
|
||||
std::make_shared<Constant>(element::i64, Shape{1}, 1),
|
||||
false); // restore to shape (1,)
|
||||
return node.default_single_output_mapping({restore_node}, {"Out"});
|
||||
}
|
||||
|
||||
return node.default_single_output_mapping({decreased_node}, {"Out"});
|
||||
}
|
||||
|
||||
return node.default_single_output_mapping({stride_slice_node}, {"Out"});
|
||||
return slice_op(node, false);
|
||||
}
|
||||
} // namespace op
|
||||
} // namespace paddle
|
||||
|
123
src/frontends/paddle/src/op/slice_ops.hpp
Normal file
123
src/frontends/paddle/src/op/slice_ops.hpp
Normal file
@ -0,0 +1,123 @@
|
||||
// Copyright (C) 2018-2022 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
#include <limits.h>
|
||||
|
||||
#include "default_opset.hpp"
|
||||
#include "openvino/frontend/paddle/node_context.hpp"
|
||||
|
||||
namespace ov {
|
||||
namespace frontend {
|
||||
namespace paddle {
|
||||
namespace op {
|
||||
namespace {
|
||||
Output<Node> idx_node(const std::string& tensor_alias,
|
||||
const std::string& list_alias,
|
||||
const std::string& attr_alias,
|
||||
const NodeContext& node) {
|
||||
if (node.has_input(tensor_alias)) {
|
||||
return std::make_shared<default_opset::Convert>(node.get_input(tensor_alias), element::i32);
|
||||
} else if (node.has_input(list_alias)) {
|
||||
auto inputs = node.get_ng_inputs(list_alias);
|
||||
return std::make_shared<default_opset::Convert>(std::make_shared<default_opset::Concat>(inputs, 0),
|
||||
element::i32);
|
||||
} else {
|
||||
auto values = node.get_attribute<std::vector<int32_t>>(attr_alias);
|
||||
return default_opset::Constant::create(element::i32, {values.size()}, values);
|
||||
}
|
||||
}
|
||||
NamedOutputs slice_op(const NodeContext& node, const bool& stride_input) {
|
||||
const auto data = node.get_input("Input");
|
||||
const auto axes = node.get_attribute<std::vector<int32_t>>("axes");
|
||||
|
||||
Output<Node> start_idx_node = idx_node("StartsTensor", "StartsTensorList", "starts", node);
|
||||
Output<Node> end_idx_node = idx_node("EndsTensor", "EndsTensorList", "ends", node);
|
||||
Output<Node> strides_idx_node;
|
||||
if (stride_input)
|
||||
strides_idx_node = idx_node("StridesTensor", "StridesTensorList", "strides", node);
|
||||
|
||||
// The following process is:
|
||||
// Given:
|
||||
// data = [ [1, 2, 3, 4], [5, 6, 7, 8], ] // shape is: [2, 4]
|
||||
// axes = [0]
|
||||
// starts = [1]
|
||||
// ends = [2]
|
||||
// Our process is:
|
||||
// 1. Get 'axes': [0, 1], 'starts', 'ends'
|
||||
// 2. Get data shape: [2,4] and dims: 2
|
||||
// 3. Create two tensor t1 and t2, shape is the dims from step2: 2. t1: [0, 0], t2: [INT_MAX, INT_MAX]
|
||||
// 4. Use 'ScatterNDUpdate' to update some elements in t1, the updated indexes are coming from 'axes', the contents
|
||||
// are coming from 'starts', t1: [1, 0]; apply the similar process to t2
|
||||
// 5. Call 'StrideSlice' with t1 and t2
|
||||
// Why using ScatterNDUpdate is that 'axes' may be discontinuous.
|
||||
|
||||
// the shape of input, such as [2, 4]
|
||||
const auto shape_node = std::make_shared<default_opset::ShapeOf>(data, element::Type_t::i32);
|
||||
// the input dim, such as [2]
|
||||
const auto rank_node = std::make_shared<default_opset::ShapeOf>(shape_node, element::i32);
|
||||
const auto const_0_node = default_opset::Constant::create(element::i32, {}, {0});
|
||||
const auto const_max_node = default_opset::Constant::create(element::i32, {}, {INT_MAX});
|
||||
const auto const_1_node = default_opset::Constant::create(element::i32, {}, {1});
|
||||
// t1: [0, 0]
|
||||
const auto start_node = std::make_shared<default_opset::Broadcast>(const_0_node, rank_node);
|
||||
// t2: [INT_MAX, INT_MAX]
|
||||
const auto end_node = std::make_shared<default_opset::Broadcast>(const_max_node, rank_node);
|
||||
const auto strides_node = std::make_shared<default_opset::Broadcast>(const_1_node, rank_node);
|
||||
const auto axes_node = default_opset::Constant::create(element::i32, {axes.size(), 1}, axes);
|
||||
// update t1
|
||||
const auto fixed_start_node =
|
||||
std::make_shared<default_opset::ScatterNDUpdate>(start_node, axes_node, start_idx_node);
|
||||
// update t2
|
||||
const auto fixed_end_node = std::make_shared<default_opset::ScatterNDUpdate>(end_node, axes_node, end_idx_node);
|
||||
std::shared_ptr<Node> stride_slice_node;
|
||||
if (stride_input) {
|
||||
const auto fixed_strides_node =
|
||||
std::make_shared<default_opset::ScatterNDUpdate>(strides_node, axes_node, strides_idx_node);
|
||||
|
||||
stride_slice_node = std::make_shared<default_opset::StridedSlice>(data,
|
||||
fixed_start_node,
|
||||
fixed_end_node,
|
||||
fixed_strides_node,
|
||||
std::vector<int64_t>{0},
|
||||
std::vector<int64_t>{0});
|
||||
} else {
|
||||
stride_slice_node = std::make_shared<default_opset::StridedSlice>(data,
|
||||
fixed_start_node,
|
||||
fixed_end_node,
|
||||
std::vector<int64_t>{0},
|
||||
std::vector<int64_t>{0});
|
||||
}
|
||||
|
||||
const auto decrease_axis = node.get_attribute<std::vector<int32_t>>("decrease_axis");
|
||||
|
||||
if (decrease_axis.size() > 0) {
|
||||
// according to paddle slice_op, when all axes are decreased, output shape is [1], instead of scalar.
|
||||
// Ref: paddle/fluid/operators/slice_op.h
|
||||
PartialShape input_shape = data.get_partial_shape();
|
||||
PADDLE_OP_CHECK(node,
|
||||
input_shape.rank().is_static(),
|
||||
"input rank of slice must be static when decrease_axis is set.");
|
||||
|
||||
const auto squeeze_index_node =
|
||||
default_opset::Constant::create(element::i32, {decrease_axis.size()}, decrease_axis);
|
||||
const auto decreased_node = std::make_shared<default_opset::Squeeze>(stride_slice_node, squeeze_index_node);
|
||||
|
||||
const auto input_rank = input_shape.rank().get_length();
|
||||
if (input_rank == decrease_axis.size()) {
|
||||
auto restore_node = std::make_shared<default_opset::Reshape>(
|
||||
decreased_node,
|
||||
std::make_shared<default_opset::Constant>(element::i64, Shape{1}, 1),
|
||||
false); // restore to shape (1,)
|
||||
return node.default_single_output_mapping({restore_node}, {"Out"});
|
||||
}
|
||||
|
||||
return node.default_single_output_mapping({decreased_node}, {"Out"});
|
||||
}
|
||||
|
||||
return node.default_single_output_mapping({stride_slice_node}, {"Out"});
|
||||
}
|
||||
} // namespace
|
||||
} // namespace op
|
||||
} // namespace paddle
|
||||
} // namespace frontend
|
||||
} // namespace ov
|
16
src/frontends/paddle/src/op/strided_slice.cpp
Normal file
16
src/frontends/paddle/src/op/strided_slice.cpp
Normal file
@ -0,0 +1,16 @@
|
||||
// Copyright (C) 2018-2022 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
#include "slice_ops.hpp"
|
||||
|
||||
namespace ov {
|
||||
namespace frontend {
|
||||
namespace paddle {
|
||||
namespace op {
|
||||
NamedOutputs strided_slice(const NodeContext& node) {
|
||||
return slice_op(node, true);
|
||||
}
|
||||
} // namespace op
|
||||
} // namespace paddle
|
||||
} // namespace frontend
|
||||
} // namespace ov
|
28
src/frontends/paddle/src/op/where.cpp
Normal file
28
src/frontends/paddle/src/op/where.cpp
Normal file
@ -0,0 +1,28 @@
|
||||
// Copyright (C) 2018-2022 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#include "default_opset.hpp"
|
||||
#include "openvino/frontend/paddle/node_context.hpp"
|
||||
|
||||
namespace ov {
|
||||
namespace frontend {
|
||||
namespace paddle {
|
||||
namespace op {
|
||||
NamedOutputs where(const NodeContext& node) {
|
||||
const auto condition_node = node.get_input("Condition");
|
||||
const auto x_node = node.get_input("X");
|
||||
const auto y_node = node.get_input("Y");
|
||||
// TODO: support 'shape x != shape y' #83233
|
||||
const auto x_shape = x_node.get_partial_shape();
|
||||
const auto y_shape = y_node.get_partial_shape();
|
||||
PADDLE_OP_CHECK(node, x_shape.compatible(y_shape), "shape x should be compatible to shape y!");
|
||||
|
||||
return node.default_single_output_mapping(
|
||||
{std::make_shared<default_opset::Select>(condition_node, x_node, y_node, ov::op::AutoBroadcastType::PDPD)},
|
||||
{"Out"});
|
||||
}
|
||||
} // namespace op
|
||||
} // namespace paddle
|
||||
} // namespace frontend
|
||||
} // namespace ov
|
@ -71,6 +71,7 @@ OP_CONVERTER(relu);
|
||||
OP_CONVERTER(relu6);
|
||||
OP_CONVERTER(reshape2);
|
||||
OP_CONVERTER(rnn);
|
||||
OP_CONVERTER(roi_align);
|
||||
OP_CONVERTER(scale);
|
||||
OP_CONVERTER(shape);
|
||||
OP_CONVERTER(slice);
|
||||
@ -80,10 +81,12 @@ OP_CONVERTER(sigmoid);
|
||||
OP_CONVERTER(split);
|
||||
OP_CONVERTER(squeeze);
|
||||
OP_CONVERTER(stack);
|
||||
OP_CONVERTER(strided_slice);
|
||||
OP_CONVERTER(tanh);
|
||||
OP_CONVERTER(transpose2);
|
||||
OP_CONVERTER(trilinear_interp_v2);
|
||||
OP_CONVERTER(unsqueeze);
|
||||
OP_CONVERTER(where);
|
||||
OP_CONVERTER(yolo_box);
|
||||
} // namespace op
|
||||
std::map<std::string, CreatorFunction> get_supported_ops() {
|
||||
@ -157,6 +160,7 @@ std::map<std::string, CreatorFunction> get_supported_ops() {
|
||||
{"relu6", op::relu6},
|
||||
{"reshape2", op::reshape2},
|
||||
{"rnn", op::rnn},
|
||||
{"roi_align", op::roi_align},
|
||||
{"scale", op::scale},
|
||||
{"shape", op::shape},
|
||||
{"slice", op::slice},
|
||||
@ -166,11 +170,13 @@ std::map<std::string, CreatorFunction> get_supported_ops() {
|
||||
{"split", op::split},
|
||||
{"squeeze2", op::squeeze},
|
||||
{"stack", op::stack},
|
||||
{"strided_slice", op::strided_slice},
|
||||
{"sync_batch_norm", op::batch_norm},
|
||||
{"tanh", op::tanh},
|
||||
{"transpose2", op::transpose2},
|
||||
{"trilinear_interp_v2", op::trilinear_interp_v2},
|
||||
{"unsqueeze2", op::unsqueeze},
|
||||
{"where", op::where},
|
||||
{"yolo_box", op::yolo_box}};
|
||||
};
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user