[Common] Implement SelectWithOneValueCondition common transformation (#14242)
* [Common] Implement SelectWithOneValueCondition common transformation Signed-off-by: Kazantsev, Roman <roman.kazantsev@intel.com> * Add transformation tests for select_with_one_value_condition * Move call of SelectWithOneValueCondition transformation to MOC * Apply code-review feedback * Remove corner case with dynamic output shape for Select Signed-off-by: Kazantsev, Roman <roman.kazantsev@intel.com>
This commit is contained in:
parent
de33d6d38b
commit
5a59b2b8f2
@ -0,0 +1,30 @@
|
||||
// Copyright (C) 2018-2022 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <memory>
|
||||
#include <openvino/pass/graph_rewrite.hpp>
|
||||
#include <transformations_visibility.hpp>
|
||||
#include <vector>
|
||||
|
||||
namespace ov {
|
||||
namespace pass {
|
||||
|
||||
class TRANSFORMATIONS_API SelectWithOneValueCondition;
|
||||
|
||||
} // namespace pass
|
||||
} // namespace ov
|
||||
|
||||
/**
|
||||
* @ingroup ie_transformation_common_api
|
||||
* @brief SelectWithOneValueCondition transformation eliminates Select operation if the condition
|
||||
* is constant and consists of al True or False elements
|
||||
*/
|
||||
|
||||
class ov::pass::SelectWithOneValueCondition : public ov::pass::MatcherPass {
|
||||
public:
|
||||
OPENVINO_RTTI("SelectWithOneValueCondition", "0");
|
||||
SelectWithOneValueCondition();
|
||||
};
|
@ -53,6 +53,7 @@
|
||||
#include <transformations/common_optimizations/remove_multi_subgraph_op_dangling_params.hpp>
|
||||
#include <transformations/common_optimizations/reshape_sequence_fusion.hpp>
|
||||
#include <transformations/common_optimizations/ric_fusion.hpp>
|
||||
#include <transformations/common_optimizations/select_with_one_value_condition.hpp>
|
||||
#include <transformations/common_optimizations/sequence_fusion.hpp>
|
||||
#include <transformations/common_optimizations/shuffle_channels_fusion.hpp>
|
||||
#include <transformations/common_optimizations/simplify_shape_of_sub_graph.hpp>
|
||||
@ -160,6 +161,7 @@ bool ov::pass::MOCTransformations::run_on_model(const std::shared_ptr<ngraph::Fu
|
||||
auto eliminations = manager.register_pass<ov::pass::GraphRewrite>();
|
||||
ADD_MATCHER(eliminations, EliminateUnsqueezeGather)
|
||||
ADD_MATCHER(eliminations, NopElimination, m_use_shapes)
|
||||
ADD_MATCHER(eliminations, SelectWithOneValueCondition)
|
||||
eliminations->set_name("ov::pass::CommonEliminations");
|
||||
|
||||
manager.register_pass<ov::pass::ConstantFolding>();
|
||||
|
@ -0,0 +1,99 @@
|
||||
// Copyright (C) 2018-2022 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#include "transformations/common_optimizations/select_with_one_value_condition.hpp"
|
||||
|
||||
#include <memory>
|
||||
|
||||
#include "itt.hpp"
|
||||
#include "openvino/core/rt_info.hpp"
|
||||
#include "openvino/core/validation_util.hpp"
|
||||
#include "openvino/opsets/opset10.hpp"
|
||||
#include "openvino/pass/pattern/op/wrap_type.hpp"
|
||||
#include "transformations/utils/utils.hpp"
|
||||
|
||||
using namespace std;
|
||||
using namespace ov;
|
||||
using namespace ov::element;
|
||||
using namespace ov::opset10;
|
||||
using namespace ov::op::util;
|
||||
|
||||
ov::pass::SelectWithOneValueCondition::SelectWithOneValueCondition() {
|
||||
MATCHER_SCOPE(SelectWithOneValueCondition);
|
||||
|
||||
auto condition = pattern::wrap_type<Constant>();
|
||||
auto then_branch = pattern::any_input();
|
||||
auto else_branch = pattern::any_input();
|
||||
auto select_pattern = make_shared<Select>(condition, then_branch, else_branch);
|
||||
|
||||
matcher_pass_callback callback = [=](pattern::Matcher& m) {
|
||||
NodeRegistry copy_from;
|
||||
NodeRegistry copy_to;
|
||||
auto& pattern_map = m.get_pattern_value_map();
|
||||
auto select_value = pattern_map.at(select_pattern);
|
||||
auto select = std::dynamic_pointer_cast<Select>(select_value.get_node_shared_ptr());
|
||||
if (!select) {
|
||||
return false;
|
||||
}
|
||||
|
||||
auto condition_value = pattern_map.at(condition);
|
||||
auto condition_const = std::dynamic_pointer_cast<Constant>(condition_value.get_node_shared_ptr());
|
||||
if (!condition_const) {
|
||||
return false;
|
||||
}
|
||||
if (condition_value.get_element_type() != element::boolean) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// check if all elements in the condition to be true or false
|
||||
// only in this case, one certain branch can be selected
|
||||
auto cond_value = condition_const->get_vector<bool>();
|
||||
if (cond_value.size() == 0) {
|
||||
return false;
|
||||
}
|
||||
auto cond_elem = cond_value[0];
|
||||
auto all_equal = all_of(cond_value.begin(), cond_value.end(), [=](bool v) {
|
||||
return v == cond_elem;
|
||||
});
|
||||
if (!all_equal) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// based on the condition value, mark the selected branch and skipped branch index
|
||||
auto branch_index = cond_elem ? 1 : 2;
|
||||
|
||||
// based on the resulted shape and the shape of the skipped branch, perform further steps
|
||||
auto select_shape = select->get_output_partial_shape(0);
|
||||
auto branch_output = select->input_value(branch_index);
|
||||
auto branch_output_shape = branch_output.get_partial_shape();
|
||||
|
||||
if (select_shape.is_static() && branch_output_shape.same_scheme(select_shape)) {
|
||||
// Broadcast is not needed if the select shape is exactly the same as the selected branch
|
||||
select->output(0).replace(branch_output);
|
||||
replace_output_update_name(select->output(0), branch_output);
|
||||
branch_output.get_node_shared_ptr()->set_friendly_name(select->get_friendly_name());
|
||||
} else if (select_shape.is_static()) {
|
||||
// if the shape of the selected branch is not the same, it needs the broadcasting
|
||||
NodeRegistry copy_to;
|
||||
auto select_rank = select_shape.size();
|
||||
vector<int32_t> select_shape_values(select_rank);
|
||||
for (size_t i = 0; i < select_rank; ++i) {
|
||||
select_shape_values[i] = static_cast<int32_t>(select_shape[i].get_length());
|
||||
}
|
||||
|
||||
auto target_shape = copy_to.make<Constant>(element::i32, Shape{select_rank}, select_shape_values);
|
||||
auto broadcast = copy_to.make<Broadcast>(branch_output, target_shape);
|
||||
select->output(0).replace(broadcast->output(0));
|
||||
broadcast->set_friendly_name(select->get_friendly_name());
|
||||
copy_runtime_info(select, copy_to.get());
|
||||
} else {
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
};
|
||||
|
||||
auto m = make_shared<pattern::Matcher>(select_pattern, matcher_name);
|
||||
this->register_matcher(m, callback);
|
||||
}
|
@ -0,0 +1,130 @@
|
||||
// Copyright (C) 2018-2022 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#include "transformations/common_optimizations/select_with_one_value_condition.hpp"
|
||||
|
||||
#include <gtest/gtest.h>
|
||||
|
||||
#include "common_test_utils/ngraph_test_utils.hpp"
|
||||
#include "openvino/opsets/opset10.hpp"
|
||||
|
||||
using namespace ov;
|
||||
using namespace std;
|
||||
using namespace testing;
|
||||
using namespace ov::opset10;
|
||||
using namespace ov::element;
|
||||
|
||||
namespace {
|
||||
enum class SELECT_BRANCH { THEN_BRANCH, ELSE_BRANCH, NONE };
|
||||
|
||||
struct SelectWithOneValueConditionParams {
|
||||
PartialShape x_shape;
|
||||
PartialShape y_shape;
|
||||
Shape cond_shape;
|
||||
vector<bool> cond_values;
|
||||
PartialShape select_shape;
|
||||
SELECT_BRANCH which_branch;
|
||||
};
|
||||
|
||||
shared_ptr<Model> gen_model(const SelectWithOneValueConditionParams& params) {
|
||||
auto x = make_shared<Parameter>(f32, params.x_shape);
|
||||
auto y = make_shared<Parameter>(f32, params.y_shape);
|
||||
auto add = make_shared<Add>(x, y);
|
||||
auto subtract = make_shared<Relu>(y);
|
||||
auto condition = make_shared<Constant>(boolean, params.cond_shape, params.cond_values);
|
||||
auto select = make_shared<Select>(condition, add, subtract);
|
||||
auto res_model = make_shared<Model>(OutputVector{select}, ParameterVector{x, y});
|
||||
return res_model;
|
||||
}
|
||||
|
||||
shared_ptr<Model> gen_reference(const SelectWithOneValueConditionParams& params) {
|
||||
auto cond_shape = params.cond_shape;
|
||||
auto x = make_shared<Parameter>(f32, params.x_shape);
|
||||
auto y = make_shared<Parameter>(f32, params.y_shape);
|
||||
Output<Node> output;
|
||||
if (params.which_branch == SELECT_BRANCH::NONE) {
|
||||
return gen_model(params);
|
||||
} else if (params.which_branch == SELECT_BRANCH::THEN_BRANCH) {
|
||||
output = make_shared<Add>(x, y)->output(0);
|
||||
} else {
|
||||
output = make_shared<Relu>(y)->output(0);
|
||||
}
|
||||
|
||||
if (!output.get_partial_shape().same_scheme(params.select_shape)) {
|
||||
vector<int32_t> select_shape_values(params.select_shape.size());
|
||||
for (size_t i = 0; i < params.select_shape.size(); ++i) {
|
||||
select_shape_values[i] = static_cast<int32_t>(params.select_shape[i].get_length());
|
||||
}
|
||||
|
||||
if (select_shape_values.size() > 0) {
|
||||
auto target_shape = make_shared<Constant>(i32, Shape{select_shape_values.size()}, select_shape_values);
|
||||
output = make_shared<Broadcast>(output, target_shape)->output(0);
|
||||
}
|
||||
}
|
||||
|
||||
return make_shared<Model>(OutputVector{output}, ParameterVector{x, y});
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
class SelectWithOneValueConditionTest : public WithParamInterface<SelectWithOneValueConditionParams>,
|
||||
public TransformationTestsF {};
|
||||
|
||||
TEST_P(SelectWithOneValueConditionTest, SelectWithOneValueConditionTestPattern) {
|
||||
const auto& p = GetParam();
|
||||
{
|
||||
function = gen_model(p);
|
||||
manager.register_pass<pass::SelectWithOneValueCondition>();
|
||||
}
|
||||
|
||||
function_ref = gen_reference(p);
|
||||
comparator.enable(FunctionsComparator::CmpValues::ACCURACY);
|
||||
comparator.enable(FunctionsComparator::CmpValues::CONST_VALUES);
|
||||
comparator.enable(FunctionsComparator::CmpValues::ATTRIBUTES);
|
||||
}
|
||||
|
||||
static const std::vector<SelectWithOneValueConditionParams> params = {
|
||||
SelectWithOneValueConditionParams{PartialShape{2},
|
||||
PartialShape{2},
|
||||
Shape{1},
|
||||
vector<bool>{false},
|
||||
PartialShape{2},
|
||||
SELECT_BRANCH::ELSE_BRANCH},
|
||||
SelectWithOneValueConditionParams{PartialShape{2},
|
||||
PartialShape{2},
|
||||
Shape{1},
|
||||
vector<bool>{true},
|
||||
PartialShape{2},
|
||||
SELECT_BRANCH::THEN_BRANCH},
|
||||
// with Broadcast case - Subtract result needs to be broadcasted
|
||||
SelectWithOneValueConditionParams{PartialShape{2, 2},
|
||||
PartialShape{2},
|
||||
Shape{1},
|
||||
vector<bool>{false},
|
||||
PartialShape{2, 2},
|
||||
SELECT_BRANCH::ELSE_BRANCH},
|
||||
// Select is not eliminated due to condition constant value
|
||||
SelectWithOneValueConditionParams{PartialShape{2, 2},
|
||||
PartialShape{2},
|
||||
Shape{2, 2},
|
||||
vector<bool>{false, true, false, false},
|
||||
PartialShape{2, 2},
|
||||
SELECT_BRANCH::NONE},
|
||||
// The branch is not possible to select due to dynamic output shape for Select
|
||||
SelectWithOneValueConditionParams{PartialShape{Dimension::dynamic(), 2},
|
||||
PartialShape{2},
|
||||
Shape{2},
|
||||
vector<bool>{true, true},
|
||||
PartialShape{Dimension::dynamic(), 2},
|
||||
SELECT_BRANCH::NONE},
|
||||
// The branch is not possible to select due to dynamic output shape for Select
|
||||
SelectWithOneValueConditionParams{PartialShape{2},
|
||||
PartialShape{Dimension::dynamic(), 2},
|
||||
Shape{2},
|
||||
vector<bool>{true, true},
|
||||
PartialShape{Dimension::dynamic(), 2},
|
||||
SELECT_BRANCH::NONE},
|
||||
};
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(SelectWithOneValueConditionTest, SelectWithOneValueConditionTest, ValuesIn(params));
|
Loading…
Reference in New Issue
Block a user