From f96f021920f91a536831118a64b216887e109c0f Mon Sep 17 00:00:00 2001 From: Maxim Vafin Date: Sat, 22 Jul 2023 16:47:44 +0200 Subject: [PATCH] [PT FE] Add helper for regular ops on quantized values (#18692) * Add helper for regular ops * Update op_table.cpp * Support ops with more then 1 output * Uncheck ops that return integer/boolean type --- src/frontends/pytorch/src/op_table.cpp | 47 +++---- src/frontends/pytorch/src/utils_quantize.cpp | 86 +++++-------- src/frontends/pytorch/src/utils_quantize.hpp | 118 ++++++++++-------- .../pytorch_tests/test_quantized_linear.py | 40 +++++- 4 files changed, 157 insertions(+), 134 deletions(-) diff --git a/src/frontends/pytorch/src/op_table.cpp b/src/frontends/pytorch/src/op_table.cpp index ff914b33b5d..49f54c7043c 100644 --- a/src/frontends/pytorch/src/op_table.cpp +++ b/src/frontends/pytorch/src/op_table.cpp @@ -6,6 +6,7 @@ #include "openvino/opsets/opset10.hpp" #include "utils.hpp" +#include "utils_quantize.hpp" namespace ov { namespace frontend { @@ -190,9 +191,9 @@ const std::map get_supported_ops() { {"aten::acos_", op::inplace_op>}, {"aten::acosh", op::translate_1to1_match_1_inputs_with_fp32_type_alignment}, {"aten::acosh_", op::inplace_op>}, - {"aten::adaptive_avg_pool2d", op::translate_1to1_match_2_inputs}, - {"aten::adaptive_avg_pool3d", op::translate_adaptive_avg_pool3d}, - {"aten::adaptive_max_pool2d", op::translate_adaptive_max_pool2d}, + {"aten::adaptive_avg_pool2d", op::quantizable_op>}, + {"aten::adaptive_avg_pool3d", op::quantizable_op}, + {"aten::adaptive_max_pool2d", op::quantizable_op}, {"aten::add", op::translate_add}, {"aten::add_", op::inplace_op}, {"aten::addcmul", op::translate_addcmul}, @@ -211,9 +212,9 @@ const std::map get_supported_ops() { {"aten::atan_", op::inplace_op>}, {"aten::atanh", op::translate_1to1_match_1_inputs_with_fp32_type_alignment}, {"aten::atanh_", op::inplace_op>}, - {"aten::avg_pool1d", op::translate_avg_poolnd}, - {"aten::avg_pool2d", op::translate_avg_poolnd}, - {"aten::avg_pool3d", op::translate_avg_poolnd}, + {"aten::avg_pool1d", op::quantizable_op}, + {"aten::avg_pool2d", op::quantizable_op}, + {"aten::avg_pool3d", op::quantizable_op}, {"aten::baddbmm", op::translate_addmm}, {"aten::batch_norm", op::translate_batch_norm}, {"aten::bitwise_not", op::translate_bitwise_not}, @@ -262,7 +263,7 @@ const std::map get_supported_ops() { {"aten::fake_quantize_per_channel_affine", op::translate_fake_quantize_per_channel_affine}, {"aten::fake_quantize_per_tensor_affine", op::translate_fake_quantize_per_tensor_affine}, {"aten::fill_", op::inplace_op}, - {"aten::flatten", op::translate_flatten}, + {"aten::flatten", op::quantizable_op}, {"aten::flip", op::translate_flip}, {"aten::floor", op::translate_1to1_match_1_inputs}, {"aten::floor_", op::inplace_op>}, @@ -278,11 +279,11 @@ const std::map get_supported_ops() { {"aten::grid_sampler", op::translate_grid_sampler}, {"aten::group_norm", op::translate_group_norm}, {"aten::gt", op::translate_1to1_match_2_inputs_align_types}, - {"aten::hardsigmoid", op::translate_1to1_match_1_inputs}, - {"aten::hardswish", op::translate_1to1_match_1_inputs}, - {"aten::hardswish_", op::inplace_op>}, - {"aten::hardtanh", op::translate_hardtanh}, - {"aten::hardtanh_", op::inplace_op}, + {"aten::hardsigmoid", op::quantizable_op>}, + {"aten::hardswish", op::quantizable_op>}, + {"aten::hardswish_", op::quantizable_op>>}, + {"aten::hardtanh", op::quantizable_op}, + {"aten::hardtanh_", op::inplace_op>}, {"aten::im2col", op::translate_im2col}, {"aten::index_put_", op::inplace_op}, {"aten::index_select", op::translate_index_select}, @@ -314,10 +315,10 @@ const std::map get_supported_ops() { {"aten::masked_fill_", op::inplace_op}, {"aten::matmul", op::translate_1to1_match_2_inputs}, {"aten::max", op::translate_max}, - {"aten::max_pool1d", op::translate_max_poolnd}, - {"aten::max_pool2d", op::translate_max_poolnd}, - {"aten::max_pool3d", op::translate_max_poolnd}, - {"aten::mean", op::translate_mean}, + {"aten::max_pool1d", op::quantizable_op}, + {"aten::max_pool2d", op::quantizable_op}, + {"aten::max_pool3d", op::quantizable_op}, + {"aten::mean", op::quantizable_op}, {"aten::meshgrid", op::translate_meshgrid}, {"aten::min", op::translate_min}, {"aten::mm", op::translate_1to1_match_2_inputs}, @@ -364,7 +365,7 @@ const std::map get_supported_ops() { {"aten::scaled_dot_product_attention", op::translate_scaled_dot_product_attention}, {"aten::scatter", op::translate_scatter}, {"aten::scatter_", op::inplace_op}, - {"aten::select", op::translate_select}, + {"aten::select", op::quantizable_op}, {"aten::selu", op::translate_selu}, {"aten::selu_", op::inplace_op}, {"aten::sigmoid", op::translate_1to1_match_1_inputs_with_fp32_type_alignment}, @@ -377,12 +378,12 @@ const std::map get_supported_ops() { {"aten::sinh", op::translate_1to1_match_1_inputs_with_fp32_type_alignment}, {"aten::sinh_", op::inplace_op>}, {"aten::size", op::translate_size}, - {"aten::slice", op::translate_slice}, + {"aten::slice", op::quantizable_op}, {"aten::softmax", op::translate_softmax}, {"aten::sort", op::translate_sort}, {"aten::sqrt", op::translate_1to1_match_1_inputs_with_fp32_type_alignment}, {"aten::square", op::translate_square}, - {"aten::squeeze", op::translate_squeeze}, + {"aten::squeeze", op::quantizable_op}, {"aten::sub", op::translate_sub}, {"aten::sub_", op::inplace_op}, {"aten::sum", op::translate_sum}, @@ -395,7 +396,7 @@ const std::map get_supported_ops() { {"aten::tensor", op::translate_as_tensor}, {"aten::to", op::translate_to}, {"aten::topk", op::translate_topk}, - {"aten::transpose", op::translate_transpose}, + {"aten::transpose", op::quantizable_op}, {"aten::tril", op::translate_tril}, {"aten::tril_", op::inplace_op}, {"aten::triu", op::translate_triu}, @@ -404,8 +405,8 @@ const std::map get_supported_ops() { op::translate_1to1_match_2_inputs}, // TODO: overflow semantics is different {"aten::unflatten", op::translate_unflatten}, {"aten::unfold", op::translate_unfold}, - {"aten::unsqueeze", op::translate_1to1_match_2_inputs}, - {"aten::unsqueeze_", op::inplace_op>}, + {"aten::unsqueeze", op::quantizable_op>}, + {"aten::unsqueeze_", op::quantizable_op>>}, {"aten::upsample_bicubic2d", op::translate_upsample_bicubic2d}, {"aten::upsample_bilinear2d", op::translate_upsample_bilinear2d}, {"aten::_upsample_bicubic2d_aa", op::translate_upsample_bicubic2d_aa}, @@ -417,7 +418,7 @@ const std::map get_supported_ops() { {"aten::upsample_trilinear3d", op::translate_upsample_trilinear3d}, {"aten::var", op::translate_var}, {"aten::var_mean", op::translate_var_mean}, - {"aten::view", op::translate_reshape}, + {"aten::view", op::quantizable_op}, {"aten::where", op::translate_where}, {"aten::zero_", op::inplace_op}, {"aten::zeros", op::translate_zeros}, diff --git a/src/frontends/pytorch/src/utils_quantize.cpp b/src/frontends/pytorch/src/utils_quantize.cpp index 7e666f5979d..a4132c00ec0 100644 --- a/src/frontends/pytorch/src/utils_quantize.cpp +++ b/src/frontends/pytorch/src/utils_quantize.cpp @@ -20,13 +20,13 @@ namespace pytorch { using namespace ov::op; -ov::Output quantize(const NodeContext& context, - std::shared_ptr input, - std::shared_ptr scale, - std::shared_ptr zero_point, - std::shared_ptr axis, - ov::element::Type dtype, - QuantizedPtNodeType quantization_type) { +Output quantize(const NodeContext& context, + const Output& input, + const Output& scale, + const Output& zero_point, + const Output& axis, + element::Type dtype, + QuantizedPtNodeType quantization_type) { if (quantization_type == QuantizedPtNodeType::QUANTIZE_PER_TENSOR) { const auto input_convert = context.mark_node(std::make_shared(input, element::f32)); const auto scale_convert = context.mark_node(std::make_shared(scale, element::f32)); @@ -63,7 +63,7 @@ ov::Output quantize(const NodeContext& context, zero_point_convert, dtype)); } else if (quantization_type == QuantizedPtNodeType::QUANTIZE_PER_CHANNEL) { - FRONT_END_OP_CONVERSION_CHECK(axis, "Axis cannot be null for quantize_per_channel."); + FRONT_END_OP_CONVERSION_CHECK(axis.get_node(), "Axis cannot be null for quantize_per_channel."); const auto input_convert = context.mark_node(std::make_shared(input, element::f32)); const auto scales_convert = context.mark_node(std::make_shared(scale, element::f32)); const auto zero_points_convert = context.mark_node(std::make_shared(zero_point, element::f32)); @@ -119,46 +119,19 @@ ov::Output quantize(const NodeContext& context, FRONT_END_OP_CONVERSION_CHECK(false, "Got unknown quantization method in quantize."); } -// ======================================================== - -ov::Output quantize(const NodeContext& context, - ov::Output input, - ov::Output scale, - ov::Output zero_point, - ov::element::Type dtype, - QuantizedPtNodeType quantization_type) { - return quantize(context, - input.get_node_shared_ptr(), - scale.get_node_shared_ptr(), - zero_point.get_node_shared_ptr(), - nullptr, - dtype, - quantization_type); +Output quantize(const NodeContext& context, + const Output& input, + const Output& scale, + const Output& zero_point, + element::Type dtype, + QuantizedPtNodeType quantization_type) { + return quantize(context, input, scale, zero_point, Output(), dtype, quantization_type); } -ov::Output quantize(const NodeContext& context, - ov::Output input, - ov::Output scale, - ov::Output zero_point, - ov::Output axis, - ov::element::Type dtype, - QuantizedPtNodeType quantization_type) { - return quantize(context, - input.get_node_shared_ptr(), - scale.get_node_shared_ptr(), - zero_point.get_node_shared_ptr(), - axis.get_node_shared_ptr(), - dtype, - quantization_type); -} - -ov::Output quantize(const NodeContext& context, - ov::Output input, - ov::Output quantized_node) { - std::shared_ptr quantized_pt_node; - if ((quantized_pt_node = cast_quantized_fw_node(quantized_node.get_node_shared_ptr()))) { +Output quantize(const NodeContext& context, const Output& input, const Output& quantized_node) { + if (const auto quantized_pt_node = cast_quantized_fw_node(quantized_node.get_node_shared_ptr())) { return quantize(context, - input.get_node_shared_ptr(), + input, quantized_pt_node->get_scale(), quantized_pt_node->get_zero_point(), quantized_pt_node->get_axis(), @@ -168,17 +141,16 @@ ov::Output quantize(const NodeContext& context, FRONT_END_OP_CONVERSION_CHECK(false, "Failed to convert a node to QuantizedPtNode"); } -ov::Output quantize(const NodeContext& context, - ov::Output input, - ov::Output scale, - ov::Output zero_point, - ov::Output quantized_node) { - std::shared_ptr quantized_pt_node; - if ((quantized_pt_node = cast_quantized_fw_node(quantized_node.get_node_shared_ptr()))) { +Output quantize(const NodeContext& context, + const Output& input, + const Output& scale, + const Output& zero_point, + const Output& quantized_node) { + if (const auto quantized_pt_node = cast_quantized_fw_node(quantized_node.get_node_shared_ptr())) { return quantize(context, - input.get_node_shared_ptr(), - scale.get_node_shared_ptr(), - zero_point.get_node_shared_ptr(), + input, + scale, + zero_point, quantized_pt_node->get_axis(), quantized_pt_node->get_dtype(), quantized_pt_node->get_type()); @@ -186,7 +158,7 @@ ov::Output quantize(const NodeContext& context, FRONT_END_OP_CONVERSION_CHECK(false, "Failed to convert a node to QuantizedPtNode"); } -std::shared_ptr cast_quantized_fw_node(ov::Output node) { +std::shared_ptr cast_quantized_fw_node(Output node) { auto quant_node = std::dynamic_pointer_cast(node.get_node_shared_ptr()); if (!quant_node) { return nullptr; @@ -198,7 +170,7 @@ std::shared_ptr cast_quantized_fw_node(ov::Output node) { return quant_node; } -std::shared_ptr cast_quantized_fw_node(ov::Output node, const std::string& type) { +std::shared_ptr cast_quantized_fw_node(Output node, const std::string& type) { auto quant_node = std::dynamic_pointer_cast(node.get_node_shared_ptr()); if (!quant_node) { return nullptr; diff --git a/src/frontends/pytorch/src/utils_quantize.hpp b/src/frontends/pytorch/src/utils_quantize.hpp index 675ff924e29..de7cf2bfd3b 100644 --- a/src/frontends/pytorch/src/utils_quantize.hpp +++ b/src/frontends/pytorch/src/utils_quantize.hpp @@ -15,22 +15,19 @@ enum QuantizedPtNodeType { QUANTIZE_PER_TENSOR, QUANTIZE_PER_CHANNEL }; class QuantizedPtNode : public PtFrameworkNode { public: - OPENVINO_OP("QuantizedPtNode", "util", ::ov::frontend::pytorch::PtFrameworkNode); + OPENVINO_OP("QuantizedPtNode", "util", PtFrameworkNode); static constexpr const char* quantized_node_type_key = "QuantizedPtTypeName"; static constexpr const char* quantize_per_tensor = "quantize_per_tensor"; static constexpr const char* quantize_per_channel = "quantize_per_channel"; QuantizedPtNode(const QuantizedPtNodeType type, const NodeContext& context, - const ov::Output input, - const ov::Output scale, - const ov::Output zero_point, + const Output input, + const Output scale, + const Output zero_point, element::Type& dtype) - : PtFrameworkNode(context.get_decoder(), {input}, 1, false), - type(type), - scale(scale.get_node_shared_ptr()), - zero_point(zero_point.get_node_shared_ptr()), - axis(nullptr) { + : PtFrameworkNode(context.get_decoder(), {input, scale, zero_point}, 1, false), + type(type) { ov::op::util::FrameworkNodeAttrs attrs = get_attrs(); if (type == QuantizedPtNodeType::QUANTIZE_PER_TENSOR) { attrs[quantized_node_type_key] = quantize_per_tensor; @@ -45,16 +42,13 @@ public: QuantizedPtNode(const QuantizedPtNodeType type, const NodeContext& context, - const ov::Output input, - const ov::Output scale, - const ov::Output zero_point, - const ov::Output axis, + const Output input, + const Output scale, + const Output zero_point, + const Output axis, element::Type& dtype) - : PtFrameworkNode(context.get_decoder(), {input}, 1, false), - type(type), - scale(scale.get_node_shared_ptr()), - zero_point(zero_point.get_node_shared_ptr()), - axis(axis.get_node_shared_ptr()) { + : PtFrameworkNode(context.get_decoder(), {input, scale, zero_point, axis}, 1, false), + type(type) { ov::op::util::FrameworkNodeAttrs attrs = get_attrs(); if (type == QuantizedPtNodeType::QUANTIZE_PER_TENSOR) { attrs[quantized_node_type_key] = quantize_per_tensor; @@ -67,66 +61,86 @@ public: this->dtype = dtype; } - const std::shared_ptr get_scale() { - return scale; + const Output get_scale() const { + return input_value(1); } - const std::shared_ptr get_zero_point() { - return zero_point; + const Output get_zero_point() const { + return input_value(2); } - const std::shared_ptr get_axis() { - return axis; + const Output get_axis() const { + if (inputs().size() < 4) { + return Output(); + } + return input_value(3); } - const QuantizedPtNodeType get_type() { + const QuantizedPtNodeType get_type() const { return type; } - const element::Type get_dtype() { + const element::Type get_dtype() const { return dtype; } private: const QuantizedPtNodeType type; - std::shared_ptr scale; - std::shared_ptr zero_point; - std::shared_ptr axis; element::Type dtype; }; /** * Quantizes input node with the given parameters. Returns a shared pointer to the new QuantizedPtNode. */ -ov::Output quantize(const NodeContext& context, - ov::Output input, - ov::Output scale, - ov::Output zero_point, - ov::element::Type dtype, - QuantizedPtNodeType quantization_type); -ov::Output quantize(const NodeContext& context, - ov::Output input, - ov::Output scale, - ov::Output zero_point, - ov::Output axis, - ov::element::Type dtype, - QuantizedPtNodeType quantization_type); +Output quantize(const NodeContext& context, + const Output& input, + const Output& scale, + const Output& zero_point, + element::Type dtype, + QuantizedPtNodeType quantization_type); +Output quantize(const NodeContext& context, + const Output& input, + const Output& scale, + const Output& zero_point, + const Output& axis, + element::Type dtype, + QuantizedPtNodeType quantization_type); /** * Quantizes input node like the quantized node. Returns a shared pointer to the new QuantizedPtNode. */ -ov::Output quantize(const NodeContext& context, - ov::Output input, - ov::Output quantized_node); +Output quantize(const NodeContext& context, Output input, Output quantized_node); /** * Quantizes input node like the quantized node, with new scale and zero_point parameters. Returns a shared pointer to * the new QuantizedPtNode. */ -ov::Output quantize(const NodeContext& context, - ov::Output input, - ov::Output scale, - ov::Output zero_point, - ov::Output quantized_node); +Output quantize(const NodeContext& context, + const Output& input, + const Output& scale, + const Output& zero_point, + const Output& quantized_node); + +std::shared_ptr cast_quantized_fw_node(Output node); +std::shared_ptr cast_quantized_fw_node(Output node, const std::string& type); + +namespace op { +/** + * Modifies conversion function to support quantized case. When input is quantized it is processed as quantized op. + */ +template +OutputVector quantizable_op(const NodeContext& context) { + auto translation_res = T(context); + FRONT_END_OP_CONVERSION_CHECK(translation_res.size() > out_idx, "Not enough outputs to apply quantization."); + if (const auto quantized_pt_node = cast_quantized_fw_node(context.get_input(in_idx).get_node_shared_ptr())) { + return {quantize(context, + translation_res[out_idx], + quantized_pt_node->get_scale(), + quantized_pt_node->get_zero_point(), + quantized_pt_node->get_axis(), + quantized_pt_node->get_dtype(), + quantized_pt_node->get_type())}; + } + return translation_res; +} +} // namespace op -std::shared_ptr cast_quantized_fw_node(ov::Output node); -std::shared_ptr cast_quantized_fw_node(ov::Output node, const std::string& type); } // namespace pytorch } // namespace frontend } // namespace ov diff --git a/tests/layer_tests/pytorch_tests/test_quantized_linear.py b/tests/layer_tests/pytorch_tests/test_quantized_linear.py index cc30a313d31..4041bd75dc6 100644 --- a/tests/layer_tests/pytorch_tests/test_quantized_linear.py +++ b/tests/layer_tests/pytorch_tests/test_quantized_linear.py @@ -6,6 +6,7 @@ import torch import numpy as np from pytorch_layer_test_class import PytorchLayerTest + class TestQuantizedLinear(PytorchLayerTest): def _prepare_input(self, input_shape=(2, 2)): return (np.random.randn(*input_shape).astype(np.float32),) @@ -16,10 +17,12 @@ class TestQuantizedLinear(PytorchLayerTest): def __init__(self, weight_shape, is_bias, scale, zero_point): super(aten_quantized_linear, self).__init__() if is_bias: - self.linear = torch.ao.nn.quantized.Linear(weight_shape[-1], weight_shape[0], True) + self.linear = torch.ao.nn.quantized.Linear( + weight_shape[-1], weight_shape[0], True) torch.nn.init.normal_(self.linear.bias()) else: - self.linear = torch.ao.nn.quantized.Linear(weight_shape[-1], weight_shape[0], False) + self.linear = torch.ao.nn.quantized.Linear( + weight_shape[-1], weight_shape[0], False) self.linear.scale = float(scale) self.linear.zero_point = int(zero_point) @@ -31,6 +34,31 @@ class TestQuantizedLinear(PytorchLayerTest): return aten_quantized_linear(weight_shape, is_bias, scale, zero_point), ref_net, "quantized::linear" + def create_hardtanh_model(self, weight_shape, is_bias, scale, zero_point, inplace): + + class aten_quantized_linear(torch.nn.Module): + def __init__(self, weight_shape, is_bias, scale, zero_point, inplace): + super(aten_quantized_linear, self).__init__() + self.hardtanh = torch.nn.Hardtanh(inplace=inplace) + if is_bias: + self.linear = torch.ao.nn.quantized.Linear( + weight_shape[-1], weight_shape[0], True) + torch.nn.init.normal_(self.linear.bias()) + else: + self.linear = torch.ao.nn.quantized.Linear( + weight_shape[-1], weight_shape[0], False) + self.linear.scale = float(scale) + self.linear.zero_point = int(zero_point) + + def forward(self, inp): + inp_q = torch.quantize_per_tensor(inp, 1., 0, torch.quint8) + inp_q = self.hardtanh(inp_q) + return torch.dequantize(self.linear(inp_q)) + + ref_net = None + + return aten_quantized_linear(weight_shape, is_bias, scale, zero_point, inplace), ref_net, ["quantized::linear", "aten::hardtanh_" if inplace else "aten::hardtanh"] + @pytest.mark.parametrize("params", [ {'input_shape': [3, 9], 'weight_shape': [10, 9]}, {'input_shape': [3, 9], 'weight_shape': [9]}, @@ -51,3 +79,11 @@ class TestQuantizedLinear(PytorchLayerTest): bias = params.get("bias", False) self._test(*self.create_model(weight_shape, bias, scale, zero_point), ie_device, precision, ir_version, kwargs_to_prepare_input={"input_shape": input_shape}, trace_model=trace, freeze_model=False) + + @pytest.mark.parametrize("trace", [True, False]) + @pytest.mark.parametrize("inplace", [True, False]) + @pytest.mark.nightly + @pytest.mark.precommit + def test_quantized_hardtanh_linear(self, trace, inplace, ie_device, precision, ir_version): + self._test(*self.create_hardtanh_model([10, 9], True, 1, 0.3, inplace), ie_device, precision, ir_version, + kwargs_to_prepare_input={"input_shape": [2, 3, 9]}, trace_model=trace, freeze_model=False)