[PT FE] Add helper for regular ops on quantized values (#18692)

* Add helper for regular ops

* Update op_table.cpp

* Support ops with more then 1 output

* Uncheck ops that return integer/boolean type
This commit is contained in:
Maxim Vafin 2023-07-22 16:47:44 +02:00 committed by GitHub
parent 7fc1fd155d
commit f96f021920
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 157 additions and 134 deletions

View File

@ -6,6 +6,7 @@
#include "openvino/opsets/opset10.hpp"
#include "utils.hpp"
#include "utils_quantize.hpp"
namespace ov {
namespace frontend {
@ -190,9 +191,9 @@ const std::map<std::string, CreatorFunction> get_supported_ops() {
{"aten::acos_", op::inplace_op<op::translate_1to1_match_1_inputs<opset10::Acos>>},
{"aten::acosh", op::translate_1to1_match_1_inputs_with_fp32_type_alignment<opset10::Acosh>},
{"aten::acosh_", op::inplace_op<op::translate_1to1_match_1_inputs<opset10::Acosh>>},
{"aten::adaptive_avg_pool2d", op::translate_1to1_match_2_inputs<opset10::AdaptiveAvgPool>},
{"aten::adaptive_avg_pool3d", op::translate_adaptive_avg_pool3d},
{"aten::adaptive_max_pool2d", op::translate_adaptive_max_pool2d},
{"aten::adaptive_avg_pool2d", op::quantizable_op<op::translate_1to1_match_2_inputs<opset10::AdaptiveAvgPool>>},
{"aten::adaptive_avg_pool3d", op::quantizable_op<op::translate_adaptive_avg_pool3d>},
{"aten::adaptive_max_pool2d", op::quantizable_op<op::translate_adaptive_max_pool2d>},
{"aten::add", op::translate_add},
{"aten::add_", op::inplace_op<op::translate_add>},
{"aten::addcmul", op::translate_addcmul},
@ -211,9 +212,9 @@ const std::map<std::string, CreatorFunction> get_supported_ops() {
{"aten::atan_", op::inplace_op<op::translate_1to1_match_1_inputs<opset10::Atan>>},
{"aten::atanh", op::translate_1to1_match_1_inputs_with_fp32_type_alignment<opset10::Atanh>},
{"aten::atanh_", op::inplace_op<op::translate_1to1_match_1_inputs<opset10::Atanh>>},
{"aten::avg_pool1d", op::translate_avg_poolnd},
{"aten::avg_pool2d", op::translate_avg_poolnd},
{"aten::avg_pool3d", op::translate_avg_poolnd},
{"aten::avg_pool1d", op::quantizable_op<op::translate_avg_poolnd>},
{"aten::avg_pool2d", op::quantizable_op<op::translate_avg_poolnd>},
{"aten::avg_pool3d", op::quantizable_op<op::translate_avg_poolnd>},
{"aten::baddbmm", op::translate_addmm},
{"aten::batch_norm", op::translate_batch_norm},
{"aten::bitwise_not", op::translate_bitwise_not},
@ -262,7 +263,7 @@ const std::map<std::string, CreatorFunction> get_supported_ops() {
{"aten::fake_quantize_per_channel_affine", op::translate_fake_quantize_per_channel_affine},
{"aten::fake_quantize_per_tensor_affine", op::translate_fake_quantize_per_tensor_affine},
{"aten::fill_", op::inplace_op<op::translate_fill_>},
{"aten::flatten", op::translate_flatten},
{"aten::flatten", op::quantizable_op<op::translate_flatten>},
{"aten::flip", op::translate_flip},
{"aten::floor", op::translate_1to1_match_1_inputs<opset10::Floor>},
{"aten::floor_", op::inplace_op<op::translate_1to1_match_1_inputs<opset10::Floor>>},
@ -278,11 +279,11 @@ const std::map<std::string, CreatorFunction> get_supported_ops() {
{"aten::grid_sampler", op::translate_grid_sampler},
{"aten::group_norm", op::translate_group_norm},
{"aten::gt", op::translate_1to1_match_2_inputs_align_types<opset10::Greater>},
{"aten::hardsigmoid", op::translate_1to1_match_1_inputs<opset10::HSigmoid>},
{"aten::hardswish", op::translate_1to1_match_1_inputs<opset10::HSwish>},
{"aten::hardswish_", op::inplace_op<op::translate_1to1_match_1_inputs<opset10::HSwish>>},
{"aten::hardtanh", op::translate_hardtanh},
{"aten::hardtanh_", op::inplace_op<op::translate_hardtanh>},
{"aten::hardsigmoid", op::quantizable_op<op::translate_1to1_match_1_inputs<opset10::HSigmoid>>},
{"aten::hardswish", op::quantizable_op<op::translate_1to1_match_1_inputs<opset10::HSwish>>},
{"aten::hardswish_", op::quantizable_op<op::inplace_op<op::translate_1to1_match_1_inputs<opset10::HSwish>>>},
{"aten::hardtanh", op::quantizable_op<op::translate_hardtanh>},
{"aten::hardtanh_", op::inplace_op<op::quantizable_op<op::translate_hardtanh>>},
{"aten::im2col", op::translate_im2col},
{"aten::index_put_", op::inplace_op<op::translate_index_put_>},
{"aten::index_select", op::translate_index_select},
@ -314,10 +315,10 @@ const std::map<std::string, CreatorFunction> get_supported_ops() {
{"aten::masked_fill_", op::inplace_op<op::translate_masked_fill>},
{"aten::matmul", op::translate_1to1_match_2_inputs<opset10::MatMul>},
{"aten::max", op::translate_max},
{"aten::max_pool1d", op::translate_max_poolnd},
{"aten::max_pool2d", op::translate_max_poolnd},
{"aten::max_pool3d", op::translate_max_poolnd},
{"aten::mean", op::translate_mean},
{"aten::max_pool1d", op::quantizable_op<op::translate_max_poolnd>},
{"aten::max_pool2d", op::quantizable_op<op::translate_max_poolnd>},
{"aten::max_pool3d", op::quantizable_op<op::translate_max_poolnd>},
{"aten::mean", op::quantizable_op<op::translate_mean>},
{"aten::meshgrid", op::translate_meshgrid},
{"aten::min", op::translate_min},
{"aten::mm", op::translate_1to1_match_2_inputs<opset10::MatMul>},
@ -364,7 +365,7 @@ const std::map<std::string, CreatorFunction> get_supported_ops() {
{"aten::scaled_dot_product_attention", op::translate_scaled_dot_product_attention},
{"aten::scatter", op::translate_scatter},
{"aten::scatter_", op::inplace_op<op::translate_scatter>},
{"aten::select", op::translate_select},
{"aten::select", op::quantizable_op<op::translate_select>},
{"aten::selu", op::translate_selu},
{"aten::selu_", op::inplace_op<op::translate_selu>},
{"aten::sigmoid", op::translate_1to1_match_1_inputs_with_fp32_type_alignment<opset10::Sigmoid>},
@ -377,12 +378,12 @@ const std::map<std::string, CreatorFunction> get_supported_ops() {
{"aten::sinh", op::translate_1to1_match_1_inputs_with_fp32_type_alignment<opset10::Sinh>},
{"aten::sinh_", op::inplace_op<op::translate_1to1_match_1_inputs<opset10::Sinh>>},
{"aten::size", op::translate_size},
{"aten::slice", op::translate_slice},
{"aten::slice", op::quantizable_op<op::translate_slice>},
{"aten::softmax", op::translate_softmax},
{"aten::sort", op::translate_sort},
{"aten::sqrt", op::translate_1to1_match_1_inputs_with_fp32_type_alignment<opset10::Sqrt>},
{"aten::square", op::translate_square},
{"aten::squeeze", op::translate_squeeze},
{"aten::squeeze", op::quantizable_op<op::translate_squeeze>},
{"aten::sub", op::translate_sub},
{"aten::sub_", op::inplace_op<op::translate_sub>},
{"aten::sum", op::translate_sum},
@ -395,7 +396,7 @@ const std::map<std::string, CreatorFunction> get_supported_ops() {
{"aten::tensor", op::translate_as_tensor},
{"aten::to", op::translate_to},
{"aten::topk", op::translate_topk},
{"aten::transpose", op::translate_transpose},
{"aten::transpose", op::quantizable_op<op::translate_transpose>},
{"aten::tril", op::translate_tril},
{"aten::tril_", op::inplace_op<op::translate_tril>},
{"aten::triu", op::translate_triu},
@ -404,8 +405,8 @@ const std::map<std::string, CreatorFunction> get_supported_ops() {
op::translate_1to1_match_2_inputs<opset10::ConvertLike>}, // TODO: overflow semantics is different
{"aten::unflatten", op::translate_unflatten},
{"aten::unfold", op::translate_unfold},
{"aten::unsqueeze", op::translate_1to1_match_2_inputs<opset10::Unsqueeze>},
{"aten::unsqueeze_", op::inplace_op<op::translate_1to1_match_2_inputs<opset10::Unsqueeze>>},
{"aten::unsqueeze", op::quantizable_op<op::translate_1to1_match_2_inputs<opset10::Unsqueeze>>},
{"aten::unsqueeze_", op::quantizable_op<op::inplace_op<op::translate_1to1_match_2_inputs<opset10::Unsqueeze>>>},
{"aten::upsample_bicubic2d", op::translate_upsample_bicubic2d},
{"aten::upsample_bilinear2d", op::translate_upsample_bilinear2d},
{"aten::_upsample_bicubic2d_aa", op::translate_upsample_bicubic2d_aa},
@ -417,7 +418,7 @@ const std::map<std::string, CreatorFunction> get_supported_ops() {
{"aten::upsample_trilinear3d", op::translate_upsample_trilinear3d},
{"aten::var", op::translate_var},
{"aten::var_mean", op::translate_var_mean},
{"aten::view", op::translate_reshape},
{"aten::view", op::quantizable_op<op::translate_reshape>},
{"aten::where", op::translate_where},
{"aten::zero_", op::inplace_op<op::translate_zeros_like>},
{"aten::zeros", op::translate_zeros},

View File

@ -20,13 +20,13 @@ namespace pytorch {
using namespace ov::op;
ov::Output<ov::Node> quantize(const NodeContext& context,
std::shared_ptr<ov::Node> input,
std::shared_ptr<ov::Node> scale,
std::shared_ptr<ov::Node> zero_point,
std::shared_ptr<ov::Node> axis,
ov::element::Type dtype,
QuantizedPtNodeType quantization_type) {
Output<Node> quantize(const NodeContext& context,
const Output<Node>& input,
const Output<Node>& scale,
const Output<Node>& zero_point,
const Output<Node>& axis,
element::Type dtype,
QuantizedPtNodeType quantization_type) {
if (quantization_type == QuantizedPtNodeType::QUANTIZE_PER_TENSOR) {
const auto input_convert = context.mark_node(std::make_shared<v0::Convert>(input, element::f32));
const auto scale_convert = context.mark_node(std::make_shared<v0::Convert>(scale, element::f32));
@ -63,7 +63,7 @@ ov::Output<ov::Node> quantize(const NodeContext& context,
zero_point_convert,
dtype));
} else if (quantization_type == QuantizedPtNodeType::QUANTIZE_PER_CHANNEL) {
FRONT_END_OP_CONVERSION_CHECK(axis, "Axis cannot be null for quantize_per_channel.");
FRONT_END_OP_CONVERSION_CHECK(axis.get_node(), "Axis cannot be null for quantize_per_channel.");
const auto input_convert = context.mark_node(std::make_shared<v0::Convert>(input, element::f32));
const auto scales_convert = context.mark_node(std::make_shared<v0::Convert>(scale, element::f32));
const auto zero_points_convert = context.mark_node(std::make_shared<v0::Convert>(zero_point, element::f32));
@ -119,46 +119,19 @@ ov::Output<ov::Node> quantize(const NodeContext& context,
FRONT_END_OP_CONVERSION_CHECK(false, "Got unknown quantization method in quantize.");
}
// ========================================================
ov::Output<ov::Node> quantize(const NodeContext& context,
ov::Output<ov::Node> input,
ov::Output<ov::Node> scale,
ov::Output<ov::Node> zero_point,
ov::element::Type dtype,
QuantizedPtNodeType quantization_type) {
return quantize(context,
input.get_node_shared_ptr(),
scale.get_node_shared_ptr(),
zero_point.get_node_shared_ptr(),
nullptr,
dtype,
quantization_type);
Output<Node> quantize(const NodeContext& context,
const Output<Node>& input,
const Output<Node>& scale,
const Output<Node>& zero_point,
element::Type dtype,
QuantizedPtNodeType quantization_type) {
return quantize(context, input, scale, zero_point, Output<Node>(), dtype, quantization_type);
}
ov::Output<ov::Node> quantize(const NodeContext& context,
ov::Output<ov::Node> input,
ov::Output<ov::Node> scale,
ov::Output<ov::Node> zero_point,
ov::Output<ov::Node> axis,
ov::element::Type dtype,
QuantizedPtNodeType quantization_type) {
return quantize(context,
input.get_node_shared_ptr(),
scale.get_node_shared_ptr(),
zero_point.get_node_shared_ptr(),
axis.get_node_shared_ptr(),
dtype,
quantization_type);
}
ov::Output<ov::Node> quantize(const NodeContext& context,
ov::Output<ov::Node> input,
ov::Output<ov::Node> quantized_node) {
std::shared_ptr<QuantizedPtNode> quantized_pt_node;
if ((quantized_pt_node = cast_quantized_fw_node(quantized_node.get_node_shared_ptr()))) {
Output<Node> quantize(const NodeContext& context, const Output<Node>& input, const Output<Node>& quantized_node) {
if (const auto quantized_pt_node = cast_quantized_fw_node(quantized_node.get_node_shared_ptr())) {
return quantize(context,
input.get_node_shared_ptr(),
input,
quantized_pt_node->get_scale(),
quantized_pt_node->get_zero_point(),
quantized_pt_node->get_axis(),
@ -168,17 +141,16 @@ ov::Output<ov::Node> quantize(const NodeContext& context,
FRONT_END_OP_CONVERSION_CHECK(false, "Failed to convert a node to QuantizedPtNode");
}
ov::Output<ov::Node> quantize(const NodeContext& context,
ov::Output<ov::Node> input,
ov::Output<ov::Node> scale,
ov::Output<ov::Node> zero_point,
ov::Output<ov::Node> quantized_node) {
std::shared_ptr<QuantizedPtNode> quantized_pt_node;
if ((quantized_pt_node = cast_quantized_fw_node(quantized_node.get_node_shared_ptr()))) {
Output<Node> quantize(const NodeContext& context,
const Output<Node>& input,
const Output<Node>& scale,
const Output<Node>& zero_point,
const Output<Node>& quantized_node) {
if (const auto quantized_pt_node = cast_quantized_fw_node(quantized_node.get_node_shared_ptr())) {
return quantize(context,
input.get_node_shared_ptr(),
scale.get_node_shared_ptr(),
zero_point.get_node_shared_ptr(),
input,
scale,
zero_point,
quantized_pt_node->get_axis(),
quantized_pt_node->get_dtype(),
quantized_pt_node->get_type());
@ -186,7 +158,7 @@ ov::Output<ov::Node> quantize(const NodeContext& context,
FRONT_END_OP_CONVERSION_CHECK(false, "Failed to convert a node to QuantizedPtNode");
}
std::shared_ptr<QuantizedPtNode> cast_quantized_fw_node(ov::Output<Node> node) {
std::shared_ptr<QuantizedPtNode> cast_quantized_fw_node(Output<Node> node) {
auto quant_node = std::dynamic_pointer_cast<QuantizedPtNode>(node.get_node_shared_ptr());
if (!quant_node) {
return nullptr;
@ -198,7 +170,7 @@ std::shared_ptr<QuantizedPtNode> cast_quantized_fw_node(ov::Output<Node> node) {
return quant_node;
}
std::shared_ptr<QuantizedPtNode> cast_quantized_fw_node(ov::Output<Node> node, const std::string& type) {
std::shared_ptr<QuantizedPtNode> cast_quantized_fw_node(Output<Node> node, const std::string& type) {
auto quant_node = std::dynamic_pointer_cast<QuantizedPtNode>(node.get_node_shared_ptr());
if (!quant_node) {
return nullptr;

View File

@ -15,22 +15,19 @@ enum QuantizedPtNodeType { QUANTIZE_PER_TENSOR, QUANTIZE_PER_CHANNEL };
class QuantizedPtNode : public PtFrameworkNode {
public:
OPENVINO_OP("QuantizedPtNode", "util", ::ov::frontend::pytorch::PtFrameworkNode);
OPENVINO_OP("QuantizedPtNode", "util", PtFrameworkNode);
static constexpr const char* quantized_node_type_key = "QuantizedPtTypeName";
static constexpr const char* quantize_per_tensor = "quantize_per_tensor";
static constexpr const char* quantize_per_channel = "quantize_per_channel";
QuantizedPtNode(const QuantizedPtNodeType type,
const NodeContext& context,
const ov::Output<ov::Node> input,
const ov::Output<ov::Node> scale,
const ov::Output<ov::Node> zero_point,
const Output<Node> input,
const Output<Node> scale,
const Output<Node> zero_point,
element::Type& dtype)
: PtFrameworkNode(context.get_decoder(), {input}, 1, false),
type(type),
scale(scale.get_node_shared_ptr()),
zero_point(zero_point.get_node_shared_ptr()),
axis(nullptr) {
: PtFrameworkNode(context.get_decoder(), {input, scale, zero_point}, 1, false),
type(type) {
ov::op::util::FrameworkNodeAttrs attrs = get_attrs();
if (type == QuantizedPtNodeType::QUANTIZE_PER_TENSOR) {
attrs[quantized_node_type_key] = quantize_per_tensor;
@ -45,16 +42,13 @@ public:
QuantizedPtNode(const QuantizedPtNodeType type,
const NodeContext& context,
const ov::Output<ov::Node> input,
const ov::Output<ov::Node> scale,
const ov::Output<ov::Node> zero_point,
const ov::Output<ov::Node> axis,
const Output<Node> input,
const Output<Node> scale,
const Output<Node> zero_point,
const Output<Node> axis,
element::Type& dtype)
: PtFrameworkNode(context.get_decoder(), {input}, 1, false),
type(type),
scale(scale.get_node_shared_ptr()),
zero_point(zero_point.get_node_shared_ptr()),
axis(axis.get_node_shared_ptr()) {
: PtFrameworkNode(context.get_decoder(), {input, scale, zero_point, axis}, 1, false),
type(type) {
ov::op::util::FrameworkNodeAttrs attrs = get_attrs();
if (type == QuantizedPtNodeType::QUANTIZE_PER_TENSOR) {
attrs[quantized_node_type_key] = quantize_per_tensor;
@ -67,66 +61,86 @@ public:
this->dtype = dtype;
}
const std::shared_ptr<ov::Node> get_scale() {
return scale;
const Output<Node> get_scale() const {
return input_value(1);
}
const std::shared_ptr<ov::Node> get_zero_point() {
return zero_point;
const Output<Node> get_zero_point() const {
return input_value(2);
}
const std::shared_ptr<ov::Node> get_axis() {
return axis;
const Output<Node> get_axis() const {
if (inputs().size() < 4) {
return Output<Node>();
}
return input_value(3);
}
const QuantizedPtNodeType get_type() {
const QuantizedPtNodeType get_type() const {
return type;
}
const element::Type get_dtype() {
const element::Type get_dtype() const {
return dtype;
}
private:
const QuantizedPtNodeType type;
std::shared_ptr<ov::Node> scale;
std::shared_ptr<ov::Node> zero_point;
std::shared_ptr<ov::Node> axis;
element::Type dtype;
};
/**
* Quantizes input node with the given parameters. Returns a shared pointer to the new QuantizedPtNode.
*/
ov::Output<ov::Node> quantize(const NodeContext& context,
ov::Output<ov::Node> input,
ov::Output<ov::Node> scale,
ov::Output<ov::Node> zero_point,
ov::element::Type dtype,
QuantizedPtNodeType quantization_type);
ov::Output<ov::Node> quantize(const NodeContext& context,
ov::Output<ov::Node> input,
ov::Output<ov::Node> scale,
ov::Output<ov::Node> zero_point,
ov::Output<ov::Node> axis,
ov::element::Type dtype,
QuantizedPtNodeType quantization_type);
Output<Node> quantize(const NodeContext& context,
const Output<Node>& input,
const Output<Node>& scale,
const Output<Node>& zero_point,
element::Type dtype,
QuantizedPtNodeType quantization_type);
Output<Node> quantize(const NodeContext& context,
const Output<Node>& input,
const Output<Node>& scale,
const Output<Node>& zero_point,
const Output<Node>& axis,
element::Type dtype,
QuantizedPtNodeType quantization_type);
/**
* Quantizes input node like the quantized node. Returns a shared pointer to the new QuantizedPtNode.
*/
ov::Output<ov::Node> quantize(const NodeContext& context,
ov::Output<ov::Node> input,
ov::Output<ov::Node> quantized_node);
Output<Node> quantize(const NodeContext& context, Output<Node> input, Output<Node> quantized_node);
/**
* Quantizes input node like the quantized node, with new scale and zero_point parameters. Returns a shared pointer to
* the new QuantizedPtNode.
*/
ov::Output<ov::Node> quantize(const NodeContext& context,
ov::Output<ov::Node> input,
ov::Output<ov::Node> scale,
ov::Output<ov::Node> zero_point,
ov::Output<ov::Node> quantized_node);
Output<Node> quantize(const NodeContext& context,
const Output<Node>& input,
const Output<Node>& scale,
const Output<Node>& zero_point,
const Output<Node>& quantized_node);
std::shared_ptr<QuantizedPtNode> cast_quantized_fw_node(Output<Node> node);
std::shared_ptr<QuantizedPtNode> cast_quantized_fw_node(Output<Node> node, const std::string& type);
namespace op {
/**
* Modifies conversion function to support quantized case. When input is quantized it is processed as quantized op.
*/
template <OutputVector (*T)(const NodeContext&), size_t in_idx = 0, size_t out_idx = 0>
OutputVector quantizable_op(const NodeContext& context) {
auto translation_res = T(context);
FRONT_END_OP_CONVERSION_CHECK(translation_res.size() > out_idx, "Not enough outputs to apply quantization.");
if (const auto quantized_pt_node = cast_quantized_fw_node(context.get_input(in_idx).get_node_shared_ptr())) {
return {quantize(context,
translation_res[out_idx],
quantized_pt_node->get_scale(),
quantized_pt_node->get_zero_point(),
quantized_pt_node->get_axis(),
quantized_pt_node->get_dtype(),
quantized_pt_node->get_type())};
}
return translation_res;
}
} // namespace op
std::shared_ptr<QuantizedPtNode> cast_quantized_fw_node(ov::Output<ov::Node> node);
std::shared_ptr<QuantizedPtNode> cast_quantized_fw_node(ov::Output<ov::Node> node, const std::string& type);
} // namespace pytorch
} // namespace frontend
} // namespace ov

View File

@ -6,6 +6,7 @@ import torch
import numpy as np
from pytorch_layer_test_class import PytorchLayerTest
class TestQuantizedLinear(PytorchLayerTest):
def _prepare_input(self, input_shape=(2, 2)):
return (np.random.randn(*input_shape).astype(np.float32),)
@ -16,10 +17,12 @@ class TestQuantizedLinear(PytorchLayerTest):
def __init__(self, weight_shape, is_bias, scale, zero_point):
super(aten_quantized_linear, self).__init__()
if is_bias:
self.linear = torch.ao.nn.quantized.Linear(weight_shape[-1], weight_shape[0], True)
self.linear = torch.ao.nn.quantized.Linear(
weight_shape[-1], weight_shape[0], True)
torch.nn.init.normal_(self.linear.bias())
else:
self.linear = torch.ao.nn.quantized.Linear(weight_shape[-1], weight_shape[0], False)
self.linear = torch.ao.nn.quantized.Linear(
weight_shape[-1], weight_shape[0], False)
self.linear.scale = float(scale)
self.linear.zero_point = int(zero_point)
@ -31,6 +34,31 @@ class TestQuantizedLinear(PytorchLayerTest):
return aten_quantized_linear(weight_shape, is_bias, scale, zero_point), ref_net, "quantized::linear"
def create_hardtanh_model(self, weight_shape, is_bias, scale, zero_point, inplace):
class aten_quantized_linear(torch.nn.Module):
def __init__(self, weight_shape, is_bias, scale, zero_point, inplace):
super(aten_quantized_linear, self).__init__()
self.hardtanh = torch.nn.Hardtanh(inplace=inplace)
if is_bias:
self.linear = torch.ao.nn.quantized.Linear(
weight_shape[-1], weight_shape[0], True)
torch.nn.init.normal_(self.linear.bias())
else:
self.linear = torch.ao.nn.quantized.Linear(
weight_shape[-1], weight_shape[0], False)
self.linear.scale = float(scale)
self.linear.zero_point = int(zero_point)
def forward(self, inp):
inp_q = torch.quantize_per_tensor(inp, 1., 0, torch.quint8)
inp_q = self.hardtanh(inp_q)
return torch.dequantize(self.linear(inp_q))
ref_net = None
return aten_quantized_linear(weight_shape, is_bias, scale, zero_point, inplace), ref_net, ["quantized::linear", "aten::hardtanh_" if inplace else "aten::hardtanh"]
@pytest.mark.parametrize("params", [
{'input_shape': [3, 9], 'weight_shape': [10, 9]},
{'input_shape': [3, 9], 'weight_shape': [9]},
@ -51,3 +79,11 @@ class TestQuantizedLinear(PytorchLayerTest):
bias = params.get("bias", False)
self._test(*self.create_model(weight_shape, bias, scale, zero_point), ie_device, precision, ir_version,
kwargs_to_prepare_input={"input_shape": input_shape}, trace_model=trace, freeze_model=False)
@pytest.mark.parametrize("trace", [True, False])
@pytest.mark.parametrize("inplace", [True, False])
@pytest.mark.nightly
@pytest.mark.precommit
def test_quantized_hardtanh_linear(self, trace, inplace, ie_device, precision, ir_version):
self._test(*self.create_hardtanh_model([10, 9], True, 1, 0.3, inplace), ie_device, precision, ir_version,
kwargs_to_prepare_input={"input_shape": [2, 3, 9]}, trace_model=trace, freeze_model=False)