diff --git a/src/bindings/python/src/openvino/frontend/pytorch/decoder.py b/src/bindings/python/src/openvino/frontend/pytorch/decoder.py index bbecf3014d8..696e08048d3 100644 --- a/src/bindings/python/src/openvino/frontend/pytorch/decoder.py +++ b/src/bindings/python/src/openvino/frontend/pytorch/decoder.py @@ -95,6 +95,9 @@ pt_to_ov_type_map = { "torch.IntTensor": OVType.i32, "torch.LongTensor": OVType.i64, "torch.BoolTensor": OVType.boolean, + "torch.quint8": OVType.u8, + "torch.qint8": OVType.i8, + "torch.qint32": OVType.i32 } diff --git a/src/frontends/pytorch/src/frontend.cpp b/src/frontends/pytorch/src/frontend.cpp index d6bc13fa7b9..ae9e894a30b 100644 --- a/src/frontends/pytorch/src/frontend.cpp +++ b/src/frontends/pytorch/src/frontend.cpp @@ -25,6 +25,7 @@ #include "transforms/aten_index_put_replacer.hpp" #include "transforms/aten_index_replacer.hpp" #include "transforms/aten_stack_list_construct_replacer.hpp" +#include "transforms/dequantize_node_remover.hpp" #include "transforms/dict_resolver.hpp" #include "transforms/einsum_list_construct.hpp" #include "transforms/index_loop_getitem_replacer.hpp" @@ -35,6 +36,7 @@ #include "transforms/prim_list_tuple_construct_replacer.hpp" #include "transforms/prim_list_unpack_replacer.hpp" #include "transforms/prim_tuple_unpack_parameter_replacer.hpp" +#include "transforms/quantized_node_remover.hpp" #include "transforms/rfftn_complex_replacer.hpp" #include "transforms/string_equality_replacer.hpp" #include "transforms/tuple_unpack_replacer.hpp" @@ -182,6 +184,8 @@ void FrontEnd::normalize(const std::shared_ptr& model) const { manager.register_pass(); manager.register_pass(); manager.register_pass(); + manager.register_pass(); + manager.register_pass(); manager.register_pass(); manager.register_pass(); diff --git a/src/frontends/pytorch/src/op/quantize.cpp b/src/frontends/pytorch/src/op/quantize.cpp new file mode 100644 index 00000000000..166839011c8 --- /dev/null +++ b/src/frontends/pytorch/src/op/quantize.cpp @@ -0,0 +1,37 @@ +// Copyright (C) 2018-2023 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include "openvino/frontend/pytorch/node_context.hpp" +#include "utils_quantize.hpp" + +namespace ov { +namespace frontend { +namespace pytorch { +namespace op { + +using namespace ov::op; + +OutputVector translate_quantize_per_tensor(const NodeContext& context) { + num_inputs_check(context, 4, 4); + const auto input = context.get_input(0); + const auto scale = context.get_input(1); + const auto zero_point = context.get_input(2); + const auto dtype = convert_dtype(context.const_input(3)); + return {quantize(context, input, scale, zero_point, dtype, QuantizedPtNodeType::QUANTIZE_PER_TENSOR)}; +} + +OutputVector translate_quantize_per_channel(const NodeContext& context) { + num_inputs_check(context, 5, 5); + const auto input = context.get_input(0); + const auto scales = context.get_input(1); + const auto zero_points = context.get_input(2); + const auto axis = context.get_input(3); + const auto dtype = convert_dtype(context.const_input(4)); + return {quantize(context, input, scales, zero_points, axis, dtype, QuantizedPtNodeType::QUANTIZE_PER_CHANNEL)}; +} + +} // namespace op +} // namespace pytorch +} // namespace frontend +} // namespace ov diff --git a/src/frontends/pytorch/src/op_table.cpp b/src/frontends/pytorch/src/op_table.cpp index 5c21e57ed5d..56f3249112d 100644 --- a/src/frontends/pytorch/src/op_table.cpp +++ b/src/frontends/pytorch/src/op_table.cpp @@ -106,6 +106,8 @@ OP_CONVERTER(translate_pad); OP_CONVERTER(translate_pairwise_distance); OP_CONVERTER(translate_pow); OP_CONVERTER(translate_pythonop); +OP_CONVERTER(translate_quantize_per_channel); +OP_CONVERTER(translate_quantize_per_tensor); OP_CONVERTER(translate_range_length); OP_CONVERTER(translate_reciprocal); OP_CONVERTER(translate_relu6); @@ -326,6 +328,8 @@ const std::map get_supported_ops() { {"aten::pairwise_distance", op::translate_pairwise_distance}, {"aten::permute", op::translate_1to1_match_2_inputs}, {"aten::pow", op::translate_pow}, + {"aten::quantize_per_channel", op::translate_quantize_per_channel}, + {"aten::quantize_per_tensor", op::translate_quantize_per_tensor}, {"aten::reciprocal", op::translate_reciprocal}, {"aten::relu", op::translate_1to1_match_1_inputs}, {"aten::relu_", op::inplace_op>}, diff --git a/src/frontends/pytorch/src/transforms/dequantize_node_remover.cpp b/src/frontends/pytorch/src/transforms/dequantize_node_remover.cpp new file mode 100644 index 00000000000..93cc100f5e0 --- /dev/null +++ b/src/frontends/pytorch/src/transforms/dequantize_node_remover.cpp @@ -0,0 +1,42 @@ +// Copyright (C) 2018-2023 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include "dequantize_node_remover.hpp" + +#include +#include + +#include "openvino/core/rt_info.hpp" +#include "openvino/pass/pattern/matcher.hpp" +#include "openvino/pass/pattern/op/wrap_type.hpp" +#include "utils.hpp" +#include "utils_quantize.hpp" + +namespace ov { +namespace frontend { +namespace pytorch { +namespace pass { + +DequantizeNodeRemover::DequantizeNodeRemover() { + auto dequantize_node = ov::pass::pattern::wrap_type(); + + ov::matcher_pass_callback callback = [](ov::pass::pattern::Matcher& m) { + auto dequantize_node = cast_fw_node(m.get_match_root(), "aten::dequantize"); + if (!dequantize_node) + return false; + + auto dequantized_input = dequantize_node->input_value(0); + dequantize_node->output(0).replace(dequantized_input); + return true; + }; + + auto m = std::make_shared(dequantize_node, + "ov::frontend::pytorch::pass::DequantizeNodeRemover"); + this->register_matcher(m, callback); +}; + +} // namespace pass +} // namespace pytorch +} // namespace frontend +} // namespace ov diff --git a/src/frontends/pytorch/src/transforms/dequantize_node_remover.hpp b/src/frontends/pytorch/src/transforms/dequantize_node_remover.hpp new file mode 100644 index 00000000000..a986774ee32 --- /dev/null +++ b/src/frontends/pytorch/src/transforms/dequantize_node_remover.hpp @@ -0,0 +1,29 @@ +// Copyright (C) 2018-2023 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#include "openvino/pass/graph_rewrite.hpp" +#include "openvino/pass/pass.hpp" + +namespace ov { +namespace frontend { +namespace pytorch { +namespace pass { + +/** + * Dequantize Node Remover + * Replacer finds the unconverted dequantize ops and removes them. + * This matches the behavior of OV's LPT. + */ +class DequantizeNodeRemover : public ov::pass::MatcherPass { +public: + OPENVINO_RTTI("ov::frontend::pytorch::pass::DequantizeNodeRemover"); + DequantizeNodeRemover(); +}; + +} // namespace pass +} // namespace pytorch +} // namespace frontend +} // namespace ov diff --git a/src/frontends/pytorch/src/transforms/quantized_node_remover.cpp b/src/frontends/pytorch/src/transforms/quantized_node_remover.cpp new file mode 100644 index 00000000000..2eb5b012d4b --- /dev/null +++ b/src/frontends/pytorch/src/transforms/quantized_node_remover.cpp @@ -0,0 +1,42 @@ +// Copyright (C) 2018-2023 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include "quantized_node_remover.hpp" + +#include +#include + +#include "openvino/core/rt_info.hpp" +#include "openvino/pass/pattern/matcher.hpp" +#include "openvino/pass/pattern/op/wrap_type.hpp" +#include "utils.hpp" +#include "utils_quantize.hpp" + +namespace ov { +namespace frontend { +namespace pytorch { +namespace pass { + +QuantizedNodeRemover::QuantizedNodeRemover() { + auto quantized_pt_node = ov::pass::pattern::wrap_type(); + + ov::matcher_pass_callback callback = [](ov::pass::pattern::Matcher& m) { + auto quantized_pt_node = cast_quantized_fw_node(m.get_match_root()); + if (!quantized_pt_node) + return false; + + auto quantized_input = quantized_pt_node->input_value(0); + quantized_pt_node->output(0).replace(quantized_input); + return true; + }; + + auto m = std::make_shared(quantized_pt_node, + "ov::frontend::pytorch::pass::QuantizedNodeRemover"); + this->register_matcher(m, callback); +}; + +} // namespace pass +} // namespace pytorch +} // namespace frontend +} // namespace ov diff --git a/src/frontends/pytorch/src/transforms/quantized_node_remover.hpp b/src/frontends/pytorch/src/transforms/quantized_node_remover.hpp new file mode 100644 index 00000000000..0a4ec91f433 --- /dev/null +++ b/src/frontends/pytorch/src/transforms/quantized_node_remover.hpp @@ -0,0 +1,30 @@ +// Copyright (C) 2018-2023 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#include "openvino/pass/graph_rewrite.hpp" +#include "openvino/pass/pass.hpp" + +namespace ov { +namespace frontend { +namespace pytorch { +namespace pass { + +/** + * Quantized Node Remover + * Removes QuantizedNodes from the graph. + * These nodes are created in translation processes to propagate scale/zero_point information, + * and are not needed in the final graph. + */ +class QuantizedNodeRemover : public ov::pass::MatcherPass { +public: + OPENVINO_RTTI("ov::frontend::pytorch::pass::QuantizedNodeRemover"); + QuantizedNodeRemover(); +}; + +} // namespace pass +} // namespace pytorch +} // namespace frontend +} // namespace ov diff --git a/src/frontends/pytorch/src/utils.cpp b/src/frontends/pytorch/src/utils.cpp index 7a24d9e447c..53d78ceec1f 100644 --- a/src/frontends/pytorch/src/utils.cpp +++ b/src/frontends/pytorch/src/utils.cpp @@ -135,15 +135,20 @@ std::shared_ptr numel(const NodeContext& context, const Output& x) { }; namespace { -const std::unordered_map TORCH_TO_OV_TYPE{{0, element::u8}, - {1, element::i8}, - {2, element::i16}, - {3, element::i32}, - {4, element::i64}, - {5, element::f16}, - {6, element::f32}, - {7, element::f64}, - {11, element::boolean}}; +const std::unordered_map TORCH_TO_OV_TYPE{ + {0, element::u8}, + {1, element::i8}, + {2, element::i16}, + {3, element::i32}, + {4, element::i64}, + {5, element::f16}, + {6, element::f32}, + {7, element::f64}, + {11, element::boolean}, + {12, element::i8}, // quantized i8 + {13, element::u8}, // quantized u8 + {14, element::i32} // quantized i32 +}; const std::unordered_map TORCH_AUTO_PAD_TO_OV{{"valid", ov::op::PadType::VALID}, {"same", ov::op::PadType::SAME_UPPER}}; diff --git a/src/frontends/pytorch/src/utils_quantize.cpp b/src/frontends/pytorch/src/utils_quantize.cpp new file mode 100644 index 00000000000..025a0b9f1ea --- /dev/null +++ b/src/frontends/pytorch/src/utils_quantize.cpp @@ -0,0 +1,180 @@ +// Copyright (C) 2018-2023 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include "utils_quantize.hpp" + +#include "openvino/frontend/pytorch/node_context.hpp" +#include "openvino/op/broadcast.hpp" +#include "openvino/op/convert.hpp" +#include "openvino/op/convert_like.hpp" +#include "openvino/op/fake_quantize.hpp" +#include "openvino/op/multiply.hpp" +#include "openvino/op/reshape.hpp" +#include "openvino/op/scatter_elements_update.hpp" +#include "openvino/op/subtract.hpp" + +namespace ov { +namespace frontend { +namespace pytorch { + +using namespace ov::op; + +ov::Output quantize(const NodeContext& context, + std::shared_ptr input, + std::shared_ptr scale, + std::shared_ptr zero_point, + std::shared_ptr axis, + ov::element::Type dtype, + QuantizedPtNodeType quantization_type) { + if (quantization_type == QuantizedPtNodeType::QUANTIZE_PER_TENSOR) { + const auto input_convert = context.mark_node(std::make_shared(input, element::f32)); + const auto scale_convert = context.mark_node(std::make_shared(scale, element::f32)); + const auto zero_point_convert = context.mark_node(std::make_shared(zero_point, element::f32)); + + int64_t out_low_i64, out_high_i64; + if (dtype == element::u8) { + out_low_i64 = (int64_t)std::numeric_limits::lowest(); + out_high_i64 = (int64_t)std::numeric_limits::max(); + } else if (dtype == element::i8) { + out_low_i64 = (int64_t)std::numeric_limits::lowest(); + out_high_i64 = (int64_t)std::numeric_limits::max(); + } else { // i32 + out_low_i64 = (int64_t)std::numeric_limits::lowest(); + out_high_i64 = (int64_t)std::numeric_limits::max(); + } + int64_t levels = out_high_i64 - out_low_i64 + 1; + const auto out_low = context.mark_node(v0::Constant::create(element::f32, Shape{}, {out_low_i64})); + const auto out_high = context.mark_node(v0::Constant::create(element::f32, Shape{}, {out_high_i64})); + const auto out_low_normalized = context.mark_node(std::make_shared(out_low, zero_point_convert)); + const auto out_high_normalized = + context.mark_node(std::make_shared(out_high, zero_point_convert)); + + const auto bound_low = context.mark_node(std::make_shared(scale_convert, out_low_normalized)); + const auto bound_high = context.mark_node(std::make_shared(scale_convert, out_high_normalized)); + + const auto quantized_input = context.mark_node( + std::make_shared(input_convert, bound_low, bound_high, bound_low, bound_high, levels)); + + return context.mark_node(std::make_shared(quantization_type, + context, + quantized_input, + scale_convert, + zero_point_convert, + dtype)); + } else if (quantization_type == QuantizedPtNodeType::QUANTIZE_PER_CHANNEL) { + FRONT_END_OP_CONVERSION_CHECK(axis, "Axis cannot be null for quantize_per_channel."); + const auto input_convert = context.mark_node(std::make_shared(input, element::f32)); + const auto scales_convert = context.mark_node(std::make_shared(scale, element::f32)); + const auto zero_points_convert = context.mark_node(std::make_shared(zero_point, element::f32)); + const auto axis_convert = context.mark_node(std::make_shared(zero_point, element::i32)); + + const auto neg_one = context.mark_node(v0::Constant::create(element::i32, Shape{}, {-1})); + const auto zero = context.mark_node(v0::Constant::create(element::i32, Shape{}, {0})); + const auto one = context.mark_node(v0::Constant::create(element::i32, Shape{}, {1})); + + int64_t out_low_i64, out_high_i64; + if (dtype == element::u8) { + out_low_i64 = (int64_t)std::numeric_limits::lowest(); + out_high_i64 = (int64_t)std::numeric_limits::max(); + } else if (dtype == element::i8) { + out_low_i64 = (int64_t)std::numeric_limits::lowest(); + out_high_i64 = (int64_t)std::numeric_limits::max(); + } else { // i32 + out_low_i64 = (int64_t)std::numeric_limits::lowest(); + out_high_i64 = (int64_t)std::numeric_limits::max(); + } + int64_t levels = out_high_i64 - out_low_i64 + 1; + const auto out_low = context.mark_node(v0::Constant::create(element::f32, Shape{}, {out_low_i64})); + const auto out_high = context.mark_node(v0::Constant::create(element::f32, Shape{}, {out_high_i64})); + + const auto rank = std::get<1>(get_shape_rank(context, input_convert)); + const auto ones = context.mark_node(std::make_shared(one, rank)); + const auto normalized_axis = normalize_axis(context, axis_convert, input_convert); + const auto new_shape = + context.mark_node(std::make_shared(ones, normalized_axis, neg_one, zero)); + + const auto scale_bc = context.mark_node(std::make_shared(scales_convert, new_shape, false)); + const auto zero_point_bc = + context.mark_node(std::make_shared(zero_points_convert, new_shape, false)); + + const auto out_low_normalized = context.mark_node(std::make_shared(out_low, zero_point_bc)); + const auto out_high_normalized = context.mark_node(std::make_shared(out_high, zero_point_bc)); + + const auto bound_low = context.mark_node(std::make_shared(scale_bc, out_low_normalized)); + const auto bound_high = context.mark_node(std::make_shared(scale_bc, out_high_normalized)); + + const auto quantized_input = context.mark_node( + std::make_shared(input_convert, out_low, out_high, bound_low, bound_high, levels)); + + return context.mark_node(std::make_shared(quantization_type, + context, + quantized_input, + scale_bc, + zero_point_bc, + dtype)); + } + FRONT_END_OP_CONVERSION_CHECK(false, "Got unknown quantization method in quantize."); +} + +// ======================================================== + +ov::Output quantize(const NodeContext& context, + ov::Output input, + ov::Output scale, + ov::Output zero_point, + ov::element::Type dtype, + QuantizedPtNodeType quantization_type) { + return quantize(context, + input.get_node_shared_ptr(), + scale.get_node_shared_ptr(), + zero_point.get_node_shared_ptr(), + nullptr, + dtype, + quantization_type); +} + +ov::Output quantize(const NodeContext& context, + ov::Output input, + ov::Output scale, + ov::Output zero_point, + ov::Output axis, + ov::element::Type dtype, + QuantizedPtNodeType quantization_type) { + return quantize(context, + input.get_node_shared_ptr(), + scale.get_node_shared_ptr(), + zero_point.get_node_shared_ptr(), + axis.get_node_shared_ptr(), + dtype, + quantization_type); +} + +std::shared_ptr cast_quantized_fw_node(ov::Output node) { + auto quant_node = std::dynamic_pointer_cast(node.get_node_shared_ptr()); + if (!quant_node) { + return nullptr; + } + const auto& attrs = quant_node->get_attrs(); + if (attrs.find(QuantizedPtNode::quantized_node_type_key) == attrs.end()) { + return nullptr; + } + return quant_node; +} + +std::shared_ptr cast_quantized_fw_node(ov::Output node, const std::string& type) { + auto quant_node = std::dynamic_pointer_cast(node.get_node_shared_ptr()); + if (!quant_node) { + return nullptr; + } + const auto& attrs = quant_node->get_attrs(); + if (attrs.find(QuantizedPtNode::quantized_node_type_key) == attrs.end() || + attrs.at(QuantizedPtNode::quantized_node_type_key) != type) { + return nullptr; + } + return quant_node; +} + +} // namespace pytorch +} // namespace frontend +} // namespace ov diff --git a/src/frontends/pytorch/src/utils_quantize.hpp b/src/frontends/pytorch/src/utils_quantize.hpp new file mode 100644 index 00000000000..816f8b30de4 --- /dev/null +++ b/src/frontends/pytorch/src/utils_quantize.hpp @@ -0,0 +1,115 @@ +// Copyright (C) 2018-2023 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#include "openvino/frontend/pytorch/node_context.hpp" +#include "pt_framework_node.hpp" + +namespace ov { +namespace frontend { +namespace pytorch { + +enum QuantizedPtNodeType { QUANTIZE_PER_TENSOR, QUANTIZE_PER_CHANNEL }; + +class QuantizedPtNode : public PtFrameworkNode { +public: + OPENVINO_OP("QuantizedPtNode", "util", ::ov::frontend::pytorch::PtFrameworkNode); + static constexpr const char* quantized_node_type_key = "QuantizedPtTypeName"; + static constexpr const char* quantize_per_tensor = "quantize_per_tensor"; + static constexpr const char* quantize_per_channel = "quantize_per_channel"; + + QuantizedPtNode(const QuantizedPtNodeType type, + const NodeContext& context, + const ov::Output input, + const ov::Output scale, + const ov::Output zero_point, + element::Type& dtype) + : PtFrameworkNode(context.get_decoder(), {input}, 1, false), + type(type), + scale(scale.get_node_shared_ptr()), + zero_point(zero_point.get_node_shared_ptr()), + axis(nullptr) { + ov::op::util::FrameworkNodeAttrs attrs = get_attrs(); + if (type == QuantizedPtNodeType::QUANTIZE_PER_TENSOR) { + attrs[quantized_node_type_key] = quantize_per_tensor; + } else if (type == QuantizedPtNodeType::QUANTIZE_PER_CHANNEL) { + FRONT_END_OP_CONVERSION_CHECK(false, "quantize_per_channel requires axis to be provided."); + } else { + FRONT_END_OP_CONVERSION_CHECK(false, "Unknown QuantizedPtNodeType: ", type); + } + set_attrs(attrs); + this->dtype = dtype; + } + + QuantizedPtNode(const QuantizedPtNodeType type, + const NodeContext& context, + const ov::Output input, + const ov::Output scale, + const ov::Output zero_point, + const ov::Output axis, + element::Type& dtype) + : PtFrameworkNode(context.get_decoder(), {input}, 1, false), + type(type), + scale(scale.get_node_shared_ptr()), + zero_point(zero_point.get_node_shared_ptr()), + axis(axis.get_node_shared_ptr()) { + ov::op::util::FrameworkNodeAttrs attrs = get_attrs(); + if (type == QuantizedPtNodeType::QUANTIZE_PER_TENSOR) { + attrs[quantized_node_type_key] = quantize_per_tensor; + } else if (type == QuantizedPtNodeType::QUANTIZE_PER_CHANNEL) { + attrs[quantized_node_type_key] = quantize_per_channel; + } else { + FRONT_END_OP_CONVERSION_CHECK(false, "Unknown QuantizedPtNodeType: ", type); + } + set_attrs(attrs); + this->dtype = dtype; + } + + const std::shared_ptr get_scale() { + return scale; + } + const std::shared_ptr get_zero_point() { + return zero_point; + } + const std::shared_ptr get_axis() { + return axis; + } + const QuantizedPtNodeType get_type() { + return type; + } + const element::Type get_dtype() { + return dtype; + } + +private: + const QuantizedPtNodeType type; + std::shared_ptr scale; + std::shared_ptr zero_point; + std::shared_ptr axis; + element::Type dtype; +}; + +/** + * Quantizes input node with the given parameters. Returns a shared pointer to the new QuantizedPtNode. + */ +ov::Output quantize(const NodeContext& context, + ov::Output input, + ov::Output scale, + ov::Output zero_point, + ov::element::Type dtype, + QuantizedPtNodeType quantization_type); +ov::Output quantize(const NodeContext& context, + ov::Output input, + ov::Output scale, + ov::Output zero_point, + ov::Output axis, + ov::element::Type dtype, + QuantizedPtNodeType quantization_type); + +std::shared_ptr cast_quantized_fw_node(ov::Output node); +std::shared_ptr cast_quantized_fw_node(ov::Output node, const std::string& type); +} // namespace pytorch +} // namespace frontend +} // namespace ov diff --git a/tests/layer_tests/pytorch_tests/test_quantize.py b/tests/layer_tests/pytorch_tests/test_quantize.py new file mode 100644 index 00000000000..e9ab802280d --- /dev/null +++ b/tests/layer_tests/pytorch_tests/test_quantize.py @@ -0,0 +1,87 @@ +# Copyright (C) 2018-2023 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +import numpy as np +import pytest +import torch + +from pytorch_layer_test_class import PytorchLayerTest + +class aten_quantize_per_tensor_aten_dequantize(torch.nn.Module): + def __init__(self, scale, zero_point, dtype) -> None: + torch.nn.Module.__init__(self) + self.scale = scale + self.zero_point = zero_point + self.dtype = dtype + + def forward(self, input_tensor): + quantized_tensor = torch.quantize_per_tensor(input_tensor, self.scale, self.zero_point, self.dtype) + dequantized_tensor = torch.dequantize(quantized_tensor) + return dequantized_tensor + +class aten_quantize_per_channel_aten_dequantize(torch.nn.Module): + def __init__(self, scales, zero_points, dtype, axis) -> None: + torch.nn.Module.__init__(self) + self.scales = torch.Tensor(scales) + self.zero_points = torch.Tensor(zero_points) + self.dtype = dtype + self.axis = axis + def forward(self, input_tensor): + quantized_tensor = torch.quantize_per_channel(input_tensor, self.scales, self.zero_points, self.axis, self.dtype) + dequantized_tensor = torch.dequantize(quantized_tensor) + return dequantized_tensor + +class TestQuantizePerTensorDequantize(PytorchLayerTest): + def _prepare_input(self): + return (np.array(5.00 * np.random.rand(100, 100) + 5.00, dtype=np.float32),) + + @pytest.mark.parametrize("scale", [ + 1.0, 0.21, 0.62 + ]) + @pytest.mark.parametrize("zero_point", [ + 0, 4, -7 + ]) + @pytest.mark.parametrize("dtype", [ + torch.quint8, + torch.qint8, + pytest.param(torch.qint32, marks=pytest.mark.skip( + reason="Not supported with FakeQuantize.")) + ]) + @pytest.mark.nightly + # @pytest.mark.precommit - sporadic issue + def test_quantize_per_tensor_dequantize(self, scale, zero_point, dtype, ie_device, precision, ir_version): + if dtype == torch.quint8: zero_point = abs(zero_point) + self._test(aten_quantize_per_tensor_aten_dequantize(scale, zero_point, dtype), None, ["aten::quantize_per_tensor", "aten::dequantize"], + ie_device, precision, ir_version, ) + +class TestQuantizePerChannelDequantize(PytorchLayerTest): + def _prepare_input(self): + return (np.array(5.00 * np.random.rand(5, 6, 7, 8) + 5.00, dtype=np.float32),) + + @pytest.mark.parametrize("scales", [ + np.array([1.0, 0.21, 0.62, 0.5], dtype=np.float32), + np.array([0.21, 0.62, 0.5, 1.0], dtype=np.float32), + np.array([0.62, 0.5, 1.0, 0.21], dtype=np.float32), + np.array([0.5, 1.0, 0.21, 0.62], dtype=np.float32), + ]) + @pytest.mark.parametrize("zero_points", [ + np.array([0, 4, 2, 1], dtype=np.int32), + np.array([0, 1, 2, 3], dtype=np.int32), + np.array([0, 0, 0, 0], dtype=np.int32), + np.array([-1, 0, -4, 5], dtype=np.int32), + ]) + @pytest.mark.parametrize("dtype", [ + torch.quint8, + torch.qint8, + pytest.param(torch.qint32, marks=pytest.mark.skip( + reason="Not supported with FakeQuantize.")) + ]) + @pytest.mark.parametrize("axis", [ + 0, 1, 2, 3 + ]) + @pytest.mark.nightly + # @pytest.mark.precommit - conversion issue + def test_quantize_per_channel_dequantize(self, scales, zero_points, dtype, axis, ie_device, precision, ir_version): + if dtype == torch.quint8: zero_points = abs(zero_points) + self._test(aten_quantize_per_channel_aten_dequantize(scales, zero_points, dtype, axis), None, ["aten::quantize_per_channel", "aten::dequantize"], + ie_device, precision, ir_version, )