Paddle FasterRCNN Ops Conversion: roi_align, strided_slice, where (#10893)

* Paddle FasterRCNN Ops Conversion: roi_align, strided_slice, where

* add check for 'aligned' feature of 'roi_align' op; use common function for idx_node in 'striede_slice' op

* Apply suggestions from code review

* use common funciton for stride_slice and slice, OP_CHECK for 'where' op conversion

* Apply suggestions from code review
This commit is contained in:
Bo Liu 2022-04-01 14:37:28 +08:00 committed by GitHub
parent 4057e408d8
commit 070f27a089
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
10 changed files with 558 additions and 91 deletions

View File

@ -256,6 +256,8 @@ static const std::vector<std::string> models{std::string("argmax"),
std::string("rnn_lstm_layer_2_forward"),
std::string("rnn_lstm_layer_1_forward_seq_len_4"),
std::string("rnn_lstm_layer_2_bidirectional_seq_len_4"),
std::string("roi_align_test"),
std::string("roi_align_test2"),
std::string("scale_bias_after_float32"),
std::string("scale_bias_after_int32"),
std::string("scale_bias_after_int64"),
@ -290,6 +292,15 @@ static const std::vector<std::string> models{std::string("argmax"),
std::string("stack_test_int32"),
std::string("stack_test_neg_axis"),
std::string("stack_test_none_axis"),
std::string("strided_slice_input1_1"),
std::string("strided_slice_input1_2"),
std::string("strided_slice_input1_3"),
std::string("strided_slice_input1_4"),
std::string("strided_slice_input2_1"),
std::string("strided_slice_input2_2"),
std::string("strided_slice_input2_3"),
std::string("strided_slice_input3_1"),
std::string("strided_slice_input3_2"),
std::string("tanh"),
std::string("trilinear_downsample_false_0"),
std::string("trilinear_downsample_false_1"),
@ -300,6 +311,9 @@ static const std::vector<std::string> models{std::string("argmax"),
std::string("trilinear_upsample_scales2"),
std::string("trilinear_upsample_true_0"),
std::string("unsqueeze"),
std::string("where_1"),
std::string("where_2"),
std::string("where_3"),
// Temporily disable them until root caused to secure CI stable.
// CVS-66703 to track this.
// std::string("yolo_box_clip_box"),

View File

@ -0,0 +1,114 @@
# Copyright (C) 2018-2022 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
#
# roi_align paddle model generator
#
import numpy as np
from save_model import saveModel
import paddle
import sys
def make_rois(batch_size, width, height, pooled_width, pooled_height, spatial_scale, roi_per_batch):
rois = []
rois_num = []
for bno in range(batch_size):
for i in range(roi_per_batch):
x1 = np.random.randint(
0, width // spatial_scale - pooled_width)
y1 = np.random.randint(
0, height // spatial_scale - pooled_height)
x2 = np.random.randint(x1 + pooled_width,
width // spatial_scale)
y2 = np.random.randint(
y1 + pooled_height, height // spatial_scale)
roi = [x1, y1, x2, y2]
rois.append(roi)
rois_num.append(len(rois))
rois = np.array(rois).astype("float32")
rois_num = np.array(rois_num).astype("int32")
return rois, rois_num
def roi_align(name: str, x_data, rois_data, rois_num_data, pooled_height, pooled_width, spatial_scale, sampling_ratio):
paddle.enable_static()
with paddle.static.program_guard(paddle.static.Program(), paddle.static.Program()):
x = paddle.static.data(
name='x', shape=x_data.shape, dtype=x_data.dtype)
rois = paddle.static.data(
name='rois', shape=rois_data.shape, dtype=rois_data.dtype)
rois_num = paddle.static.data(
name='rois_num', shape=rois_num_data.shape, dtype=rois_num_data.dtype)
# TODO: 'aligned' attribute is not supported by Paddle 2.1
out = paddle.fluid.layers.roi_align(input=x,
rois=rois,
pooled_height=pooled_height,
pooled_width=pooled_width,
spatial_scale=spatial_scale,
sampling_ratio=sampling_ratio,
rois_num=rois_num)
cpu = paddle.static.cpu_places(1)
exe = paddle.static.Executor(cpu[0])
# startup program will call initializer to initialize the parameters.
exe.run(paddle.static.default_startup_program())
outs = exe.run(
feed={'x': x_data, 'rois': rois_data, 'rois_num': rois_num_data},
fetch_list=[out])
saveModel(name, exe, feedkeys=['x', 'rois', 'rois_num'], fetchlist=[out], inputs=[
x_data, rois_data, rois_num_data], outputs=[outs[0]], target_dir=sys.argv[1])
return outs[0]
def main():
batch_size = 1
channels = 3
height = 8
width = 6
x_dim = (batch_size, channels, height, width)
x = np.random.random(x_dim).astype('float32')
spatial_scale = 1.0 / 2.0
pooled_height = 2
pooled_width = 2
sampling_ratio = -1
roi_per_batch = 1
rois, rois_num = make_rois(batch_size, width, height, pooled_width,
pooled_height, spatial_scale, roi_per_batch)
roi_align("roi_align_test", x, rois, rois_num, pooled_height,
pooled_width, spatial_scale, sampling_ratio)
batch_size = 1
channels = 3
height = 8
width = 6
x_dim = (batch_size, channels, height, width)
x = np.random.random(x_dim).astype('float32')
spatial_scale = 1.0 / 2.0
pooled_height = 2
pooled_width = 2
sampling_ratio = 2
roi_per_batch = 2
rois, rois_num = make_rois(batch_size, width, height, pooled_width,
pooled_height, spatial_scale, roi_per_batch)
roi_align("roi_align_test2", x, rois, rois_num, pooled_height,
pooled_width, spatial_scale, sampling_ratio)
if __name__ == "__main__":
main()

View File

@ -0,0 +1,134 @@
# Copyright (C) 2018-2022 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
#
# strided_slice paddle model generator
#
import numpy as np
from save_model import saveModel
import sys
def strided_slice(name: str, input_data, attrs: dict):
import paddle
paddle.enable_static()
with paddle.static.program_guard(paddle.static.Program(), paddle.static.Program()):
Input = paddle.static.data(
name='x', shape=input_data.shape, dtype=input_data.dtype)
out = paddle.fluid.layers.strided_slice(Input, axes=attrs['axes'],
starts=attrs['starts'],
ends=attrs['ends'],
strides=attrs['strides'])
cpu = paddle.static.cpu_places(1)
exe = paddle.static.Executor(cpu[0])
# startup program will call initializer to initialize the parameters.
exe.run(paddle.static.default_startup_program())
outs = exe.run(
feed={'x': input_data},
fetch_list=[out])
# Save inputs in order of ngraph function, to facilite Fuzzy test,
# which accepts inputs and outputs in this order as well.
saveModel(name, exe, feedkeys=['x'], fetchlist=[out],
inputs=[input_data], outputs=[outs[0]], target_dir=sys.argv[1])
return outs
if __name__ == "__main__":
strided_slice_input1_1 = {
'name': "strided_slice_input1_1",
'axes': np.array([0]).astype('int32').tolist(),
'starts': np.array([-4]).astype('int32').tolist(),
'ends': np.array([-3]).astype('int32').tolist(),
'strides': np.array([1]).astype('int32').tolist()
}
strided_slice_input1_2 = {
'name': "strided_slice_input1_2",
'axes': np.array([0]).astype('int32').tolist(),
'starts': np.array([3]).astype('int32').tolist(),
'ends': np.array([8]).astype('int32').tolist(),
'strides': np.array([1]).astype('int32').tolist()
}
strided_slice_input1_3 = {
'name': "strided_slice_input1_3",
'axes': np.array([0]).astype('int32').tolist(),
'starts': np.array([5]).astype('int32').tolist(),
'ends': np.array([0]).astype('int32').tolist(),
'strides': np.array([-1]).astype('int32').tolist()
}
strided_slice_input1_4 = {
'name': "strided_slice_input1_4",
'axes': np.array([0]).astype('int32').tolist(),
'starts': np.array([-1]).astype('int32').tolist(),
'ends': np.array([-3]).astype('int32').tolist(),
'strides': np.array([-1]).astype('int32').tolist()
}
strided_slice_input2_1 = {
'name': "strided_slice_input2_1",
'axes': np.array([0, 1, 2]).astype('int32').tolist(),
'starts': np.array([1, 0, 0]).astype('int32').tolist(),
'ends': np.array([2, 1, 3]).astype('int32').tolist(),
'strides': np.array([1, 1, 1]).astype('int32').tolist()
}
strided_slice_input2_2 = {
'name': "strided_slice_input2_2",
'axes': np.array([0, 1, 2]).astype('int32').tolist(),
'starts': np.array([1, -1, 0]).astype('int32').tolist(),
'ends': np.array([2, -3, 3]).astype('int32').tolist(),
'strides': np.array([1, -1, 1]).astype('int32').tolist()
}
strided_slice_input2_3 = {
'name': "strided_slice_input2_3",
'axes': np.array([0, 1, 2]).astype('int32').tolist(),
'starts': np.array([1, 0, 0]).astype('int32').tolist(),
'ends': np.array([2, 2, 3]).astype('int32').tolist(),
'strides': np.array([1, 1, 1]).astype('int32').tolist()
}
strided_slice_input3_1 = {
'name': "strided_slice_input3_1",
'axes': np.array([1]).astype('int32').tolist(),
'starts': np.array([1]).astype('int32').tolist(),
'ends': np.array([2]).astype('int32').tolist(),
'strides': np.array([1]).astype('int32').tolist()
}
strided_slice_input3_2 = {
'name': "strided_slice_input3_2",
'axes': np.array([1]).astype('int32').tolist(),
'starts': np.array([-1]).astype('int32').tolist(),
'ends': np.array([-2]).astype('int32').tolist(),
'strides': np.array([-1]).astype('int32').tolist()
}
strided_slice_input1_list = [strided_slice_input1_1,
strided_slice_input1_2, strided_slice_input1_3, strided_slice_input1_4]
strided_slice_input2_list = [strided_slice_input2_1,
strided_slice_input2_2, strided_slice_input2_3]
strided_slice_input3_list = [
strided_slice_input3_1, strided_slice_input3_2]
input1 = np.random.rand(100).astype('float32')
for item in strided_slice_input1_list:
pred_paddle = strided_slice(item['name'], input1, item)
input2 = np.random.rand(5, 5, 5).astype('int32')
for item in strided_slice_input2_list:
pred_paddle = strided_slice(item['name'], input2, item)
input3 = np.random.rand(1, 100, 1).astype('float32')
for item in strided_slice_input3_list:
pred_paddle = strided_slice(item['name'], input3, item)

View File

@ -0,0 +1,69 @@
# Copyright (C) 2018-2022 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
#
# where paddle model generator
#
import numpy as np
from save_model import saveModel
import sys
def where(name, test_x, test_y, test_cond):
import paddle
paddle.enable_static()
main_program = paddle.static.Program()
startup_program = paddle.static.Program()
with paddle.static.program_guard(main_program, startup_program):
X_Node = paddle.static.data(
name='x', shape=test_x.shape, dtype=test_x.dtype)
Y_Node = paddle.static.data(
name='y', shape=test_y.shape, dtype=test_y.dtype)
Cond_Node = paddle.static.data(
name='cond', shape=test_cond.shape, dtype=test_cond.dtype)
Cond_Node_bl = paddle.fluid.layers.cast(Cond_Node, "bool")
out = paddle.where(Cond_Node_bl, X_Node, Y_Node)
cpu = paddle.static.cpu_places(1)
exe = paddle.static.Executor(cpu[0])
# startup program will call initializer to initialize the parameters.
exe.run(paddle.static.default_startup_program())
outs = exe.run(
feed={'x': test_x, 'y': test_y, 'cond': test_cond},
fetch_list=[out]
)
saveModel(name, exe, feedkeys=['x', 'y', 'cond'], fetchlist=[out], inputs=[
test_x, test_y, test_cond], outputs=[outs[0]], target_dir=sys.argv[1])
def main():
test_cases = [
{
"name": "where_1",
"x": np.random.uniform(-3, 5, (100)).astype("float32"),
"y": np.random.uniform(-3, 5, (100)).astype("float32"),
"cond": np.zeros((100)).astype("int32")
},
{
"name": "where_2",
"x": np.random.uniform(-5, 5, (60, 2)).astype("int32"),
"y": np.random.uniform(-5, 5, (60, 2)).astype("int32"),
"cond": np.ones((60, 2)).astype("int32")
},
{
"name": "where_3",
"x": np.random.uniform(-3, 5, (20, 2, 4)).astype("float32"),
"y": np.random.uniform(-3, 5, (20, 2, 4)).astype("float32"),
"cond": np.array(np.random.randint(2, size=(20, 2, 4)), dtype="int32")
}
]
for test in test_cases:
where(test['name'], test['x'], test['y'], test['cond'])
if __name__ == "__main__":
main()

View File

@ -0,0 +1,52 @@
// Copyright (C) 2018-2022 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include "default_opset.hpp"
#include "openvino/frontend/paddle/node_context.hpp"
namespace ov {
namespace frontend {
namespace paddle {
namespace op {
NamedOutputs roi_align(const NodeContext& node) {
const auto data_node = node.get_input("X");
const auto roi_node = node.get_input("ROIs");
// TODO: support 'aligned' feature #82319
const auto aligned = node.get_attribute("aligned", false);
PADDLE_OP_CHECK(node, !aligned, "OpenVINO not support 'aligned' feature!");
// TODO: support multiple batches #83232
if (data_node.get_partial_shape().rank().is_static() && data_node.get_partial_shape()[0].is_static())
PADDLE_OP_CHECK(node, data_node.get_partial_shape()[0] == 1, "roi_align currenty only support batch_size = 1!");
const auto roi_node_shape = std::make_shared<default_opset::ShapeOf>(roi_node, element::i32);
const auto start = default_opset::Constant::create(element::i64, {1}, {0});
const auto stop = default_opset::Constant::create(element::i64, {1}, {1});
const auto step = default_opset::Constant::create(element::i64, {1}, {1});
const auto roisNum = std::make_shared<default_opset::Slice>(roi_node_shape, start, stop, step);
const auto zero_const = std::make_shared<default_opset::Constant>(element::i32, Shape{1}, 0);
const auto fake_roisNum_node = std::make_shared<default_opset::Broadcast>(zero_const, roisNum);
const auto pooled_h = node.get_attribute<int>("pooled_height", 1);
const auto pooled_w = node.get_attribute<int>("pooled_width", 1);
const auto spatial_scale = node.get_attribute<float>("spatial_scale", 1.0);
auto sampling_ratio = node.get_attribute<int>("sampling_ratio", -1);
sampling_ratio = (sampling_ratio <= 0) ? 0 : sampling_ratio;
// Paddle only use 'avg' interpolation mode
return node.default_single_output_mapping({std::make_shared<default_opset::ROIAlign>(data_node,
roi_node,
fake_roisNum_node,
pooled_h,
pooled_w,
sampling_ratio,
spatial_scale,
"avg")},
{"Out"});
}
} // namespace op
} // namespace paddle
} // namespace frontend
} // namespace ov

View File

@ -1,103 +1,14 @@
// Copyright (C) 2018-2022 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include <limits.h>
#include "default_opset.hpp"
#include "openvino/frontend/paddle/node_context.hpp"
#include "slice_ops.hpp"
namespace ov {
namespace frontend {
namespace paddle {
namespace op {
using namespace default_opset;
NamedOutputs slice(const NodeContext& node) {
auto data = node.get_input("Input");
auto axes = node.get_attribute<std::vector<int32_t>>("axes");
Output<Node> start_idx_node, end_idx_node;
if (node.has_input("StartsTensor")) {
start_idx_node = node.get_input("StartsTensor");
} else if (node.has_input("StartsTensorList")) {
auto inputs = node.get_ng_inputs("StartsTensorList");
start_idx_node = std::make_shared<Concat>(inputs, 0);
} else {
auto starts = node.get_attribute<std::vector<int32_t>>("starts");
start_idx_node = Constant::create(element::i32, {starts.size()}, starts);
}
if (node.has_input("EndsTensor")) {
end_idx_node = node.get_input("EndsTensor");
} else if (node.has_input("EndsTensorList")) {
auto inputs = node.get_ng_inputs("EndsTensorList");
end_idx_node = std::make_shared<Concat>(inputs, 0);
} else {
auto ends = node.get_attribute<std::vector<int32_t>>("ends");
end_idx_node = Constant::create(element::i32, {ends.size()}, ends);
}
// The following process is:
// Given:
// data = [ [1, 2, 3, 4], [5, 6, 7, 8], ] // shape is: [2, 4]
// axes = [0]
// starts = [1]
// ends = [2]
// Our process is:
// 1. Get 'axes': [0, 1], 'starts', 'ends'
// 2. Get data shape: [2,4] and dims: 2
// 3. Create two tensor t1 and t2, shape is the dims from step2: 2. t1: [0, 0], t2: [INT_MAX, INT_MAX]
// 4. Use 'ScatterNDUpdate' to update some elements in t1, the updated indexes are coming from 'axes', the contents
// are coming from 'starts', t1: [1, 0]; apply the similar process to t2
// 5. Call 'StrideSlice' with t1 and t2
// Why using ScatterNDUpdate is that 'axes' may be discontinuous.
// the shape of input, such as [2, 4]
auto shape_node = std::make_shared<ShapeOf>(data, element::Type_t::i32);
// the input dim, such as [2]
auto shape_shape_node = std::make_shared<ShapeOf>(shape_node, element::i32);
auto const_0_node = Constant::create(element::i32, {}, {0});
auto const_max_node = Constant::create(element::i32, {}, {INT_MAX});
// t1: [0, 0]
auto start_node = std::make_shared<Broadcast>(const_0_node, shape_shape_node);
// t2: [INT_MAX, INT_MAX]
auto end_node = std::make_shared<Broadcast>(const_max_node, shape_shape_node);
auto axes_node = Constant::create(element::i32, {axes.size(), 1}, axes);
// update t1
auto fixed_start_node = std::make_shared<ScatterNDUpdate>(start_node, axes_node, start_idx_node);
// update t2
auto fixed_end_node = std::make_shared<ScatterNDUpdate>(end_node, axes_node, end_idx_node);
auto stride_slice_node = std::make_shared<StridedSlice>(data,
fixed_start_node,
fixed_end_node,
std::vector<int64_t>{0},
std::vector<int64_t>{0});
auto decrease_axis = node.get_attribute<std::vector<int32_t>>("decrease_axis");
if (decrease_axis.size() > 0) {
// according to paddle slice_op, when all axes are decreased, output shape is [1], instead of scalar.
// Ref: paddle/fluid/operators/slice_op.h
PartialShape input_shape = data.get_partial_shape();
PADDLE_OP_CHECK(node,
input_shape.rank().is_static(),
"input rank of slice must be static when decrease_axis is set.");
auto squeeze_index_node = Constant::create(element::i32, {decrease_axis.size()}, decrease_axis);
auto decreased_node = std::make_shared<Squeeze>(stride_slice_node, squeeze_index_node);
auto input_rank = input_shape.rank().get_length();
if (input_rank == decrease_axis.size()) {
auto restore_node = std::make_shared<Reshape>(decreased_node,
std::make_shared<Constant>(element::i64, Shape{1}, 1),
false); // restore to shape (1,)
return node.default_single_output_mapping({restore_node}, {"Out"});
}
return node.default_single_output_mapping({decreased_node}, {"Out"});
}
return node.default_single_output_mapping({stride_slice_node}, {"Out"});
return slice_op(node, false);
}
} // namespace op
} // namespace paddle

View File

@ -0,0 +1,123 @@
// Copyright (C) 2018-2022 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include <limits.h>
#include "default_opset.hpp"
#include "openvino/frontend/paddle/node_context.hpp"
namespace ov {
namespace frontend {
namespace paddle {
namespace op {
namespace {
Output<Node> idx_node(const std::string& tensor_alias,
const std::string& list_alias,
const std::string& attr_alias,
const NodeContext& node) {
if (node.has_input(tensor_alias)) {
return std::make_shared<default_opset::Convert>(node.get_input(tensor_alias), element::i32);
} else if (node.has_input(list_alias)) {
auto inputs = node.get_ng_inputs(list_alias);
return std::make_shared<default_opset::Convert>(std::make_shared<default_opset::Concat>(inputs, 0),
element::i32);
} else {
auto values = node.get_attribute<std::vector<int32_t>>(attr_alias);
return default_opset::Constant::create(element::i32, {values.size()}, values);
}
}
NamedOutputs slice_op(const NodeContext& node, const bool& stride_input) {
const auto data = node.get_input("Input");
const auto axes = node.get_attribute<std::vector<int32_t>>("axes");
Output<Node> start_idx_node = idx_node("StartsTensor", "StartsTensorList", "starts", node);
Output<Node> end_idx_node = idx_node("EndsTensor", "EndsTensorList", "ends", node);
Output<Node> strides_idx_node;
if (stride_input)
strides_idx_node = idx_node("StridesTensor", "StridesTensorList", "strides", node);
// The following process is:
// Given:
// data = [ [1, 2, 3, 4], [5, 6, 7, 8], ] // shape is: [2, 4]
// axes = [0]
// starts = [1]
// ends = [2]
// Our process is:
// 1. Get 'axes': [0, 1], 'starts', 'ends'
// 2. Get data shape: [2,4] and dims: 2
// 3. Create two tensor t1 and t2, shape is the dims from step2: 2. t1: [0, 0], t2: [INT_MAX, INT_MAX]
// 4. Use 'ScatterNDUpdate' to update some elements in t1, the updated indexes are coming from 'axes', the contents
// are coming from 'starts', t1: [1, 0]; apply the similar process to t2
// 5. Call 'StrideSlice' with t1 and t2
// Why using ScatterNDUpdate is that 'axes' may be discontinuous.
// the shape of input, such as [2, 4]
const auto shape_node = std::make_shared<default_opset::ShapeOf>(data, element::Type_t::i32);
// the input dim, such as [2]
const auto rank_node = std::make_shared<default_opset::ShapeOf>(shape_node, element::i32);
const auto const_0_node = default_opset::Constant::create(element::i32, {}, {0});
const auto const_max_node = default_opset::Constant::create(element::i32, {}, {INT_MAX});
const auto const_1_node = default_opset::Constant::create(element::i32, {}, {1});
// t1: [0, 0]
const auto start_node = std::make_shared<default_opset::Broadcast>(const_0_node, rank_node);
// t2: [INT_MAX, INT_MAX]
const auto end_node = std::make_shared<default_opset::Broadcast>(const_max_node, rank_node);
const auto strides_node = std::make_shared<default_opset::Broadcast>(const_1_node, rank_node);
const auto axes_node = default_opset::Constant::create(element::i32, {axes.size(), 1}, axes);
// update t1
const auto fixed_start_node =
std::make_shared<default_opset::ScatterNDUpdate>(start_node, axes_node, start_idx_node);
// update t2
const auto fixed_end_node = std::make_shared<default_opset::ScatterNDUpdate>(end_node, axes_node, end_idx_node);
std::shared_ptr<Node> stride_slice_node;
if (stride_input) {
const auto fixed_strides_node =
std::make_shared<default_opset::ScatterNDUpdate>(strides_node, axes_node, strides_idx_node);
stride_slice_node = std::make_shared<default_opset::StridedSlice>(data,
fixed_start_node,
fixed_end_node,
fixed_strides_node,
std::vector<int64_t>{0},
std::vector<int64_t>{0});
} else {
stride_slice_node = std::make_shared<default_opset::StridedSlice>(data,
fixed_start_node,
fixed_end_node,
std::vector<int64_t>{0},
std::vector<int64_t>{0});
}
const auto decrease_axis = node.get_attribute<std::vector<int32_t>>("decrease_axis");
if (decrease_axis.size() > 0) {
// according to paddle slice_op, when all axes are decreased, output shape is [1], instead of scalar.
// Ref: paddle/fluid/operators/slice_op.h
PartialShape input_shape = data.get_partial_shape();
PADDLE_OP_CHECK(node,
input_shape.rank().is_static(),
"input rank of slice must be static when decrease_axis is set.");
const auto squeeze_index_node =
default_opset::Constant::create(element::i32, {decrease_axis.size()}, decrease_axis);
const auto decreased_node = std::make_shared<default_opset::Squeeze>(stride_slice_node, squeeze_index_node);
const auto input_rank = input_shape.rank().get_length();
if (input_rank == decrease_axis.size()) {
auto restore_node = std::make_shared<default_opset::Reshape>(
decreased_node,
std::make_shared<default_opset::Constant>(element::i64, Shape{1}, 1),
false); // restore to shape (1,)
return node.default_single_output_mapping({restore_node}, {"Out"});
}
return node.default_single_output_mapping({decreased_node}, {"Out"});
}
return node.default_single_output_mapping({stride_slice_node}, {"Out"});
}
} // namespace
} // namespace op
} // namespace paddle
} // namespace frontend
} // namespace ov

View File

@ -0,0 +1,16 @@
// Copyright (C) 2018-2022 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include "slice_ops.hpp"
namespace ov {
namespace frontend {
namespace paddle {
namespace op {
NamedOutputs strided_slice(const NodeContext& node) {
return slice_op(node, true);
}
} // namespace op
} // namespace paddle
} // namespace frontend
} // namespace ov

View File

@ -0,0 +1,28 @@
// Copyright (C) 2018-2022 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include "default_opset.hpp"
#include "openvino/frontend/paddle/node_context.hpp"
namespace ov {
namespace frontend {
namespace paddle {
namespace op {
NamedOutputs where(const NodeContext& node) {
const auto condition_node = node.get_input("Condition");
const auto x_node = node.get_input("X");
const auto y_node = node.get_input("Y");
// TODO: support 'shape x != shape y' #83233
const auto x_shape = x_node.get_partial_shape();
const auto y_shape = y_node.get_partial_shape();
PADDLE_OP_CHECK(node, x_shape.compatible(y_shape), "shape x should be compatible to shape y!");
return node.default_single_output_mapping(
{std::make_shared<default_opset::Select>(condition_node, x_node, y_node, ov::op::AutoBroadcastType::PDPD)},
{"Out"});
}
} // namespace op
} // namespace paddle
} // namespace frontend
} // namespace ov

View File

@ -71,6 +71,7 @@ OP_CONVERTER(relu);
OP_CONVERTER(relu6);
OP_CONVERTER(reshape2);
OP_CONVERTER(rnn);
OP_CONVERTER(roi_align);
OP_CONVERTER(scale);
OP_CONVERTER(shape);
OP_CONVERTER(slice);
@ -80,10 +81,12 @@ OP_CONVERTER(sigmoid);
OP_CONVERTER(split);
OP_CONVERTER(squeeze);
OP_CONVERTER(stack);
OP_CONVERTER(strided_slice);
OP_CONVERTER(tanh);
OP_CONVERTER(transpose2);
OP_CONVERTER(trilinear_interp_v2);
OP_CONVERTER(unsqueeze);
OP_CONVERTER(where);
OP_CONVERTER(yolo_box);
} // namespace op
std::map<std::string, CreatorFunction> get_supported_ops() {
@ -157,6 +160,7 @@ std::map<std::string, CreatorFunction> get_supported_ops() {
{"relu6", op::relu6},
{"reshape2", op::reshape2},
{"rnn", op::rnn},
{"roi_align", op::roi_align},
{"scale", op::scale},
{"shape", op::shape},
{"slice", op::slice},
@ -166,11 +170,13 @@ std::map<std::string, CreatorFunction> get_supported_ops() {
{"split", op::split},
{"squeeze2", op::squeeze},
{"stack", op::stack},
{"strided_slice", op::strided_slice},
{"sync_batch_norm", op::batch_norm},
{"tanh", op::tanh},
{"transpose2", op::transpose2},
{"trilinear_interp_v2", op::trilinear_interp_v2},
{"unsqueeze2", op::unsqueeze},
{"where", op::where},
{"yolo_box", op::yolo_box}};
};