frontend 'Slice' op patch for 'decrease_axis' (#7931)
This commit is contained in:
@@ -4,13 +4,15 @@
|
||||
|
||||
#include <limits.h>
|
||||
|
||||
#include <ngraph/opsets/opset6.hpp>
|
||||
#include <node_context.hpp>
|
||||
|
||||
#include "default_opset.hpp"
|
||||
|
||||
namespace ngraph {
|
||||
namespace frontend {
|
||||
namespace pdpd {
|
||||
namespace op {
|
||||
using namespace default_opset;
|
||||
NamedOutputs slice(const NodeContext& node) {
|
||||
auto data = node.get_ng_input("Input");
|
||||
auto axes = node.get_attribute<std::vector<int32_t>>("axes");
|
||||
@@ -19,20 +21,20 @@ NamedOutputs slice(const NodeContext& node) {
|
||||
start_idx_node = node.get_ng_input("StartsTensor");
|
||||
} else if (node.has_ng_input("StartsTensorList")) {
|
||||
auto inputs = node.get_ng_inputs("StartsTensorList");
|
||||
start_idx_node = std::make_shared<ngraph::opset6::Concat>(inputs, 0);
|
||||
start_idx_node = std::make_shared<Concat>(inputs, 0);
|
||||
} else {
|
||||
auto starts = node.get_attribute<std::vector<int32_t>>("starts");
|
||||
start_idx_node = opset6::Constant::create(element::i32, {starts.size()}, starts);
|
||||
start_idx_node = Constant::create(element::i32, {starts.size()}, starts);
|
||||
}
|
||||
|
||||
if (node.has_ng_input("EndsTensor")) {
|
||||
end_idx_node = node.get_ng_input("EndsTensor");
|
||||
} else if (node.has_ng_input("EndsTensorList")) {
|
||||
auto inputs = node.get_ng_inputs("EndsTensorList");
|
||||
end_idx_node = std::make_shared<ngraph::opset6::Concat>(inputs, 0);
|
||||
end_idx_node = std::make_shared<Concat>(inputs, 0);
|
||||
} else {
|
||||
auto ends = node.get_attribute<std::vector<int32_t>>("ends");
|
||||
end_idx_node = opset6::Constant::create(element::i32, {ends.size()}, ends);
|
||||
end_idx_node = Constant::create(element::i32, {ends.size()}, ends);
|
||||
}
|
||||
|
||||
// The following process is:
|
||||
@@ -51,27 +53,44 @@ NamedOutputs slice(const NodeContext& node) {
|
||||
// Why using ScatterNDUpdate is that 'axes' may be discontinuous.
|
||||
|
||||
// the shape of input, such as [2, 4]
|
||||
auto shape_node = std::make_shared<opset6::ShapeOf>(data, element::Type_t::i32);
|
||||
auto shape_node = std::make_shared<ShapeOf>(data, element::Type_t::i32);
|
||||
// the input dim, such as [2]
|
||||
auto shape_shape_node = std::make_shared<opset6::ShapeOf>(shape_node, element::i32);
|
||||
auto const_0_node = opset6::Constant::create(element::i32, {}, {0});
|
||||
auto const_max_node = opset6::Constant::create(element::i32, {}, {INT_MAX});
|
||||
auto shape_shape_node = std::make_shared<ShapeOf>(shape_node, element::i32);
|
||||
auto const_0_node = Constant::create(element::i32, {}, {0});
|
||||
auto const_max_node = Constant::create(element::i32, {}, {INT_MAX});
|
||||
// t1: [0, 0]
|
||||
auto start_node = std::make_shared<opset6::Broadcast>(const_0_node, shape_shape_node);
|
||||
auto start_node = std::make_shared<Broadcast>(const_0_node, shape_shape_node);
|
||||
// t2: [INT_MAX, INT_MAX]
|
||||
auto end_node = std::make_shared<opset6::Broadcast>(const_max_node, shape_shape_node);
|
||||
auto axes_node = opset6::Constant::create(element::i32, {axes.size(), 1}, axes);
|
||||
auto end_node = std::make_shared<Broadcast>(const_max_node, shape_shape_node);
|
||||
auto axes_node = Constant::create(element::i32, {axes.size(), 1}, axes);
|
||||
// update t1
|
||||
auto fixed_start_node = std::make_shared<opset6::ScatterNDUpdate>(start_node, axes_node, start_idx_node);
|
||||
auto fixed_start_node = std::make_shared<ScatterNDUpdate>(start_node, axes_node, start_idx_node);
|
||||
// update t2
|
||||
auto fixed_end_node = std::make_shared<opset6::ScatterNDUpdate>(end_node, axes_node, end_idx_node);
|
||||
auto fixed_end_node = std::make_shared<ScatterNDUpdate>(end_node, axes_node, end_idx_node);
|
||||
|
||||
return node.default_single_output_mapping({std::make_shared<ngraph::opset6::StridedSlice>(data,
|
||||
fixed_start_node,
|
||||
fixed_end_node,
|
||||
std::vector<int64_t>{0},
|
||||
std::vector<int64_t>{0})},
|
||||
{"Out"});
|
||||
auto stride_slice_node = std::make_shared<StridedSlice>(data,
|
||||
fixed_start_node,
|
||||
fixed_end_node,
|
||||
std::vector<int64_t>{0},
|
||||
std::vector<int64_t>{0});
|
||||
|
||||
auto decrease_axis = node.get_attribute<std::vector<int32_t>>("decrease_axis");
|
||||
|
||||
if (decrease_axis.size() > 0) {
|
||||
auto stride_slice_output_shape = stride_slice_node->get_output_partial_shape(0);
|
||||
|
||||
for (size_t i = 0; i < decrease_axis.size(); ++i)
|
||||
PDPD_OP_VALIDATION_CHECK(node,
|
||||
stride_slice_output_shape[decrease_axis[i]] == 1,
|
||||
"decrease dim should be 1!");
|
||||
|
||||
auto squeeze_index_node = Constant::create(element::i32, {}, decrease_axis);
|
||||
auto decreased_node = std::make_shared<Squeeze>(stride_slice_node, squeeze_index_node);
|
||||
|
||||
return node.default_single_output_mapping({decreased_node}, {"Out"});
|
||||
}
|
||||
|
||||
return node.default_single_output_mapping({stride_slice_node}, {"Out"});
|
||||
}
|
||||
} // namespace op
|
||||
} // namespace pdpd
|
||||
|
||||
Reference in New Issue
Block a user