Create reshape with scalar when shape is empty
This commit is contained in:
parent
c54eb185c0
commit
0f901f419a
@ -253,10 +253,17 @@ ov::matcher_pass_callback ConvertReduceBase::convert_reduce_to_pooling() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if (shape_end != input.get_shape()) {
|
if (shape_end != input.get_shape()) {
|
||||||
input = std::make_shared<ov::opset1::Reshape>(
|
if (shape_end == ov::Shape{}) {
|
||||||
input,
|
input = std::make_shared<ov::opset1::Reshape>(
|
||||||
ov::opset1::Constant::create(ov::element::i64, ov::Shape{shape_end.size()}, shape_end),
|
input,
|
||||||
true);
|
ov::opset1::Constant::create(ov::element::i64, ov::Shape{1}, {0}),
|
||||||
|
false);
|
||||||
|
} else {
|
||||||
|
input = std::make_shared<ov::opset1::Reshape>(
|
||||||
|
input,
|
||||||
|
ov::opset1::Constant::create(ov::element::i64, ov::Shape{shape_end.size()}, shape_end),
|
||||||
|
true);
|
||||||
|
}
|
||||||
new_ops.push_back(input.get_node_shared_ptr());
|
new_ops.push_back(input.get_node_shared_ptr());
|
||||||
}
|
}
|
||||||
input.get_node_shared_ptr()->set_friendly_name(reduce->get_friendly_name());
|
input.get_node_shared_ptr()->set_friendly_name(reduce->get_friendly_name());
|
||||||
|
Loading…
Reference in New Issue
Block a user