diff --git a/src/frontends/pytorch/src/op/quantized_convnd.cpp b/src/frontends/pytorch/src/op/quantized_convnd.cpp new file mode 100644 index 00000000000..37ab867d72a --- /dev/null +++ b/src/frontends/pytorch/src/op/quantized_convnd.cpp @@ -0,0 +1,95 @@ +// Copyright (C) 2018-2023 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include "openvino/frontend/pytorch/node_context.hpp" +#include "openvino/op/add.hpp" +#include "openvino/op/constant.hpp" +#include "openvino/op/convolution.hpp" +#include "openvino/op/group_conv.hpp" +#include "openvino/op/relu.hpp" +#include "utils.hpp" +#include "utils_quantize.hpp" + +namespace ov { +namespace frontend { +namespace pytorch { +namespace op { + +using namespace ov::op; + +namespace { +Output translate_quantized_convnd_base(const NodeContext& context) { + auto input = context.get_input(0); + auto packed_params_node = + std::dynamic_pointer_cast(context.get_input(1).get_node_shared_ptr()); + FRONT_END_OP_CONVERSION_CHECK(packed_params_node, "Packed params input node type is required to be FrameworkNode."); + const auto& attrs = packed_params_node->get_attrs(); + FRONT_END_OP_CONVERSION_CHECK((attrs.find(PtFrameworkNode::op_type_key) != attrs.end()), + "Packed params input node does not contain information about op type."); + FRONT_END_OP_CONVERSION_CHECK((attrs.at(PtFrameworkNode::op_type_key) == "prim::GetAttr"), + "Incorrect packed params input node operator type, expected prim::GetAttr."); + auto packed_params = packed_params_node->inputs(); + + FRONT_END_OP_CONVERSION_CHECK(packed_params.size() == 6, + "Packed parameters for quantized conv should contain 6 items."); + // Packed params: weight, bias, stride, padding, dilation, groups + auto weight = packed_params[0].get_source_output(); + auto bias = packed_params[1].get_source_output(); + auto strides = std::dynamic_pointer_cast(packed_params[2].get_source_output().get_node_shared_ptr()) + ->cast_vector(); + auto pads = std::dynamic_pointer_cast(packed_params[3].get_source_output().get_node_shared_ptr()) + ->cast_vector(); + auto dilations = std::dynamic_pointer_cast(packed_params[4].get_source_output().get_node_shared_ptr()) + ->cast_vector(); + int64_t groups = std::dynamic_pointer_cast(packed_params[5].get_source_output().get_node_shared_ptr()) + ->cast_vector()[0]; + + auto pad_type = ov::op::PadType::EXPLICIT; + + std::shared_ptr conv; + if (groups == 1) { + conv = std::make_shared(input, weight, strides, pads, pads, dilations, pad_type); + } else { + conv = std::make_shared(input, + reshape_kernel_for_group(context, weight, groups), + strides, + pads, + pads, + dilations, + pad_type); + } + auto bias_rank = bias.get_partial_shape().rank(); + if (bias_rank == 1) { + bias = reshape_channelwise(context, bias, conv); + } + conv = context.mark_node(std::make_shared(conv, bias)); + + return conv->output(0); +}; +}; // namespace + +OutputVector translate_quantized_convnd(const NodeContext& context) { + // "quantized::conv2d.new(Tensor qx, __torch__.torch.classes.quantized.Conv2dPackedParamsBase packed_weight, float + // output_scale, int output_zero_point) -> Tensor" + num_inputs_check(context, 4, 4); + auto scale = context.get_input(2); + auto zero_point = context.get_input(3); + return {quantize(context, translate_quantized_convnd_base(context), scale, zero_point, context.get_input(0))}; +} + +OutputVector translate_quantized_convnd_relu(const NodeContext& context) { + // "quantized::conv2d_relu.new(Tensor qx, __torch__.torch.classes.quantized.Conv2dPackedParamsBase packed_weight, + // float output_scale, int output_zero_point) -> Tensor" + num_inputs_check(context, 4, 4); + auto scale = context.get_input(2); + auto zero_point = context.get_input(3); + auto conv = translate_quantized_convnd_base(context); + auto relu = context.mark_node(std::make_shared(conv)); + return {quantize(context, relu->output(0), scale, zero_point, context.get_input(0))}; +} + +} // namespace op +} // namespace pytorch +} // namespace frontend +} // namespace ov diff --git a/src/frontends/pytorch/src/op/quantized_linear.cpp b/src/frontends/pytorch/src/op/quantized_linear.cpp index 13c19402b30..a69013f3fab 100644 --- a/src/frontends/pytorch/src/op/quantized_linear.cpp +++ b/src/frontends/pytorch/src/op/quantized_linear.cpp @@ -37,7 +37,7 @@ OutputVector translate_quantized_linear(const NodeContext& context) { linear = context.mark_node(std::make_shared(linear, bias)); auto scale = context.get_input(2); auto zero_point = context.get_input(3); - return {context.mark_output(quantize(context, linear, scale, zero_point, x))}; + return {quantize(context, linear, scale, zero_point, x)}; }; } // namespace op diff --git a/src/frontends/pytorch/src/op_table.cpp b/src/frontends/pytorch/src/op_table.cpp index 77e5adc80c7..f39c8aadbbb 100644 --- a/src/frontends/pytorch/src/op_table.cpp +++ b/src/frontends/pytorch/src/op_table.cpp @@ -158,6 +158,8 @@ OP_CONVERTER(translate_var_mean); OP_CONVERTER(translate_where); OP_CONVERTER(translate_zeros); OP_CONVERTER(translate_zeros_like); +OP_CONVERTER(translate_quantized_convnd); +OP_CONVERTER(translate_quantized_convnd_relu); OP_CONVERTER(translate_quantized_linear); } // namespace op @@ -419,6 +421,8 @@ const std::map get_supported_ops() { {"prim::requires_grad", op::return_false_scalar}, {"prim::PythonOp", op::translate_pythonop}, {"prim::type", op::skip_node}, // Used with prim::device, pass PtFrameworkNode. + {"quantized::conv2d", op::translate_quantized_convnd}, + {"quantized::conv2d_relu", op::translate_quantized_convnd_relu}, {"quantized::linear", op::translate_quantized_linear}, {"torchvision::deform_conv2d", op::translate_deform_conv}, {"torchvision::nms", op::translate_nms}, diff --git a/tests/layer_tests/pytorch_tests/test_quantized_convnd.py b/tests/layer_tests/pytorch_tests/test_quantized_convnd.py new file mode 100644 index 00000000000..7424636eea3 --- /dev/null +++ b/tests/layer_tests/pytorch_tests/test_quantized_convnd.py @@ -0,0 +1,85 @@ +# Copyright (C) 2018-2023 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +import pytest +import numpy as np +import torch + +from openvino.frontend import FrontEndManager +from openvino.frontend.pytorch.decoder import TorchScriptPythonDecoder +from pytorch_layer_test_class import PytorchLayerTest + + +class TestQuantizedConv2D(PytorchLayerTest): + def _prepare_input(self): + return (np.random.randn(2, 3, 25, 25).astype(np.float32),) + + def create_model(self, weights_shape, strides, pads, dilations, groups, bias, relu, scale, zero_point): + class quantized_conv2d(torch.nn.Module): + def __init__(self): + super(quantized_conv2d, self).__init__() + if not relu: + conv_func = torch.ao.nn.quantized.Conv2d + else: + conv_func = torch.ao.nn.intrinsic.quantized.ConvReLU2d + self.conv = conv_func( + weights_shape[1] * groups, + weights_shape[0], + weights_shape[2:], + strides, + pads, + dilations, + groups, + bias, + ) + if bias: + torch.nn.init.normal_(self.conv.bias()) + self.conv.scale = float(scale) + self.conv.zero_point = int(zero_point) + + def forward(self, x): + x_quantized = torch.quantize_per_tensor(x, 1.0, 0, torch.quint8) + conv = self.conv(x_quantized) + return torch.dequantize(conv).contiguous() + + ref_net = None + if not relu: + op_name = "quantized::conv2d" + else: + op_name = "quantized::conv2d_relu" + + return quantized_conv2d(), ref_net, op_name + + @pytest.mark.parametrize( + "params", + [ + pytest.param( + {"weights_shape": [1, 3, 3, 3], "strides": 1, "pads": 0, "dilations": 1, "groups": 1}, + marks=pytest.mark.xfail( + reason="Output channels equal to 1 creates output that fails to cast to contiguous." + ), + ), + {"weights_shape": [2, 3, 3, 3], "strides": 1, "pads": 0, "dilations": 1, "groups": 1}, + {"weights_shape": [2, 3, 3, 3], "strides": 2, "pads": 0, "dilations": 1, "groups": 1}, + {"weights_shape": [2, 3, 3, 3], "strides": 1, "pads": 1, "dilations": 1, "groups": 1}, + {"weights_shape": [2, 3, 3, 3], "strides": 1, "pads": 0, "dilations": 2, "groups": 1}, + {"weights_shape": [2, 3, 3, 3], "strides": 1, "pads": [0, 1], "dilations": 1, "groups": 1}, + {"weights_shape": [2, 3, 3, 3], "strides": 1, "pads": [1, 0], "dilations": 1, "groups": 1}, + {"weights_shape": [3, 1, 3, 3], "strides": 1, "pads": 0, "dilations": 1, "groups": 3}, + ], + ) + @pytest.mark.parametrize("bias", [True, False]) + @pytest.mark.parametrize("relu", [True, False]) + @pytest.mark.parametrize("scale", [1, 0.3, 1.3]) + @pytest.mark.parametrize("zero_point", [0, 1]) + @pytest.mark.nightly + # @pytest.mark.precommit Test disabled due to sporadic issues + def test_quantized_conv2d(self, params, bias, relu, scale, zero_point, ie_device, precision, ir_version): + self._test( + *self.create_model(**params, bias=bias, relu=relu, scale=scale, zero_point=zero_point), + ie_device, + precision, + ir_version, + trace_model=True, + freeze_model=False + ) diff --git a/tests/layer_tests/pytorch_tests/test_quantized_linear.py b/tests/layer_tests/pytorch_tests/test_quantized_linear.py index 21f1353eeef..cc30a313d31 100644 --- a/tests/layer_tests/pytorch_tests/test_quantized_linear.py +++ b/tests/layer_tests/pytorch_tests/test_quantized_linear.py @@ -44,7 +44,7 @@ class TestQuantizedLinear(PytorchLayerTest): @pytest.mark.parametrize("zero_point", [0, 1]) @pytest.mark.parametrize("trace", [True, False]) @pytest.mark.nightly - @pytest.mark.precommit + # @pytest.mark.precommit Test disabled due to sporadic issues def test_quantized_linear(self, params, scale, zero_point, trace, ie_device, precision, ir_version): input_shape = params.get("input_shape") weight_shape = params.get("weight_shape")