[Common] Implement SelectWithOneValueCondition common transformation (#14242)

* [Common] Implement SelectWithOneValueCondition common transformation

Signed-off-by: Kazantsev, Roman <roman.kazantsev@intel.com>

* Add transformation tests for select_with_one_value_condition

* Move call of SelectWithOneValueCondition transformation to MOC

* Apply code-review feedback

* Remove corner case with dynamic output shape for Select

Signed-off-by: Kazantsev, Roman <roman.kazantsev@intel.com>
This commit is contained in:
Roman Kazantsev 2022-11-29 15:40:04 +04:00 committed by GitHub
parent de33d6d38b
commit 5a59b2b8f2
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 261 additions and 0 deletions

View File

@ -0,0 +1,30 @@
// Copyright (C) 2018-2022 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#pragma once
#include <memory>
#include <openvino/pass/graph_rewrite.hpp>
#include <transformations_visibility.hpp>
#include <vector>
namespace ov {
namespace pass {
class TRANSFORMATIONS_API SelectWithOneValueCondition;
} // namespace pass
} // namespace ov
/**
* @ingroup ie_transformation_common_api
* @brief SelectWithOneValueCondition transformation eliminates Select operation if the condition
* is constant and consists of al True or False elements
*/
class ov::pass::SelectWithOneValueCondition : public ov::pass::MatcherPass {
public:
OPENVINO_RTTI("SelectWithOneValueCondition", "0");
SelectWithOneValueCondition();
};

View File

@ -53,6 +53,7 @@
#include <transformations/common_optimizations/remove_multi_subgraph_op_dangling_params.hpp>
#include <transformations/common_optimizations/reshape_sequence_fusion.hpp>
#include <transformations/common_optimizations/ric_fusion.hpp>
#include <transformations/common_optimizations/select_with_one_value_condition.hpp>
#include <transformations/common_optimizations/sequence_fusion.hpp>
#include <transformations/common_optimizations/shuffle_channels_fusion.hpp>
#include <transformations/common_optimizations/simplify_shape_of_sub_graph.hpp>
@ -160,6 +161,7 @@ bool ov::pass::MOCTransformations::run_on_model(const std::shared_ptr<ngraph::Fu
auto eliminations = manager.register_pass<ov::pass::GraphRewrite>();
ADD_MATCHER(eliminations, EliminateUnsqueezeGather)
ADD_MATCHER(eliminations, NopElimination, m_use_shapes)
ADD_MATCHER(eliminations, SelectWithOneValueCondition)
eliminations->set_name("ov::pass::CommonEliminations");
manager.register_pass<ov::pass::ConstantFolding>();

View File

@ -0,0 +1,99 @@
// Copyright (C) 2018-2022 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include "transformations/common_optimizations/select_with_one_value_condition.hpp"
#include <memory>
#include "itt.hpp"
#include "openvino/core/rt_info.hpp"
#include "openvino/core/validation_util.hpp"
#include "openvino/opsets/opset10.hpp"
#include "openvino/pass/pattern/op/wrap_type.hpp"
#include "transformations/utils/utils.hpp"
using namespace std;
using namespace ov;
using namespace ov::element;
using namespace ov::opset10;
using namespace ov::op::util;
ov::pass::SelectWithOneValueCondition::SelectWithOneValueCondition() {
MATCHER_SCOPE(SelectWithOneValueCondition);
auto condition = pattern::wrap_type<Constant>();
auto then_branch = pattern::any_input();
auto else_branch = pattern::any_input();
auto select_pattern = make_shared<Select>(condition, then_branch, else_branch);
matcher_pass_callback callback = [=](pattern::Matcher& m) {
NodeRegistry copy_from;
NodeRegistry copy_to;
auto& pattern_map = m.get_pattern_value_map();
auto select_value = pattern_map.at(select_pattern);
auto select = std::dynamic_pointer_cast<Select>(select_value.get_node_shared_ptr());
if (!select) {
return false;
}
auto condition_value = pattern_map.at(condition);
auto condition_const = std::dynamic_pointer_cast<Constant>(condition_value.get_node_shared_ptr());
if (!condition_const) {
return false;
}
if (condition_value.get_element_type() != element::boolean) {
return false;
}
// check if all elements in the condition to be true or false
// only in this case, one certain branch can be selected
auto cond_value = condition_const->get_vector<bool>();
if (cond_value.size() == 0) {
return false;
}
auto cond_elem = cond_value[0];
auto all_equal = all_of(cond_value.begin(), cond_value.end(), [=](bool v) {
return v == cond_elem;
});
if (!all_equal) {
return false;
}
// based on the condition value, mark the selected branch and skipped branch index
auto branch_index = cond_elem ? 1 : 2;
// based on the resulted shape and the shape of the skipped branch, perform further steps
auto select_shape = select->get_output_partial_shape(0);
auto branch_output = select->input_value(branch_index);
auto branch_output_shape = branch_output.get_partial_shape();
if (select_shape.is_static() && branch_output_shape.same_scheme(select_shape)) {
// Broadcast is not needed if the select shape is exactly the same as the selected branch
select->output(0).replace(branch_output);
replace_output_update_name(select->output(0), branch_output);
branch_output.get_node_shared_ptr()->set_friendly_name(select->get_friendly_name());
} else if (select_shape.is_static()) {
// if the shape of the selected branch is not the same, it needs the broadcasting
NodeRegistry copy_to;
auto select_rank = select_shape.size();
vector<int32_t> select_shape_values(select_rank);
for (size_t i = 0; i < select_rank; ++i) {
select_shape_values[i] = static_cast<int32_t>(select_shape[i].get_length());
}
auto target_shape = copy_to.make<Constant>(element::i32, Shape{select_rank}, select_shape_values);
auto broadcast = copy_to.make<Broadcast>(branch_output, target_shape);
select->output(0).replace(broadcast->output(0));
broadcast->set_friendly_name(select->get_friendly_name());
copy_runtime_info(select, copy_to.get());
} else {
return false;
}
return true;
};
auto m = make_shared<pattern::Matcher>(select_pattern, matcher_name);
this->register_matcher(m, callback);
}

View File

@ -0,0 +1,130 @@
// Copyright (C) 2018-2022 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include "transformations/common_optimizations/select_with_one_value_condition.hpp"
#include <gtest/gtest.h>
#include "common_test_utils/ngraph_test_utils.hpp"
#include "openvino/opsets/opset10.hpp"
using namespace ov;
using namespace std;
using namespace testing;
using namespace ov::opset10;
using namespace ov::element;
namespace {
enum class SELECT_BRANCH { THEN_BRANCH, ELSE_BRANCH, NONE };
struct SelectWithOneValueConditionParams {
PartialShape x_shape;
PartialShape y_shape;
Shape cond_shape;
vector<bool> cond_values;
PartialShape select_shape;
SELECT_BRANCH which_branch;
};
shared_ptr<Model> gen_model(const SelectWithOneValueConditionParams& params) {
auto x = make_shared<Parameter>(f32, params.x_shape);
auto y = make_shared<Parameter>(f32, params.y_shape);
auto add = make_shared<Add>(x, y);
auto subtract = make_shared<Relu>(y);
auto condition = make_shared<Constant>(boolean, params.cond_shape, params.cond_values);
auto select = make_shared<Select>(condition, add, subtract);
auto res_model = make_shared<Model>(OutputVector{select}, ParameterVector{x, y});
return res_model;
}
shared_ptr<Model> gen_reference(const SelectWithOneValueConditionParams& params) {
auto cond_shape = params.cond_shape;
auto x = make_shared<Parameter>(f32, params.x_shape);
auto y = make_shared<Parameter>(f32, params.y_shape);
Output<Node> output;
if (params.which_branch == SELECT_BRANCH::NONE) {
return gen_model(params);
} else if (params.which_branch == SELECT_BRANCH::THEN_BRANCH) {
output = make_shared<Add>(x, y)->output(0);
} else {
output = make_shared<Relu>(y)->output(0);
}
if (!output.get_partial_shape().same_scheme(params.select_shape)) {
vector<int32_t> select_shape_values(params.select_shape.size());
for (size_t i = 0; i < params.select_shape.size(); ++i) {
select_shape_values[i] = static_cast<int32_t>(params.select_shape[i].get_length());
}
if (select_shape_values.size() > 0) {
auto target_shape = make_shared<Constant>(i32, Shape{select_shape_values.size()}, select_shape_values);
output = make_shared<Broadcast>(output, target_shape)->output(0);
}
}
return make_shared<Model>(OutputVector{output}, ParameterVector{x, y});
}
} // namespace
class SelectWithOneValueConditionTest : public WithParamInterface<SelectWithOneValueConditionParams>,
public TransformationTestsF {};
TEST_P(SelectWithOneValueConditionTest, SelectWithOneValueConditionTestPattern) {
const auto& p = GetParam();
{
function = gen_model(p);
manager.register_pass<pass::SelectWithOneValueCondition>();
}
function_ref = gen_reference(p);
comparator.enable(FunctionsComparator::CmpValues::ACCURACY);
comparator.enable(FunctionsComparator::CmpValues::CONST_VALUES);
comparator.enable(FunctionsComparator::CmpValues::ATTRIBUTES);
}
static const std::vector<SelectWithOneValueConditionParams> params = {
SelectWithOneValueConditionParams{PartialShape{2},
PartialShape{2},
Shape{1},
vector<bool>{false},
PartialShape{2},
SELECT_BRANCH::ELSE_BRANCH},
SelectWithOneValueConditionParams{PartialShape{2},
PartialShape{2},
Shape{1},
vector<bool>{true},
PartialShape{2},
SELECT_BRANCH::THEN_BRANCH},
// with Broadcast case - Subtract result needs to be broadcasted
SelectWithOneValueConditionParams{PartialShape{2, 2},
PartialShape{2},
Shape{1},
vector<bool>{false},
PartialShape{2, 2},
SELECT_BRANCH::ELSE_BRANCH},
// Select is not eliminated due to condition constant value
SelectWithOneValueConditionParams{PartialShape{2, 2},
PartialShape{2},
Shape{2, 2},
vector<bool>{false, true, false, false},
PartialShape{2, 2},
SELECT_BRANCH::NONE},
// The branch is not possible to select due to dynamic output shape for Select
SelectWithOneValueConditionParams{PartialShape{Dimension::dynamic(), 2},
PartialShape{2},
Shape{2},
vector<bool>{true, true},
PartialShape{Dimension::dynamic(), 2},
SELECT_BRANCH::NONE},
// The branch is not possible to select due to dynamic output shape for Select
SelectWithOneValueConditionParams{PartialShape{2},
PartialShape{Dimension::dynamic(), 2},
Shape{2},
vector<bool>{true, true},
PartialShape{Dimension::dynamic(), 2},
SELECT_BRANCH::NONE},
};
INSTANTIATE_TEST_SUITE_P(SelectWithOneValueConditionTest, SelectWithOneValueConditionTest, ValuesIn(params));