Eliminate no-op elementwise operations (#8840)
* Eliminate no-op elementwise operations This change adds EliminateEltwise pass to NopElimination. EliminateEltwise removes: - Subtract with zero - Multiply with one - Divide by one * add Add support to EliminateEltwise * fix unit test * use get_all_data_elements_bitwise_identical instead of get_single_value * fix are_all_data_elements_bitwise_identical for constant created from HostTensor * fix lpt tests * check for mul in is_dequantization_subgraph function * optimize fetching constant value * apply review comments
This commit is contained in:
parent
cde61f2411
commit
8ba339d16f
@ -21,6 +21,7 @@ class TRANSFORMATIONS_API EliminateConvertNonZero;
|
||||
class TRANSFORMATIONS_API EliminateConcat;
|
||||
class TRANSFORMATIONS_API EliminateSplit;
|
||||
class TRANSFORMATIONS_API EliminateTranspose;
|
||||
class TRANSFORMATIONS_API EliminateEltwise;
|
||||
class TRANSFORMATIONS_API NopElimination;
|
||||
|
||||
} // namespace pass
|
||||
@ -86,6 +87,15 @@ public:
|
||||
EliminateTranspose();
|
||||
};
|
||||
|
||||
/**
|
||||
* @ingroup ie_transformation_common_api
|
||||
* @brief EliminateEltwise eliminates eltwise ops that do nothing
|
||||
*/
|
||||
class ngraph::pass::EliminateEltwise: public ngraph::pass::MatcherPass {
|
||||
public:
|
||||
NGRAPH_RTTI_DECLARATION;
|
||||
EliminateEltwise();
|
||||
};
|
||||
|
||||
class ngraph::pass::NopElimination: public GraphRewrite {
|
||||
public:
|
||||
|
@ -210,6 +210,10 @@ TRANSFORMATIONS_API std::shared_ptr<ngraph::Node> node_to_get_shape_value_of_ind
|
||||
|
||||
TRANSFORMATIONS_API std::shared_ptr<ngraph::Node> node_to_get_shape_value_of_indices_from_shape_source(
|
||||
const ngraph::Output<ngraph::Node>& shape_source, const std::vector<size_t>& indices);
|
||||
|
||||
TRANSFORMATIONS_API bool is_dequantization_subgraph(const Output<Node>& node);
|
||||
|
||||
TRANSFORMATIONS_API bool can_eliminate_eltwise_node(const std::shared_ptr<Node>& eltwise, const Output<Node>& constant, const Output<Node>& non_constant_input);
|
||||
} // namespace util
|
||||
} // namespace op
|
||||
} // namespace ngraph
|
||||
|
@ -18,32 +18,6 @@
|
||||
|
||||
NGRAPH_RTTI_DEFINITION(ngraph::pass::MultiplyConvolutionFusion, "MultiplyConvolutionFusion", 0);
|
||||
|
||||
static bool is_dequantization_subgraph(const ngraph::Output<ngraph::Node>& multiply) {
|
||||
auto inputs = multiply.get_node()->input_values();
|
||||
const auto subtract = std::find_if(inputs.begin(), inputs.end(),
|
||||
[] (const ngraph::Output<ngraph::Node>& n) -> bool {
|
||||
return ov::is_type<ngraph::opset8::Subtract>(n.get_node());
|
||||
});
|
||||
if (subtract != inputs.end())
|
||||
inputs = subtract->get_node()->input_values();
|
||||
const auto first_convert = std::find_if(inputs.begin(), inputs.end(),
|
||||
[] (const ngraph::Output<ngraph::Node>& n) -> bool {
|
||||
if (ov::is_type<ngraph::opset8::Convert>(n.get_node())) {
|
||||
const auto input = n.get_node()->input_value(0);
|
||||
return ov::is_type<ngraph::opset8::Convert>(input.get_node());
|
||||
}
|
||||
return false;
|
||||
});
|
||||
if (first_convert == inputs.end())
|
||||
return false;
|
||||
const auto second_convert = first_convert->get_node()->input_value(0);
|
||||
const auto& first_convert_src_type = second_convert.get_element_type();
|
||||
const auto& first_convert_dest_type = first_convert->get_element_type();
|
||||
const auto second_convert_src_type = second_convert.get_node()->input_value(0).get_element_type();
|
||||
return (first_convert_src_type == ngraph::element::i8 || first_convert_src_type == ngraph::element::u8) &&
|
||||
first_convert_dest_type == second_convert_src_type;
|
||||
}
|
||||
|
||||
ngraph::pass::MultiplyConvolutionFusion::MultiplyConvolutionFusion() {
|
||||
MATCHER_SCOPE(MultiplyConvolutionFusion);
|
||||
auto input_pattern = pattern::any_input();
|
||||
@ -57,7 +31,7 @@ ngraph::pass::MultiplyConvolutionFusion::MultiplyConvolutionFusion() {
|
||||
|
||||
// Can't fuse Multiply to Convolution if that Multiply is part of dequantization subgraph
|
||||
// since that breaks low precision transformations
|
||||
if (is_dequantization_subgraph(pattern_to_output.at(mul_pattern)))
|
||||
if (op::util::is_dequantization_subgraph(pattern_to_output.at(mul_pattern)))
|
||||
return false;
|
||||
|
||||
const auto& weights = pattern_to_output.at(weights_pattern);
|
||||
@ -110,7 +84,7 @@ ngraph::pass::MultiplyGroupConvolutionFusion::MultiplyGroupConvolutionFusion() {
|
||||
|
||||
// Can't fuse Multiply to Convolution if that Multiply is part of dequantization subgraph
|
||||
// since that breaks low precision transformations
|
||||
if (is_dequantization_subgraph(pattern_to_output.at(mul_pattern)))
|
||||
if (op::util::is_dequantization_subgraph(pattern_to_output.at(mul_pattern)))
|
||||
return false;
|
||||
|
||||
const auto& weights = pattern_to_output.at(weights_pattern);
|
||||
@ -174,7 +148,7 @@ ngraph::pass::MultiplyConvolutionBackpropDataFusion::MultiplyConvolutionBackprop
|
||||
|
||||
// Can't fuse Multiply to Convolution if that Multiply is part of dequantization subgraph
|
||||
// since that breaks low precision transformations
|
||||
if (is_dequantization_subgraph(pattern_to_output.at(mul_pattern)))
|
||||
if (op::util::is_dequantization_subgraph(pattern_to_output.at(mul_pattern)))
|
||||
return false;
|
||||
|
||||
const auto& weights = pattern_to_output.at(weights_pattern);
|
||||
@ -240,7 +214,7 @@ ngraph::pass::MultiplyGroupConvolutionBackpropDataFusion::MultiplyGroupConvoluti
|
||||
|
||||
// Can't fuse Multiply to Convolution if that Multiply is part of dequantization subgraph
|
||||
// since that breaks low precision transformations
|
||||
if (is_dequantization_subgraph(pattern_to_output.at(mul_pattern)))
|
||||
if (op::util::is_dequantization_subgraph(pattern_to_output.at(mul_pattern)))
|
||||
return false;
|
||||
|
||||
const auto& weights = pattern_to_output.at(weights_pattern);
|
||||
|
@ -12,6 +12,7 @@
|
||||
#include <ngraph/util.hpp>
|
||||
#include <ngraph/log.hpp>
|
||||
#include <transformations/common_optimizations/nop_elimination.hpp>
|
||||
#include <transformations/utils/utils.hpp>
|
||||
#include <ngraph/pattern/op/wrap_type.hpp>
|
||||
|
||||
using namespace std;
|
||||
@ -503,6 +504,34 @@ pass::EliminateTranspose::EliminateTranspose() {
|
||||
this->register_matcher(m, callback);
|
||||
}
|
||||
|
||||
NGRAPH_RTTI_DEFINITION(pass::EliminateEltwise, "EliminateEltwise", 0);
|
||||
|
||||
pass::EliminateEltwise::EliminateEltwise() {
|
||||
MATCHER_SCOPE(EliminateEltwise);
|
||||
auto input = pattern::any_input();
|
||||
auto constant_pattern = pattern::wrap_type<opset8::Constant>();
|
||||
auto eltwise_pattern = pattern::wrap_type<opset8::Add,
|
||||
opset8::Subtract,
|
||||
opset8::Multiply,
|
||||
opset8::Divide>({input, constant_pattern});
|
||||
|
||||
matcher_pass_callback callback = [=](pattern::Matcher& m) {
|
||||
const auto& pattern_map = m.get_pattern_value_map();
|
||||
auto eltwise = pattern_map.at(eltwise_pattern).get_node_shared_ptr();
|
||||
auto non_const_input = pattern_map.at(input);
|
||||
auto constant = pattern_map.at(constant_pattern);
|
||||
|
||||
if (!op::util::can_eliminate_eltwise_node(eltwise, constant, non_const_input)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
return replace_output_update_name(eltwise->output(0), non_const_input);
|
||||
};
|
||||
|
||||
auto m = std::make_shared<pattern::Matcher>(eltwise_pattern, matcher_name);
|
||||
this->register_matcher(m, callback);
|
||||
}
|
||||
|
||||
NGRAPH_RTTI_DEFINITION(ngraph::pass::NopElimination, "NopElimination", 0);
|
||||
|
||||
ngraph::pass::NopElimination::NopElimination(bool use_shape_for_elimination) {
|
||||
@ -513,6 +542,7 @@ ngraph::pass::NopElimination::NopElimination(bool use_shape_for_elimination) {
|
||||
add_matcher<EliminateConcat>();
|
||||
add_matcher<EliminateSplit>();
|
||||
add_matcher<EliminateTranspose>();
|
||||
add_matcher<EliminateEltwise>();
|
||||
|
||||
// shape-dependent transformations
|
||||
if (use_shape_for_elimination) {
|
||||
|
@ -4,6 +4,8 @@
|
||||
|
||||
#include "itt.hpp"
|
||||
#include "transformations/op_conversions/convert_divide.hpp"
|
||||
#include "transformations/utils/utils.hpp"
|
||||
|
||||
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
@ -28,25 +30,30 @@ bool convert_divide(std::shared_ptr<ngraph::Node> node) {
|
||||
return false;
|
||||
}
|
||||
|
||||
ngraph::Output<ngraph::Node> pow = std::make_shared<ngraph::opset1::Power>(div->input_value(1),
|
||||
std::shared_ptr<ngraph::Node> pow = std::make_shared<ngraph::opset1::Power>(div->input_value(1),
|
||||
ngraph::op::Constant::create(div->get_input_element_type(1), ngraph::Shape{}, {-1}));
|
||||
|
||||
if (std::dynamic_pointer_cast<ngraph::op::Constant>(div->get_input_node_shared_ptr(1))) {
|
||||
if (auto const_pow = ngraph::get_constant_from_source(pow)) {
|
||||
pow = const_pow;
|
||||
} else {
|
||||
NGRAPH_DEBUG << "ConvertDivide has failed due to unsupported evaluate type in " << pow.get_node();
|
||||
NGRAPH_DEBUG << "ConvertDivide has failed due to unsupported evaluate type in " << pow.get();
|
||||
return false;
|
||||
}
|
||||
} else {
|
||||
ngraph::copy_runtime_info(div, pow.get_node_shared_ptr());
|
||||
ngraph::copy_runtime_info(div, pow);
|
||||
}
|
||||
|
||||
auto mul = std::make_shared<ngraph::opset1::Multiply>(div->input(0).get_source_output(), pow);
|
||||
|
||||
mul->set_friendly_name(div->get_friendly_name());
|
||||
ngraph::copy_runtime_info(div, mul);
|
||||
ngraph::replace_node(div, mul);
|
||||
// if Divide is an inverse, then we don't need the Multiply
|
||||
if (ngraph::op::util::can_eliminate_eltwise_node(mul, mul->input_value(0), mul->input_value(1))) {
|
||||
pow->set_friendly_name(div->get_friendly_name());
|
||||
ngraph::replace_node(div, pow);
|
||||
} else {
|
||||
mul->set_friendly_name(div->get_friendly_name());
|
||||
ngraph::copy_runtime_info(div, mul);
|
||||
ngraph::replace_node(div, mul);
|
||||
}
|
||||
return true;
|
||||
}
|
||||
} // namespace
|
||||
@ -74,4 +81,4 @@ ngraph::pass::ConvertDivideWithConstant::ConvertDivideWithConstant() {
|
||||
|
||||
auto m = std::make_shared<ngraph::pattern::Matcher>(div, matcher_name);
|
||||
this->register_matcher(m, callback);
|
||||
}
|
||||
}
|
||||
|
@ -203,6 +203,132 @@ void visit_shape_path(const std::shared_ptr<ov::Node>& node,
|
||||
}
|
||||
}
|
||||
|
||||
bool is_dequantization_subgraph(const Output<Node>& node) {
|
||||
if (!is_type<opset8::Multiply>(node.get_node())) {
|
||||
return false;
|
||||
}
|
||||
|
||||
auto mul_inputs = node.get_node()->input_values();
|
||||
Node* sub = nullptr;
|
||||
Node* convert = nullptr;
|
||||
|
||||
if (is_type<opset8::Subtract>(mul_inputs[0].get_node())) {
|
||||
sub = mul_inputs[0].get_node();
|
||||
} else if (is_type<opset8::Convert>(mul_inputs[0].get_node())) {
|
||||
convert = mul_inputs[0].get_node();
|
||||
} else {
|
||||
return false;
|
||||
}
|
||||
|
||||
if (sub) {
|
||||
auto sub_inputs = sub->input_values();
|
||||
if (is_type<opset8::Convert>(sub_inputs[0].get_node())) {
|
||||
convert = sub_inputs[0].get_node();
|
||||
}
|
||||
}
|
||||
|
||||
if (!convert) {
|
||||
return false;
|
||||
}
|
||||
|
||||
auto input_type = convert->get_input_element_type(0);
|
||||
auto output_type = convert->get_output_element_type(0);
|
||||
return input_type.is_integral() && output_type.is_real();
|
||||
}
|
||||
|
||||
bool can_eliminate_eltwise_node(const std::shared_ptr<Node>& eltwise, const Output<Node>& constant, const Output<Node>& non_constant_input) {
|
||||
if (!is_type<opset8::Add>(eltwise) &&
|
||||
!is_type<opset8::Subtract>(eltwise) &&
|
||||
!is_type<opset8::Multiply>(eltwise) &&
|
||||
!is_type<opset8::Divide>(eltwise)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
if (is_dequantization_subgraph(eltwise)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// check if constant has a single value with either 0 (for Add, Subtract) or 1 (for Multiply, Divide)
|
||||
auto constant_ptr = std::dynamic_pointer_cast<opset8::Constant>(constant.get_node_shared_ptr());
|
||||
if (!constant_ptr) {
|
||||
return false;
|
||||
}
|
||||
if (!constant_ptr->get_all_data_elements_bitwise_identical()) {
|
||||
return false;
|
||||
}
|
||||
float actual_const = 0;
|
||||
const void* data_ptr = constant_ptr->get_data_ptr();
|
||||
switch (constant_ptr->get_element_type()) {
|
||||
case element::f32:
|
||||
actual_const = reinterpret_cast<const float*>(data_ptr)[0];
|
||||
break;
|
||||
case element::i32:
|
||||
actual_const = static_cast<float>(reinterpret_cast<const int32_t*>(data_ptr)[0]);
|
||||
break;
|
||||
case element::u32:
|
||||
actual_const = static_cast<float>(reinterpret_cast<const uint32_t*>(data_ptr)[0]);
|
||||
break;
|
||||
case element::i64:
|
||||
actual_const = static_cast<float>(reinterpret_cast<const int64_t*>(data_ptr)[0]);
|
||||
break;
|
||||
case element::u64:
|
||||
actual_const = static_cast<float>(reinterpret_cast<const uint64_t*>(data_ptr)[0]);
|
||||
break;
|
||||
case element::i8:
|
||||
actual_const = static_cast<float>(reinterpret_cast<const int8_t*>(data_ptr)[0]);
|
||||
break;
|
||||
case element::u8:
|
||||
actual_const = static_cast<float>(reinterpret_cast<const uint8_t*>(data_ptr)[0]);
|
||||
break;
|
||||
case element::i16:
|
||||
actual_const = static_cast<float>(reinterpret_cast<const int16_t*>(data_ptr)[0]);
|
||||
break;
|
||||
case element::u16:
|
||||
actual_const = static_cast<float>(reinterpret_cast<const uint16_t*>(data_ptr)[0]);
|
||||
break;
|
||||
case element::f64:
|
||||
actual_const = static_cast<float>(reinterpret_cast<const double*>(data_ptr)[0]);
|
||||
break;
|
||||
default:
|
||||
return false;
|
||||
}
|
||||
float expected_const = 0;
|
||||
if (is_type<opset8::Multiply>(eltwise) ||
|
||||
is_type<opset8::Divide>(eltwise)) {
|
||||
expected_const = 1;
|
||||
}
|
||||
if (actual_const != expected_const) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// fuse uncoditionally if constant is a scalar
|
||||
const auto& constant_shape = constant.get_shape();
|
||||
if (ov::is_scalar(constant_shape)) {
|
||||
return true;
|
||||
}
|
||||
|
||||
const auto& input_shape = non_constant_input.get_partial_shape();
|
||||
if (input_shape.rank().is_dynamic()) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// cannot fuse if constant extends input's rank
|
||||
auto input_rank = static_cast<size_t>(input_shape.rank().get_length());
|
||||
auto constant_rank = constant_shape.size();
|
||||
if (input_rank < constant_rank) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// cannot fuse if constant makes input to be broadcasted, e.g.
|
||||
// Multiply(input{2, 1, 5}, constant{1, 5, 1}) -> {2, 5, 5}
|
||||
for (size_t i = 0; i < constant_rank; i++) {
|
||||
auto constant_dim = constant_shape[constant_rank - i - 1];
|
||||
if (constant_dim != 1 && input_shape[input_rank - i - 1] != constant_dim) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
} // namespace util
|
||||
} // namespace op
|
||||
} // namespace ngraph
|
||||
|
@ -48,8 +48,8 @@ ov::op::v0::Constant::Constant(const shared_ptr<ngraph::runtime::Tensor>& tensor
|
||||
constructor_validate_and_infer_types();
|
||||
allocate_buffer();
|
||||
tensor->read(get_data_ptr_nc(), tensor->get_size_in_bytes());
|
||||
m_all_elements_bitwise_identical = are_all_data_elements_bitwise_identical();
|
||||
}
|
||||
m_all_elements_bitwise_identical = are_all_data_elements_bitwise_identical();
|
||||
constructor_validate_and_infer_types();
|
||||
}
|
||||
|
||||
|
@ -9,6 +9,7 @@
|
||||
#include "ngraph/opsets/opset3.hpp"
|
||||
#include "ngraph/opsets/opset4.hpp"
|
||||
#include "ngraph/opsets/opset5.hpp"
|
||||
#include "ngraph/runtime/host_tensor.hpp"
|
||||
#include "util/visitor.hpp"
|
||||
|
||||
using namespace std;
|
||||
@ -56,3 +57,35 @@ TEST(attributes, constant_op_identical_elements) {
|
||||
EXPECT_EQ(data, g_data);
|
||||
ASSERT_TRUE(g_k->get_all_data_elements_bitwise_identical());
|
||||
}
|
||||
|
||||
TEST(attributes, constant_op_from_host_tensor_different_elements) {
|
||||
vector<int64_t> data{5, 4, 3, 2, 1, 0};
|
||||
auto tensor = std::make_shared<runtime::HostTensor>(element::i64, Shape{2, 3}, &data[0]);
|
||||
auto k = make_shared<op::v0::Constant>(tensor);
|
||||
ASSERT_FALSE(k->get_all_data_elements_bitwise_identical());
|
||||
NodeBuilder builder(k);
|
||||
auto g_k = ov::as_type_ptr<op::v0::Constant>(builder.create());
|
||||
g_k->validate_and_infer_types();
|
||||
ASSERT_TRUE(g_k);
|
||||
EXPECT_EQ(k->get_element_type(), g_k->get_element_type());
|
||||
EXPECT_EQ(k->get_shape(), g_k->get_shape());
|
||||
vector<int64_t> g_data = g_k->get_vector<int64_t>();
|
||||
EXPECT_EQ(data, g_data);
|
||||
ASSERT_FALSE(g_k->get_all_data_elements_bitwise_identical());
|
||||
}
|
||||
|
||||
TEST(attributes, constant_op_from_host_tensor_identical_elements) {
|
||||
vector<int64_t> data{5, 5, 5, 5, 5, 5};
|
||||
auto tensor = std::make_shared<runtime::HostTensor>(element::i64, Shape{2, 3}, &data[0]);
|
||||
auto k = make_shared<op::v0::Constant>(tensor);
|
||||
ASSERT_TRUE(k->get_all_data_elements_bitwise_identical());
|
||||
NodeBuilder builder(k);
|
||||
auto g_k = ov::as_type_ptr<op::v0::Constant>(builder.create());
|
||||
g_k->validate_and_infer_types();
|
||||
ASSERT_TRUE(g_k);
|
||||
EXPECT_EQ(k->get_element_type(), g_k->get_element_type());
|
||||
EXPECT_EQ(k->get_shape(), g_k->get_shape());
|
||||
vector<int64_t> g_data = g_k->get_vector<int64_t>();
|
||||
EXPECT_EQ(data, g_data);
|
||||
ASSERT_TRUE(g_k->get_all_data_elements_bitwise_identical());
|
||||
}
|
||||
|
@ -38,8 +38,31 @@ TEST_F(TransformationTestsF, ConvertDivide) {
|
||||
|
||||
function_ref = std::make_shared<ngraph::Function>(ngraph::NodeVector{mul}, ngraph::ParameterVector{data});
|
||||
}
|
||||
comparator.enable(FunctionsComparator::CmpValues::CONST_VALUES);
|
||||
}
|
||||
|
||||
TEST_F(TransformationTestsF, ConvertDivideInverse) {
|
||||
{
|
||||
auto data = std::make_shared<ngraph::opset1::Parameter>(ngraph::element::f32, ngraph::Shape{3, 1, 2});
|
||||
auto divide_constant = ngraph::opset1::Constant::create(ngraph::element::f32, ngraph::Shape{1}, {1});
|
||||
auto divide = std::make_shared<ngraph::opset1::Divide>(divide_constant, data);
|
||||
|
||||
function = std::make_shared<ngraph::Function>(ngraph::NodeVector{divide}, ngraph::ParameterVector{data});
|
||||
|
||||
manager.register_pass<ngraph::pass::ConvertDivide>();
|
||||
}
|
||||
|
||||
{
|
||||
auto data = std::make_shared<ngraph::opset1::Parameter>(ngraph::element::f32, ngraph::Shape{3, 1, 2});
|
||||
auto constant = ngraph::opset1::Constant::create(ngraph::element::f32, ngraph::Shape{}, {-1.0});
|
||||
auto pow = std::make_shared<ngraph::opset1::Power>(data, constant);
|
||||
|
||||
function_ref = std::make_shared<ngraph::Function>(ngraph::NodeVector{pow}, ngraph::ParameterVector{data});
|
||||
}
|
||||
comparator.enable(FunctionsComparator::CmpValues::CONST_VALUES);
|
||||
}
|
||||
|
||||
|
||||
TEST_F(TransformationTestsF, ConvertDivideNegative) {
|
||||
{
|
||||
auto data = std::make_shared<ngraph::opset1::Parameter>(ngraph::element::i32, ngraph::Shape{3, 1, 2});
|
||||
@ -58,6 +81,7 @@ TEST_F(TransformationTestsF, ConvertDivideNegative) {
|
||||
|
||||
function_ref = std::make_shared<ngraph::Function>(ngraph::NodeVector{divide}, ngraph::ParameterVector{data});
|
||||
}
|
||||
comparator.enable(FunctionsComparator::CmpValues::CONST_VALUES);
|
||||
}
|
||||
|
||||
TEST_F(TransformationTestsF, ConvertDivideScalar) {
|
||||
@ -84,6 +108,7 @@ TEST_F(TransformationTestsF, ConvertDivideScalar) {
|
||||
|
||||
NGRAPH_CHECK(mul->get_output_partial_shape(0).rank().get_length() == 0);
|
||||
}
|
||||
comparator.enable(FunctionsComparator::CmpValues::CONST_VALUES);
|
||||
}
|
||||
|
||||
TEST_F(TransformationTestsF, ConvertDivideWithConstantPositive) {
|
||||
@ -103,6 +128,7 @@ TEST_F(TransformationTestsF, ConvertDivideWithConstantPositive) {
|
||||
|
||||
function_ref = std::make_shared<ngraph::Function>(ngraph::NodeVector{mul}, ngraph::ParameterVector{data});
|
||||
}
|
||||
comparator.enable(FunctionsComparator::CmpValues::CONST_VALUES);
|
||||
}
|
||||
|
||||
TEST_F(TransformationTestsF, ConvertDivideWithConstantNegative) {
|
||||
@ -122,6 +148,7 @@ TEST_F(TransformationTestsF, ConvertDivideWithConstantNegative) {
|
||||
|
||||
function_ref = std::make_shared<ngraph::Function>(ngraph::NodeVector{divide}, ngraph::ParameterVector{data1, data2});
|
||||
}
|
||||
comparator.enable(FunctionsComparator::CmpValues::CONST_VALUES);
|
||||
}
|
||||
|
||||
TEST_F(TransformationTestsF, ConvertDivideFP16ShapeOfSubgraphNegative) {
|
||||
|
@ -18,8 +18,11 @@
|
||||
#include <transformations/utils/utils.hpp>
|
||||
#include <transformations/init_node_info.hpp>
|
||||
#include <transformations/rt_info/fused_names_attribute.hpp>
|
||||
#include <ngraph_functions/utils/ngraph_helpers.hpp>
|
||||
#include <ngraph_functions/builders.hpp>
|
||||
|
||||
#include "common_test_utils/ngraph_test_utils.hpp"
|
||||
#include "common_test_utils/common_utils.hpp"
|
||||
|
||||
#include <ngraph/pass/manager.hpp>
|
||||
|
||||
@ -817,4 +820,215 @@ TEST(nop_elimination, gather_3d_indices_constant_axis_1) {
|
||||
check_usecase(PartialShape{3, 2, 1}, i32, multiout, std::vector<int64_t>{1}, 2, 0);
|
||||
check_usecase(PartialShape{1, 16}, i32, multiout, std::vector<int64_t>{0, 0}, 0, 1);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
using namespace helpers;
|
||||
|
||||
struct ShapeParams {
|
||||
PartialShape shape1;
|
||||
Shape shape2;
|
||||
bool swap_inputs;
|
||||
bool can_fuse;
|
||||
};
|
||||
|
||||
enum class ConstantKind {
|
||||
ZERO,
|
||||
ONE,
|
||||
RANDOM,
|
||||
};
|
||||
|
||||
static std::ostream& operator<<(std::ostream& os, ConstantKind kind) {
|
||||
switch (kind) {
|
||||
case ConstantKind::ZERO:
|
||||
os << "zero";
|
||||
break;
|
||||
case ConstantKind::ONE:
|
||||
os << "one";
|
||||
break;
|
||||
case ConstantKind::RANDOM:
|
||||
os << "random";
|
||||
break;
|
||||
}
|
||||
return os;
|
||||
}
|
||||
|
||||
struct TypeParams {
|
||||
EltwiseTypes op_type;
|
||||
ConstantKind constant_kind;
|
||||
bool can_fuse;
|
||||
};
|
||||
|
||||
using EliminateEltwiseParams = std::tuple<ShapeParams, TypeParams, element::Type>;
|
||||
|
||||
class EliminateEltwiseTests: public testing::WithParamInterface<EliminateEltwiseParams>, virtual public TransformationTestsF {
|
||||
public:
|
||||
static std::string get_test_case_name(testing::TestParamInfo<EliminateEltwiseParams> info) {
|
||||
const auto& shape_params = std::get<0>(info.param);
|
||||
const auto& type_params = std::get<1>(info.param);
|
||||
const auto& element_type = std::get<2>(info.param);
|
||||
std::ostringstream result;
|
||||
result << type_params.op_type
|
||||
<< "_input1=" << shape_params.shape1
|
||||
<< "_input2=" << shape_params.shape2
|
||||
<< "_swap_inputs=" << std::boolalpha << shape_params.swap_inputs
|
||||
<< "_constant=" << type_params.constant_kind
|
||||
<< "_type=" << element_type;
|
||||
return result.str();
|
||||
}
|
||||
};
|
||||
|
||||
TEST_P(EliminateEltwiseTests, eliminate_eltwise) {
|
||||
auto params = GetParam();
|
||||
const auto& shape_params = std::get<0>(params);
|
||||
const auto& type_params = std::get<1>(params);
|
||||
const auto& type = std::get<2>(params);
|
||||
const auto& shape1 = shape_params.shape1;
|
||||
const auto& shape2 = shape_params.shape2;
|
||||
bool swap_inputs = shape_params.swap_inputs;
|
||||
bool can_fuse = shape_params.can_fuse && type_params.can_fuse;
|
||||
|
||||
auto parameter = make_shared<op::Parameter>(type, shape1);
|
||||
std::shared_ptr<Node> constant;
|
||||
switch (type_params.constant_kind) {
|
||||
case ConstantKind::ZERO:
|
||||
constant = op::Constant::create(type, shape2, {0});
|
||||
break;
|
||||
case ConstantKind::ONE:
|
||||
constant = op::Constant::create(type, shape2, {1});
|
||||
break;
|
||||
case ConstantKind::RANDOM:
|
||||
constant = builder::makeConstant(type, shape2, {}, true, 20 /* upTo */, 2 /* startFrom */);
|
||||
break;
|
||||
}
|
||||
|
||||
shared_ptr<Node> A = parameter;
|
||||
shared_ptr<Node> B = constant;
|
||||
if (swap_inputs) {
|
||||
std::swap(A, B);
|
||||
if (type_params.op_type == EltwiseTypes::SUBTRACT ||
|
||||
type_params.op_type == EltwiseTypes::DIVIDE) {
|
||||
can_fuse = false;
|
||||
}
|
||||
}
|
||||
|
||||
shared_ptr<Node> node;
|
||||
switch (type_params.op_type) {
|
||||
case EltwiseTypes::ADD:
|
||||
node = make_shared<opset8::Add>(A, B);
|
||||
break;
|
||||
case EltwiseTypes::SUBTRACT:
|
||||
node = make_shared<opset8::Subtract>(A, B);
|
||||
break;
|
||||
case EltwiseTypes::MULTIPLY:
|
||||
node = make_shared<opset8::Multiply>(A, B);
|
||||
break;
|
||||
case EltwiseTypes::DIVIDE:
|
||||
node = make_shared<opset8::Divide>(A, B);
|
||||
break;
|
||||
default:
|
||||
ASSERT_FALSE(true) << "Invalid EltwiseType";
|
||||
}
|
||||
auto abs = make_shared<opset8::Abs>(node);
|
||||
function = make_shared<Function>(abs, ParameterVector{parameter});
|
||||
|
||||
manager.register_pass<pass::NopElimination>();
|
||||
|
||||
if (can_fuse) {
|
||||
auto abs = make_shared<opset8::Abs>(parameter);
|
||||
function_ref = make_shared<Function>(abs, ParameterVector{parameter});
|
||||
}
|
||||
|
||||
comparator.enable(FunctionsComparator::CmpValues::CONST_VALUES);
|
||||
if (type == element::f32) {
|
||||
enable_accuracy_check();
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<ShapeParams> shape_params = {
|
||||
// input 1, input 2, swap inputs, can fuse
|
||||
{ Shape{}, Shape{}, false, true },
|
||||
{ Shape{}, Shape{}, true, true },
|
||||
{ Shape{5}, Shape{}, false, true },
|
||||
{ Shape{5}, Shape{1}, false, true },
|
||||
{ Shape{5}, Shape{5}, false, true },
|
||||
{ Shape{5}, Shape{5}, true, true },
|
||||
{ Shape{2, 3, 5}, Shape{}, false, true },
|
||||
{ Shape{2, 3, 5}, Shape{1}, false, true },
|
||||
{ Shape{2, 3, 5}, Shape{1, 1}, false, true },
|
||||
{ Shape{2, 3, 5}, Shape{1, 1, 1}, false, true },
|
||||
{ Shape{2, 3, 5}, Shape{5}, false, true },
|
||||
{ Shape{2, 3, 5}, Shape{1, 5}, false, true },
|
||||
{ Shape{2, 3, 5}, Shape{1, 1, 5}, false, true },
|
||||
{ Shape{2, 3, 5}, Shape{3, 5}, false, true },
|
||||
{ Shape{2, 3, 5}, Shape{1, 3, 5}, false, true },
|
||||
{ Shape{2, 3, 5}, Shape{2, 3, 5}, false, true },
|
||||
{ Shape{2, 3, 5}, Shape{2, 3, 5}, true, true },
|
||||
{ PartialShape::dynamic(), Shape{}, false, true },
|
||||
{ PartialShape::dynamic(3), Shape{}, false, true },
|
||||
{ PartialShape::dynamic(3), Shape{1}, false, true },
|
||||
{ PartialShape::dynamic(3), Shape{1, 1}, false, true },
|
||||
{ PartialShape::dynamic(3), Shape{1, 1, 1}, false, true },
|
||||
{ PartialShape{Dimension::dynamic(), 3, Dimension::dynamic()}, Shape{1, 1}, false, true },
|
||||
{ PartialShape{Dimension::dynamic(), 3, Dimension::dynamic()}, Shape{3, 1}, false, true },
|
||||
{ PartialShape{Dimension::dynamic(), 3, Dimension::dynamic()}, Shape{1, 3, 1}, false, true },
|
||||
// negative cases
|
||||
{ Shape{}, Shape{1}, false, false },
|
||||
{ Shape{}, Shape{2, 3}, false, false },
|
||||
{ Shape{5}, Shape{1, 1}, false, false },
|
||||
{ Shape{4, 1, 3}, Shape{2, 3}, false, false },
|
||||
{ Shape{1, 2, 3}, Shape{4, 2, 3}, false, false },
|
||||
{ Shape{1, 1, 3}, Shape{2, 1, 3}, false, false },
|
||||
{ Shape{1, 2, 3}, Shape{1, 1, 1, 1}, false, false },
|
||||
{ PartialShape::dynamic(), Shape{2, 3, 4}, false, false },
|
||||
{ PartialShape::dynamic(3), Shape{1, 2, 1}, false, false },
|
||||
{ PartialShape::dynamic(3), Shape{1, 1, 1, 1}, false, false },
|
||||
};
|
||||
|
||||
std::vector<TypeParams> type_params = {
|
||||
// op type, constant value, can fuse
|
||||
{ EltwiseTypes::ADD, ConstantKind::ZERO, true },
|
||||
{ EltwiseTypes::ADD, ConstantKind::RANDOM, false },
|
||||
{ EltwiseTypes::SUBTRACT, ConstantKind::ZERO, true },
|
||||
{ EltwiseTypes::SUBTRACT, ConstantKind::RANDOM, false },
|
||||
{ EltwiseTypes::MULTIPLY, ConstantKind::ONE, true },
|
||||
{ EltwiseTypes::MULTIPLY, ConstantKind::RANDOM, false },
|
||||
{ EltwiseTypes::DIVIDE, ConstantKind::ONE, true },
|
||||
{ EltwiseTypes::DIVIDE, ConstantKind::RANDOM, false },
|
||||
};
|
||||
|
||||
std::vector<element::Type> types{
|
||||
element::f32, element::f64,
|
||||
element::i32, element::u32,
|
||||
element::i64, element::u64,
|
||||
element::i8, element::u8,
|
||||
element::i16, element::u16,
|
||||
};
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(EliminateEltwise, EliminateEltwiseTests,
|
||||
::testing::Combine(
|
||||
::testing::ValuesIn(shape_params),
|
||||
::testing::ValuesIn(type_params),
|
||||
::testing::ValuesIn(types)),
|
||||
EliminateEltwiseTests::get_test_case_name);
|
||||
|
||||
|
||||
TEST_F(TransformationTestsF, eliminate_eltwise_dequantization_subgraph) {
|
||||
{
|
||||
auto constant = opset8::Constant::create(element::i8, Shape{}, {2});
|
||||
auto convert = make_shared<opset8::Convert>(constant, element::f32);
|
||||
auto sub = make_shared<opset8::Subtract>(convert, opset8::Constant::create(element::f32, Shape{}, {0}));
|
||||
auto mul = make_shared<opset8::Multiply>(sub, opset8::Constant::create(element::f32, Shape{}, {1}));
|
||||
function = make_shared<Function>(mul, ParameterVector{});
|
||||
}
|
||||
{
|
||||
auto constant = opset8::Constant::create(element::i8, Shape{}, {2});
|
||||
auto convert = make_shared<opset8::Convert>(constant, element::f32);
|
||||
auto mul = make_shared<opset8::Multiply>(convert, opset8::Constant::create(element::f32, Shape{}, {1}));
|
||||
function_ref = make_shared<Function>(mul, ParameterVector{});
|
||||
}
|
||||
|
||||
manager.register_pass<pass::NopElimination>();
|
||||
|
||||
comparator.enable(FunctionsComparator::CmpValues::CONST_VALUES);
|
||||
enable_accuracy_check();
|
||||
}
|
||||
|
@ -201,7 +201,11 @@ void TransformationTestsF::accuracy_check(std::shared_ptr<ov::Model> ref_functio
|
||||
for (auto param : ref_function->get_parameters()) {
|
||||
types.push_back(param->get_element_type());
|
||||
|
||||
InferenceEngine::TensorDesc td(InferenceEngine::Precision::FP32, param->get_shape(), InferenceEngine::Layout::ANY);
|
||||
auto layout = InferenceEngine::Layout::ANY;
|
||||
if (ov::is_scalar(param->get_shape())) {
|
||||
layout = InferenceEngine::Layout::SCALAR;
|
||||
}
|
||||
InferenceEngine::TensorDesc td(InferenceEngine::Precision::FP32, param->get_shape(), layout);
|
||||
const auto &input = FuncTestUtils::createAndFillBlob(td);
|
||||
const auto &input_size = input->byteSize();
|
||||
|
||||
@ -227,8 +231,8 @@ void TransformationTestsF::accuracy_check(std::shared_ptr<ov::Model> ref_functio
|
||||
IE_ASSERT(ref_outputs[i].second.size() == outputs[i].second.size());
|
||||
auto * ref = reinterpret_cast<float *>(ref_outputs[i].second.data());
|
||||
auto * out = reinterpret_cast<float *>(outputs[i].second.data());
|
||||
IE_ASSERT(ref_outputs[i].second.size() / 8);
|
||||
size_t size = ref_outputs[i].second.size() / sizeof(float);
|
||||
IE_ASSERT(size > 0);
|
||||
Compare<float, float>(ref, out, size, 1e-5);
|
||||
}
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user