Eliminate dequantization shift when zero point == 0 (#13353)

Ticket: 91111
This commit is contained in:
Mateusz Tabaka 2022-10-27 01:37:13 +02:00 committed by GitHub
parent 554af81085
commit 154850e8ca
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 205 additions and 80 deletions

View File

@ -272,46 +272,51 @@ std::tuple<std::shared_ptr<Node>, std::shared_ptr<Node>> decomposeFakeQuantize(
} // namespace fq_decomposition
bool FakeQuantizeDecompositionTransformation::transform(TransformationContext& context, ngraph::pattern::Matcher& m) {
auto layer = ov::as_type_ptr<opset1::FakeQuantize>(m.get_match_root());
if (!layer || !NetworkHelper::isQuantizeSupported(layer)) {
auto node = ov::as_type_ptr<opset1::FakeQuantize>(m.get_match_root());
if (!node || !NetworkHelper::isQuantizeSupported(node)) {
return false;
}
layer = NetworkHelper::fuseConvert(layer);
auto layer = NetworkHelper::fuseConvert(node);
bool rewritten = layer.get() != node.get();
if (rewritten) {
register_new_node(layer);
}
if (NetworkHelper::isConstantPath(layer)) {
return false;
return rewritten;
}
auto attribute = getAttributeFromOutput<PrecisionsAttribute>(layer->output(0));
if (attribute.empty() || (attribute.as<PrecisionsAttribute>().value().empty())) {
return false;
return rewritten;
}
const ngraph::element::Type outputPrecision = layer->get_output_element_type(0);
if (DataPrecision::isSupported(outputPrecision)) {
const FakeQuantizeDequantization dequantization = NetworkHelper::getDequantizationBelow(layer);
if (dequantization.empty()) {
return false;
return rewritten;
}
const DataPrecision expectedDataPrecision = fq_decomposition::getDataPrecisionByOutputPortAndFakeQuantize(layer);
// TODO: need test to compose FakeQuantize
if ((expectedDataPrecision.precision == element::undefined) || (expectedDataPrecision.precision == outputPrecision)) {
return false;
return rewritten;
}
layer = NetworkHelper::composeFakeQuantize(layer, defaultPrecisions);
if (layer == nullptr) {
return false;
return rewritten;
}
}
if (!QuantizationDetails::outputLayoutIsSupported(layer)) {
return false;
return rewritten;
}
if (!QuantizationDetails::isSupportedLevel(layer->get_levels())) {
return false;
return rewritten;
}
DataPrecision dataPrecision = fq_decomposition::getDataPrecisionByOutputPort(layer);
@ -343,7 +348,7 @@ bool FakeQuantizeDecompositionTransformation::transform(TransformationContext& c
// FakeQuantize operations are combined in supported cascade (per tensor quantization)
if (!intervalsAlignment.empty() && (intervalsAlignment.as<IntervalsAlignmentAttribute>().value().minLevels <= 2ul)) {
return false;
return rewritten;
}
// if IntervalsAlignment attribute is defined then, the attribute defines decomposition parameters,
@ -396,6 +401,8 @@ bool FakeQuantizeDecompositionTransformation::transform(TransformationContext& c
}
}
// clear the node that was produced by fuseConvert
clear_new_nodes();
auto QDQ = fq_decomposition::decomposeFakeQuantize(
this,
layer,
@ -407,7 +414,7 @@ bool FakeQuantizeDecompositionTransformation::transform(TransformationContext& c
std::shared_ptr<ngraph::Node> dequantize = std::get<0>(QDQ);
std::shared_ptr<ngraph::Node> newFakeQuantize = std::get<1>(QDQ);
if (dequantize == nullptr || newFakeQuantize == nullptr) {
return false;
return rewritten;
}
updateOutput(context, dequantize, newFakeQuantize);

View File

@ -7,6 +7,7 @@
#include <ngraph/log.hpp>
#include <ngraph/opsets/opset3.hpp>
#include <ngraph/opsets/opset8.hpp>
#include <ngraph/pattern/op/or.hpp>
#include <ngraph/pattern/op/wrap_type.hpp>
#include <ngraph/util.hpp>
#include <numeric>
@ -696,12 +697,15 @@ pass::EliminateEltwise::EliminateEltwise() {
auto constant_pattern = pattern::wrap_type<opset8::Constant>();
auto eltwise_pattern =
pattern::wrap_type<opset8::Add, opset8::Subtract, opset8::Multiply, opset8::Divide>({input, constant_pattern});
auto subtract_pattern =
pattern::wrap_type<opset8::Subtract>({input, pattern::wrap_type<opset8::Convert>({constant_pattern})});
auto root = make_shared<pattern::op::Or>(OutputVector{eltwise_pattern, subtract_pattern});
matcher_pass_callback callback = [=](pattern::Matcher& m) {
const auto& pattern_map = m.get_pattern_value_map();
auto eltwise = pattern_map.at(eltwise_pattern).get_node_shared_ptr();
auto non_const_input = pattern_map.at(input);
auto constant = pattern_map.at(constant_pattern);
auto eltwise = m.get_match_root();
const auto& non_const_input = pattern_map.at(input);
const auto& constant = pattern_map.at(constant_pattern);
if (!op::util::can_eliminate_eltwise_node(eltwise, constant, non_const_input)) {
return false;
@ -709,7 +713,7 @@ pass::EliminateEltwise::EliminateEltwise() {
return replace_output_update_name(eltwise->output(0), non_const_input);
};
auto m = make_shared<pattern::Matcher>(eltwise_pattern, matcher_name);
auto m = make_shared<pattern::Matcher>(root, matcher_name);
this->register_matcher(m, callback);
}

View File

@ -20,18 +20,17 @@ namespace ngraph {
namespace onnx_import {
namespace op {
namespace detail {
Output<ngraph::Node> get_zero_point(const OutputVector& inputs) {
std::shared_ptr<ngraph::Node> get_zero_point(const OutputVector& inputs) {
if (inputs.size() == 3 && !ngraph::op::is_null(inputs[2])) {
auto zero_point = inputs[2];
const auto& zero_point = inputs[2];
if (zero_point.get_element_type() != element::f32) {
zero_point = std::make_shared<default_opset::Convert>(zero_point, element::f32);
return std::make_shared<default_opset::Convert>(zero_point, element::f32);
}
return zero_point;
} else {
return default_opset::Constant::create(element::f32, Shape{}, {0});
return zero_point.get_node_shared_ptr();
}
return nullptr;
}
} // namespace detail
namespace set_1 {
@ -42,18 +41,22 @@ OutputVector dequantize_linear(const Node& node) {
"The DequantizeLinear op expects 2 required and one optional input. Got: ",
inputs.size());
const auto x = inputs[0];
const auto scale = inputs[1];
const auto& x = inputs[0];
const auto& scale = inputs[1];
const auto zero_point = detail::get_zero_point(inputs);
common::validate_scalar_input("Dequantization scale", scale.get_node_shared_ptr(), {element::f32});
common::validate_scalar_input("Zero point", zero_point.get_node_shared_ptr());
const auto converted_x = std::make_shared<default_opset::Convert>(x, element::f32);
return {
std::make_shared<default_opset::Multiply>(std::make_shared<default_opset::Subtract>(converted_x, zero_point),
scale)};
if (zero_point) {
common::validate_scalar_input("Zero point", zero_point);
return {std::make_shared<default_opset::Multiply>(
std::make_shared<default_opset::Subtract>(converted_x, zero_point),
scale)};
} else {
return {std::make_shared<default_opset::Multiply>(converted_x, scale)};
}
}
} // namespace set_1
@ -99,9 +102,10 @@ void validate_zero_point(const Output<ngraph::Node> zero_point, const Output<ngr
}
}
std::shared_ptr<ngraph::Node> reshape_input(const Output<ngraph::Node> input,
std::shared_ptr<ngraph::Node> reshape_input(const Output<ngraph::Node>& input,
const int64_t axis,
const PartialShape& x_shape) {
// these reshapes make sure that dequantization happens over the specified axis
auto input_rank = input.get_partial_shape().rank();
// Do not reshape input, if it contains a scalar value
@ -130,29 +134,29 @@ std::shared_ptr<ngraph::Node> reshape_input(const Output<ngraph::Node> input,
return std::make_shared<default_opset::Reshape>(input, target_shape, true);
}
OutputVector dequantize_linear(Output<ngraph::Node> x,
Output<ngraph::Node> scale,
Output<ngraph::Node> zero_point,
OutputVector dequantize_linear(const Output<ngraph::Node>& x,
const Output<ngraph::Node>& scale,
const std::shared_ptr<ngraph::Node>& zero_point,
int64_t axis,
Node node) {
const auto x_shape = x.get_partial_shape();
const Node& node) {
const auto& x_shape = x.get_partial_shape();
NGRAPH_CHECK(x_shape.rank().is_static(), "Rank of the input data tensor has to be known (static).");
axis = ngraph::normalize_axis(node.get_description(), axis, x_shape.rank());
validate_scale(scale, x, axis);
validate_zero_point(zero_point, x, axis);
// these reshapes make sure that dequantization happens over the specified axis
scale = reshape_input(scale, axis, x_shape);
zero_point = reshape_input(zero_point, axis, x_shape);
const auto scale_reshaped = reshape_input(scale, axis, x_shape);
const auto converted_x = std::make_shared<default_opset::Convert>(x, element::f32);
return {
std::make_shared<default_opset::Multiply>(std::make_shared<default_opset::Subtract>(converted_x, zero_point),
scale)};
if (zero_point) {
validate_zero_point(zero_point, x, axis);
return {std::make_shared<default_opset::Multiply>(
std::make_shared<default_opset::Subtract>(converted_x, reshape_input(zero_point, axis, x_shape)),
scale_reshaped)};
} else {
return {std::make_shared<default_opset::Multiply>(converted_x, scale_reshaped)};
}
}
} // namespace detail
@ -163,9 +167,9 @@ OutputVector dequantize_linear(const Node& node) {
"The DequantizeLinear op expects 2 required and one optional "
"input. Got: ",
inputs.size());
const auto x = inputs[0];
auto scale = inputs[1];
auto zero_point = op::detail::get_zero_point(inputs);
const auto& x = inputs[0];
const auto& scale = inputs[1];
const auto zero_point = op::detail::get_zero_point(inputs);
// these reshapes make sure that dequantization happens over the specified axis
return detail::dequantize_linear(x, scale, zero_point, node.get_attribute_value<int64_t>("axis", 1), node);

View File

@ -10,9 +10,6 @@
namespace ngraph {
namespace onnx_import {
namespace op {
namespace detail {
Output<ngraph::Node> get_zero_point(const OutputVector& inputs);
}
namespace set_1 {
OutputVector dequantize_linear(const Node& node);
@ -21,11 +18,11 @@ OutputVector dequantize_linear(const Node& node);
namespace set_13 {
namespace detail {
OutputVector dequantize_linear(Output<ngraph::Node> x,
Output<ngraph::Node> scale,
Output<ngraph::Node> zero_point,
OutputVector dequantize_linear(const Output<ngraph::Node>& x,
const Output<ngraph::Node>& scale,
const std::shared_ptr<ngraph::Node>& zero_point,
int64_t axis,
Node node);
const Node& node);
}
OutputVector dequantize_linear(const Node& node);
} // namespace set_13

View File

@ -0,0 +1,54 @@
ir_version: 7
producer_name: "ngraph ONNXImporter"
graph {
node {
input: "x"
input: "x_scale"
output: "y"
name: "node1"
op_type: "DequantizeLinear"
}
name: "test"
input {
name: "x"
type {
tensor_type {
elem_type: 2
shape {
dim {
dim_value: 2
}
dim {
dim_value: 2
}
}
}
}
}
input {
name: "x_scale"
type {
tensor_type {
elem_type: 1
shape{
dim {
dim_value: 2
}
}
}
}
}
output {
name: "y"
type {
tensor_type {
elem_type: 1
shape {
}
}
}
}
}
opset_import {
version: 13
}

View File

@ -178,6 +178,19 @@ NGRAPH_TEST(${BACKEND_NAME}, onnx_model_dequantize_linear_scalar_zero_point) {
test_case.run();
}
NGRAPH_TEST(${BACKEND_NAME}, onnx_model_dequantize_linear_no_zero_point) {
auto function = onnx_import::import_onnx_model(file_util::path_join(CommonTestUtils::getExecutableDirectory(),
SERIALIZED_ZOO,
"onnx/dequantize_linear_no_zero_point.onnx"));
auto test_case = test::TestCase(function, s_device);
test_case.add_input(std::vector<std::uint8_t>{19, 210, 21, 10}); // x
test_case.add_input(std::vector<float>{2.0f, 1.0f}); // scale
test_case.add_expected_output<float>(std::vector<float>{38, 210, 42, 10});
test_case.run();
}
NGRAPH_TEST(${BACKEND_NAME}, onnx_model_dequantize_linear_scalar_zero_scale_uint8) {
auto function = onnx_import::import_onnx_model(file_util::path_join(CommonTestUtils::getExecutableDirectory(),
SERIALIZED_ZOO,

View File

@ -890,6 +890,35 @@ struct ShapeParams {
bool can_fuse;
};
enum class OpType {
ADD,
SUBTRACT,
SUBTRACT_WITH_CONVERT,
MULTIPLY,
DIVIDE,
};
static std::ostream& operator<<(std::ostream& os, OpType kind) {
switch (kind) {
case OpType::ADD:
os << "add";
break;
case OpType::SUBTRACT:
os << "subtract";
break;
case OpType::SUBTRACT_WITH_CONVERT:
os << "subtract_with_convert";
break;
case OpType::MULTIPLY:
os << "multiply";
break;
case OpType::DIVIDE:
os << "divide";
break;
}
return os;
}
enum class ConstantKind {
ZERO,
ONE,
@ -912,7 +941,7 @@ static std::ostream& operator<<(std::ostream& os, ConstantKind kind) {
}
struct TypeParams {
EltwiseTypes op_type;
OpType op_type;
ConstantKind constant_kind;
bool can_fuse;
};
@ -936,6 +965,14 @@ class EliminateEltwiseTests: public testing::WithParamInterface<EliminateEltwise
}
};
std::vector<element::Type> types{
element::f32, element::f64,
element::i32, element::u32,
element::i64, element::u64,
element::i8, element::u8,
element::i16, element::u16,
};
TEST_P(EliminateEltwiseTests, eliminate_eltwise) {
auto params = GetParam();
const auto& shape_params = std::get<0>(params);
@ -947,45 +984,60 @@ TEST_P(EliminateEltwiseTests, eliminate_eltwise) {
bool can_fuse = shape_params.can_fuse && type_params.can_fuse;
auto parameter = make_shared<op::Parameter>(type, shape1);
auto constant_type = type;
if (type_params.op_type == OpType::SUBTRACT_WITH_CONVERT) {
if (type == types[0])
constant_type = types[1];
else
constant_type = types[0];
}
std::shared_ptr<Node> constant;
switch (type_params.constant_kind) {
case ConstantKind::ZERO:
constant = op::Constant::create(type, shape2, {0});
constant = op::Constant::create(constant_type, shape2, {0});
break;
case ConstantKind::ONE:
constant = op::Constant::create(type, shape2, {1});
constant = op::Constant::create(constant_type, shape2, {1});
break;
case ConstantKind::RANDOM:
constant = builder::makeConstant(type, shape2, {}, true, 20 /* upTo */, 2 /* startFrom */);
constant = builder::makeConstant(constant_type, shape2, {}, true, 20 /* upTo */, 2 /* startFrom */);
break;
}
if (type_params.op_type == OpType::SUBTRACT_WITH_CONVERT) {
constant = std::make_shared<opset8::Convert>(constant, type);
}
shared_ptr<Node> A = parameter;
shared_ptr<Node> B = constant;
if (swap_inputs) {
std::swap(A, B);
if (type_params.op_type == EltwiseTypes::SUBTRACT ||
type_params.op_type == EltwiseTypes::DIVIDE) {
if (type_params.op_type == OpType::SUBTRACT ||
type_params.op_type == OpType::SUBTRACT_WITH_CONVERT ||
type_params.op_type == OpType::DIVIDE) {
can_fuse = false;
}
}
shared_ptr<Node> node;
switch (type_params.op_type) {
case EltwiseTypes::ADD:
case OpType::ADD:
node = make_shared<opset8::Add>(A, B);
break;
case EltwiseTypes::SUBTRACT:
case OpType::SUBTRACT:
case OpType::SUBTRACT_WITH_CONVERT:
node = make_shared<opset8::Subtract>(A, B);
break;
case EltwiseTypes::MULTIPLY:
case OpType::MULTIPLY:
node = make_shared<opset8::Multiply>(A, B);
break;
case EltwiseTypes::DIVIDE:
case OpType::DIVIDE:
node = make_shared<opset8::Divide>(A, B);
break;
default:
ASSERT_FALSE(true) << "Invalid EltwiseType";
ASSERT_FALSE(true) << "Invalid OpType";
}
auto abs = make_shared<opset8::Abs>(node);
function = make_shared<Function>(abs, ParameterVector{parameter});
@ -1045,22 +1097,16 @@ std::vector<ShapeParams> shape_params = {
std::vector<TypeParams> type_params = {
// op type, constant value, can fuse
{ EltwiseTypes::ADD, ConstantKind::ZERO, true },
{ EltwiseTypes::ADD, ConstantKind::RANDOM, false },
{ EltwiseTypes::SUBTRACT, ConstantKind::ZERO, true },
{ EltwiseTypes::SUBTRACT, ConstantKind::RANDOM, false },
{ EltwiseTypes::MULTIPLY, ConstantKind::ONE, true },
{ EltwiseTypes::MULTIPLY, ConstantKind::RANDOM, false },
{ EltwiseTypes::DIVIDE, ConstantKind::ONE, true },
{ EltwiseTypes::DIVIDE, ConstantKind::RANDOM, false },
};
std::vector<element::Type> types{
element::f32, element::f64,
element::i32, element::u32,
element::i64, element::u64,
element::i8, element::u8,
element::i16, element::u16,
{ OpType::ADD, ConstantKind::ZERO, true },
{ OpType::ADD, ConstantKind::RANDOM, false },
{ OpType::SUBTRACT, ConstantKind::ZERO, true },
{ OpType::SUBTRACT, ConstantKind::RANDOM, false },
{ OpType::SUBTRACT_WITH_CONVERT, ConstantKind::ZERO, true },
{ OpType::SUBTRACT_WITH_CONVERT, ConstantKind::RANDOM, false },
{ OpType::MULTIPLY, ConstantKind::ONE, true },
{ OpType::MULTIPLY, ConstantKind::RANDOM, false },
{ OpType::DIVIDE, ConstantKind::ONE, true },
{ OpType::DIVIDE, ConstantKind::RANDOM, false },
};
INSTANTIATE_TEST_SUITE_P(EliminateEltwise, EliminateEltwiseTests,