[GPU] Support dynamic shape for one-hot and reshape operations (#13516)

This commit is contained in:
Kelvin Choi 2022-10-27 09:12:17 +09:00 committed by GitHub
parent 154850e8ca
commit a7a14a89c8
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 24 additions and 9 deletions

View File

@ -104,6 +104,9 @@ std::string one_hot_inst::to_string(one_hot_node const& node) {
one_hot_inst::typed_primitive_inst(network& network, one_hot_node const& node) : parent(network, node) {
auto input_layout = node.input().get_output_layout();
if (input_layout.is_dynamic())
return;
const auto& input_sizes = input_layout.get_tensor();
const auto& output_sizes = argument.shape;

View File

@ -200,7 +200,7 @@ std::string reorder_inst::to_string(reorder_node const& node) {
}
reorder_inst::typed_primitive_inst(network& network, reorder_node const& node)
: parent(network, node, !node.can_be_optimized() && !node.is_dynamic()) {
: parent(network, node, (!node.can_be_optimized() && node.get_output_layout().is_static()) ? true : false) {
if (node.can_be_optimized())
reuse_input();

View File

@ -60,7 +60,14 @@ std::vector<layout> reshape_inst::calc_output_layouts(reshape_node const& /*node
// we return output_partial_shape taken from the original model intead of something like PartialShape::dynamic(rank)
// as ngraph may refine output shape using interval arithmetic
if ((memory_deps.empty() && prim->output_pattern.empty()) || input_layout.is_dynamic()) {
return { layout{prim->output_partial_shape, input_layout.data_type, format::adjust_to_rank(input_layout.format, prim->output_partial_shape.size())} };
if (prim->output_partial_shape.size() > 0) {
auto fm = format::adjust_to_rank(input_layout.format, prim->output_partial_shape.size());
return { layout{prim->output_partial_shape, input_layout.data_type, fm} };
} else if (prim->output_shape != tensor()) {
return { layout{input_layout.data_type, input_layout.format, prim->output_shape} };
} else {
OPENVINO_ASSERT("There are no output pattern, predefined output partial shape, and output shape!");
}
}
ShapeType pattern_shape = impl_param.input_layouts.size() == 2 ? impl_param.get_input_layout(1).get<ShapeType>()
@ -78,16 +85,19 @@ std::vector<layout> reshape_inst::calc_output_layouts(reshape_node const& /*node
case reshape::reshape_mode::base: {
ov::op::v1::Reshape op;
op.set_special_zero(prim->special_zero);
op.set_friendly_name(prim->id.c_str());
shape_infer(&op, input_shapes, output_shapes, const_data);
break;
}
case reshape::reshape_mode::squeeze: {
ov::op::v0::Squeeze op;
op.set_friendly_name(prim->id.c_str());
shape_infer(&op, input_shapes, output_shapes, const_data);
break;
}
case reshape::reshape_mode::unsqueeze: {
ov::op::v0::Unsqueeze op;
op.set_friendly_name(prim->id.c_str());
shape_infer(&op, input_shapes, output_shapes, const_data);
break;
}
@ -114,7 +124,7 @@ std::vector<layout> reshape_inst::calc_output_layouts(reshape_node const& /*node
run_shape_infer(prim->mode);
}
return { layout{output_shapes[0], input_layout.data_type, format::adjust_to_rank(input_layout.format, output_shapes[0].size())} };
return { layout {output_shapes[0], input_layout.data_type, format::adjust_to_rank(input_layout.format, output_shapes[0].size())} };
}
template std::vector<layout> reshape_inst::calc_output_layouts<ov::PartialShape>(reshape_node const& node, const kernel_impl_params& impl_param);

View File

@ -52,10 +52,7 @@ static void CreateOneHotOp(Program& p, const std::shared_ptr<ngraph::op::v1::One
int64_t depth = depth_value_node->cast_vector<int64_t>()[0];
auto out_pshape = op->get_output_partial_shape(0);
if (out_pshape.is_dynamic()) {
IE_THROW() << "OneHot doesn't support dynamic shapes yet";
}
auto out_tensor = tensor_from_dims(out_pshape.to_shape());
cldnn::tensor out_tensor = out_pshape.is_static() ? tensor_from_dims(out_pshape.to_shape()) : cldnn::tensor{};
auto oneHotPrim = cldnn::one_hot(layerName,
inputPrimitives[0],

View File

@ -227,9 +227,14 @@ static void CreateStridedSliceOp(Program& p, const std::shared_ptr<ngraph::op::v
// Reshape in case of deleting of axis
if (!shrink_axis_mask.empty()) {
auto targetShape = tensor_from_dims(output_shape);
std::vector<int64_t> output_pattern(output_shape.size());
auto out_p = output_pattern.begin();
for (auto s = output_shape.begin(); s != output_shape.end() && out_p != output_pattern.end(); s++, out_p++) {
*out_p = *s;
}
auto reshapeOutName = op->get_friendly_name() + "/Crop";
auto reshapePrim = cldnn::reshape(reshapeOutName, layerName, targetShape);
auto reshapePrim = cldnn::reshape(reshapeOutName, layerName, false, output_pattern, output_pshape);
p.add_primitive(*op, reshapePrim);
last_layer_primitive = reshapeOutName;
}