[PT FE] Add quantized::conv2d and quantized::conv2d_relu (#18651)
* Add quantized conv2d * Fix schema * Remove mark_output * Remove tests from pre-commit
This commit is contained in:
parent
2dfb537bcb
commit
bc261424ef
95
src/frontends/pytorch/src/op/quantized_convnd.cpp
Normal file
95
src/frontends/pytorch/src/op/quantized_convnd.cpp
Normal file
@ -0,0 +1,95 @@
|
||||
// Copyright (C) 2018-2023 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#include "openvino/frontend/pytorch/node_context.hpp"
|
||||
#include "openvino/op/add.hpp"
|
||||
#include "openvino/op/constant.hpp"
|
||||
#include "openvino/op/convolution.hpp"
|
||||
#include "openvino/op/group_conv.hpp"
|
||||
#include "openvino/op/relu.hpp"
|
||||
#include "utils.hpp"
|
||||
#include "utils_quantize.hpp"
|
||||
|
||||
namespace ov {
|
||||
namespace frontend {
|
||||
namespace pytorch {
|
||||
namespace op {
|
||||
|
||||
using namespace ov::op;
|
||||
|
||||
namespace {
|
||||
Output<ov::Node> translate_quantized_convnd_base(const NodeContext& context) {
|
||||
auto input = context.get_input(0);
|
||||
auto packed_params_node =
|
||||
std::dynamic_pointer_cast<ov::op::util::FrameworkNode>(context.get_input(1).get_node_shared_ptr());
|
||||
FRONT_END_OP_CONVERSION_CHECK(packed_params_node, "Packed params input node type is required to be FrameworkNode.");
|
||||
const auto& attrs = packed_params_node->get_attrs();
|
||||
FRONT_END_OP_CONVERSION_CHECK((attrs.find(PtFrameworkNode::op_type_key) != attrs.end()),
|
||||
"Packed params input node does not contain information about op type.");
|
||||
FRONT_END_OP_CONVERSION_CHECK((attrs.at(PtFrameworkNode::op_type_key) == "prim::GetAttr"),
|
||||
"Incorrect packed params input node operator type, expected prim::GetAttr.");
|
||||
auto packed_params = packed_params_node->inputs();
|
||||
|
||||
FRONT_END_OP_CONVERSION_CHECK(packed_params.size() == 6,
|
||||
"Packed parameters for quantized conv should contain 6 items.");
|
||||
// Packed params: weight, bias, stride, padding, dilation, groups
|
||||
auto weight = packed_params[0].get_source_output();
|
||||
auto bias = packed_params[1].get_source_output();
|
||||
auto strides = std::dynamic_pointer_cast<v0::Constant>(packed_params[2].get_source_output().get_node_shared_ptr())
|
||||
->cast_vector<Strides::value_type>();
|
||||
auto pads = std::dynamic_pointer_cast<v0::Constant>(packed_params[3].get_source_output().get_node_shared_ptr())
|
||||
->cast_vector<CoordinateDiff::value_type>();
|
||||
auto dilations = std::dynamic_pointer_cast<v0::Constant>(packed_params[4].get_source_output().get_node_shared_ptr())
|
||||
->cast_vector<Strides::value_type>();
|
||||
int64_t groups = std::dynamic_pointer_cast<v0::Constant>(packed_params[5].get_source_output().get_node_shared_ptr())
|
||||
->cast_vector<int64_t>()[0];
|
||||
|
||||
auto pad_type = ov::op::PadType::EXPLICIT;
|
||||
|
||||
std::shared_ptr<ov::Node> conv;
|
||||
if (groups == 1) {
|
||||
conv = std::make_shared<v1::Convolution>(input, weight, strides, pads, pads, dilations, pad_type);
|
||||
} else {
|
||||
conv = std::make_shared<v1::GroupConvolution>(input,
|
||||
reshape_kernel_for_group(context, weight, groups),
|
||||
strides,
|
||||
pads,
|
||||
pads,
|
||||
dilations,
|
||||
pad_type);
|
||||
}
|
||||
auto bias_rank = bias.get_partial_shape().rank();
|
||||
if (bias_rank == 1) {
|
||||
bias = reshape_channelwise(context, bias, conv);
|
||||
}
|
||||
conv = context.mark_node(std::make_shared<v1::Add>(conv, bias));
|
||||
|
||||
return conv->output(0);
|
||||
};
|
||||
}; // namespace
|
||||
|
||||
OutputVector translate_quantized_convnd(const NodeContext& context) {
|
||||
// "quantized::conv2d.new(Tensor qx, __torch__.torch.classes.quantized.Conv2dPackedParamsBase packed_weight, float
|
||||
// output_scale, int output_zero_point) -> Tensor"
|
||||
num_inputs_check(context, 4, 4);
|
||||
auto scale = context.get_input(2);
|
||||
auto zero_point = context.get_input(3);
|
||||
return {quantize(context, translate_quantized_convnd_base(context), scale, zero_point, context.get_input(0))};
|
||||
}
|
||||
|
||||
OutputVector translate_quantized_convnd_relu(const NodeContext& context) {
|
||||
// "quantized::conv2d_relu.new(Tensor qx, __torch__.torch.classes.quantized.Conv2dPackedParamsBase packed_weight,
|
||||
// float output_scale, int output_zero_point) -> Tensor"
|
||||
num_inputs_check(context, 4, 4);
|
||||
auto scale = context.get_input(2);
|
||||
auto zero_point = context.get_input(3);
|
||||
auto conv = translate_quantized_convnd_base(context);
|
||||
auto relu = context.mark_node(std::make_shared<v0::Relu>(conv));
|
||||
return {quantize(context, relu->output(0), scale, zero_point, context.get_input(0))};
|
||||
}
|
||||
|
||||
} // namespace op
|
||||
} // namespace pytorch
|
||||
} // namespace frontend
|
||||
} // namespace ov
|
@ -37,7 +37,7 @@ OutputVector translate_quantized_linear(const NodeContext& context) {
|
||||
linear = context.mark_node(std::make_shared<ov::op::v1::Add>(linear, bias));
|
||||
auto scale = context.get_input(2);
|
||||
auto zero_point = context.get_input(3);
|
||||
return {context.mark_output(quantize(context, linear, scale, zero_point, x))};
|
||||
return {quantize(context, linear, scale, zero_point, x)};
|
||||
};
|
||||
|
||||
} // namespace op
|
||||
|
@ -158,6 +158,8 @@ OP_CONVERTER(translate_var_mean);
|
||||
OP_CONVERTER(translate_where);
|
||||
OP_CONVERTER(translate_zeros);
|
||||
OP_CONVERTER(translate_zeros_like);
|
||||
OP_CONVERTER(translate_quantized_convnd);
|
||||
OP_CONVERTER(translate_quantized_convnd_relu);
|
||||
OP_CONVERTER(translate_quantized_linear);
|
||||
|
||||
} // namespace op
|
||||
@ -419,6 +421,8 @@ const std::map<std::string, CreatorFunction> get_supported_ops() {
|
||||
{"prim::requires_grad", op::return_false_scalar},
|
||||
{"prim::PythonOp", op::translate_pythonop},
|
||||
{"prim::type", op::skip_node}, // Used with prim::device, pass PtFrameworkNode.
|
||||
{"quantized::conv2d", op::translate_quantized_convnd},
|
||||
{"quantized::conv2d_relu", op::translate_quantized_convnd_relu},
|
||||
{"quantized::linear", op::translate_quantized_linear},
|
||||
{"torchvision::deform_conv2d", op::translate_deform_conv},
|
||||
{"torchvision::nms", op::translate_nms},
|
||||
|
85
tests/layer_tests/pytorch_tests/test_quantized_convnd.py
Normal file
85
tests/layer_tests/pytorch_tests/test_quantized_convnd.py
Normal file
@ -0,0 +1,85 @@
|
||||
# Copyright (C) 2018-2023 Intel Corporation
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import pytest
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from openvino.frontend import FrontEndManager
|
||||
from openvino.frontend.pytorch.decoder import TorchScriptPythonDecoder
|
||||
from pytorch_layer_test_class import PytorchLayerTest
|
||||
|
||||
|
||||
class TestQuantizedConv2D(PytorchLayerTest):
|
||||
def _prepare_input(self):
|
||||
return (np.random.randn(2, 3, 25, 25).astype(np.float32),)
|
||||
|
||||
def create_model(self, weights_shape, strides, pads, dilations, groups, bias, relu, scale, zero_point):
|
||||
class quantized_conv2d(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super(quantized_conv2d, self).__init__()
|
||||
if not relu:
|
||||
conv_func = torch.ao.nn.quantized.Conv2d
|
||||
else:
|
||||
conv_func = torch.ao.nn.intrinsic.quantized.ConvReLU2d
|
||||
self.conv = conv_func(
|
||||
weights_shape[1] * groups,
|
||||
weights_shape[0],
|
||||
weights_shape[2:],
|
||||
strides,
|
||||
pads,
|
||||
dilations,
|
||||
groups,
|
||||
bias,
|
||||
)
|
||||
if bias:
|
||||
torch.nn.init.normal_(self.conv.bias())
|
||||
self.conv.scale = float(scale)
|
||||
self.conv.zero_point = int(zero_point)
|
||||
|
||||
def forward(self, x):
|
||||
x_quantized = torch.quantize_per_tensor(x, 1.0, 0, torch.quint8)
|
||||
conv = self.conv(x_quantized)
|
||||
return torch.dequantize(conv).contiguous()
|
||||
|
||||
ref_net = None
|
||||
if not relu:
|
||||
op_name = "quantized::conv2d"
|
||||
else:
|
||||
op_name = "quantized::conv2d_relu"
|
||||
|
||||
return quantized_conv2d(), ref_net, op_name
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"params",
|
||||
[
|
||||
pytest.param(
|
||||
{"weights_shape": [1, 3, 3, 3], "strides": 1, "pads": 0, "dilations": 1, "groups": 1},
|
||||
marks=pytest.mark.xfail(
|
||||
reason="Output channels equal to 1 creates output that fails to cast to contiguous."
|
||||
),
|
||||
),
|
||||
{"weights_shape": [2, 3, 3, 3], "strides": 1, "pads": 0, "dilations": 1, "groups": 1},
|
||||
{"weights_shape": [2, 3, 3, 3], "strides": 2, "pads": 0, "dilations": 1, "groups": 1},
|
||||
{"weights_shape": [2, 3, 3, 3], "strides": 1, "pads": 1, "dilations": 1, "groups": 1},
|
||||
{"weights_shape": [2, 3, 3, 3], "strides": 1, "pads": 0, "dilations": 2, "groups": 1},
|
||||
{"weights_shape": [2, 3, 3, 3], "strides": 1, "pads": [0, 1], "dilations": 1, "groups": 1},
|
||||
{"weights_shape": [2, 3, 3, 3], "strides": 1, "pads": [1, 0], "dilations": 1, "groups": 1},
|
||||
{"weights_shape": [3, 1, 3, 3], "strides": 1, "pads": 0, "dilations": 1, "groups": 3},
|
||||
],
|
||||
)
|
||||
@pytest.mark.parametrize("bias", [True, False])
|
||||
@pytest.mark.parametrize("relu", [True, False])
|
||||
@pytest.mark.parametrize("scale", [1, 0.3, 1.3])
|
||||
@pytest.mark.parametrize("zero_point", [0, 1])
|
||||
@pytest.mark.nightly
|
||||
# @pytest.mark.precommit Test disabled due to sporadic issues
|
||||
def test_quantized_conv2d(self, params, bias, relu, scale, zero_point, ie_device, precision, ir_version):
|
||||
self._test(
|
||||
*self.create_model(**params, bias=bias, relu=relu, scale=scale, zero_point=zero_point),
|
||||
ie_device,
|
||||
precision,
|
||||
ir_version,
|
||||
trace_model=True,
|
||||
freeze_model=False
|
||||
)
|
@ -44,7 +44,7 @@ class TestQuantizedLinear(PytorchLayerTest):
|
||||
@pytest.mark.parametrize("zero_point", [0, 1])
|
||||
@pytest.mark.parametrize("trace", [True, False])
|
||||
@pytest.mark.nightly
|
||||
@pytest.mark.precommit
|
||||
# @pytest.mark.precommit Test disabled due to sporadic issues
|
||||
def test_quantized_linear(self, params, scale, zero_point, trace, ie_device, precision, ir_version):
|
||||
input_shape = params.get("input_shape")
|
||||
weight_shape = params.get("weight_shape")
|
||||
|
Loading…
Reference in New Issue
Block a user