[Common] Implement SelectWithOneValueCondition common transformation (#14242)
* [Common] Implement SelectWithOneValueCondition common transformation Signed-off-by: Kazantsev, Roman <roman.kazantsev@intel.com> * Add transformation tests for select_with_one_value_condition * Move call of SelectWithOneValueCondition transformation to MOC * Apply code-review feedback * Remove corner case with dynamic output shape for Select Signed-off-by: Kazantsev, Roman <roman.kazantsev@intel.com>
This commit is contained in:
parent
de33d6d38b
commit
5a59b2b8f2
@ -0,0 +1,30 @@
|
|||||||
|
// Copyright (C) 2018-2022 Intel Corporation
|
||||||
|
// SPDX-License-Identifier: Apache-2.0
|
||||||
|
//
|
||||||
|
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include <memory>
|
||||||
|
#include <openvino/pass/graph_rewrite.hpp>
|
||||||
|
#include <transformations_visibility.hpp>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
namespace ov {
|
||||||
|
namespace pass {
|
||||||
|
|
||||||
|
class TRANSFORMATIONS_API SelectWithOneValueCondition;
|
||||||
|
|
||||||
|
} // namespace pass
|
||||||
|
} // namespace ov
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @ingroup ie_transformation_common_api
|
||||||
|
* @brief SelectWithOneValueCondition transformation eliminates Select operation if the condition
|
||||||
|
* is constant and consists of al True or False elements
|
||||||
|
*/
|
||||||
|
|
||||||
|
class ov::pass::SelectWithOneValueCondition : public ov::pass::MatcherPass {
|
||||||
|
public:
|
||||||
|
OPENVINO_RTTI("SelectWithOneValueCondition", "0");
|
||||||
|
SelectWithOneValueCondition();
|
||||||
|
};
|
@ -53,6 +53,7 @@
|
|||||||
#include <transformations/common_optimizations/remove_multi_subgraph_op_dangling_params.hpp>
|
#include <transformations/common_optimizations/remove_multi_subgraph_op_dangling_params.hpp>
|
||||||
#include <transformations/common_optimizations/reshape_sequence_fusion.hpp>
|
#include <transformations/common_optimizations/reshape_sequence_fusion.hpp>
|
||||||
#include <transformations/common_optimizations/ric_fusion.hpp>
|
#include <transformations/common_optimizations/ric_fusion.hpp>
|
||||||
|
#include <transformations/common_optimizations/select_with_one_value_condition.hpp>
|
||||||
#include <transformations/common_optimizations/sequence_fusion.hpp>
|
#include <transformations/common_optimizations/sequence_fusion.hpp>
|
||||||
#include <transformations/common_optimizations/shuffle_channels_fusion.hpp>
|
#include <transformations/common_optimizations/shuffle_channels_fusion.hpp>
|
||||||
#include <transformations/common_optimizations/simplify_shape_of_sub_graph.hpp>
|
#include <transformations/common_optimizations/simplify_shape_of_sub_graph.hpp>
|
||||||
@ -160,6 +161,7 @@ bool ov::pass::MOCTransformations::run_on_model(const std::shared_ptr<ngraph::Fu
|
|||||||
auto eliminations = manager.register_pass<ov::pass::GraphRewrite>();
|
auto eliminations = manager.register_pass<ov::pass::GraphRewrite>();
|
||||||
ADD_MATCHER(eliminations, EliminateUnsqueezeGather)
|
ADD_MATCHER(eliminations, EliminateUnsqueezeGather)
|
||||||
ADD_MATCHER(eliminations, NopElimination, m_use_shapes)
|
ADD_MATCHER(eliminations, NopElimination, m_use_shapes)
|
||||||
|
ADD_MATCHER(eliminations, SelectWithOneValueCondition)
|
||||||
eliminations->set_name("ov::pass::CommonEliminations");
|
eliminations->set_name("ov::pass::CommonEliminations");
|
||||||
|
|
||||||
manager.register_pass<ov::pass::ConstantFolding>();
|
manager.register_pass<ov::pass::ConstantFolding>();
|
||||||
|
@ -0,0 +1,99 @@
|
|||||||
|
// Copyright (C) 2018-2022 Intel Corporation
|
||||||
|
// SPDX-License-Identifier: Apache-2.0
|
||||||
|
//
|
||||||
|
|
||||||
|
#include "transformations/common_optimizations/select_with_one_value_condition.hpp"
|
||||||
|
|
||||||
|
#include <memory>
|
||||||
|
|
||||||
|
#include "itt.hpp"
|
||||||
|
#include "openvino/core/rt_info.hpp"
|
||||||
|
#include "openvino/core/validation_util.hpp"
|
||||||
|
#include "openvino/opsets/opset10.hpp"
|
||||||
|
#include "openvino/pass/pattern/op/wrap_type.hpp"
|
||||||
|
#include "transformations/utils/utils.hpp"
|
||||||
|
|
||||||
|
using namespace std;
|
||||||
|
using namespace ov;
|
||||||
|
using namespace ov::element;
|
||||||
|
using namespace ov::opset10;
|
||||||
|
using namespace ov::op::util;
|
||||||
|
|
||||||
|
ov::pass::SelectWithOneValueCondition::SelectWithOneValueCondition() {
|
||||||
|
MATCHER_SCOPE(SelectWithOneValueCondition);
|
||||||
|
|
||||||
|
auto condition = pattern::wrap_type<Constant>();
|
||||||
|
auto then_branch = pattern::any_input();
|
||||||
|
auto else_branch = pattern::any_input();
|
||||||
|
auto select_pattern = make_shared<Select>(condition, then_branch, else_branch);
|
||||||
|
|
||||||
|
matcher_pass_callback callback = [=](pattern::Matcher& m) {
|
||||||
|
NodeRegistry copy_from;
|
||||||
|
NodeRegistry copy_to;
|
||||||
|
auto& pattern_map = m.get_pattern_value_map();
|
||||||
|
auto select_value = pattern_map.at(select_pattern);
|
||||||
|
auto select = std::dynamic_pointer_cast<Select>(select_value.get_node_shared_ptr());
|
||||||
|
if (!select) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
auto condition_value = pattern_map.at(condition);
|
||||||
|
auto condition_const = std::dynamic_pointer_cast<Constant>(condition_value.get_node_shared_ptr());
|
||||||
|
if (!condition_const) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
if (condition_value.get_element_type() != element::boolean) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
// check if all elements in the condition to be true or false
|
||||||
|
// only in this case, one certain branch can be selected
|
||||||
|
auto cond_value = condition_const->get_vector<bool>();
|
||||||
|
if (cond_value.size() == 0) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
auto cond_elem = cond_value[0];
|
||||||
|
auto all_equal = all_of(cond_value.begin(), cond_value.end(), [=](bool v) {
|
||||||
|
return v == cond_elem;
|
||||||
|
});
|
||||||
|
if (!all_equal) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
// based on the condition value, mark the selected branch and skipped branch index
|
||||||
|
auto branch_index = cond_elem ? 1 : 2;
|
||||||
|
|
||||||
|
// based on the resulted shape and the shape of the skipped branch, perform further steps
|
||||||
|
auto select_shape = select->get_output_partial_shape(0);
|
||||||
|
auto branch_output = select->input_value(branch_index);
|
||||||
|
auto branch_output_shape = branch_output.get_partial_shape();
|
||||||
|
|
||||||
|
if (select_shape.is_static() && branch_output_shape.same_scheme(select_shape)) {
|
||||||
|
// Broadcast is not needed if the select shape is exactly the same as the selected branch
|
||||||
|
select->output(0).replace(branch_output);
|
||||||
|
replace_output_update_name(select->output(0), branch_output);
|
||||||
|
branch_output.get_node_shared_ptr()->set_friendly_name(select->get_friendly_name());
|
||||||
|
} else if (select_shape.is_static()) {
|
||||||
|
// if the shape of the selected branch is not the same, it needs the broadcasting
|
||||||
|
NodeRegistry copy_to;
|
||||||
|
auto select_rank = select_shape.size();
|
||||||
|
vector<int32_t> select_shape_values(select_rank);
|
||||||
|
for (size_t i = 0; i < select_rank; ++i) {
|
||||||
|
select_shape_values[i] = static_cast<int32_t>(select_shape[i].get_length());
|
||||||
|
}
|
||||||
|
|
||||||
|
auto target_shape = copy_to.make<Constant>(element::i32, Shape{select_rank}, select_shape_values);
|
||||||
|
auto broadcast = copy_to.make<Broadcast>(branch_output, target_shape);
|
||||||
|
select->output(0).replace(broadcast->output(0));
|
||||||
|
broadcast->set_friendly_name(select->get_friendly_name());
|
||||||
|
copy_runtime_info(select, copy_to.get());
|
||||||
|
} else {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
return true;
|
||||||
|
};
|
||||||
|
|
||||||
|
auto m = make_shared<pattern::Matcher>(select_pattern, matcher_name);
|
||||||
|
this->register_matcher(m, callback);
|
||||||
|
}
|
@ -0,0 +1,130 @@
|
|||||||
|
// Copyright (C) 2018-2022 Intel Corporation
|
||||||
|
// SPDX-License-Identifier: Apache-2.0
|
||||||
|
//
|
||||||
|
|
||||||
|
#include "transformations/common_optimizations/select_with_one_value_condition.hpp"
|
||||||
|
|
||||||
|
#include <gtest/gtest.h>
|
||||||
|
|
||||||
|
#include "common_test_utils/ngraph_test_utils.hpp"
|
||||||
|
#include "openvino/opsets/opset10.hpp"
|
||||||
|
|
||||||
|
using namespace ov;
|
||||||
|
using namespace std;
|
||||||
|
using namespace testing;
|
||||||
|
using namespace ov::opset10;
|
||||||
|
using namespace ov::element;
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
enum class SELECT_BRANCH { THEN_BRANCH, ELSE_BRANCH, NONE };
|
||||||
|
|
||||||
|
struct SelectWithOneValueConditionParams {
|
||||||
|
PartialShape x_shape;
|
||||||
|
PartialShape y_shape;
|
||||||
|
Shape cond_shape;
|
||||||
|
vector<bool> cond_values;
|
||||||
|
PartialShape select_shape;
|
||||||
|
SELECT_BRANCH which_branch;
|
||||||
|
};
|
||||||
|
|
||||||
|
shared_ptr<Model> gen_model(const SelectWithOneValueConditionParams& params) {
|
||||||
|
auto x = make_shared<Parameter>(f32, params.x_shape);
|
||||||
|
auto y = make_shared<Parameter>(f32, params.y_shape);
|
||||||
|
auto add = make_shared<Add>(x, y);
|
||||||
|
auto subtract = make_shared<Relu>(y);
|
||||||
|
auto condition = make_shared<Constant>(boolean, params.cond_shape, params.cond_values);
|
||||||
|
auto select = make_shared<Select>(condition, add, subtract);
|
||||||
|
auto res_model = make_shared<Model>(OutputVector{select}, ParameterVector{x, y});
|
||||||
|
return res_model;
|
||||||
|
}
|
||||||
|
|
||||||
|
shared_ptr<Model> gen_reference(const SelectWithOneValueConditionParams& params) {
|
||||||
|
auto cond_shape = params.cond_shape;
|
||||||
|
auto x = make_shared<Parameter>(f32, params.x_shape);
|
||||||
|
auto y = make_shared<Parameter>(f32, params.y_shape);
|
||||||
|
Output<Node> output;
|
||||||
|
if (params.which_branch == SELECT_BRANCH::NONE) {
|
||||||
|
return gen_model(params);
|
||||||
|
} else if (params.which_branch == SELECT_BRANCH::THEN_BRANCH) {
|
||||||
|
output = make_shared<Add>(x, y)->output(0);
|
||||||
|
} else {
|
||||||
|
output = make_shared<Relu>(y)->output(0);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!output.get_partial_shape().same_scheme(params.select_shape)) {
|
||||||
|
vector<int32_t> select_shape_values(params.select_shape.size());
|
||||||
|
for (size_t i = 0; i < params.select_shape.size(); ++i) {
|
||||||
|
select_shape_values[i] = static_cast<int32_t>(params.select_shape[i].get_length());
|
||||||
|
}
|
||||||
|
|
||||||
|
if (select_shape_values.size() > 0) {
|
||||||
|
auto target_shape = make_shared<Constant>(i32, Shape{select_shape_values.size()}, select_shape_values);
|
||||||
|
output = make_shared<Broadcast>(output, target_shape)->output(0);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return make_shared<Model>(OutputVector{output}, ParameterVector{x, y});
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
class SelectWithOneValueConditionTest : public WithParamInterface<SelectWithOneValueConditionParams>,
|
||||||
|
public TransformationTestsF {};
|
||||||
|
|
||||||
|
TEST_P(SelectWithOneValueConditionTest, SelectWithOneValueConditionTestPattern) {
|
||||||
|
const auto& p = GetParam();
|
||||||
|
{
|
||||||
|
function = gen_model(p);
|
||||||
|
manager.register_pass<pass::SelectWithOneValueCondition>();
|
||||||
|
}
|
||||||
|
|
||||||
|
function_ref = gen_reference(p);
|
||||||
|
comparator.enable(FunctionsComparator::CmpValues::ACCURACY);
|
||||||
|
comparator.enable(FunctionsComparator::CmpValues::CONST_VALUES);
|
||||||
|
comparator.enable(FunctionsComparator::CmpValues::ATTRIBUTES);
|
||||||
|
}
|
||||||
|
|
||||||
|
static const std::vector<SelectWithOneValueConditionParams> params = {
|
||||||
|
SelectWithOneValueConditionParams{PartialShape{2},
|
||||||
|
PartialShape{2},
|
||||||
|
Shape{1},
|
||||||
|
vector<bool>{false},
|
||||||
|
PartialShape{2},
|
||||||
|
SELECT_BRANCH::ELSE_BRANCH},
|
||||||
|
SelectWithOneValueConditionParams{PartialShape{2},
|
||||||
|
PartialShape{2},
|
||||||
|
Shape{1},
|
||||||
|
vector<bool>{true},
|
||||||
|
PartialShape{2},
|
||||||
|
SELECT_BRANCH::THEN_BRANCH},
|
||||||
|
// with Broadcast case - Subtract result needs to be broadcasted
|
||||||
|
SelectWithOneValueConditionParams{PartialShape{2, 2},
|
||||||
|
PartialShape{2},
|
||||||
|
Shape{1},
|
||||||
|
vector<bool>{false},
|
||||||
|
PartialShape{2, 2},
|
||||||
|
SELECT_BRANCH::ELSE_BRANCH},
|
||||||
|
// Select is not eliminated due to condition constant value
|
||||||
|
SelectWithOneValueConditionParams{PartialShape{2, 2},
|
||||||
|
PartialShape{2},
|
||||||
|
Shape{2, 2},
|
||||||
|
vector<bool>{false, true, false, false},
|
||||||
|
PartialShape{2, 2},
|
||||||
|
SELECT_BRANCH::NONE},
|
||||||
|
// The branch is not possible to select due to dynamic output shape for Select
|
||||||
|
SelectWithOneValueConditionParams{PartialShape{Dimension::dynamic(), 2},
|
||||||
|
PartialShape{2},
|
||||||
|
Shape{2},
|
||||||
|
vector<bool>{true, true},
|
||||||
|
PartialShape{Dimension::dynamic(), 2},
|
||||||
|
SELECT_BRANCH::NONE},
|
||||||
|
// The branch is not possible to select due to dynamic output shape for Select
|
||||||
|
SelectWithOneValueConditionParams{PartialShape{2},
|
||||||
|
PartialShape{Dimension::dynamic(), 2},
|
||||||
|
Shape{2},
|
||||||
|
vector<bool>{true, true},
|
||||||
|
PartialShape{Dimension::dynamic(), 2},
|
||||||
|
SELECT_BRANCH::NONE},
|
||||||
|
};
|
||||||
|
|
||||||
|
INSTANTIATE_TEST_SUITE_P(SelectWithOneValueConditionTest, SelectWithOneValueConditionTest, ValuesIn(params));
|
Loading…
Reference in New Issue
Block a user