Eliminate no-op elementwise operations (#8840)

* Eliminate no-op elementwise operations

This change adds EliminateEltwise pass to NopElimination.
EliminateEltwise removes:
- Subtract with zero
- Multiply with one
- Divide by one

* add Add support to EliminateEltwise

* fix unit test

* use get_all_data_elements_bitwise_identical instead of get_single_value

* fix are_all_data_elements_bitwise_identical for constant created from HostTensor

* fix lpt tests

* check for mul in is_dequantization_subgraph function

* optimize fetching constant value

* apply review comments
This commit is contained in:
Mateusz Tabaka 2022-01-13 18:05:47 +01:00 committed by GitHub
parent cde61f2411
commit 8ba339d16f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
11 changed files with 471 additions and 42 deletions

View File

@ -21,6 +21,7 @@ class TRANSFORMATIONS_API EliminateConvertNonZero;
class TRANSFORMATIONS_API EliminateConcat;
class TRANSFORMATIONS_API EliminateSplit;
class TRANSFORMATIONS_API EliminateTranspose;
class TRANSFORMATIONS_API EliminateEltwise;
class TRANSFORMATIONS_API NopElimination;
} // namespace pass
@ -86,6 +87,15 @@ public:
EliminateTranspose();
};
/**
* @ingroup ie_transformation_common_api
* @brief EliminateEltwise eliminates eltwise ops that do nothing
*/
class ngraph::pass::EliminateEltwise: public ngraph::pass::MatcherPass {
public:
NGRAPH_RTTI_DECLARATION;
EliminateEltwise();
};
class ngraph::pass::NopElimination: public GraphRewrite {
public:

View File

@ -210,6 +210,10 @@ TRANSFORMATIONS_API std::shared_ptr<ngraph::Node> node_to_get_shape_value_of_ind
TRANSFORMATIONS_API std::shared_ptr<ngraph::Node> node_to_get_shape_value_of_indices_from_shape_source(
const ngraph::Output<ngraph::Node>& shape_source, const std::vector<size_t>& indices);
TRANSFORMATIONS_API bool is_dequantization_subgraph(const Output<Node>& node);
TRANSFORMATIONS_API bool can_eliminate_eltwise_node(const std::shared_ptr<Node>& eltwise, const Output<Node>& constant, const Output<Node>& non_constant_input);
} // namespace util
} // namespace op
} // namespace ngraph

View File

@ -18,32 +18,6 @@
NGRAPH_RTTI_DEFINITION(ngraph::pass::MultiplyConvolutionFusion, "MultiplyConvolutionFusion", 0);
static bool is_dequantization_subgraph(const ngraph::Output<ngraph::Node>& multiply) {
auto inputs = multiply.get_node()->input_values();
const auto subtract = std::find_if(inputs.begin(), inputs.end(),
[] (const ngraph::Output<ngraph::Node>& n) -> bool {
return ov::is_type<ngraph::opset8::Subtract>(n.get_node());
});
if (subtract != inputs.end())
inputs = subtract->get_node()->input_values();
const auto first_convert = std::find_if(inputs.begin(), inputs.end(),
[] (const ngraph::Output<ngraph::Node>& n) -> bool {
if (ov::is_type<ngraph::opset8::Convert>(n.get_node())) {
const auto input = n.get_node()->input_value(0);
return ov::is_type<ngraph::opset8::Convert>(input.get_node());
}
return false;
});
if (first_convert == inputs.end())
return false;
const auto second_convert = first_convert->get_node()->input_value(0);
const auto& first_convert_src_type = second_convert.get_element_type();
const auto& first_convert_dest_type = first_convert->get_element_type();
const auto second_convert_src_type = second_convert.get_node()->input_value(0).get_element_type();
return (first_convert_src_type == ngraph::element::i8 || first_convert_src_type == ngraph::element::u8) &&
first_convert_dest_type == second_convert_src_type;
}
ngraph::pass::MultiplyConvolutionFusion::MultiplyConvolutionFusion() {
MATCHER_SCOPE(MultiplyConvolutionFusion);
auto input_pattern = pattern::any_input();
@ -57,7 +31,7 @@ ngraph::pass::MultiplyConvolutionFusion::MultiplyConvolutionFusion() {
// Can't fuse Multiply to Convolution if that Multiply is part of dequantization subgraph
// since that breaks low precision transformations
if (is_dequantization_subgraph(pattern_to_output.at(mul_pattern)))
if (op::util::is_dequantization_subgraph(pattern_to_output.at(mul_pattern)))
return false;
const auto& weights = pattern_to_output.at(weights_pattern);
@ -110,7 +84,7 @@ ngraph::pass::MultiplyGroupConvolutionFusion::MultiplyGroupConvolutionFusion() {
// Can't fuse Multiply to Convolution if that Multiply is part of dequantization subgraph
// since that breaks low precision transformations
if (is_dequantization_subgraph(pattern_to_output.at(mul_pattern)))
if (op::util::is_dequantization_subgraph(pattern_to_output.at(mul_pattern)))
return false;
const auto& weights = pattern_to_output.at(weights_pattern);
@ -174,7 +148,7 @@ ngraph::pass::MultiplyConvolutionBackpropDataFusion::MultiplyConvolutionBackprop
// Can't fuse Multiply to Convolution if that Multiply is part of dequantization subgraph
// since that breaks low precision transformations
if (is_dequantization_subgraph(pattern_to_output.at(mul_pattern)))
if (op::util::is_dequantization_subgraph(pattern_to_output.at(mul_pattern)))
return false;
const auto& weights = pattern_to_output.at(weights_pattern);
@ -240,7 +214,7 @@ ngraph::pass::MultiplyGroupConvolutionBackpropDataFusion::MultiplyGroupConvoluti
// Can't fuse Multiply to Convolution if that Multiply is part of dequantization subgraph
// since that breaks low precision transformations
if (is_dequantization_subgraph(pattern_to_output.at(mul_pattern)))
if (op::util::is_dequantization_subgraph(pattern_to_output.at(mul_pattern)))
return false;
const auto& weights = pattern_to_output.at(weights_pattern);

View File

@ -12,6 +12,7 @@
#include <ngraph/util.hpp>
#include <ngraph/log.hpp>
#include <transformations/common_optimizations/nop_elimination.hpp>
#include <transformations/utils/utils.hpp>
#include <ngraph/pattern/op/wrap_type.hpp>
using namespace std;
@ -503,6 +504,34 @@ pass::EliminateTranspose::EliminateTranspose() {
this->register_matcher(m, callback);
}
NGRAPH_RTTI_DEFINITION(pass::EliminateEltwise, "EliminateEltwise", 0);
pass::EliminateEltwise::EliminateEltwise() {
MATCHER_SCOPE(EliminateEltwise);
auto input = pattern::any_input();
auto constant_pattern = pattern::wrap_type<opset8::Constant>();
auto eltwise_pattern = pattern::wrap_type<opset8::Add,
opset8::Subtract,
opset8::Multiply,
opset8::Divide>({input, constant_pattern});
matcher_pass_callback callback = [=](pattern::Matcher& m) {
const auto& pattern_map = m.get_pattern_value_map();
auto eltwise = pattern_map.at(eltwise_pattern).get_node_shared_ptr();
auto non_const_input = pattern_map.at(input);
auto constant = pattern_map.at(constant_pattern);
if (!op::util::can_eliminate_eltwise_node(eltwise, constant, non_const_input)) {
return false;
}
return replace_output_update_name(eltwise->output(0), non_const_input);
};
auto m = std::make_shared<pattern::Matcher>(eltwise_pattern, matcher_name);
this->register_matcher(m, callback);
}
NGRAPH_RTTI_DEFINITION(ngraph::pass::NopElimination, "NopElimination", 0);
ngraph::pass::NopElimination::NopElimination(bool use_shape_for_elimination) {
@ -513,6 +542,7 @@ ngraph::pass::NopElimination::NopElimination(bool use_shape_for_elimination) {
add_matcher<EliminateConcat>();
add_matcher<EliminateSplit>();
add_matcher<EliminateTranspose>();
add_matcher<EliminateEltwise>();
// shape-dependent transformations
if (use_shape_for_elimination) {

View File

@ -4,6 +4,8 @@
#include "itt.hpp"
#include "transformations/op_conversions/convert_divide.hpp"
#include "transformations/utils/utils.hpp"
#include <memory>
#include <vector>
@ -28,25 +30,30 @@ bool convert_divide(std::shared_ptr<ngraph::Node> node) {
return false;
}
ngraph::Output<ngraph::Node> pow = std::make_shared<ngraph::opset1::Power>(div->input_value(1),
std::shared_ptr<ngraph::Node> pow = std::make_shared<ngraph::opset1::Power>(div->input_value(1),
ngraph::op::Constant::create(div->get_input_element_type(1), ngraph::Shape{}, {-1}));
if (std::dynamic_pointer_cast<ngraph::op::Constant>(div->get_input_node_shared_ptr(1))) {
if (auto const_pow = ngraph::get_constant_from_source(pow)) {
pow = const_pow;
} else {
NGRAPH_DEBUG << "ConvertDivide has failed due to unsupported evaluate type in " << pow.get_node();
NGRAPH_DEBUG << "ConvertDivide has failed due to unsupported evaluate type in " << pow.get();
return false;
}
} else {
ngraph::copy_runtime_info(div, pow.get_node_shared_ptr());
ngraph::copy_runtime_info(div, pow);
}
auto mul = std::make_shared<ngraph::opset1::Multiply>(div->input(0).get_source_output(), pow);
mul->set_friendly_name(div->get_friendly_name());
ngraph::copy_runtime_info(div, mul);
ngraph::replace_node(div, mul);
// if Divide is an inverse, then we don't need the Multiply
if (ngraph::op::util::can_eliminate_eltwise_node(mul, mul->input_value(0), mul->input_value(1))) {
pow->set_friendly_name(div->get_friendly_name());
ngraph::replace_node(div, pow);
} else {
mul->set_friendly_name(div->get_friendly_name());
ngraph::copy_runtime_info(div, mul);
ngraph::replace_node(div, mul);
}
return true;
}
} // namespace
@ -74,4 +81,4 @@ ngraph::pass::ConvertDivideWithConstant::ConvertDivideWithConstant() {
auto m = std::make_shared<ngraph::pattern::Matcher>(div, matcher_name);
this->register_matcher(m, callback);
}
}

View File

@ -203,6 +203,132 @@ void visit_shape_path(const std::shared_ptr<ov::Node>& node,
}
}
bool is_dequantization_subgraph(const Output<Node>& node) {
if (!is_type<opset8::Multiply>(node.get_node())) {
return false;
}
auto mul_inputs = node.get_node()->input_values();
Node* sub = nullptr;
Node* convert = nullptr;
if (is_type<opset8::Subtract>(mul_inputs[0].get_node())) {
sub = mul_inputs[0].get_node();
} else if (is_type<opset8::Convert>(mul_inputs[0].get_node())) {
convert = mul_inputs[0].get_node();
} else {
return false;
}
if (sub) {
auto sub_inputs = sub->input_values();
if (is_type<opset8::Convert>(sub_inputs[0].get_node())) {
convert = sub_inputs[0].get_node();
}
}
if (!convert) {
return false;
}
auto input_type = convert->get_input_element_type(0);
auto output_type = convert->get_output_element_type(0);
return input_type.is_integral() && output_type.is_real();
}
bool can_eliminate_eltwise_node(const std::shared_ptr<Node>& eltwise, const Output<Node>& constant, const Output<Node>& non_constant_input) {
if (!is_type<opset8::Add>(eltwise) &&
!is_type<opset8::Subtract>(eltwise) &&
!is_type<opset8::Multiply>(eltwise) &&
!is_type<opset8::Divide>(eltwise)) {
return false;
}
if (is_dequantization_subgraph(eltwise)) {
return false;
}
// check if constant has a single value with either 0 (for Add, Subtract) or 1 (for Multiply, Divide)
auto constant_ptr = std::dynamic_pointer_cast<opset8::Constant>(constant.get_node_shared_ptr());
if (!constant_ptr) {
return false;
}
if (!constant_ptr->get_all_data_elements_bitwise_identical()) {
return false;
}
float actual_const = 0;
const void* data_ptr = constant_ptr->get_data_ptr();
switch (constant_ptr->get_element_type()) {
case element::f32:
actual_const = reinterpret_cast<const float*>(data_ptr)[0];
break;
case element::i32:
actual_const = static_cast<float>(reinterpret_cast<const int32_t*>(data_ptr)[0]);
break;
case element::u32:
actual_const = static_cast<float>(reinterpret_cast<const uint32_t*>(data_ptr)[0]);
break;
case element::i64:
actual_const = static_cast<float>(reinterpret_cast<const int64_t*>(data_ptr)[0]);
break;
case element::u64:
actual_const = static_cast<float>(reinterpret_cast<const uint64_t*>(data_ptr)[0]);
break;
case element::i8:
actual_const = static_cast<float>(reinterpret_cast<const int8_t*>(data_ptr)[0]);
break;
case element::u8:
actual_const = static_cast<float>(reinterpret_cast<const uint8_t*>(data_ptr)[0]);
break;
case element::i16:
actual_const = static_cast<float>(reinterpret_cast<const int16_t*>(data_ptr)[0]);
break;
case element::u16:
actual_const = static_cast<float>(reinterpret_cast<const uint16_t*>(data_ptr)[0]);
break;
case element::f64:
actual_const = static_cast<float>(reinterpret_cast<const double*>(data_ptr)[0]);
break;
default:
return false;
}
float expected_const = 0;
if (is_type<opset8::Multiply>(eltwise) ||
is_type<opset8::Divide>(eltwise)) {
expected_const = 1;
}
if (actual_const != expected_const) {
return false;
}
// fuse uncoditionally if constant is a scalar
const auto& constant_shape = constant.get_shape();
if (ov::is_scalar(constant_shape)) {
return true;
}
const auto& input_shape = non_constant_input.get_partial_shape();
if (input_shape.rank().is_dynamic()) {
return false;
}
// cannot fuse if constant extends input's rank
auto input_rank = static_cast<size_t>(input_shape.rank().get_length());
auto constant_rank = constant_shape.size();
if (input_rank < constant_rank) {
return false;
}
// cannot fuse if constant makes input to be broadcasted, e.g.
// Multiply(input{2, 1, 5}, constant{1, 5, 1}) -> {2, 5, 5}
for (size_t i = 0; i < constant_rank; i++) {
auto constant_dim = constant_shape[constant_rank - i - 1];
if (constant_dim != 1 && input_shape[input_rank - i - 1] != constant_dim) {
return false;
}
}
return true;
}
} // namespace util
} // namespace op
} // namespace ngraph

View File

@ -48,8 +48,8 @@ ov::op::v0::Constant::Constant(const shared_ptr<ngraph::runtime::Tensor>& tensor
constructor_validate_and_infer_types();
allocate_buffer();
tensor->read(get_data_ptr_nc(), tensor->get_size_in_bytes());
m_all_elements_bitwise_identical = are_all_data_elements_bitwise_identical();
}
m_all_elements_bitwise_identical = are_all_data_elements_bitwise_identical();
constructor_validate_and_infer_types();
}

View File

@ -9,6 +9,7 @@
#include "ngraph/opsets/opset3.hpp"
#include "ngraph/opsets/opset4.hpp"
#include "ngraph/opsets/opset5.hpp"
#include "ngraph/runtime/host_tensor.hpp"
#include "util/visitor.hpp"
using namespace std;
@ -56,3 +57,35 @@ TEST(attributes, constant_op_identical_elements) {
EXPECT_EQ(data, g_data);
ASSERT_TRUE(g_k->get_all_data_elements_bitwise_identical());
}
TEST(attributes, constant_op_from_host_tensor_different_elements) {
vector<int64_t> data{5, 4, 3, 2, 1, 0};
auto tensor = std::make_shared<runtime::HostTensor>(element::i64, Shape{2, 3}, &data[0]);
auto k = make_shared<op::v0::Constant>(tensor);
ASSERT_FALSE(k->get_all_data_elements_bitwise_identical());
NodeBuilder builder(k);
auto g_k = ov::as_type_ptr<op::v0::Constant>(builder.create());
g_k->validate_and_infer_types();
ASSERT_TRUE(g_k);
EXPECT_EQ(k->get_element_type(), g_k->get_element_type());
EXPECT_EQ(k->get_shape(), g_k->get_shape());
vector<int64_t> g_data = g_k->get_vector<int64_t>();
EXPECT_EQ(data, g_data);
ASSERT_FALSE(g_k->get_all_data_elements_bitwise_identical());
}
TEST(attributes, constant_op_from_host_tensor_identical_elements) {
vector<int64_t> data{5, 5, 5, 5, 5, 5};
auto tensor = std::make_shared<runtime::HostTensor>(element::i64, Shape{2, 3}, &data[0]);
auto k = make_shared<op::v0::Constant>(tensor);
ASSERT_TRUE(k->get_all_data_elements_bitwise_identical());
NodeBuilder builder(k);
auto g_k = ov::as_type_ptr<op::v0::Constant>(builder.create());
g_k->validate_and_infer_types();
ASSERT_TRUE(g_k);
EXPECT_EQ(k->get_element_type(), g_k->get_element_type());
EXPECT_EQ(k->get_shape(), g_k->get_shape());
vector<int64_t> g_data = g_k->get_vector<int64_t>();
EXPECT_EQ(data, g_data);
ASSERT_TRUE(g_k->get_all_data_elements_bitwise_identical());
}

View File

@ -38,8 +38,31 @@ TEST_F(TransformationTestsF, ConvertDivide) {
function_ref = std::make_shared<ngraph::Function>(ngraph::NodeVector{mul}, ngraph::ParameterVector{data});
}
comparator.enable(FunctionsComparator::CmpValues::CONST_VALUES);
}
TEST_F(TransformationTestsF, ConvertDivideInverse) {
{
auto data = std::make_shared<ngraph::opset1::Parameter>(ngraph::element::f32, ngraph::Shape{3, 1, 2});
auto divide_constant = ngraph::opset1::Constant::create(ngraph::element::f32, ngraph::Shape{1}, {1});
auto divide = std::make_shared<ngraph::opset1::Divide>(divide_constant, data);
function = std::make_shared<ngraph::Function>(ngraph::NodeVector{divide}, ngraph::ParameterVector{data});
manager.register_pass<ngraph::pass::ConvertDivide>();
}
{
auto data = std::make_shared<ngraph::opset1::Parameter>(ngraph::element::f32, ngraph::Shape{3, 1, 2});
auto constant = ngraph::opset1::Constant::create(ngraph::element::f32, ngraph::Shape{}, {-1.0});
auto pow = std::make_shared<ngraph::opset1::Power>(data, constant);
function_ref = std::make_shared<ngraph::Function>(ngraph::NodeVector{pow}, ngraph::ParameterVector{data});
}
comparator.enable(FunctionsComparator::CmpValues::CONST_VALUES);
}
TEST_F(TransformationTestsF, ConvertDivideNegative) {
{
auto data = std::make_shared<ngraph::opset1::Parameter>(ngraph::element::i32, ngraph::Shape{3, 1, 2});
@ -58,6 +81,7 @@ TEST_F(TransformationTestsF, ConvertDivideNegative) {
function_ref = std::make_shared<ngraph::Function>(ngraph::NodeVector{divide}, ngraph::ParameterVector{data});
}
comparator.enable(FunctionsComparator::CmpValues::CONST_VALUES);
}
TEST_F(TransformationTestsF, ConvertDivideScalar) {
@ -84,6 +108,7 @@ TEST_F(TransformationTestsF, ConvertDivideScalar) {
NGRAPH_CHECK(mul->get_output_partial_shape(0).rank().get_length() == 0);
}
comparator.enable(FunctionsComparator::CmpValues::CONST_VALUES);
}
TEST_F(TransformationTestsF, ConvertDivideWithConstantPositive) {
@ -103,6 +128,7 @@ TEST_F(TransformationTestsF, ConvertDivideWithConstantPositive) {
function_ref = std::make_shared<ngraph::Function>(ngraph::NodeVector{mul}, ngraph::ParameterVector{data});
}
comparator.enable(FunctionsComparator::CmpValues::CONST_VALUES);
}
TEST_F(TransformationTestsF, ConvertDivideWithConstantNegative) {
@ -122,6 +148,7 @@ TEST_F(TransformationTestsF, ConvertDivideWithConstantNegative) {
function_ref = std::make_shared<ngraph::Function>(ngraph::NodeVector{divide}, ngraph::ParameterVector{data1, data2});
}
comparator.enable(FunctionsComparator::CmpValues::CONST_VALUES);
}
TEST_F(TransformationTestsF, ConvertDivideFP16ShapeOfSubgraphNegative) {

View File

@ -18,8 +18,11 @@
#include <transformations/utils/utils.hpp>
#include <transformations/init_node_info.hpp>
#include <transformations/rt_info/fused_names_attribute.hpp>
#include <ngraph_functions/utils/ngraph_helpers.hpp>
#include <ngraph_functions/builders.hpp>
#include "common_test_utils/ngraph_test_utils.hpp"
#include "common_test_utils/common_utils.hpp"
#include <ngraph/pass/manager.hpp>
@ -817,4 +820,215 @@ TEST(nop_elimination, gather_3d_indices_constant_axis_1) {
check_usecase(PartialShape{3, 2, 1}, i32, multiout, std::vector<int64_t>{1}, 2, 0);
check_usecase(PartialShape{1, 16}, i32, multiout, std::vector<int64_t>{0, 0}, 0, 1);
}
}
}
using namespace helpers;
struct ShapeParams {
PartialShape shape1;
Shape shape2;
bool swap_inputs;
bool can_fuse;
};
enum class ConstantKind {
ZERO,
ONE,
RANDOM,
};
static std::ostream& operator<<(std::ostream& os, ConstantKind kind) {
switch (kind) {
case ConstantKind::ZERO:
os << "zero";
break;
case ConstantKind::ONE:
os << "one";
break;
case ConstantKind::RANDOM:
os << "random";
break;
}
return os;
}
struct TypeParams {
EltwiseTypes op_type;
ConstantKind constant_kind;
bool can_fuse;
};
using EliminateEltwiseParams = std::tuple<ShapeParams, TypeParams, element::Type>;
class EliminateEltwiseTests: public testing::WithParamInterface<EliminateEltwiseParams>, virtual public TransformationTestsF {
public:
static std::string get_test_case_name(testing::TestParamInfo<EliminateEltwiseParams> info) {
const auto& shape_params = std::get<0>(info.param);
const auto& type_params = std::get<1>(info.param);
const auto& element_type = std::get<2>(info.param);
std::ostringstream result;
result << type_params.op_type
<< "_input1=" << shape_params.shape1
<< "_input2=" << shape_params.shape2
<< "_swap_inputs=" << std::boolalpha << shape_params.swap_inputs
<< "_constant=" << type_params.constant_kind
<< "_type=" << element_type;
return result.str();
}
};
TEST_P(EliminateEltwiseTests, eliminate_eltwise) {
auto params = GetParam();
const auto& shape_params = std::get<0>(params);
const auto& type_params = std::get<1>(params);
const auto& type = std::get<2>(params);
const auto& shape1 = shape_params.shape1;
const auto& shape2 = shape_params.shape2;
bool swap_inputs = shape_params.swap_inputs;
bool can_fuse = shape_params.can_fuse && type_params.can_fuse;
auto parameter = make_shared<op::Parameter>(type, shape1);
std::shared_ptr<Node> constant;
switch (type_params.constant_kind) {
case ConstantKind::ZERO:
constant = op::Constant::create(type, shape2, {0});
break;
case ConstantKind::ONE:
constant = op::Constant::create(type, shape2, {1});
break;
case ConstantKind::RANDOM:
constant = builder::makeConstant(type, shape2, {}, true, 20 /* upTo */, 2 /* startFrom */);
break;
}
shared_ptr<Node> A = parameter;
shared_ptr<Node> B = constant;
if (swap_inputs) {
std::swap(A, B);
if (type_params.op_type == EltwiseTypes::SUBTRACT ||
type_params.op_type == EltwiseTypes::DIVIDE) {
can_fuse = false;
}
}
shared_ptr<Node> node;
switch (type_params.op_type) {
case EltwiseTypes::ADD:
node = make_shared<opset8::Add>(A, B);
break;
case EltwiseTypes::SUBTRACT:
node = make_shared<opset8::Subtract>(A, B);
break;
case EltwiseTypes::MULTIPLY:
node = make_shared<opset8::Multiply>(A, B);
break;
case EltwiseTypes::DIVIDE:
node = make_shared<opset8::Divide>(A, B);
break;
default:
ASSERT_FALSE(true) << "Invalid EltwiseType";
}
auto abs = make_shared<opset8::Abs>(node);
function = make_shared<Function>(abs, ParameterVector{parameter});
manager.register_pass<pass::NopElimination>();
if (can_fuse) {
auto abs = make_shared<opset8::Abs>(parameter);
function_ref = make_shared<Function>(abs, ParameterVector{parameter});
}
comparator.enable(FunctionsComparator::CmpValues::CONST_VALUES);
if (type == element::f32) {
enable_accuracy_check();
}
}
std::vector<ShapeParams> shape_params = {
// input 1, input 2, swap inputs, can fuse
{ Shape{}, Shape{}, false, true },
{ Shape{}, Shape{}, true, true },
{ Shape{5}, Shape{}, false, true },
{ Shape{5}, Shape{1}, false, true },
{ Shape{5}, Shape{5}, false, true },
{ Shape{5}, Shape{5}, true, true },
{ Shape{2, 3, 5}, Shape{}, false, true },
{ Shape{2, 3, 5}, Shape{1}, false, true },
{ Shape{2, 3, 5}, Shape{1, 1}, false, true },
{ Shape{2, 3, 5}, Shape{1, 1, 1}, false, true },
{ Shape{2, 3, 5}, Shape{5}, false, true },
{ Shape{2, 3, 5}, Shape{1, 5}, false, true },
{ Shape{2, 3, 5}, Shape{1, 1, 5}, false, true },
{ Shape{2, 3, 5}, Shape{3, 5}, false, true },
{ Shape{2, 3, 5}, Shape{1, 3, 5}, false, true },
{ Shape{2, 3, 5}, Shape{2, 3, 5}, false, true },
{ Shape{2, 3, 5}, Shape{2, 3, 5}, true, true },
{ PartialShape::dynamic(), Shape{}, false, true },
{ PartialShape::dynamic(3), Shape{}, false, true },
{ PartialShape::dynamic(3), Shape{1}, false, true },
{ PartialShape::dynamic(3), Shape{1, 1}, false, true },
{ PartialShape::dynamic(3), Shape{1, 1, 1}, false, true },
{ PartialShape{Dimension::dynamic(), 3, Dimension::dynamic()}, Shape{1, 1}, false, true },
{ PartialShape{Dimension::dynamic(), 3, Dimension::dynamic()}, Shape{3, 1}, false, true },
{ PartialShape{Dimension::dynamic(), 3, Dimension::dynamic()}, Shape{1, 3, 1}, false, true },
// negative cases
{ Shape{}, Shape{1}, false, false },
{ Shape{}, Shape{2, 3}, false, false },
{ Shape{5}, Shape{1, 1}, false, false },
{ Shape{4, 1, 3}, Shape{2, 3}, false, false },
{ Shape{1, 2, 3}, Shape{4, 2, 3}, false, false },
{ Shape{1, 1, 3}, Shape{2, 1, 3}, false, false },
{ Shape{1, 2, 3}, Shape{1, 1, 1, 1}, false, false },
{ PartialShape::dynamic(), Shape{2, 3, 4}, false, false },
{ PartialShape::dynamic(3), Shape{1, 2, 1}, false, false },
{ PartialShape::dynamic(3), Shape{1, 1, 1, 1}, false, false },
};
std::vector<TypeParams> type_params = {
// op type, constant value, can fuse
{ EltwiseTypes::ADD, ConstantKind::ZERO, true },
{ EltwiseTypes::ADD, ConstantKind::RANDOM, false },
{ EltwiseTypes::SUBTRACT, ConstantKind::ZERO, true },
{ EltwiseTypes::SUBTRACT, ConstantKind::RANDOM, false },
{ EltwiseTypes::MULTIPLY, ConstantKind::ONE, true },
{ EltwiseTypes::MULTIPLY, ConstantKind::RANDOM, false },
{ EltwiseTypes::DIVIDE, ConstantKind::ONE, true },
{ EltwiseTypes::DIVIDE, ConstantKind::RANDOM, false },
};
std::vector<element::Type> types{
element::f32, element::f64,
element::i32, element::u32,
element::i64, element::u64,
element::i8, element::u8,
element::i16, element::u16,
};
INSTANTIATE_TEST_SUITE_P(EliminateEltwise, EliminateEltwiseTests,
::testing::Combine(
::testing::ValuesIn(shape_params),
::testing::ValuesIn(type_params),
::testing::ValuesIn(types)),
EliminateEltwiseTests::get_test_case_name);
TEST_F(TransformationTestsF, eliminate_eltwise_dequantization_subgraph) {
{
auto constant = opset8::Constant::create(element::i8, Shape{}, {2});
auto convert = make_shared<opset8::Convert>(constant, element::f32);
auto sub = make_shared<opset8::Subtract>(convert, opset8::Constant::create(element::f32, Shape{}, {0}));
auto mul = make_shared<opset8::Multiply>(sub, opset8::Constant::create(element::f32, Shape{}, {1}));
function = make_shared<Function>(mul, ParameterVector{});
}
{
auto constant = opset8::Constant::create(element::i8, Shape{}, {2});
auto convert = make_shared<opset8::Convert>(constant, element::f32);
auto mul = make_shared<opset8::Multiply>(convert, opset8::Constant::create(element::f32, Shape{}, {1}));
function_ref = make_shared<Function>(mul, ParameterVector{});
}
manager.register_pass<pass::NopElimination>();
comparator.enable(FunctionsComparator::CmpValues::CONST_VALUES);
enable_accuracy_check();
}

View File

@ -201,7 +201,11 @@ void TransformationTestsF::accuracy_check(std::shared_ptr<ov::Model> ref_functio
for (auto param : ref_function->get_parameters()) {
types.push_back(param->get_element_type());
InferenceEngine::TensorDesc td(InferenceEngine::Precision::FP32, param->get_shape(), InferenceEngine::Layout::ANY);
auto layout = InferenceEngine::Layout::ANY;
if (ov::is_scalar(param->get_shape())) {
layout = InferenceEngine::Layout::SCALAR;
}
InferenceEngine::TensorDesc td(InferenceEngine::Precision::FP32, param->get_shape(), layout);
const auto &input = FuncTestUtils::createAndFillBlob(td);
const auto &input_size = input->byteSize();
@ -227,8 +231,8 @@ void TransformationTestsF::accuracy_check(std::shared_ptr<ov::Model> ref_functio
IE_ASSERT(ref_outputs[i].second.size() == outputs[i].second.size());
auto * ref = reinterpret_cast<float *>(ref_outputs[i].second.data());
auto * out = reinterpret_cast<float *>(outputs[i].second.data());
IE_ASSERT(ref_outputs[i].second.size() / 8);
size_t size = ref_outputs[i].second.size() / sizeof(float);
IE_ASSERT(size > 0);
Compare<float, float>(ref, out, size, 1e-5);
}
}