[GPU] Support dynamic shape for one-hot and reshape operations (#13516)

This commit is contained in:
Kelvin Choi 2022-10-27 09:12:17 +09:00 committed by GitHub
parent 154850e8ca
commit a7a14a89c8
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 24 additions and 9 deletions

View File

@ -104,6 +104,9 @@ std::string one_hot_inst::to_string(one_hot_node const& node) {
one_hot_inst::typed_primitive_inst(network& network, one_hot_node const& node) : parent(network, node) { one_hot_inst::typed_primitive_inst(network& network, one_hot_node const& node) : parent(network, node) {
auto input_layout = node.input().get_output_layout(); auto input_layout = node.input().get_output_layout();
if (input_layout.is_dynamic())
return;
const auto& input_sizes = input_layout.get_tensor(); const auto& input_sizes = input_layout.get_tensor();
const auto& output_sizes = argument.shape; const auto& output_sizes = argument.shape;

View File

@ -200,7 +200,7 @@ std::string reorder_inst::to_string(reorder_node const& node) {
} }
reorder_inst::typed_primitive_inst(network& network, reorder_node const& node) reorder_inst::typed_primitive_inst(network& network, reorder_node const& node)
: parent(network, node, !node.can_be_optimized() && !node.is_dynamic()) { : parent(network, node, (!node.can_be_optimized() && node.get_output_layout().is_static()) ? true : false) {
if (node.can_be_optimized()) if (node.can_be_optimized())
reuse_input(); reuse_input();

View File

@ -60,7 +60,14 @@ std::vector<layout> reshape_inst::calc_output_layouts(reshape_node const& /*node
// we return output_partial_shape taken from the original model intead of something like PartialShape::dynamic(rank) // we return output_partial_shape taken from the original model intead of something like PartialShape::dynamic(rank)
// as ngraph may refine output shape using interval arithmetic // as ngraph may refine output shape using interval arithmetic
if ((memory_deps.empty() && prim->output_pattern.empty()) || input_layout.is_dynamic()) { if ((memory_deps.empty() && prim->output_pattern.empty()) || input_layout.is_dynamic()) {
return { layout{prim->output_partial_shape, input_layout.data_type, format::adjust_to_rank(input_layout.format, prim->output_partial_shape.size())} }; if (prim->output_partial_shape.size() > 0) {
auto fm = format::adjust_to_rank(input_layout.format, prim->output_partial_shape.size());
return { layout{prim->output_partial_shape, input_layout.data_type, fm} };
} else if (prim->output_shape != tensor()) {
return { layout{input_layout.data_type, input_layout.format, prim->output_shape} };
} else {
OPENVINO_ASSERT("There are no output pattern, predefined output partial shape, and output shape!");
}
} }
ShapeType pattern_shape = impl_param.input_layouts.size() == 2 ? impl_param.get_input_layout(1).get<ShapeType>() ShapeType pattern_shape = impl_param.input_layouts.size() == 2 ? impl_param.get_input_layout(1).get<ShapeType>()
@ -78,16 +85,19 @@ std::vector<layout> reshape_inst::calc_output_layouts(reshape_node const& /*node
case reshape::reshape_mode::base: { case reshape::reshape_mode::base: {
ov::op::v1::Reshape op; ov::op::v1::Reshape op;
op.set_special_zero(prim->special_zero); op.set_special_zero(prim->special_zero);
op.set_friendly_name(prim->id.c_str());
shape_infer(&op, input_shapes, output_shapes, const_data); shape_infer(&op, input_shapes, output_shapes, const_data);
break; break;
} }
case reshape::reshape_mode::squeeze: { case reshape::reshape_mode::squeeze: {
ov::op::v0::Squeeze op; ov::op::v0::Squeeze op;
op.set_friendly_name(prim->id.c_str());
shape_infer(&op, input_shapes, output_shapes, const_data); shape_infer(&op, input_shapes, output_shapes, const_data);
break; break;
} }
case reshape::reshape_mode::unsqueeze: { case reshape::reshape_mode::unsqueeze: {
ov::op::v0::Unsqueeze op; ov::op::v0::Unsqueeze op;
op.set_friendly_name(prim->id.c_str());
shape_infer(&op, input_shapes, output_shapes, const_data); shape_infer(&op, input_shapes, output_shapes, const_data);
break; break;
} }
@ -114,7 +124,7 @@ std::vector<layout> reshape_inst::calc_output_layouts(reshape_node const& /*node
run_shape_infer(prim->mode); run_shape_infer(prim->mode);
} }
return { layout{output_shapes[0], input_layout.data_type, format::adjust_to_rank(input_layout.format, output_shapes[0].size())} }; return { layout {output_shapes[0], input_layout.data_type, format::adjust_to_rank(input_layout.format, output_shapes[0].size())} };
} }
template std::vector<layout> reshape_inst::calc_output_layouts<ov::PartialShape>(reshape_node const& node, const kernel_impl_params& impl_param); template std::vector<layout> reshape_inst::calc_output_layouts<ov::PartialShape>(reshape_node const& node, const kernel_impl_params& impl_param);

View File

@ -52,10 +52,7 @@ static void CreateOneHotOp(Program& p, const std::shared_ptr<ngraph::op::v1::One
int64_t depth = depth_value_node->cast_vector<int64_t>()[0]; int64_t depth = depth_value_node->cast_vector<int64_t>()[0];
auto out_pshape = op->get_output_partial_shape(0); auto out_pshape = op->get_output_partial_shape(0);
if (out_pshape.is_dynamic()) { cldnn::tensor out_tensor = out_pshape.is_static() ? tensor_from_dims(out_pshape.to_shape()) : cldnn::tensor{};
IE_THROW() << "OneHot doesn't support dynamic shapes yet";
}
auto out_tensor = tensor_from_dims(out_pshape.to_shape());
auto oneHotPrim = cldnn::one_hot(layerName, auto oneHotPrim = cldnn::one_hot(layerName,
inputPrimitives[0], inputPrimitives[0],

View File

@ -227,9 +227,14 @@ static void CreateStridedSliceOp(Program& p, const std::shared_ptr<ngraph::op::v
// Reshape in case of deleting of axis // Reshape in case of deleting of axis
if (!shrink_axis_mask.empty()) { if (!shrink_axis_mask.empty()) {
auto targetShape = tensor_from_dims(output_shape); std::vector<int64_t> output_pattern(output_shape.size());
auto out_p = output_pattern.begin();
for (auto s = output_shape.begin(); s != output_shape.end() && out_p != output_pattern.end(); s++, out_p++) {
*out_p = *s;
}
auto reshapeOutName = op->get_friendly_name() + "/Crop"; auto reshapeOutName = op->get_friendly_name() + "/Crop";
auto reshapePrim = cldnn::reshape(reshapeOutName, layerName, targetShape); auto reshapePrim = cldnn::reshape(reshapeOutName, layerName, false, output_pattern, output_pshape);
p.add_primitive(*op, reshapePrim); p.add_primitive(*op, reshapePrim);
last_layer_primitive = reshapeOutName; last_layer_primitive = reshapeOutName;
} }