[PT FE] Add quantized::conv2d and quantized::conv2d_relu (#18651)
* Add quantized conv2d * Fix schema * Remove mark_output * Remove tests from pre-commit
This commit is contained in:
parent
2dfb537bcb
commit
bc261424ef
95
src/frontends/pytorch/src/op/quantized_convnd.cpp
Normal file
95
src/frontends/pytorch/src/op/quantized_convnd.cpp
Normal file
@ -0,0 +1,95 @@
|
|||||||
|
// Copyright (C) 2018-2023 Intel Corporation
|
||||||
|
// SPDX-License-Identifier: Apache-2.0
|
||||||
|
//
|
||||||
|
|
||||||
|
#include "openvino/frontend/pytorch/node_context.hpp"
|
||||||
|
#include "openvino/op/add.hpp"
|
||||||
|
#include "openvino/op/constant.hpp"
|
||||||
|
#include "openvino/op/convolution.hpp"
|
||||||
|
#include "openvino/op/group_conv.hpp"
|
||||||
|
#include "openvino/op/relu.hpp"
|
||||||
|
#include "utils.hpp"
|
||||||
|
#include "utils_quantize.hpp"
|
||||||
|
|
||||||
|
namespace ov {
|
||||||
|
namespace frontend {
|
||||||
|
namespace pytorch {
|
||||||
|
namespace op {
|
||||||
|
|
||||||
|
using namespace ov::op;
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
Output<ov::Node> translate_quantized_convnd_base(const NodeContext& context) {
|
||||||
|
auto input = context.get_input(0);
|
||||||
|
auto packed_params_node =
|
||||||
|
std::dynamic_pointer_cast<ov::op::util::FrameworkNode>(context.get_input(1).get_node_shared_ptr());
|
||||||
|
FRONT_END_OP_CONVERSION_CHECK(packed_params_node, "Packed params input node type is required to be FrameworkNode.");
|
||||||
|
const auto& attrs = packed_params_node->get_attrs();
|
||||||
|
FRONT_END_OP_CONVERSION_CHECK((attrs.find(PtFrameworkNode::op_type_key) != attrs.end()),
|
||||||
|
"Packed params input node does not contain information about op type.");
|
||||||
|
FRONT_END_OP_CONVERSION_CHECK((attrs.at(PtFrameworkNode::op_type_key) == "prim::GetAttr"),
|
||||||
|
"Incorrect packed params input node operator type, expected prim::GetAttr.");
|
||||||
|
auto packed_params = packed_params_node->inputs();
|
||||||
|
|
||||||
|
FRONT_END_OP_CONVERSION_CHECK(packed_params.size() == 6,
|
||||||
|
"Packed parameters for quantized conv should contain 6 items.");
|
||||||
|
// Packed params: weight, bias, stride, padding, dilation, groups
|
||||||
|
auto weight = packed_params[0].get_source_output();
|
||||||
|
auto bias = packed_params[1].get_source_output();
|
||||||
|
auto strides = std::dynamic_pointer_cast<v0::Constant>(packed_params[2].get_source_output().get_node_shared_ptr())
|
||||||
|
->cast_vector<Strides::value_type>();
|
||||||
|
auto pads = std::dynamic_pointer_cast<v0::Constant>(packed_params[3].get_source_output().get_node_shared_ptr())
|
||||||
|
->cast_vector<CoordinateDiff::value_type>();
|
||||||
|
auto dilations = std::dynamic_pointer_cast<v0::Constant>(packed_params[4].get_source_output().get_node_shared_ptr())
|
||||||
|
->cast_vector<Strides::value_type>();
|
||||||
|
int64_t groups = std::dynamic_pointer_cast<v0::Constant>(packed_params[5].get_source_output().get_node_shared_ptr())
|
||||||
|
->cast_vector<int64_t>()[0];
|
||||||
|
|
||||||
|
auto pad_type = ov::op::PadType::EXPLICIT;
|
||||||
|
|
||||||
|
std::shared_ptr<ov::Node> conv;
|
||||||
|
if (groups == 1) {
|
||||||
|
conv = std::make_shared<v1::Convolution>(input, weight, strides, pads, pads, dilations, pad_type);
|
||||||
|
} else {
|
||||||
|
conv = std::make_shared<v1::GroupConvolution>(input,
|
||||||
|
reshape_kernel_for_group(context, weight, groups),
|
||||||
|
strides,
|
||||||
|
pads,
|
||||||
|
pads,
|
||||||
|
dilations,
|
||||||
|
pad_type);
|
||||||
|
}
|
||||||
|
auto bias_rank = bias.get_partial_shape().rank();
|
||||||
|
if (bias_rank == 1) {
|
||||||
|
bias = reshape_channelwise(context, bias, conv);
|
||||||
|
}
|
||||||
|
conv = context.mark_node(std::make_shared<v1::Add>(conv, bias));
|
||||||
|
|
||||||
|
return conv->output(0);
|
||||||
|
};
|
||||||
|
}; // namespace
|
||||||
|
|
||||||
|
OutputVector translate_quantized_convnd(const NodeContext& context) {
|
||||||
|
// "quantized::conv2d.new(Tensor qx, __torch__.torch.classes.quantized.Conv2dPackedParamsBase packed_weight, float
|
||||||
|
// output_scale, int output_zero_point) -> Tensor"
|
||||||
|
num_inputs_check(context, 4, 4);
|
||||||
|
auto scale = context.get_input(2);
|
||||||
|
auto zero_point = context.get_input(3);
|
||||||
|
return {quantize(context, translate_quantized_convnd_base(context), scale, zero_point, context.get_input(0))};
|
||||||
|
}
|
||||||
|
|
||||||
|
OutputVector translate_quantized_convnd_relu(const NodeContext& context) {
|
||||||
|
// "quantized::conv2d_relu.new(Tensor qx, __torch__.torch.classes.quantized.Conv2dPackedParamsBase packed_weight,
|
||||||
|
// float output_scale, int output_zero_point) -> Tensor"
|
||||||
|
num_inputs_check(context, 4, 4);
|
||||||
|
auto scale = context.get_input(2);
|
||||||
|
auto zero_point = context.get_input(3);
|
||||||
|
auto conv = translate_quantized_convnd_base(context);
|
||||||
|
auto relu = context.mark_node(std::make_shared<v0::Relu>(conv));
|
||||||
|
return {quantize(context, relu->output(0), scale, zero_point, context.get_input(0))};
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace op
|
||||||
|
} // namespace pytorch
|
||||||
|
} // namespace frontend
|
||||||
|
} // namespace ov
|
@ -37,7 +37,7 @@ OutputVector translate_quantized_linear(const NodeContext& context) {
|
|||||||
linear = context.mark_node(std::make_shared<ov::op::v1::Add>(linear, bias));
|
linear = context.mark_node(std::make_shared<ov::op::v1::Add>(linear, bias));
|
||||||
auto scale = context.get_input(2);
|
auto scale = context.get_input(2);
|
||||||
auto zero_point = context.get_input(3);
|
auto zero_point = context.get_input(3);
|
||||||
return {context.mark_output(quantize(context, linear, scale, zero_point, x))};
|
return {quantize(context, linear, scale, zero_point, x)};
|
||||||
};
|
};
|
||||||
|
|
||||||
} // namespace op
|
} // namespace op
|
||||||
|
@ -158,6 +158,8 @@ OP_CONVERTER(translate_var_mean);
|
|||||||
OP_CONVERTER(translate_where);
|
OP_CONVERTER(translate_where);
|
||||||
OP_CONVERTER(translate_zeros);
|
OP_CONVERTER(translate_zeros);
|
||||||
OP_CONVERTER(translate_zeros_like);
|
OP_CONVERTER(translate_zeros_like);
|
||||||
|
OP_CONVERTER(translate_quantized_convnd);
|
||||||
|
OP_CONVERTER(translate_quantized_convnd_relu);
|
||||||
OP_CONVERTER(translate_quantized_linear);
|
OP_CONVERTER(translate_quantized_linear);
|
||||||
|
|
||||||
} // namespace op
|
} // namespace op
|
||||||
@ -419,6 +421,8 @@ const std::map<std::string, CreatorFunction> get_supported_ops() {
|
|||||||
{"prim::requires_grad", op::return_false_scalar},
|
{"prim::requires_grad", op::return_false_scalar},
|
||||||
{"prim::PythonOp", op::translate_pythonop},
|
{"prim::PythonOp", op::translate_pythonop},
|
||||||
{"prim::type", op::skip_node}, // Used with prim::device, pass PtFrameworkNode.
|
{"prim::type", op::skip_node}, // Used with prim::device, pass PtFrameworkNode.
|
||||||
|
{"quantized::conv2d", op::translate_quantized_convnd},
|
||||||
|
{"quantized::conv2d_relu", op::translate_quantized_convnd_relu},
|
||||||
{"quantized::linear", op::translate_quantized_linear},
|
{"quantized::linear", op::translate_quantized_linear},
|
||||||
{"torchvision::deform_conv2d", op::translate_deform_conv},
|
{"torchvision::deform_conv2d", op::translate_deform_conv},
|
||||||
{"torchvision::nms", op::translate_nms},
|
{"torchvision::nms", op::translate_nms},
|
||||||
|
85
tests/layer_tests/pytorch_tests/test_quantized_convnd.py
Normal file
85
tests/layer_tests/pytorch_tests/test_quantized_convnd.py
Normal file
@ -0,0 +1,85 @@
|
|||||||
|
# Copyright (C) 2018-2023 Intel Corporation
|
||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from openvino.frontend import FrontEndManager
|
||||||
|
from openvino.frontend.pytorch.decoder import TorchScriptPythonDecoder
|
||||||
|
from pytorch_layer_test_class import PytorchLayerTest
|
||||||
|
|
||||||
|
|
||||||
|
class TestQuantizedConv2D(PytorchLayerTest):
|
||||||
|
def _prepare_input(self):
|
||||||
|
return (np.random.randn(2, 3, 25, 25).astype(np.float32),)
|
||||||
|
|
||||||
|
def create_model(self, weights_shape, strides, pads, dilations, groups, bias, relu, scale, zero_point):
|
||||||
|
class quantized_conv2d(torch.nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super(quantized_conv2d, self).__init__()
|
||||||
|
if not relu:
|
||||||
|
conv_func = torch.ao.nn.quantized.Conv2d
|
||||||
|
else:
|
||||||
|
conv_func = torch.ao.nn.intrinsic.quantized.ConvReLU2d
|
||||||
|
self.conv = conv_func(
|
||||||
|
weights_shape[1] * groups,
|
||||||
|
weights_shape[0],
|
||||||
|
weights_shape[2:],
|
||||||
|
strides,
|
||||||
|
pads,
|
||||||
|
dilations,
|
||||||
|
groups,
|
||||||
|
bias,
|
||||||
|
)
|
||||||
|
if bias:
|
||||||
|
torch.nn.init.normal_(self.conv.bias())
|
||||||
|
self.conv.scale = float(scale)
|
||||||
|
self.conv.zero_point = int(zero_point)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
x_quantized = torch.quantize_per_tensor(x, 1.0, 0, torch.quint8)
|
||||||
|
conv = self.conv(x_quantized)
|
||||||
|
return torch.dequantize(conv).contiguous()
|
||||||
|
|
||||||
|
ref_net = None
|
||||||
|
if not relu:
|
||||||
|
op_name = "quantized::conv2d"
|
||||||
|
else:
|
||||||
|
op_name = "quantized::conv2d_relu"
|
||||||
|
|
||||||
|
return quantized_conv2d(), ref_net, op_name
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"params",
|
||||||
|
[
|
||||||
|
pytest.param(
|
||||||
|
{"weights_shape": [1, 3, 3, 3], "strides": 1, "pads": 0, "dilations": 1, "groups": 1},
|
||||||
|
marks=pytest.mark.xfail(
|
||||||
|
reason="Output channels equal to 1 creates output that fails to cast to contiguous."
|
||||||
|
),
|
||||||
|
),
|
||||||
|
{"weights_shape": [2, 3, 3, 3], "strides": 1, "pads": 0, "dilations": 1, "groups": 1},
|
||||||
|
{"weights_shape": [2, 3, 3, 3], "strides": 2, "pads": 0, "dilations": 1, "groups": 1},
|
||||||
|
{"weights_shape": [2, 3, 3, 3], "strides": 1, "pads": 1, "dilations": 1, "groups": 1},
|
||||||
|
{"weights_shape": [2, 3, 3, 3], "strides": 1, "pads": 0, "dilations": 2, "groups": 1},
|
||||||
|
{"weights_shape": [2, 3, 3, 3], "strides": 1, "pads": [0, 1], "dilations": 1, "groups": 1},
|
||||||
|
{"weights_shape": [2, 3, 3, 3], "strides": 1, "pads": [1, 0], "dilations": 1, "groups": 1},
|
||||||
|
{"weights_shape": [3, 1, 3, 3], "strides": 1, "pads": 0, "dilations": 1, "groups": 3},
|
||||||
|
],
|
||||||
|
)
|
||||||
|
@pytest.mark.parametrize("bias", [True, False])
|
||||||
|
@pytest.mark.parametrize("relu", [True, False])
|
||||||
|
@pytest.mark.parametrize("scale", [1, 0.3, 1.3])
|
||||||
|
@pytest.mark.parametrize("zero_point", [0, 1])
|
||||||
|
@pytest.mark.nightly
|
||||||
|
# @pytest.mark.precommit Test disabled due to sporadic issues
|
||||||
|
def test_quantized_conv2d(self, params, bias, relu, scale, zero_point, ie_device, precision, ir_version):
|
||||||
|
self._test(
|
||||||
|
*self.create_model(**params, bias=bias, relu=relu, scale=scale, zero_point=zero_point),
|
||||||
|
ie_device,
|
||||||
|
precision,
|
||||||
|
ir_version,
|
||||||
|
trace_model=True,
|
||||||
|
freeze_model=False
|
||||||
|
)
|
@ -44,7 +44,7 @@ class TestQuantizedLinear(PytorchLayerTest):
|
|||||||
@pytest.mark.parametrize("zero_point", [0, 1])
|
@pytest.mark.parametrize("zero_point", [0, 1])
|
||||||
@pytest.mark.parametrize("trace", [True, False])
|
@pytest.mark.parametrize("trace", [True, False])
|
||||||
@pytest.mark.nightly
|
@pytest.mark.nightly
|
||||||
@pytest.mark.precommit
|
# @pytest.mark.precommit Test disabled due to sporadic issues
|
||||||
def test_quantized_linear(self, params, scale, zero_point, trace, ie_device, precision, ir_version):
|
def test_quantized_linear(self, params, scale, zero_point, trace, ie_device, precision, ir_version):
|
||||||
input_shape = params.get("input_shape")
|
input_shape = params.get("input_shape")
|
||||||
weight_shape = params.get("weight_shape")
|
weight_shape = params.get("weight_shape")
|
||||||
|
Loading…
Reference in New Issue
Block a user