Handle Reshape's special zero in SimplifySecondInputOfReshape (#20785)
* Handle Reshape's special zero in SimplifySecondInputOfReshape SimplifySecondInputOfReshape detects ShapeOf->Gather->Concat subgraphs on Reshape's second input and replaces ShapeOf->Gather with a Constant with zero(s). Currently it works only with Reshapes that have special_zero set to true, but it can work for Reshapes with special_zero == false if non-Gather inputs to Concat are Constants and don't contain any zero. Ticket: CVS-123434 * fix no default output
This commit is contained in:
parent
e6ff5edc3d
commit
4084e9acc0
@ -200,7 +200,7 @@ pass::SimplifySecondInputOfReshape::SimplifySecondInputOfReshape() {
|
|||||||
matcher_pass_callback callback = [=](Matcher& m) {
|
matcher_pass_callback callback = [=](Matcher& m) {
|
||||||
auto node = m.get_match_root();
|
auto node = m.get_match_root();
|
||||||
const auto reshape = as_type_ptr<v1::Reshape>(node);
|
const auto reshape = as_type_ptr<v1::Reshape>(node);
|
||||||
if (!reshape || reshape->get_special_zero() == false) {
|
if (!reshape) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -219,7 +219,7 @@ pass::SimplifySecondInputOfReshape::SimplifySecondInputOfReshape() {
|
|||||||
|
|
||||||
auto check_shape_of_gather = [&](const std::shared_ptr<Node>& gather) {
|
auto check_shape_of_gather = [&](const std::shared_ptr<Node>& gather) {
|
||||||
auto shape_of = gather->get_input_node_shared_ptr(0);
|
auto shape_of = gather->get_input_node_shared_ptr(0);
|
||||||
if (!is_type<v3::ShapeOf>(shape_of) && !is_type<v0::ShapeOf>(shape_of)) {
|
if (!is_type<op::util::ShapeOfBase>(shape_of)) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
return shape_of->input_value(0) == data;
|
return shape_of->input_value(0) == data;
|
||||||
@ -237,16 +237,15 @@ pass::SimplifySecondInputOfReshape::SimplifySecondInputOfReshape() {
|
|||||||
gather_dims_expected_location += concat_input_shape[0];
|
gather_dims_expected_location += concat_input_shape[0];
|
||||||
};
|
};
|
||||||
|
|
||||||
|
bool special_zero = reshape->get_special_zero();
|
||||||
|
|
||||||
// We need this check to avoid sequences shapeOf -> gather -> concat
|
// We need this check to avoid sequences shapeOf -> gather -> concat
|
||||||
// that change the arrangement of dimensions in the reshape pattern
|
// that change the arrangement of dimensions in the reshape pattern
|
||||||
for (auto& concat_input : new_concat_inputs) {
|
for (auto& concat_input : new_concat_inputs) {
|
||||||
if (const auto gather = as_type_ptr<op::util::GatherBase>(concat_input.get_node_shared_ptr())) {
|
auto node = concat_input.get_node_shared_ptr();
|
||||||
auto indices_constant = as_type_ptr<v0::Constant>(gather->get_input_node_shared_ptr(1));
|
if (ov::is_type<op::util::GatherBase>(node) &&
|
||||||
if (!indices_constant || !check_shape_of_gather(gather)) {
|
ov::is_type<v0::Constant>(node->get_input_node_shared_ptr(1)) && check_shape_of_gather(node)) {
|
||||||
update_expected_gather_location(gather);
|
auto indices_constant = as_type_ptr<v0::Constant>(node->get_input_node_shared_ptr(1));
|
||||||
continue;
|
|
||||||
}
|
|
||||||
|
|
||||||
bool gather_can_be_fused = true;
|
bool gather_can_be_fused = true;
|
||||||
const auto indices = indices_constant->cast_vector<std::int64_t>();
|
const auto indices = indices_constant->cast_vector<std::int64_t>();
|
||||||
for (size_t i = 0; i < indices.size(); ++i) {
|
for (size_t i = 0; i < indices.size(); ++i) {
|
||||||
@ -258,11 +257,21 @@ pass::SimplifySecondInputOfReshape::SimplifySecondInputOfReshape() {
|
|||||||
|
|
||||||
if (gather_can_be_fused) {
|
if (gather_can_be_fused) {
|
||||||
const size_t num_of_unchanged_dimensions = indices.size();
|
const size_t num_of_unchanged_dimensions = indices.size();
|
||||||
const auto subgraph_et = gather->get_input_element_type(0);
|
const auto subgraph_et = node->get_input_element_type(0);
|
||||||
concat_input = v0::Constant::create(subgraph_et, Shape{num_of_unchanged_dimensions}, {0});
|
concat_input = v0::Constant::create(subgraph_et, Shape{num_of_unchanged_dimensions}, {0});
|
||||||
gather_folded = true;
|
gather_folded = true;
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
|
if (!special_zero) {
|
||||||
|
// If special zero is false - check if other inputs to Concat are Constants.
|
||||||
|
// If any of those Constants contain zero - return false.
|
||||||
|
auto constant = as_type_ptr<v0::Constant>(node);
|
||||||
|
if (!constant)
|
||||||
|
return false;
|
||||||
|
auto values = constant->cast_vector<int64_t>();
|
||||||
|
if (std::find(values.begin(), values.end(), 0) != values.end())
|
||||||
|
return false;
|
||||||
|
}
|
||||||
update_expected_gather_location(concat_input);
|
update_expected_gather_location(concat_input);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -275,7 +284,7 @@ pass::SimplifySecondInputOfReshape::SimplifySecondInputOfReshape() {
|
|||||||
new_concat->set_friendly_name(concat->get_friendly_name());
|
new_concat->set_friendly_name(concat->get_friendly_name());
|
||||||
copy_runtime_info(concat, new_concat);
|
copy_runtime_info(concat, new_concat);
|
||||||
|
|
||||||
const auto new_reshape = reshape->clone_with_new_inputs({reshape->input_value(0), new_concat});
|
const auto new_reshape = std::make_shared<v1::Reshape>(reshape->input_value(0), new_concat, true);
|
||||||
new_reshape->set_friendly_name(reshape->get_friendly_name());
|
new_reshape->set_friendly_name(reshape->get_friendly_name());
|
||||||
|
|
||||||
copy_runtime_info(reshape, new_reshape);
|
copy_runtime_info(reshape, new_reshape);
|
||||||
|
@ -611,3 +611,53 @@ TEST_F(TransformationTestsF, SimplifySecondInputOfReshapeTest21) {
|
|||||||
}
|
}
|
||||||
comparator.enable(FunctionsComparator::CONST_VALUES);
|
comparator.enable(FunctionsComparator::CONST_VALUES);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST_F(TransformationTestsF, SimplifySecondInputOfReshapeTestFalseSpecialZero) {
|
||||||
|
PartialShape data_shape{1, 128, 12, 64};
|
||||||
|
{
|
||||||
|
auto data = std::make_shared<opset7::Parameter>(element::f32, data_shape);
|
||||||
|
|
||||||
|
auto shape_of = std::make_shared<opset7::ShapeOf>(data);
|
||||||
|
auto gather_op = gather(shape_of, std::vector<int64_t>{0, 1});
|
||||||
|
auto constant = opset7::Constant::create(element::i64, Shape{1}, {768});
|
||||||
|
auto concat = std::make_shared<opset7::Concat>(OutputVector{gather_op, constant}, -1);
|
||||||
|
|
||||||
|
auto reshape = std::make_shared<opset7::Reshape>(data, concat, false);
|
||||||
|
model = std::make_shared<Model>(NodeVector{reshape}, ParameterVector{data});
|
||||||
|
|
||||||
|
manager.register_pass<ov::pass::SimplifySecondInputOfReshape>();
|
||||||
|
}
|
||||||
|
{
|
||||||
|
auto data = std::make_shared<opset7::Parameter>(element::f32, data_shape);
|
||||||
|
auto reshape_pattern = opset7::Constant::create(element::i64, Shape{3}, {0, 0, 768});
|
||||||
|
auto reshape = std::make_shared<opset7::Reshape>(data, reshape_pattern, true);
|
||||||
|
model_ref = std::make_shared<Model>(NodeVector{reshape}, ParameterVector{data});
|
||||||
|
}
|
||||||
|
comparator.enable(FunctionsComparator::ATTRIBUTES);
|
||||||
|
comparator.enable(FunctionsComparator::CONST_VALUES);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(TransformationTestsF, SimplifySecondInputOfReshapeTestFalseSpecialZeroZeroDim) {
|
||||||
|
PartialShape data_shape{1, 0, 12, 64};
|
||||||
|
{
|
||||||
|
auto data = std::make_shared<opset7::Parameter>(element::f32, data_shape);
|
||||||
|
|
||||||
|
auto shape_of = std::make_shared<opset7::ShapeOf>(data);
|
||||||
|
auto gather_op = gather(shape_of, std::vector<int64_t>{0, 1});
|
||||||
|
auto constant = opset7::Constant::create(element::i64, Shape{1}, {768});
|
||||||
|
auto concat = std::make_shared<opset7::Concat>(OutputVector{gather_op, constant}, -1);
|
||||||
|
|
||||||
|
auto reshape = std::make_shared<opset7::Reshape>(data, concat, false);
|
||||||
|
model = std::make_shared<Model>(NodeVector{reshape}, ParameterVector{data});
|
||||||
|
|
||||||
|
manager.register_pass<ov::pass::SimplifySecondInputOfReshape>();
|
||||||
|
}
|
||||||
|
{
|
||||||
|
auto data = std::make_shared<opset7::Parameter>(element::f32, data_shape);
|
||||||
|
auto reshape_pattern = opset7::Constant::create(element::i64, Shape{3}, {0, 0, 768});
|
||||||
|
auto reshape = std::make_shared<opset7::Reshape>(data, reshape_pattern, true);
|
||||||
|
model_ref = std::make_shared<Model>(NodeVector{reshape}, ParameterVector{data});
|
||||||
|
}
|
||||||
|
comparator.enable(FunctionsComparator::ATTRIBUTES);
|
||||||
|
comparator.enable(FunctionsComparator::CONST_VALUES);
|
||||||
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user