[PT FE] Add quantized::conv2d and quantized::conv2d_relu (#18651)

* Add quantized conv2d

* Fix schema

* Remove mark_output

* Remove tests from pre-commit
This commit is contained in:
Mateusz Mikolajczyk 2023-07-20 17:35:11 +02:00 committed by GitHub
parent 2dfb537bcb
commit bc261424ef
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 186 additions and 2 deletions

View File

@ -0,0 +1,95 @@
// Copyright (C) 2018-2023 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include "openvino/frontend/pytorch/node_context.hpp"
#include "openvino/op/add.hpp"
#include "openvino/op/constant.hpp"
#include "openvino/op/convolution.hpp"
#include "openvino/op/group_conv.hpp"
#include "openvino/op/relu.hpp"
#include "utils.hpp"
#include "utils_quantize.hpp"
namespace ov {
namespace frontend {
namespace pytorch {
namespace op {
using namespace ov::op;
namespace {
Output<ov::Node> translate_quantized_convnd_base(const NodeContext& context) {
auto input = context.get_input(0);
auto packed_params_node =
std::dynamic_pointer_cast<ov::op::util::FrameworkNode>(context.get_input(1).get_node_shared_ptr());
FRONT_END_OP_CONVERSION_CHECK(packed_params_node, "Packed params input node type is required to be FrameworkNode.");
const auto& attrs = packed_params_node->get_attrs();
FRONT_END_OP_CONVERSION_CHECK((attrs.find(PtFrameworkNode::op_type_key) != attrs.end()),
"Packed params input node does not contain information about op type.");
FRONT_END_OP_CONVERSION_CHECK((attrs.at(PtFrameworkNode::op_type_key) == "prim::GetAttr"),
"Incorrect packed params input node operator type, expected prim::GetAttr.");
auto packed_params = packed_params_node->inputs();
FRONT_END_OP_CONVERSION_CHECK(packed_params.size() == 6,
"Packed parameters for quantized conv should contain 6 items.");
// Packed params: weight, bias, stride, padding, dilation, groups
auto weight = packed_params[0].get_source_output();
auto bias = packed_params[1].get_source_output();
auto strides = std::dynamic_pointer_cast<v0::Constant>(packed_params[2].get_source_output().get_node_shared_ptr())
->cast_vector<Strides::value_type>();
auto pads = std::dynamic_pointer_cast<v0::Constant>(packed_params[3].get_source_output().get_node_shared_ptr())
->cast_vector<CoordinateDiff::value_type>();
auto dilations = std::dynamic_pointer_cast<v0::Constant>(packed_params[4].get_source_output().get_node_shared_ptr())
->cast_vector<Strides::value_type>();
int64_t groups = std::dynamic_pointer_cast<v0::Constant>(packed_params[5].get_source_output().get_node_shared_ptr())
->cast_vector<int64_t>()[0];
auto pad_type = ov::op::PadType::EXPLICIT;
std::shared_ptr<ov::Node> conv;
if (groups == 1) {
conv = std::make_shared<v1::Convolution>(input, weight, strides, pads, pads, dilations, pad_type);
} else {
conv = std::make_shared<v1::GroupConvolution>(input,
reshape_kernel_for_group(context, weight, groups),
strides,
pads,
pads,
dilations,
pad_type);
}
auto bias_rank = bias.get_partial_shape().rank();
if (bias_rank == 1) {
bias = reshape_channelwise(context, bias, conv);
}
conv = context.mark_node(std::make_shared<v1::Add>(conv, bias));
return conv->output(0);
};
}; // namespace
OutputVector translate_quantized_convnd(const NodeContext& context) {
// "quantized::conv2d.new(Tensor qx, __torch__.torch.classes.quantized.Conv2dPackedParamsBase packed_weight, float
// output_scale, int output_zero_point) -> Tensor"
num_inputs_check(context, 4, 4);
auto scale = context.get_input(2);
auto zero_point = context.get_input(3);
return {quantize(context, translate_quantized_convnd_base(context), scale, zero_point, context.get_input(0))};
}
OutputVector translate_quantized_convnd_relu(const NodeContext& context) {
// "quantized::conv2d_relu.new(Tensor qx, __torch__.torch.classes.quantized.Conv2dPackedParamsBase packed_weight,
// float output_scale, int output_zero_point) -> Tensor"
num_inputs_check(context, 4, 4);
auto scale = context.get_input(2);
auto zero_point = context.get_input(3);
auto conv = translate_quantized_convnd_base(context);
auto relu = context.mark_node(std::make_shared<v0::Relu>(conv));
return {quantize(context, relu->output(0), scale, zero_point, context.get_input(0))};
}
} // namespace op
} // namespace pytorch
} // namespace frontend
} // namespace ov

View File

@ -37,7 +37,7 @@ OutputVector translate_quantized_linear(const NodeContext& context) {
linear = context.mark_node(std::make_shared<ov::op::v1::Add>(linear, bias));
auto scale = context.get_input(2);
auto zero_point = context.get_input(3);
return {context.mark_output(quantize(context, linear, scale, zero_point, x))};
return {quantize(context, linear, scale, zero_point, x)};
};
} // namespace op

View File

@ -158,6 +158,8 @@ OP_CONVERTER(translate_var_mean);
OP_CONVERTER(translate_where);
OP_CONVERTER(translate_zeros);
OP_CONVERTER(translate_zeros_like);
OP_CONVERTER(translate_quantized_convnd);
OP_CONVERTER(translate_quantized_convnd_relu);
OP_CONVERTER(translate_quantized_linear);
} // namespace op
@ -419,6 +421,8 @@ const std::map<std::string, CreatorFunction> get_supported_ops() {
{"prim::requires_grad", op::return_false_scalar},
{"prim::PythonOp", op::translate_pythonop},
{"prim::type", op::skip_node}, // Used with prim::device, pass PtFrameworkNode.
{"quantized::conv2d", op::translate_quantized_convnd},
{"quantized::conv2d_relu", op::translate_quantized_convnd_relu},
{"quantized::linear", op::translate_quantized_linear},
{"torchvision::deform_conv2d", op::translate_deform_conv},
{"torchvision::nms", op::translate_nms},

View File

@ -0,0 +1,85 @@
# Copyright (C) 2018-2023 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
import pytest
import numpy as np
import torch
from openvino.frontend import FrontEndManager
from openvino.frontend.pytorch.decoder import TorchScriptPythonDecoder
from pytorch_layer_test_class import PytorchLayerTest
class TestQuantizedConv2D(PytorchLayerTest):
def _prepare_input(self):
return (np.random.randn(2, 3, 25, 25).astype(np.float32),)
def create_model(self, weights_shape, strides, pads, dilations, groups, bias, relu, scale, zero_point):
class quantized_conv2d(torch.nn.Module):
def __init__(self):
super(quantized_conv2d, self).__init__()
if not relu:
conv_func = torch.ao.nn.quantized.Conv2d
else:
conv_func = torch.ao.nn.intrinsic.quantized.ConvReLU2d
self.conv = conv_func(
weights_shape[1] * groups,
weights_shape[0],
weights_shape[2:],
strides,
pads,
dilations,
groups,
bias,
)
if bias:
torch.nn.init.normal_(self.conv.bias())
self.conv.scale = float(scale)
self.conv.zero_point = int(zero_point)
def forward(self, x):
x_quantized = torch.quantize_per_tensor(x, 1.0, 0, torch.quint8)
conv = self.conv(x_quantized)
return torch.dequantize(conv).contiguous()
ref_net = None
if not relu:
op_name = "quantized::conv2d"
else:
op_name = "quantized::conv2d_relu"
return quantized_conv2d(), ref_net, op_name
@pytest.mark.parametrize(
"params",
[
pytest.param(
{"weights_shape": [1, 3, 3, 3], "strides": 1, "pads": 0, "dilations": 1, "groups": 1},
marks=pytest.mark.xfail(
reason="Output channels equal to 1 creates output that fails to cast to contiguous."
),
),
{"weights_shape": [2, 3, 3, 3], "strides": 1, "pads": 0, "dilations": 1, "groups": 1},
{"weights_shape": [2, 3, 3, 3], "strides": 2, "pads": 0, "dilations": 1, "groups": 1},
{"weights_shape": [2, 3, 3, 3], "strides": 1, "pads": 1, "dilations": 1, "groups": 1},
{"weights_shape": [2, 3, 3, 3], "strides": 1, "pads": 0, "dilations": 2, "groups": 1},
{"weights_shape": [2, 3, 3, 3], "strides": 1, "pads": [0, 1], "dilations": 1, "groups": 1},
{"weights_shape": [2, 3, 3, 3], "strides": 1, "pads": [1, 0], "dilations": 1, "groups": 1},
{"weights_shape": [3, 1, 3, 3], "strides": 1, "pads": 0, "dilations": 1, "groups": 3},
],
)
@pytest.mark.parametrize("bias", [True, False])
@pytest.mark.parametrize("relu", [True, False])
@pytest.mark.parametrize("scale", [1, 0.3, 1.3])
@pytest.mark.parametrize("zero_point", [0, 1])
@pytest.mark.nightly
# @pytest.mark.precommit Test disabled due to sporadic issues
def test_quantized_conv2d(self, params, bias, relu, scale, zero_point, ie_device, precision, ir_version):
self._test(
*self.create_model(**params, bias=bias, relu=relu, scale=scale, zero_point=zero_point),
ie_device,
precision,
ir_version,
trace_model=True,
freeze_model=False
)

View File

@ -44,7 +44,7 @@ class TestQuantizedLinear(PytorchLayerTest):
@pytest.mark.parametrize("zero_point", [0, 1])
@pytest.mark.parametrize("trace", [True, False])
@pytest.mark.nightly
@pytest.mark.precommit
# @pytest.mark.precommit Test disabled due to sporadic issues
def test_quantized_linear(self, params, scale, zero_point, trace, ie_device, precision, ir_version):
input_shape = params.get("input_shape")
weight_shape = params.get("weight_shape")