diff --git a/src/core/tests/frontend/paddle/op_fuzzy.cpp b/src/core/tests/frontend/paddle/op_fuzzy.cpp index b7637fa67b8..b5009de3f6a 100644 --- a/src/core/tests/frontend/paddle/op_fuzzy.cpp +++ b/src/core/tests/frontend/paddle/op_fuzzy.cpp @@ -256,6 +256,8 @@ static const std::vector models{std::string("argmax"), std::string("rnn_lstm_layer_2_forward"), std::string("rnn_lstm_layer_1_forward_seq_len_4"), std::string("rnn_lstm_layer_2_bidirectional_seq_len_4"), + std::string("roi_align_test"), + std::string("roi_align_test2"), std::string("scale_bias_after_float32"), std::string("scale_bias_after_int32"), std::string("scale_bias_after_int64"), @@ -290,6 +292,15 @@ static const std::vector models{std::string("argmax"), std::string("stack_test_int32"), std::string("stack_test_neg_axis"), std::string("stack_test_none_axis"), + std::string("strided_slice_input1_1"), + std::string("strided_slice_input1_2"), + std::string("strided_slice_input1_3"), + std::string("strided_slice_input1_4"), + std::string("strided_slice_input2_1"), + std::string("strided_slice_input2_2"), + std::string("strided_slice_input2_3"), + std::string("strided_slice_input3_1"), + std::string("strided_slice_input3_2"), std::string("tanh"), std::string("trilinear_downsample_false_0"), std::string("trilinear_downsample_false_1"), @@ -300,6 +311,9 @@ static const std::vector models{std::string("argmax"), std::string("trilinear_upsample_scales2"), std::string("trilinear_upsample_true_0"), std::string("unsqueeze"), + std::string("where_1"), + std::string("where_2"), + std::string("where_3"), // Temporily disable them until root caused to secure CI stable. // CVS-66703 to track this. // std::string("yolo_box_clip_box"), diff --git a/src/core/tests/frontend/paddle/test_models/gen_scripts/generate_roi_align.py b/src/core/tests/frontend/paddle/test_models/gen_scripts/generate_roi_align.py new file mode 100644 index 00000000000..2475a5ed666 --- /dev/null +++ b/src/core/tests/frontend/paddle/test_models/gen_scripts/generate_roi_align.py @@ -0,0 +1,114 @@ +# Copyright (C) 2018-2022 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +# +# roi_align paddle model generator +# +import numpy as np +from save_model import saveModel +import paddle +import sys + + +def make_rois(batch_size, width, height, pooled_width, pooled_height, spatial_scale, roi_per_batch): + rois = [] + rois_num = [] + for bno in range(batch_size): + for i in range(roi_per_batch): + x1 = np.random.randint( + 0, width // spatial_scale - pooled_width) + y1 = np.random.randint( + 0, height // spatial_scale - pooled_height) + + x2 = np.random.randint(x1 + pooled_width, + width // spatial_scale) + y2 = np.random.randint( + y1 + pooled_height, height // spatial_scale) + + roi = [x1, y1, x2, y2] + rois.append(roi) + rois_num.append(len(rois)) + rois = np.array(rois).astype("float32") + rois_num = np.array(rois_num).astype("int32") + + return rois, rois_num + + +def roi_align(name: str, x_data, rois_data, rois_num_data, pooled_height, pooled_width, spatial_scale, sampling_ratio): + paddle.enable_static() + + with paddle.static.program_guard(paddle.static.Program(), paddle.static.Program()): + x = paddle.static.data( + name='x', shape=x_data.shape, dtype=x_data.dtype) + rois = paddle.static.data( + name='rois', shape=rois_data.shape, dtype=rois_data.dtype) + rois_num = paddle.static.data( + name='rois_num', shape=rois_num_data.shape, dtype=rois_num_data.dtype) + # TODO: 'aligned' attribute is not supported by Paddle 2.1 + out = paddle.fluid.layers.roi_align(input=x, + rois=rois, + pooled_height=pooled_height, + pooled_width=pooled_width, + spatial_scale=spatial_scale, + sampling_ratio=sampling_ratio, + rois_num=rois_num) + + cpu = paddle.static.cpu_places(1) + exe = paddle.static.Executor(cpu[0]) + # startup program will call initializer to initialize the parameters. + exe.run(paddle.static.default_startup_program()) + + outs = exe.run( + feed={'x': x_data, 'rois': rois_data, 'rois_num': rois_num_data}, + fetch_list=[out]) + + saveModel(name, exe, feedkeys=['x', 'rois', 'rois_num'], fetchlist=[out], inputs=[ + x_data, rois_data, rois_num_data], outputs=[outs[0]], target_dir=sys.argv[1]) + + return outs[0] + + +def main(): + batch_size = 1 + channels = 3 + height = 8 + width = 6 + + x_dim = (batch_size, channels, height, width) + x = np.random.random(x_dim).astype('float32') + + spatial_scale = 1.0 / 2.0 + pooled_height = 2 + pooled_width = 2 + sampling_ratio = -1 + + roi_per_batch = 1 + rois, rois_num = make_rois(batch_size, width, height, pooled_width, + pooled_height, spatial_scale, roi_per_batch) + + roi_align("roi_align_test", x, rois, rois_num, pooled_height, + pooled_width, spatial_scale, sampling_ratio) + + batch_size = 1 + channels = 3 + height = 8 + width = 6 + + x_dim = (batch_size, channels, height, width) + x = np.random.random(x_dim).astype('float32') + + spatial_scale = 1.0 / 2.0 + pooled_height = 2 + pooled_width = 2 + sampling_ratio = 2 + + roi_per_batch = 2 + rois, rois_num = make_rois(batch_size, width, height, pooled_width, + pooled_height, spatial_scale, roi_per_batch) + + roi_align("roi_align_test2", x, rois, rois_num, pooled_height, + pooled_width, spatial_scale, sampling_ratio) + + +if __name__ == "__main__": + main() diff --git a/src/core/tests/frontend/paddle/test_models/gen_scripts/generate_strided_slice.py b/src/core/tests/frontend/paddle/test_models/gen_scripts/generate_strided_slice.py new file mode 100644 index 00000000000..2c1b29ee2d8 --- /dev/null +++ b/src/core/tests/frontend/paddle/test_models/gen_scripts/generate_strided_slice.py @@ -0,0 +1,134 @@ +# Copyright (C) 2018-2022 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +# +# strided_slice paddle model generator +# +import numpy as np +from save_model import saveModel +import sys + + +def strided_slice(name: str, input_data, attrs: dict): + import paddle + paddle.enable_static() + + with paddle.static.program_guard(paddle.static.Program(), paddle.static.Program()): + Input = paddle.static.data( + name='x', shape=input_data.shape, dtype=input_data.dtype) + + out = paddle.fluid.layers.strided_slice(Input, axes=attrs['axes'], + starts=attrs['starts'], + ends=attrs['ends'], + strides=attrs['strides']) + + cpu = paddle.static.cpu_places(1) + exe = paddle.static.Executor(cpu[0]) + # startup program will call initializer to initialize the parameters. + exe.run(paddle.static.default_startup_program()) + + outs = exe.run( + feed={'x': input_data}, + fetch_list=[out]) + + # Save inputs in order of ngraph function, to facilite Fuzzy test, + # which accepts inputs and outputs in this order as well. + saveModel(name, exe, feedkeys=['x'], fetchlist=[out], + inputs=[input_data], outputs=[outs[0]], target_dir=sys.argv[1]) + return outs + + +if __name__ == "__main__": + + strided_slice_input1_1 = { + 'name': "strided_slice_input1_1", + 'axes': np.array([0]).astype('int32').tolist(), + 'starts': np.array([-4]).astype('int32').tolist(), + 'ends': np.array([-3]).astype('int32').tolist(), + 'strides': np.array([1]).astype('int32').tolist() + } + + strided_slice_input1_2 = { + 'name': "strided_slice_input1_2", + 'axes': np.array([0]).astype('int32').tolist(), + 'starts': np.array([3]).astype('int32').tolist(), + 'ends': np.array([8]).astype('int32').tolist(), + 'strides': np.array([1]).astype('int32').tolist() + } + + strided_slice_input1_3 = { + 'name': "strided_slice_input1_3", + 'axes': np.array([0]).astype('int32').tolist(), + 'starts': np.array([5]).astype('int32').tolist(), + 'ends': np.array([0]).astype('int32').tolist(), + 'strides': np.array([-1]).astype('int32').tolist() + } + + strided_slice_input1_4 = { + 'name': "strided_slice_input1_4", + 'axes': np.array([0]).astype('int32').tolist(), + 'starts': np.array([-1]).astype('int32').tolist(), + 'ends': np.array([-3]).astype('int32').tolist(), + 'strides': np.array([-1]).astype('int32').tolist() + } + + strided_slice_input2_1 = { + 'name': "strided_slice_input2_1", + 'axes': np.array([0, 1, 2]).astype('int32').tolist(), + 'starts': np.array([1, 0, 0]).astype('int32').tolist(), + 'ends': np.array([2, 1, 3]).astype('int32').tolist(), + 'strides': np.array([1, 1, 1]).astype('int32').tolist() + } + + strided_slice_input2_2 = { + 'name': "strided_slice_input2_2", + 'axes': np.array([0, 1, 2]).astype('int32').tolist(), + 'starts': np.array([1, -1, 0]).astype('int32').tolist(), + 'ends': np.array([2, -3, 3]).astype('int32').tolist(), + 'strides': np.array([1, -1, 1]).astype('int32').tolist() + } + + strided_slice_input2_3 = { + 'name': "strided_slice_input2_3", + 'axes': np.array([0, 1, 2]).astype('int32').tolist(), + 'starts': np.array([1, 0, 0]).astype('int32').tolist(), + 'ends': np.array([2, 2, 3]).astype('int32').tolist(), + 'strides': np.array([1, 1, 1]).astype('int32').tolist() + } + + strided_slice_input3_1 = { + 'name': "strided_slice_input3_1", + 'axes': np.array([1]).astype('int32').tolist(), + 'starts': np.array([1]).astype('int32').tolist(), + 'ends': np.array([2]).astype('int32').tolist(), + 'strides': np.array([1]).astype('int32').tolist() + } + + strided_slice_input3_2 = { + 'name': "strided_slice_input3_2", + 'axes': np.array([1]).astype('int32').tolist(), + 'starts': np.array([-1]).astype('int32').tolist(), + 'ends': np.array([-2]).astype('int32').tolist(), + 'strides': np.array([-1]).astype('int32').tolist() + } + + strided_slice_input1_list = [strided_slice_input1_1, + strided_slice_input1_2, strided_slice_input1_3, strided_slice_input1_4] + + strided_slice_input2_list = [strided_slice_input2_1, + strided_slice_input2_2, strided_slice_input2_3] + + strided_slice_input3_list = [ + strided_slice_input3_1, strided_slice_input3_2] + + input1 = np.random.rand(100).astype('float32') + for item in strided_slice_input1_list: + pred_paddle = strided_slice(item['name'], input1, item) + + input2 = np.random.rand(5, 5, 5).astype('int32') + for item in strided_slice_input2_list: + pred_paddle = strided_slice(item['name'], input2, item) + + input3 = np.random.rand(1, 100, 1).astype('float32') + for item in strided_slice_input3_list: + pred_paddle = strided_slice(item['name'], input3, item) diff --git a/src/core/tests/frontend/paddle/test_models/gen_scripts/generate_where.py b/src/core/tests/frontend/paddle/test_models/gen_scripts/generate_where.py new file mode 100644 index 00000000000..400d7732879 --- /dev/null +++ b/src/core/tests/frontend/paddle/test_models/gen_scripts/generate_where.py @@ -0,0 +1,69 @@ +# Copyright (C) 2018-2022 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +# +# where paddle model generator +# +import numpy as np +from save_model import saveModel +import sys + + +def where(name, test_x, test_y, test_cond): + import paddle + paddle.enable_static() + main_program = paddle.static.Program() + startup_program = paddle.static.Program() + with paddle.static.program_guard(main_program, startup_program): + X_Node = paddle.static.data( + name='x', shape=test_x.shape, dtype=test_x.dtype) + Y_Node = paddle.static.data( + name='y', shape=test_y.shape, dtype=test_y.dtype) + Cond_Node = paddle.static.data( + name='cond', shape=test_cond.shape, dtype=test_cond.dtype) + + Cond_Node_bl = paddle.fluid.layers.cast(Cond_Node, "bool") + + out = paddle.where(Cond_Node_bl, X_Node, Y_Node) + cpu = paddle.static.cpu_places(1) + exe = paddle.static.Executor(cpu[0]) + # startup program will call initializer to initialize the parameters. + exe.run(paddle.static.default_startup_program()) + + outs = exe.run( + feed={'x': test_x, 'y': test_y, 'cond': test_cond}, + fetch_list=[out] + ) + + saveModel(name, exe, feedkeys=['x', 'y', 'cond'], fetchlist=[out], inputs=[ + test_x, test_y, test_cond], outputs=[outs[0]], target_dir=sys.argv[1]) + + +def main(): + + test_cases = [ + { + "name": "where_1", + "x": np.random.uniform(-3, 5, (100)).astype("float32"), + "y": np.random.uniform(-3, 5, (100)).astype("float32"), + "cond": np.zeros((100)).astype("int32") + }, + { + "name": "where_2", + "x": np.random.uniform(-5, 5, (60, 2)).astype("int32"), + "y": np.random.uniform(-5, 5, (60, 2)).astype("int32"), + "cond": np.ones((60, 2)).astype("int32") + }, + { + "name": "where_3", + "x": np.random.uniform(-3, 5, (20, 2, 4)).astype("float32"), + "y": np.random.uniform(-3, 5, (20, 2, 4)).astype("float32"), + "cond": np.array(np.random.randint(2, size=(20, 2, 4)), dtype="int32") + } + ] + for test in test_cases: + where(test['name'], test['x'], test['y'], test['cond']) + + +if __name__ == "__main__": + main() diff --git a/src/frontends/paddle/src/op/roi_align.cpp b/src/frontends/paddle/src/op/roi_align.cpp new file mode 100644 index 00000000000..51347182050 --- /dev/null +++ b/src/frontends/paddle/src/op/roi_align.cpp @@ -0,0 +1,52 @@ +// Copyright (C) 2018-2022 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include "default_opset.hpp" +#include "openvino/frontend/paddle/node_context.hpp" + +namespace ov { +namespace frontend { +namespace paddle { +namespace op { +NamedOutputs roi_align(const NodeContext& node) { + const auto data_node = node.get_input("X"); + const auto roi_node = node.get_input("ROIs"); + // TODO: support 'aligned' feature #82319 + const auto aligned = node.get_attribute("aligned", false); + PADDLE_OP_CHECK(node, !aligned, "OpenVINO not support 'aligned' feature!"); + + // TODO: support multiple batches #83232 + if (data_node.get_partial_shape().rank().is_static() && data_node.get_partial_shape()[0].is_static()) + PADDLE_OP_CHECK(node, data_node.get_partial_shape()[0] == 1, "roi_align currenty only support batch_size = 1!"); + + const auto roi_node_shape = std::make_shared(roi_node, element::i32); + const auto start = default_opset::Constant::create(element::i64, {1}, {0}); + const auto stop = default_opset::Constant::create(element::i64, {1}, {1}); + const auto step = default_opset::Constant::create(element::i64, {1}, {1}); + const auto roisNum = std::make_shared(roi_node_shape, start, stop, step); + + const auto zero_const = std::make_shared(element::i32, Shape{1}, 0); + const auto fake_roisNum_node = std::make_shared(zero_const, roisNum); + + const auto pooled_h = node.get_attribute("pooled_height", 1); + const auto pooled_w = node.get_attribute("pooled_width", 1); + const auto spatial_scale = node.get_attribute("spatial_scale", 1.0); + auto sampling_ratio = node.get_attribute("sampling_ratio", -1); + sampling_ratio = (sampling_ratio <= 0) ? 0 : sampling_ratio; + + // Paddle only use 'avg' interpolation mode + return node.default_single_output_mapping({std::make_shared(data_node, + roi_node, + fake_roisNum_node, + pooled_h, + pooled_w, + sampling_ratio, + spatial_scale, + "avg")}, + {"Out"}); +} +} // namespace op +} // namespace paddle +} // namespace frontend +} // namespace ov \ No newline at end of file diff --git a/src/frontends/paddle/src/op/slice.cpp b/src/frontends/paddle/src/op/slice.cpp index 382cea40317..fdbb317ad01 100644 --- a/src/frontends/paddle/src/op/slice.cpp +++ b/src/frontends/paddle/src/op/slice.cpp @@ -1,103 +1,14 @@ // Copyright (C) 2018-2022 Intel Corporation // SPDX-License-Identifier: Apache-2.0 // - -#include - -#include "default_opset.hpp" -#include "openvino/frontend/paddle/node_context.hpp" +#include "slice_ops.hpp" namespace ov { namespace frontend { namespace paddle { namespace op { -using namespace default_opset; NamedOutputs slice(const NodeContext& node) { - auto data = node.get_input("Input"); - auto axes = node.get_attribute>("axes"); - Output start_idx_node, end_idx_node; - if (node.has_input("StartsTensor")) { - start_idx_node = node.get_input("StartsTensor"); - } else if (node.has_input("StartsTensorList")) { - auto inputs = node.get_ng_inputs("StartsTensorList"); - start_idx_node = std::make_shared(inputs, 0); - } else { - auto starts = node.get_attribute>("starts"); - start_idx_node = Constant::create(element::i32, {starts.size()}, starts); - } - - if (node.has_input("EndsTensor")) { - end_idx_node = node.get_input("EndsTensor"); - } else if (node.has_input("EndsTensorList")) { - auto inputs = node.get_ng_inputs("EndsTensorList"); - end_idx_node = std::make_shared(inputs, 0); - } else { - auto ends = node.get_attribute>("ends"); - end_idx_node = Constant::create(element::i32, {ends.size()}, ends); - } - - // The following process is: - // Given: - // data = [ [1, 2, 3, 4], [5, 6, 7, 8], ] // shape is: [2, 4] - // axes = [0] - // starts = [1] - // ends = [2] - // Our process is: - // 1. Get 'axes': [0, 1], 'starts', 'ends' - // 2. Get data shape: [2,4] and dims: 2 - // 3. Create two tensor t1 and t2, shape is the dims from step2: 2. t1: [0, 0], t2: [INT_MAX, INT_MAX] - // 4. Use 'ScatterNDUpdate' to update some elements in t1, the updated indexes are coming from 'axes', the contents - // are coming from 'starts', t1: [1, 0]; apply the similar process to t2 - // 5. Call 'StrideSlice' with t1 and t2 - // Why using ScatterNDUpdate is that 'axes' may be discontinuous. - - // the shape of input, such as [2, 4] - auto shape_node = std::make_shared(data, element::Type_t::i32); - // the input dim, such as [2] - auto shape_shape_node = std::make_shared(shape_node, element::i32); - auto const_0_node = Constant::create(element::i32, {}, {0}); - auto const_max_node = Constant::create(element::i32, {}, {INT_MAX}); - // t1: [0, 0] - auto start_node = std::make_shared(const_0_node, shape_shape_node); - // t2: [INT_MAX, INT_MAX] - auto end_node = std::make_shared(const_max_node, shape_shape_node); - auto axes_node = Constant::create(element::i32, {axes.size(), 1}, axes); - // update t1 - auto fixed_start_node = std::make_shared(start_node, axes_node, start_idx_node); - // update t2 - auto fixed_end_node = std::make_shared(end_node, axes_node, end_idx_node); - - auto stride_slice_node = std::make_shared(data, - fixed_start_node, - fixed_end_node, - std::vector{0}, - std::vector{0}); - - auto decrease_axis = node.get_attribute>("decrease_axis"); - - if (decrease_axis.size() > 0) { - // according to paddle slice_op, when all axes are decreased, output shape is [1], instead of scalar. - // Ref: paddle/fluid/operators/slice_op.h - PartialShape input_shape = data.get_partial_shape(); - PADDLE_OP_CHECK(node, - input_shape.rank().is_static(), - "input rank of slice must be static when decrease_axis is set."); - - auto squeeze_index_node = Constant::create(element::i32, {decrease_axis.size()}, decrease_axis); - auto decreased_node = std::make_shared(stride_slice_node, squeeze_index_node); - - auto input_rank = input_shape.rank().get_length(); - if (input_rank == decrease_axis.size()) { - auto restore_node = std::make_shared(decreased_node, - std::make_shared(element::i64, Shape{1}, 1), - false); // restore to shape (1,) - return node.default_single_output_mapping({restore_node}, {"Out"}); - } - - return node.default_single_output_mapping({decreased_node}, {"Out"}); - } - - return node.default_single_output_mapping({stride_slice_node}, {"Out"}); + return slice_op(node, false); } } // namespace op } // namespace paddle diff --git a/src/frontends/paddle/src/op/slice_ops.hpp b/src/frontends/paddle/src/op/slice_ops.hpp new file mode 100644 index 00000000000..8896097ac98 --- /dev/null +++ b/src/frontends/paddle/src/op/slice_ops.hpp @@ -0,0 +1,123 @@ +// Copyright (C) 2018-2022 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// +#include + +#include "default_opset.hpp" +#include "openvino/frontend/paddle/node_context.hpp" + +namespace ov { +namespace frontend { +namespace paddle { +namespace op { +namespace { +Output idx_node(const std::string& tensor_alias, + const std::string& list_alias, + const std::string& attr_alias, + const NodeContext& node) { + if (node.has_input(tensor_alias)) { + return std::make_shared(node.get_input(tensor_alias), element::i32); + } else if (node.has_input(list_alias)) { + auto inputs = node.get_ng_inputs(list_alias); + return std::make_shared(std::make_shared(inputs, 0), + element::i32); + } else { + auto values = node.get_attribute>(attr_alias); + return default_opset::Constant::create(element::i32, {values.size()}, values); + } +} +NamedOutputs slice_op(const NodeContext& node, const bool& stride_input) { + const auto data = node.get_input("Input"); + const auto axes = node.get_attribute>("axes"); + + Output start_idx_node = idx_node("StartsTensor", "StartsTensorList", "starts", node); + Output end_idx_node = idx_node("EndsTensor", "EndsTensorList", "ends", node); + Output strides_idx_node; + if (stride_input) + strides_idx_node = idx_node("StridesTensor", "StridesTensorList", "strides", node); + + // The following process is: + // Given: + // data = [ [1, 2, 3, 4], [5, 6, 7, 8], ] // shape is: [2, 4] + // axes = [0] + // starts = [1] + // ends = [2] + // Our process is: + // 1. Get 'axes': [0, 1], 'starts', 'ends' + // 2. Get data shape: [2,4] and dims: 2 + // 3. Create two tensor t1 and t2, shape is the dims from step2: 2. t1: [0, 0], t2: [INT_MAX, INT_MAX] + // 4. Use 'ScatterNDUpdate' to update some elements in t1, the updated indexes are coming from 'axes', the contents + // are coming from 'starts', t1: [1, 0]; apply the similar process to t2 + // 5. Call 'StrideSlice' with t1 and t2 + // Why using ScatterNDUpdate is that 'axes' may be discontinuous. + + // the shape of input, such as [2, 4] + const auto shape_node = std::make_shared(data, element::Type_t::i32); + // the input dim, such as [2] + const auto rank_node = std::make_shared(shape_node, element::i32); + const auto const_0_node = default_opset::Constant::create(element::i32, {}, {0}); + const auto const_max_node = default_opset::Constant::create(element::i32, {}, {INT_MAX}); + const auto const_1_node = default_opset::Constant::create(element::i32, {}, {1}); + // t1: [0, 0] + const auto start_node = std::make_shared(const_0_node, rank_node); + // t2: [INT_MAX, INT_MAX] + const auto end_node = std::make_shared(const_max_node, rank_node); + const auto strides_node = std::make_shared(const_1_node, rank_node); + const auto axes_node = default_opset::Constant::create(element::i32, {axes.size(), 1}, axes); + // update t1 + const auto fixed_start_node = + std::make_shared(start_node, axes_node, start_idx_node); + // update t2 + const auto fixed_end_node = std::make_shared(end_node, axes_node, end_idx_node); + std::shared_ptr stride_slice_node; + if (stride_input) { + const auto fixed_strides_node = + std::make_shared(strides_node, axes_node, strides_idx_node); + + stride_slice_node = std::make_shared(data, + fixed_start_node, + fixed_end_node, + fixed_strides_node, + std::vector{0}, + std::vector{0}); + } else { + stride_slice_node = std::make_shared(data, + fixed_start_node, + fixed_end_node, + std::vector{0}, + std::vector{0}); + } + + const auto decrease_axis = node.get_attribute>("decrease_axis"); + + if (decrease_axis.size() > 0) { + // according to paddle slice_op, when all axes are decreased, output shape is [1], instead of scalar. + // Ref: paddle/fluid/operators/slice_op.h + PartialShape input_shape = data.get_partial_shape(); + PADDLE_OP_CHECK(node, + input_shape.rank().is_static(), + "input rank of slice must be static when decrease_axis is set."); + + const auto squeeze_index_node = + default_opset::Constant::create(element::i32, {decrease_axis.size()}, decrease_axis); + const auto decreased_node = std::make_shared(stride_slice_node, squeeze_index_node); + + const auto input_rank = input_shape.rank().get_length(); + if (input_rank == decrease_axis.size()) { + auto restore_node = std::make_shared( + decreased_node, + std::make_shared(element::i64, Shape{1}, 1), + false); // restore to shape (1,) + return node.default_single_output_mapping({restore_node}, {"Out"}); + } + + return node.default_single_output_mapping({decreased_node}, {"Out"}); + } + + return node.default_single_output_mapping({stride_slice_node}, {"Out"}); +} +} // namespace +} // namespace op +} // namespace paddle +} // namespace frontend +} // namespace ov \ No newline at end of file diff --git a/src/frontends/paddle/src/op/strided_slice.cpp b/src/frontends/paddle/src/op/strided_slice.cpp new file mode 100644 index 00000000000..01e31db9ef2 --- /dev/null +++ b/src/frontends/paddle/src/op/strided_slice.cpp @@ -0,0 +1,16 @@ +// Copyright (C) 2018-2022 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// +#include "slice_ops.hpp" + +namespace ov { +namespace frontend { +namespace paddle { +namespace op { +NamedOutputs strided_slice(const NodeContext& node) { + return slice_op(node, true); +} +} // namespace op +} // namespace paddle +} // namespace frontend +} // namespace ov \ No newline at end of file diff --git a/src/frontends/paddle/src/op/where.cpp b/src/frontends/paddle/src/op/where.cpp new file mode 100644 index 00000000000..56327493fac --- /dev/null +++ b/src/frontends/paddle/src/op/where.cpp @@ -0,0 +1,28 @@ +// Copyright (C) 2018-2022 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include "default_opset.hpp" +#include "openvino/frontend/paddle/node_context.hpp" + +namespace ov { +namespace frontend { +namespace paddle { +namespace op { +NamedOutputs where(const NodeContext& node) { + const auto condition_node = node.get_input("Condition"); + const auto x_node = node.get_input("X"); + const auto y_node = node.get_input("Y"); + // TODO: support 'shape x != shape y' #83233 + const auto x_shape = x_node.get_partial_shape(); + const auto y_shape = y_node.get_partial_shape(); + PADDLE_OP_CHECK(node, x_shape.compatible(y_shape), "shape x should be compatible to shape y!"); + + return node.default_single_output_mapping( + {std::make_shared(condition_node, x_node, y_node, ov::op::AutoBroadcastType::PDPD)}, + {"Out"}); +} +} // namespace op +} // namespace paddle +} // namespace frontend +} // namespace ov \ No newline at end of file diff --git a/src/frontends/paddle/src/op_table.cpp b/src/frontends/paddle/src/op_table.cpp index 040c476ebb9..797e45b6824 100644 --- a/src/frontends/paddle/src/op_table.cpp +++ b/src/frontends/paddle/src/op_table.cpp @@ -71,6 +71,7 @@ OP_CONVERTER(relu); OP_CONVERTER(relu6); OP_CONVERTER(reshape2); OP_CONVERTER(rnn); +OP_CONVERTER(roi_align); OP_CONVERTER(scale); OP_CONVERTER(shape); OP_CONVERTER(slice); @@ -80,10 +81,12 @@ OP_CONVERTER(sigmoid); OP_CONVERTER(split); OP_CONVERTER(squeeze); OP_CONVERTER(stack); +OP_CONVERTER(strided_slice); OP_CONVERTER(tanh); OP_CONVERTER(transpose2); OP_CONVERTER(trilinear_interp_v2); OP_CONVERTER(unsqueeze); +OP_CONVERTER(where); OP_CONVERTER(yolo_box); } // namespace op std::map get_supported_ops() { @@ -157,6 +160,7 @@ std::map get_supported_ops() { {"relu6", op::relu6}, {"reshape2", op::reshape2}, {"rnn", op::rnn}, + {"roi_align", op::roi_align}, {"scale", op::scale}, {"shape", op::shape}, {"slice", op::slice}, @@ -166,11 +170,13 @@ std::map get_supported_ops() { {"split", op::split}, {"squeeze2", op::squeeze}, {"stack", op::stack}, + {"strided_slice", op::strided_slice}, {"sync_batch_norm", op::batch_norm}, {"tanh", op::tanh}, {"transpose2", op::transpose2}, {"trilinear_interp_v2", op::trilinear_interp_v2}, {"unsqueeze2", op::unsqueeze}, + {"where", op::where}, {"yolo_box", op::yolo_box}}; };