[PT FE] Add helper for regular ops on quantized values (#18692)

* Add helper for regular ops

* Update op_table.cpp

* Support ops with more then 1 output

* Uncheck ops that return integer/boolean type
This commit is contained in:
Maxim Vafin 2023-07-22 16:47:44 +02:00 committed by GitHub
parent 7fc1fd155d
commit f96f021920
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 157 additions and 134 deletions

View File

@ -6,6 +6,7 @@
#include "openvino/opsets/opset10.hpp" #include "openvino/opsets/opset10.hpp"
#include "utils.hpp" #include "utils.hpp"
#include "utils_quantize.hpp"
namespace ov { namespace ov {
namespace frontend { namespace frontend {
@ -190,9 +191,9 @@ const std::map<std::string, CreatorFunction> get_supported_ops() {
{"aten::acos_", op::inplace_op<op::translate_1to1_match_1_inputs<opset10::Acos>>}, {"aten::acos_", op::inplace_op<op::translate_1to1_match_1_inputs<opset10::Acos>>},
{"aten::acosh", op::translate_1to1_match_1_inputs_with_fp32_type_alignment<opset10::Acosh>}, {"aten::acosh", op::translate_1to1_match_1_inputs_with_fp32_type_alignment<opset10::Acosh>},
{"aten::acosh_", op::inplace_op<op::translate_1to1_match_1_inputs<opset10::Acosh>>}, {"aten::acosh_", op::inplace_op<op::translate_1to1_match_1_inputs<opset10::Acosh>>},
{"aten::adaptive_avg_pool2d", op::translate_1to1_match_2_inputs<opset10::AdaptiveAvgPool>}, {"aten::adaptive_avg_pool2d", op::quantizable_op<op::translate_1to1_match_2_inputs<opset10::AdaptiveAvgPool>>},
{"aten::adaptive_avg_pool3d", op::translate_adaptive_avg_pool3d}, {"aten::adaptive_avg_pool3d", op::quantizable_op<op::translate_adaptive_avg_pool3d>},
{"aten::adaptive_max_pool2d", op::translate_adaptive_max_pool2d}, {"aten::adaptive_max_pool2d", op::quantizable_op<op::translate_adaptive_max_pool2d>},
{"aten::add", op::translate_add}, {"aten::add", op::translate_add},
{"aten::add_", op::inplace_op<op::translate_add>}, {"aten::add_", op::inplace_op<op::translate_add>},
{"aten::addcmul", op::translate_addcmul}, {"aten::addcmul", op::translate_addcmul},
@ -211,9 +212,9 @@ const std::map<std::string, CreatorFunction> get_supported_ops() {
{"aten::atan_", op::inplace_op<op::translate_1to1_match_1_inputs<opset10::Atan>>}, {"aten::atan_", op::inplace_op<op::translate_1to1_match_1_inputs<opset10::Atan>>},
{"aten::atanh", op::translate_1to1_match_1_inputs_with_fp32_type_alignment<opset10::Atanh>}, {"aten::atanh", op::translate_1to1_match_1_inputs_with_fp32_type_alignment<opset10::Atanh>},
{"aten::atanh_", op::inplace_op<op::translate_1to1_match_1_inputs<opset10::Atanh>>}, {"aten::atanh_", op::inplace_op<op::translate_1to1_match_1_inputs<opset10::Atanh>>},
{"aten::avg_pool1d", op::translate_avg_poolnd}, {"aten::avg_pool1d", op::quantizable_op<op::translate_avg_poolnd>},
{"aten::avg_pool2d", op::translate_avg_poolnd}, {"aten::avg_pool2d", op::quantizable_op<op::translate_avg_poolnd>},
{"aten::avg_pool3d", op::translate_avg_poolnd}, {"aten::avg_pool3d", op::quantizable_op<op::translate_avg_poolnd>},
{"aten::baddbmm", op::translate_addmm}, {"aten::baddbmm", op::translate_addmm},
{"aten::batch_norm", op::translate_batch_norm}, {"aten::batch_norm", op::translate_batch_norm},
{"aten::bitwise_not", op::translate_bitwise_not}, {"aten::bitwise_not", op::translate_bitwise_not},
@ -262,7 +263,7 @@ const std::map<std::string, CreatorFunction> get_supported_ops() {
{"aten::fake_quantize_per_channel_affine", op::translate_fake_quantize_per_channel_affine}, {"aten::fake_quantize_per_channel_affine", op::translate_fake_quantize_per_channel_affine},
{"aten::fake_quantize_per_tensor_affine", op::translate_fake_quantize_per_tensor_affine}, {"aten::fake_quantize_per_tensor_affine", op::translate_fake_quantize_per_tensor_affine},
{"aten::fill_", op::inplace_op<op::translate_fill_>}, {"aten::fill_", op::inplace_op<op::translate_fill_>},
{"aten::flatten", op::translate_flatten}, {"aten::flatten", op::quantizable_op<op::translate_flatten>},
{"aten::flip", op::translate_flip}, {"aten::flip", op::translate_flip},
{"aten::floor", op::translate_1to1_match_1_inputs<opset10::Floor>}, {"aten::floor", op::translate_1to1_match_1_inputs<opset10::Floor>},
{"aten::floor_", op::inplace_op<op::translate_1to1_match_1_inputs<opset10::Floor>>}, {"aten::floor_", op::inplace_op<op::translate_1to1_match_1_inputs<opset10::Floor>>},
@ -278,11 +279,11 @@ const std::map<std::string, CreatorFunction> get_supported_ops() {
{"aten::grid_sampler", op::translate_grid_sampler}, {"aten::grid_sampler", op::translate_grid_sampler},
{"aten::group_norm", op::translate_group_norm}, {"aten::group_norm", op::translate_group_norm},
{"aten::gt", op::translate_1to1_match_2_inputs_align_types<opset10::Greater>}, {"aten::gt", op::translate_1to1_match_2_inputs_align_types<opset10::Greater>},
{"aten::hardsigmoid", op::translate_1to1_match_1_inputs<opset10::HSigmoid>}, {"aten::hardsigmoid", op::quantizable_op<op::translate_1to1_match_1_inputs<opset10::HSigmoid>>},
{"aten::hardswish", op::translate_1to1_match_1_inputs<opset10::HSwish>}, {"aten::hardswish", op::quantizable_op<op::translate_1to1_match_1_inputs<opset10::HSwish>>},
{"aten::hardswish_", op::inplace_op<op::translate_1to1_match_1_inputs<opset10::HSwish>>}, {"aten::hardswish_", op::quantizable_op<op::inplace_op<op::translate_1to1_match_1_inputs<opset10::HSwish>>>},
{"aten::hardtanh", op::translate_hardtanh}, {"aten::hardtanh", op::quantizable_op<op::translate_hardtanh>},
{"aten::hardtanh_", op::inplace_op<op::translate_hardtanh>}, {"aten::hardtanh_", op::inplace_op<op::quantizable_op<op::translate_hardtanh>>},
{"aten::im2col", op::translate_im2col}, {"aten::im2col", op::translate_im2col},
{"aten::index_put_", op::inplace_op<op::translate_index_put_>}, {"aten::index_put_", op::inplace_op<op::translate_index_put_>},
{"aten::index_select", op::translate_index_select}, {"aten::index_select", op::translate_index_select},
@ -314,10 +315,10 @@ const std::map<std::string, CreatorFunction> get_supported_ops() {
{"aten::masked_fill_", op::inplace_op<op::translate_masked_fill>}, {"aten::masked_fill_", op::inplace_op<op::translate_masked_fill>},
{"aten::matmul", op::translate_1to1_match_2_inputs<opset10::MatMul>}, {"aten::matmul", op::translate_1to1_match_2_inputs<opset10::MatMul>},
{"aten::max", op::translate_max}, {"aten::max", op::translate_max},
{"aten::max_pool1d", op::translate_max_poolnd}, {"aten::max_pool1d", op::quantizable_op<op::translate_max_poolnd>},
{"aten::max_pool2d", op::translate_max_poolnd}, {"aten::max_pool2d", op::quantizable_op<op::translate_max_poolnd>},
{"aten::max_pool3d", op::translate_max_poolnd}, {"aten::max_pool3d", op::quantizable_op<op::translate_max_poolnd>},
{"aten::mean", op::translate_mean}, {"aten::mean", op::quantizable_op<op::translate_mean>},
{"aten::meshgrid", op::translate_meshgrid}, {"aten::meshgrid", op::translate_meshgrid},
{"aten::min", op::translate_min}, {"aten::min", op::translate_min},
{"aten::mm", op::translate_1to1_match_2_inputs<opset10::MatMul>}, {"aten::mm", op::translate_1to1_match_2_inputs<opset10::MatMul>},
@ -364,7 +365,7 @@ const std::map<std::string, CreatorFunction> get_supported_ops() {
{"aten::scaled_dot_product_attention", op::translate_scaled_dot_product_attention}, {"aten::scaled_dot_product_attention", op::translate_scaled_dot_product_attention},
{"aten::scatter", op::translate_scatter}, {"aten::scatter", op::translate_scatter},
{"aten::scatter_", op::inplace_op<op::translate_scatter>}, {"aten::scatter_", op::inplace_op<op::translate_scatter>},
{"aten::select", op::translate_select}, {"aten::select", op::quantizable_op<op::translate_select>},
{"aten::selu", op::translate_selu}, {"aten::selu", op::translate_selu},
{"aten::selu_", op::inplace_op<op::translate_selu>}, {"aten::selu_", op::inplace_op<op::translate_selu>},
{"aten::sigmoid", op::translate_1to1_match_1_inputs_with_fp32_type_alignment<opset10::Sigmoid>}, {"aten::sigmoid", op::translate_1to1_match_1_inputs_with_fp32_type_alignment<opset10::Sigmoid>},
@ -377,12 +378,12 @@ const std::map<std::string, CreatorFunction> get_supported_ops() {
{"aten::sinh", op::translate_1to1_match_1_inputs_with_fp32_type_alignment<opset10::Sinh>}, {"aten::sinh", op::translate_1to1_match_1_inputs_with_fp32_type_alignment<opset10::Sinh>},
{"aten::sinh_", op::inplace_op<op::translate_1to1_match_1_inputs<opset10::Sinh>>}, {"aten::sinh_", op::inplace_op<op::translate_1to1_match_1_inputs<opset10::Sinh>>},
{"aten::size", op::translate_size}, {"aten::size", op::translate_size},
{"aten::slice", op::translate_slice}, {"aten::slice", op::quantizable_op<op::translate_slice>},
{"aten::softmax", op::translate_softmax}, {"aten::softmax", op::translate_softmax},
{"aten::sort", op::translate_sort}, {"aten::sort", op::translate_sort},
{"aten::sqrt", op::translate_1to1_match_1_inputs_with_fp32_type_alignment<opset10::Sqrt>}, {"aten::sqrt", op::translate_1to1_match_1_inputs_with_fp32_type_alignment<opset10::Sqrt>},
{"aten::square", op::translate_square}, {"aten::square", op::translate_square},
{"aten::squeeze", op::translate_squeeze}, {"aten::squeeze", op::quantizable_op<op::translate_squeeze>},
{"aten::sub", op::translate_sub}, {"aten::sub", op::translate_sub},
{"aten::sub_", op::inplace_op<op::translate_sub>}, {"aten::sub_", op::inplace_op<op::translate_sub>},
{"aten::sum", op::translate_sum}, {"aten::sum", op::translate_sum},
@ -395,7 +396,7 @@ const std::map<std::string, CreatorFunction> get_supported_ops() {
{"aten::tensor", op::translate_as_tensor}, {"aten::tensor", op::translate_as_tensor},
{"aten::to", op::translate_to}, {"aten::to", op::translate_to},
{"aten::topk", op::translate_topk}, {"aten::topk", op::translate_topk},
{"aten::transpose", op::translate_transpose}, {"aten::transpose", op::quantizable_op<op::translate_transpose>},
{"aten::tril", op::translate_tril}, {"aten::tril", op::translate_tril},
{"aten::tril_", op::inplace_op<op::translate_tril>}, {"aten::tril_", op::inplace_op<op::translate_tril>},
{"aten::triu", op::translate_triu}, {"aten::triu", op::translate_triu},
@ -404,8 +405,8 @@ const std::map<std::string, CreatorFunction> get_supported_ops() {
op::translate_1to1_match_2_inputs<opset10::ConvertLike>}, // TODO: overflow semantics is different op::translate_1to1_match_2_inputs<opset10::ConvertLike>}, // TODO: overflow semantics is different
{"aten::unflatten", op::translate_unflatten}, {"aten::unflatten", op::translate_unflatten},
{"aten::unfold", op::translate_unfold}, {"aten::unfold", op::translate_unfold},
{"aten::unsqueeze", op::translate_1to1_match_2_inputs<opset10::Unsqueeze>}, {"aten::unsqueeze", op::quantizable_op<op::translate_1to1_match_2_inputs<opset10::Unsqueeze>>},
{"aten::unsqueeze_", op::inplace_op<op::translate_1to1_match_2_inputs<opset10::Unsqueeze>>}, {"aten::unsqueeze_", op::quantizable_op<op::inplace_op<op::translate_1to1_match_2_inputs<opset10::Unsqueeze>>>},
{"aten::upsample_bicubic2d", op::translate_upsample_bicubic2d}, {"aten::upsample_bicubic2d", op::translate_upsample_bicubic2d},
{"aten::upsample_bilinear2d", op::translate_upsample_bilinear2d}, {"aten::upsample_bilinear2d", op::translate_upsample_bilinear2d},
{"aten::_upsample_bicubic2d_aa", op::translate_upsample_bicubic2d_aa}, {"aten::_upsample_bicubic2d_aa", op::translate_upsample_bicubic2d_aa},
@ -417,7 +418,7 @@ const std::map<std::string, CreatorFunction> get_supported_ops() {
{"aten::upsample_trilinear3d", op::translate_upsample_trilinear3d}, {"aten::upsample_trilinear3d", op::translate_upsample_trilinear3d},
{"aten::var", op::translate_var}, {"aten::var", op::translate_var},
{"aten::var_mean", op::translate_var_mean}, {"aten::var_mean", op::translate_var_mean},
{"aten::view", op::translate_reshape}, {"aten::view", op::quantizable_op<op::translate_reshape>},
{"aten::where", op::translate_where}, {"aten::where", op::translate_where},
{"aten::zero_", op::inplace_op<op::translate_zeros_like>}, {"aten::zero_", op::inplace_op<op::translate_zeros_like>},
{"aten::zeros", op::translate_zeros}, {"aten::zeros", op::translate_zeros},

View File

@ -20,12 +20,12 @@ namespace pytorch {
using namespace ov::op; using namespace ov::op;
ov::Output<ov::Node> quantize(const NodeContext& context, Output<Node> quantize(const NodeContext& context,
std::shared_ptr<ov::Node> input, const Output<Node>& input,
std::shared_ptr<ov::Node> scale, const Output<Node>& scale,
std::shared_ptr<ov::Node> zero_point, const Output<Node>& zero_point,
std::shared_ptr<ov::Node> axis, const Output<Node>& axis,
ov::element::Type dtype, element::Type dtype,
QuantizedPtNodeType quantization_type) { QuantizedPtNodeType quantization_type) {
if (quantization_type == QuantizedPtNodeType::QUANTIZE_PER_TENSOR) { if (quantization_type == QuantizedPtNodeType::QUANTIZE_PER_TENSOR) {
const auto input_convert = context.mark_node(std::make_shared<v0::Convert>(input, element::f32)); const auto input_convert = context.mark_node(std::make_shared<v0::Convert>(input, element::f32));
@ -63,7 +63,7 @@ ov::Output<ov::Node> quantize(const NodeContext& context,
zero_point_convert, zero_point_convert,
dtype)); dtype));
} else if (quantization_type == QuantizedPtNodeType::QUANTIZE_PER_CHANNEL) { } else if (quantization_type == QuantizedPtNodeType::QUANTIZE_PER_CHANNEL) {
FRONT_END_OP_CONVERSION_CHECK(axis, "Axis cannot be null for quantize_per_channel."); FRONT_END_OP_CONVERSION_CHECK(axis.get_node(), "Axis cannot be null for quantize_per_channel.");
const auto input_convert = context.mark_node(std::make_shared<v0::Convert>(input, element::f32)); const auto input_convert = context.mark_node(std::make_shared<v0::Convert>(input, element::f32));
const auto scales_convert = context.mark_node(std::make_shared<v0::Convert>(scale, element::f32)); const auto scales_convert = context.mark_node(std::make_shared<v0::Convert>(scale, element::f32));
const auto zero_points_convert = context.mark_node(std::make_shared<v0::Convert>(zero_point, element::f32)); const auto zero_points_convert = context.mark_node(std::make_shared<v0::Convert>(zero_point, element::f32));
@ -119,46 +119,19 @@ ov::Output<ov::Node> quantize(const NodeContext& context,
FRONT_END_OP_CONVERSION_CHECK(false, "Got unknown quantization method in quantize."); FRONT_END_OP_CONVERSION_CHECK(false, "Got unknown quantization method in quantize.");
} }
// ======================================================== Output<Node> quantize(const NodeContext& context,
const Output<Node>& input,
ov::Output<ov::Node> quantize(const NodeContext& context, const Output<Node>& scale,
ov::Output<ov::Node> input, const Output<Node>& zero_point,
ov::Output<ov::Node> scale, element::Type dtype,
ov::Output<ov::Node> zero_point,
ov::element::Type dtype,
QuantizedPtNodeType quantization_type) { QuantizedPtNodeType quantization_type) {
return quantize(context, return quantize(context, input, scale, zero_point, Output<Node>(), dtype, quantization_type);
input.get_node_shared_ptr(),
scale.get_node_shared_ptr(),
zero_point.get_node_shared_ptr(),
nullptr,
dtype,
quantization_type);
} }
ov::Output<ov::Node> quantize(const NodeContext& context, Output<Node> quantize(const NodeContext& context, const Output<Node>& input, const Output<Node>& quantized_node) {
ov::Output<ov::Node> input, if (const auto quantized_pt_node = cast_quantized_fw_node(quantized_node.get_node_shared_ptr())) {
ov::Output<ov::Node> scale,
ov::Output<ov::Node> zero_point,
ov::Output<ov::Node> axis,
ov::element::Type dtype,
QuantizedPtNodeType quantization_type) {
return quantize(context, return quantize(context,
input.get_node_shared_ptr(), input,
scale.get_node_shared_ptr(),
zero_point.get_node_shared_ptr(),
axis.get_node_shared_ptr(),
dtype,
quantization_type);
}
ov::Output<ov::Node> quantize(const NodeContext& context,
ov::Output<ov::Node> input,
ov::Output<ov::Node> quantized_node) {
std::shared_ptr<QuantizedPtNode> quantized_pt_node;
if ((quantized_pt_node = cast_quantized_fw_node(quantized_node.get_node_shared_ptr()))) {
return quantize(context,
input.get_node_shared_ptr(),
quantized_pt_node->get_scale(), quantized_pt_node->get_scale(),
quantized_pt_node->get_zero_point(), quantized_pt_node->get_zero_point(),
quantized_pt_node->get_axis(), quantized_pt_node->get_axis(),
@ -168,17 +141,16 @@ ov::Output<ov::Node> quantize(const NodeContext& context,
FRONT_END_OP_CONVERSION_CHECK(false, "Failed to convert a node to QuantizedPtNode"); FRONT_END_OP_CONVERSION_CHECK(false, "Failed to convert a node to QuantizedPtNode");
} }
ov::Output<ov::Node> quantize(const NodeContext& context, Output<Node> quantize(const NodeContext& context,
ov::Output<ov::Node> input, const Output<Node>& input,
ov::Output<ov::Node> scale, const Output<Node>& scale,
ov::Output<ov::Node> zero_point, const Output<Node>& zero_point,
ov::Output<ov::Node> quantized_node) { const Output<Node>& quantized_node) {
std::shared_ptr<QuantizedPtNode> quantized_pt_node; if (const auto quantized_pt_node = cast_quantized_fw_node(quantized_node.get_node_shared_ptr())) {
if ((quantized_pt_node = cast_quantized_fw_node(quantized_node.get_node_shared_ptr()))) {
return quantize(context, return quantize(context,
input.get_node_shared_ptr(), input,
scale.get_node_shared_ptr(), scale,
zero_point.get_node_shared_ptr(), zero_point,
quantized_pt_node->get_axis(), quantized_pt_node->get_axis(),
quantized_pt_node->get_dtype(), quantized_pt_node->get_dtype(),
quantized_pt_node->get_type()); quantized_pt_node->get_type());
@ -186,7 +158,7 @@ ov::Output<ov::Node> quantize(const NodeContext& context,
FRONT_END_OP_CONVERSION_CHECK(false, "Failed to convert a node to QuantizedPtNode"); FRONT_END_OP_CONVERSION_CHECK(false, "Failed to convert a node to QuantizedPtNode");
} }
std::shared_ptr<QuantizedPtNode> cast_quantized_fw_node(ov::Output<Node> node) { std::shared_ptr<QuantizedPtNode> cast_quantized_fw_node(Output<Node> node) {
auto quant_node = std::dynamic_pointer_cast<QuantizedPtNode>(node.get_node_shared_ptr()); auto quant_node = std::dynamic_pointer_cast<QuantizedPtNode>(node.get_node_shared_ptr());
if (!quant_node) { if (!quant_node) {
return nullptr; return nullptr;
@ -198,7 +170,7 @@ std::shared_ptr<QuantizedPtNode> cast_quantized_fw_node(ov::Output<Node> node) {
return quant_node; return quant_node;
} }
std::shared_ptr<QuantizedPtNode> cast_quantized_fw_node(ov::Output<Node> node, const std::string& type) { std::shared_ptr<QuantizedPtNode> cast_quantized_fw_node(Output<Node> node, const std::string& type) {
auto quant_node = std::dynamic_pointer_cast<QuantizedPtNode>(node.get_node_shared_ptr()); auto quant_node = std::dynamic_pointer_cast<QuantizedPtNode>(node.get_node_shared_ptr());
if (!quant_node) { if (!quant_node) {
return nullptr; return nullptr;

View File

@ -15,22 +15,19 @@ enum QuantizedPtNodeType { QUANTIZE_PER_TENSOR, QUANTIZE_PER_CHANNEL };
class QuantizedPtNode : public PtFrameworkNode { class QuantizedPtNode : public PtFrameworkNode {
public: public:
OPENVINO_OP("QuantizedPtNode", "util", ::ov::frontend::pytorch::PtFrameworkNode); OPENVINO_OP("QuantizedPtNode", "util", PtFrameworkNode);
static constexpr const char* quantized_node_type_key = "QuantizedPtTypeName"; static constexpr const char* quantized_node_type_key = "QuantizedPtTypeName";
static constexpr const char* quantize_per_tensor = "quantize_per_tensor"; static constexpr const char* quantize_per_tensor = "quantize_per_tensor";
static constexpr const char* quantize_per_channel = "quantize_per_channel"; static constexpr const char* quantize_per_channel = "quantize_per_channel";
QuantizedPtNode(const QuantizedPtNodeType type, QuantizedPtNode(const QuantizedPtNodeType type,
const NodeContext& context, const NodeContext& context,
const ov::Output<ov::Node> input, const Output<Node> input,
const ov::Output<ov::Node> scale, const Output<Node> scale,
const ov::Output<ov::Node> zero_point, const Output<Node> zero_point,
element::Type& dtype) element::Type& dtype)
: PtFrameworkNode(context.get_decoder(), {input}, 1, false), : PtFrameworkNode(context.get_decoder(), {input, scale, zero_point}, 1, false),
type(type), type(type) {
scale(scale.get_node_shared_ptr()),
zero_point(zero_point.get_node_shared_ptr()),
axis(nullptr) {
ov::op::util::FrameworkNodeAttrs attrs = get_attrs(); ov::op::util::FrameworkNodeAttrs attrs = get_attrs();
if (type == QuantizedPtNodeType::QUANTIZE_PER_TENSOR) { if (type == QuantizedPtNodeType::QUANTIZE_PER_TENSOR) {
attrs[quantized_node_type_key] = quantize_per_tensor; attrs[quantized_node_type_key] = quantize_per_tensor;
@ -45,16 +42,13 @@ public:
QuantizedPtNode(const QuantizedPtNodeType type, QuantizedPtNode(const QuantizedPtNodeType type,
const NodeContext& context, const NodeContext& context,
const ov::Output<ov::Node> input, const Output<Node> input,
const ov::Output<ov::Node> scale, const Output<Node> scale,
const ov::Output<ov::Node> zero_point, const Output<Node> zero_point,
const ov::Output<ov::Node> axis, const Output<Node> axis,
element::Type& dtype) element::Type& dtype)
: PtFrameworkNode(context.get_decoder(), {input}, 1, false), : PtFrameworkNode(context.get_decoder(), {input, scale, zero_point, axis}, 1, false),
type(type), type(type) {
scale(scale.get_node_shared_ptr()),
zero_point(zero_point.get_node_shared_ptr()),
axis(axis.get_node_shared_ptr()) {
ov::op::util::FrameworkNodeAttrs attrs = get_attrs(); ov::op::util::FrameworkNodeAttrs attrs = get_attrs();
if (type == QuantizedPtNodeType::QUANTIZE_PER_TENSOR) { if (type == QuantizedPtNodeType::QUANTIZE_PER_TENSOR) {
attrs[quantized_node_type_key] = quantize_per_tensor; attrs[quantized_node_type_key] = quantize_per_tensor;
@ -67,66 +61,86 @@ public:
this->dtype = dtype; this->dtype = dtype;
} }
const std::shared_ptr<ov::Node> get_scale() { const Output<Node> get_scale() const {
return scale; return input_value(1);
} }
const std::shared_ptr<ov::Node> get_zero_point() { const Output<Node> get_zero_point() const {
return zero_point; return input_value(2);
} }
const std::shared_ptr<ov::Node> get_axis() { const Output<Node> get_axis() const {
return axis; if (inputs().size() < 4) {
return Output<Node>();
} }
const QuantizedPtNodeType get_type() { return input_value(3);
}
const QuantizedPtNodeType get_type() const {
return type; return type;
} }
const element::Type get_dtype() { const element::Type get_dtype() const {
return dtype; return dtype;
} }
private: private:
const QuantizedPtNodeType type; const QuantizedPtNodeType type;
std::shared_ptr<ov::Node> scale;
std::shared_ptr<ov::Node> zero_point;
std::shared_ptr<ov::Node> axis;
element::Type dtype; element::Type dtype;
}; };
/** /**
* Quantizes input node with the given parameters. Returns a shared pointer to the new QuantizedPtNode. * Quantizes input node with the given parameters. Returns a shared pointer to the new QuantizedPtNode.
*/ */
ov::Output<ov::Node> quantize(const NodeContext& context, Output<Node> quantize(const NodeContext& context,
ov::Output<ov::Node> input, const Output<Node>& input,
ov::Output<ov::Node> scale, const Output<Node>& scale,
ov::Output<ov::Node> zero_point, const Output<Node>& zero_point,
ov::element::Type dtype, element::Type dtype,
QuantizedPtNodeType quantization_type); QuantizedPtNodeType quantization_type);
ov::Output<ov::Node> quantize(const NodeContext& context, Output<Node> quantize(const NodeContext& context,
ov::Output<ov::Node> input, const Output<Node>& input,
ov::Output<ov::Node> scale, const Output<Node>& scale,
ov::Output<ov::Node> zero_point, const Output<Node>& zero_point,
ov::Output<ov::Node> axis, const Output<Node>& axis,
ov::element::Type dtype, element::Type dtype,
QuantizedPtNodeType quantization_type); QuantizedPtNodeType quantization_type);
/** /**
* Quantizes input node like the quantized node. Returns a shared pointer to the new QuantizedPtNode. * Quantizes input node like the quantized node. Returns a shared pointer to the new QuantizedPtNode.
*/ */
ov::Output<ov::Node> quantize(const NodeContext& context, Output<Node> quantize(const NodeContext& context, Output<Node> input, Output<Node> quantized_node);
ov::Output<ov::Node> input,
ov::Output<ov::Node> quantized_node);
/** /**
* Quantizes input node like the quantized node, with new scale and zero_point parameters. Returns a shared pointer to * Quantizes input node like the quantized node, with new scale and zero_point parameters. Returns a shared pointer to
* the new QuantizedPtNode. * the new QuantizedPtNode.
*/ */
ov::Output<ov::Node> quantize(const NodeContext& context, Output<Node> quantize(const NodeContext& context,
ov::Output<ov::Node> input, const Output<Node>& input,
ov::Output<ov::Node> scale, const Output<Node>& scale,
ov::Output<ov::Node> zero_point, const Output<Node>& zero_point,
ov::Output<ov::Node> quantized_node); const Output<Node>& quantized_node);
std::shared_ptr<QuantizedPtNode> cast_quantized_fw_node(Output<Node> node);
std::shared_ptr<QuantizedPtNode> cast_quantized_fw_node(Output<Node> node, const std::string& type);
namespace op {
/**
* Modifies conversion function to support quantized case. When input is quantized it is processed as quantized op.
*/
template <OutputVector (*T)(const NodeContext&), size_t in_idx = 0, size_t out_idx = 0>
OutputVector quantizable_op(const NodeContext& context) {
auto translation_res = T(context);
FRONT_END_OP_CONVERSION_CHECK(translation_res.size() > out_idx, "Not enough outputs to apply quantization.");
if (const auto quantized_pt_node = cast_quantized_fw_node(context.get_input(in_idx).get_node_shared_ptr())) {
return {quantize(context,
translation_res[out_idx],
quantized_pt_node->get_scale(),
quantized_pt_node->get_zero_point(),
quantized_pt_node->get_axis(),
quantized_pt_node->get_dtype(),
quantized_pt_node->get_type())};
}
return translation_res;
}
} // namespace op
std::shared_ptr<QuantizedPtNode> cast_quantized_fw_node(ov::Output<ov::Node> node);
std::shared_ptr<QuantizedPtNode> cast_quantized_fw_node(ov::Output<ov::Node> node, const std::string& type);
} // namespace pytorch } // namespace pytorch
} // namespace frontend } // namespace frontend
} // namespace ov } // namespace ov

View File

@ -6,6 +6,7 @@ import torch
import numpy as np import numpy as np
from pytorch_layer_test_class import PytorchLayerTest from pytorch_layer_test_class import PytorchLayerTest
class TestQuantizedLinear(PytorchLayerTest): class TestQuantizedLinear(PytorchLayerTest):
def _prepare_input(self, input_shape=(2, 2)): def _prepare_input(self, input_shape=(2, 2)):
return (np.random.randn(*input_shape).astype(np.float32),) return (np.random.randn(*input_shape).astype(np.float32),)
@ -16,10 +17,12 @@ class TestQuantizedLinear(PytorchLayerTest):
def __init__(self, weight_shape, is_bias, scale, zero_point): def __init__(self, weight_shape, is_bias, scale, zero_point):
super(aten_quantized_linear, self).__init__() super(aten_quantized_linear, self).__init__()
if is_bias: if is_bias:
self.linear = torch.ao.nn.quantized.Linear(weight_shape[-1], weight_shape[0], True) self.linear = torch.ao.nn.quantized.Linear(
weight_shape[-1], weight_shape[0], True)
torch.nn.init.normal_(self.linear.bias()) torch.nn.init.normal_(self.linear.bias())
else: else:
self.linear = torch.ao.nn.quantized.Linear(weight_shape[-1], weight_shape[0], False) self.linear = torch.ao.nn.quantized.Linear(
weight_shape[-1], weight_shape[0], False)
self.linear.scale = float(scale) self.linear.scale = float(scale)
self.linear.zero_point = int(zero_point) self.linear.zero_point = int(zero_point)
@ -31,6 +34,31 @@ class TestQuantizedLinear(PytorchLayerTest):
return aten_quantized_linear(weight_shape, is_bias, scale, zero_point), ref_net, "quantized::linear" return aten_quantized_linear(weight_shape, is_bias, scale, zero_point), ref_net, "quantized::linear"
def create_hardtanh_model(self, weight_shape, is_bias, scale, zero_point, inplace):
class aten_quantized_linear(torch.nn.Module):
def __init__(self, weight_shape, is_bias, scale, zero_point, inplace):
super(aten_quantized_linear, self).__init__()
self.hardtanh = torch.nn.Hardtanh(inplace=inplace)
if is_bias:
self.linear = torch.ao.nn.quantized.Linear(
weight_shape[-1], weight_shape[0], True)
torch.nn.init.normal_(self.linear.bias())
else:
self.linear = torch.ao.nn.quantized.Linear(
weight_shape[-1], weight_shape[0], False)
self.linear.scale = float(scale)
self.linear.zero_point = int(zero_point)
def forward(self, inp):
inp_q = torch.quantize_per_tensor(inp, 1., 0, torch.quint8)
inp_q = self.hardtanh(inp_q)
return torch.dequantize(self.linear(inp_q))
ref_net = None
return aten_quantized_linear(weight_shape, is_bias, scale, zero_point, inplace), ref_net, ["quantized::linear", "aten::hardtanh_" if inplace else "aten::hardtanh"]
@pytest.mark.parametrize("params", [ @pytest.mark.parametrize("params", [
{'input_shape': [3, 9], 'weight_shape': [10, 9]}, {'input_shape': [3, 9], 'weight_shape': [10, 9]},
{'input_shape': [3, 9], 'weight_shape': [9]}, {'input_shape': [3, 9], 'weight_shape': [9]},
@ -51,3 +79,11 @@ class TestQuantizedLinear(PytorchLayerTest):
bias = params.get("bias", False) bias = params.get("bias", False)
self._test(*self.create_model(weight_shape, bias, scale, zero_point), ie_device, precision, ir_version, self._test(*self.create_model(weight_shape, bias, scale, zero_point), ie_device, precision, ir_version,
kwargs_to_prepare_input={"input_shape": input_shape}, trace_model=trace, freeze_model=False) kwargs_to_prepare_input={"input_shape": input_shape}, trace_model=trace, freeze_model=False)
@pytest.mark.parametrize("trace", [True, False])
@pytest.mark.parametrize("inplace", [True, False])
@pytest.mark.nightly
@pytest.mark.precommit
def test_quantized_hardtanh_linear(self, trace, inplace, ie_device, precision, ir_version):
self._test(*self.create_hardtanh_model([10, 9], True, 1, 0.3, inplace), ie_device, precision, ir_version,
kwargs_to_prepare_input={"input_shape": [2, 3, 9]}, trace_model=trace, freeze_model=False)