[PT FE] Add helper for regular ops on quantized values (#18692)
* Add helper for regular ops * Update op_table.cpp * Support ops with more then 1 output * Uncheck ops that return integer/boolean type
This commit is contained in:
parent
7fc1fd155d
commit
f96f021920
@ -6,6 +6,7 @@
|
||||
|
||||
#include "openvino/opsets/opset10.hpp"
|
||||
#include "utils.hpp"
|
||||
#include "utils_quantize.hpp"
|
||||
|
||||
namespace ov {
|
||||
namespace frontend {
|
||||
@ -190,9 +191,9 @@ const std::map<std::string, CreatorFunction> get_supported_ops() {
|
||||
{"aten::acos_", op::inplace_op<op::translate_1to1_match_1_inputs<opset10::Acos>>},
|
||||
{"aten::acosh", op::translate_1to1_match_1_inputs_with_fp32_type_alignment<opset10::Acosh>},
|
||||
{"aten::acosh_", op::inplace_op<op::translate_1to1_match_1_inputs<opset10::Acosh>>},
|
||||
{"aten::adaptive_avg_pool2d", op::translate_1to1_match_2_inputs<opset10::AdaptiveAvgPool>},
|
||||
{"aten::adaptive_avg_pool3d", op::translate_adaptive_avg_pool3d},
|
||||
{"aten::adaptive_max_pool2d", op::translate_adaptive_max_pool2d},
|
||||
{"aten::adaptive_avg_pool2d", op::quantizable_op<op::translate_1to1_match_2_inputs<opset10::AdaptiveAvgPool>>},
|
||||
{"aten::adaptive_avg_pool3d", op::quantizable_op<op::translate_adaptive_avg_pool3d>},
|
||||
{"aten::adaptive_max_pool2d", op::quantizable_op<op::translate_adaptive_max_pool2d>},
|
||||
{"aten::add", op::translate_add},
|
||||
{"aten::add_", op::inplace_op<op::translate_add>},
|
||||
{"aten::addcmul", op::translate_addcmul},
|
||||
@ -211,9 +212,9 @@ const std::map<std::string, CreatorFunction> get_supported_ops() {
|
||||
{"aten::atan_", op::inplace_op<op::translate_1to1_match_1_inputs<opset10::Atan>>},
|
||||
{"aten::atanh", op::translate_1to1_match_1_inputs_with_fp32_type_alignment<opset10::Atanh>},
|
||||
{"aten::atanh_", op::inplace_op<op::translate_1to1_match_1_inputs<opset10::Atanh>>},
|
||||
{"aten::avg_pool1d", op::translate_avg_poolnd},
|
||||
{"aten::avg_pool2d", op::translate_avg_poolnd},
|
||||
{"aten::avg_pool3d", op::translate_avg_poolnd},
|
||||
{"aten::avg_pool1d", op::quantizable_op<op::translate_avg_poolnd>},
|
||||
{"aten::avg_pool2d", op::quantizable_op<op::translate_avg_poolnd>},
|
||||
{"aten::avg_pool3d", op::quantizable_op<op::translate_avg_poolnd>},
|
||||
{"aten::baddbmm", op::translate_addmm},
|
||||
{"aten::batch_norm", op::translate_batch_norm},
|
||||
{"aten::bitwise_not", op::translate_bitwise_not},
|
||||
@ -262,7 +263,7 @@ const std::map<std::string, CreatorFunction> get_supported_ops() {
|
||||
{"aten::fake_quantize_per_channel_affine", op::translate_fake_quantize_per_channel_affine},
|
||||
{"aten::fake_quantize_per_tensor_affine", op::translate_fake_quantize_per_tensor_affine},
|
||||
{"aten::fill_", op::inplace_op<op::translate_fill_>},
|
||||
{"aten::flatten", op::translate_flatten},
|
||||
{"aten::flatten", op::quantizable_op<op::translate_flatten>},
|
||||
{"aten::flip", op::translate_flip},
|
||||
{"aten::floor", op::translate_1to1_match_1_inputs<opset10::Floor>},
|
||||
{"aten::floor_", op::inplace_op<op::translate_1to1_match_1_inputs<opset10::Floor>>},
|
||||
@ -278,11 +279,11 @@ const std::map<std::string, CreatorFunction> get_supported_ops() {
|
||||
{"aten::grid_sampler", op::translate_grid_sampler},
|
||||
{"aten::group_norm", op::translate_group_norm},
|
||||
{"aten::gt", op::translate_1to1_match_2_inputs_align_types<opset10::Greater>},
|
||||
{"aten::hardsigmoid", op::translate_1to1_match_1_inputs<opset10::HSigmoid>},
|
||||
{"aten::hardswish", op::translate_1to1_match_1_inputs<opset10::HSwish>},
|
||||
{"aten::hardswish_", op::inplace_op<op::translate_1to1_match_1_inputs<opset10::HSwish>>},
|
||||
{"aten::hardtanh", op::translate_hardtanh},
|
||||
{"aten::hardtanh_", op::inplace_op<op::translate_hardtanh>},
|
||||
{"aten::hardsigmoid", op::quantizable_op<op::translate_1to1_match_1_inputs<opset10::HSigmoid>>},
|
||||
{"aten::hardswish", op::quantizable_op<op::translate_1to1_match_1_inputs<opset10::HSwish>>},
|
||||
{"aten::hardswish_", op::quantizable_op<op::inplace_op<op::translate_1to1_match_1_inputs<opset10::HSwish>>>},
|
||||
{"aten::hardtanh", op::quantizable_op<op::translate_hardtanh>},
|
||||
{"aten::hardtanh_", op::inplace_op<op::quantizable_op<op::translate_hardtanh>>},
|
||||
{"aten::im2col", op::translate_im2col},
|
||||
{"aten::index_put_", op::inplace_op<op::translate_index_put_>},
|
||||
{"aten::index_select", op::translate_index_select},
|
||||
@ -314,10 +315,10 @@ const std::map<std::string, CreatorFunction> get_supported_ops() {
|
||||
{"aten::masked_fill_", op::inplace_op<op::translate_masked_fill>},
|
||||
{"aten::matmul", op::translate_1to1_match_2_inputs<opset10::MatMul>},
|
||||
{"aten::max", op::translate_max},
|
||||
{"aten::max_pool1d", op::translate_max_poolnd},
|
||||
{"aten::max_pool2d", op::translate_max_poolnd},
|
||||
{"aten::max_pool3d", op::translate_max_poolnd},
|
||||
{"aten::mean", op::translate_mean},
|
||||
{"aten::max_pool1d", op::quantizable_op<op::translate_max_poolnd>},
|
||||
{"aten::max_pool2d", op::quantizable_op<op::translate_max_poolnd>},
|
||||
{"aten::max_pool3d", op::quantizable_op<op::translate_max_poolnd>},
|
||||
{"aten::mean", op::quantizable_op<op::translate_mean>},
|
||||
{"aten::meshgrid", op::translate_meshgrid},
|
||||
{"aten::min", op::translate_min},
|
||||
{"aten::mm", op::translate_1to1_match_2_inputs<opset10::MatMul>},
|
||||
@ -364,7 +365,7 @@ const std::map<std::string, CreatorFunction> get_supported_ops() {
|
||||
{"aten::scaled_dot_product_attention", op::translate_scaled_dot_product_attention},
|
||||
{"aten::scatter", op::translate_scatter},
|
||||
{"aten::scatter_", op::inplace_op<op::translate_scatter>},
|
||||
{"aten::select", op::translate_select},
|
||||
{"aten::select", op::quantizable_op<op::translate_select>},
|
||||
{"aten::selu", op::translate_selu},
|
||||
{"aten::selu_", op::inplace_op<op::translate_selu>},
|
||||
{"aten::sigmoid", op::translate_1to1_match_1_inputs_with_fp32_type_alignment<opset10::Sigmoid>},
|
||||
@ -377,12 +378,12 @@ const std::map<std::string, CreatorFunction> get_supported_ops() {
|
||||
{"aten::sinh", op::translate_1to1_match_1_inputs_with_fp32_type_alignment<opset10::Sinh>},
|
||||
{"aten::sinh_", op::inplace_op<op::translate_1to1_match_1_inputs<opset10::Sinh>>},
|
||||
{"aten::size", op::translate_size},
|
||||
{"aten::slice", op::translate_slice},
|
||||
{"aten::slice", op::quantizable_op<op::translate_slice>},
|
||||
{"aten::softmax", op::translate_softmax},
|
||||
{"aten::sort", op::translate_sort},
|
||||
{"aten::sqrt", op::translate_1to1_match_1_inputs_with_fp32_type_alignment<opset10::Sqrt>},
|
||||
{"aten::square", op::translate_square},
|
||||
{"aten::squeeze", op::translate_squeeze},
|
||||
{"aten::squeeze", op::quantizable_op<op::translate_squeeze>},
|
||||
{"aten::sub", op::translate_sub},
|
||||
{"aten::sub_", op::inplace_op<op::translate_sub>},
|
||||
{"aten::sum", op::translate_sum},
|
||||
@ -395,7 +396,7 @@ const std::map<std::string, CreatorFunction> get_supported_ops() {
|
||||
{"aten::tensor", op::translate_as_tensor},
|
||||
{"aten::to", op::translate_to},
|
||||
{"aten::topk", op::translate_topk},
|
||||
{"aten::transpose", op::translate_transpose},
|
||||
{"aten::transpose", op::quantizable_op<op::translate_transpose>},
|
||||
{"aten::tril", op::translate_tril},
|
||||
{"aten::tril_", op::inplace_op<op::translate_tril>},
|
||||
{"aten::triu", op::translate_triu},
|
||||
@ -404,8 +405,8 @@ const std::map<std::string, CreatorFunction> get_supported_ops() {
|
||||
op::translate_1to1_match_2_inputs<opset10::ConvertLike>}, // TODO: overflow semantics is different
|
||||
{"aten::unflatten", op::translate_unflatten},
|
||||
{"aten::unfold", op::translate_unfold},
|
||||
{"aten::unsqueeze", op::translate_1to1_match_2_inputs<opset10::Unsqueeze>},
|
||||
{"aten::unsqueeze_", op::inplace_op<op::translate_1to1_match_2_inputs<opset10::Unsqueeze>>},
|
||||
{"aten::unsqueeze", op::quantizable_op<op::translate_1to1_match_2_inputs<opset10::Unsqueeze>>},
|
||||
{"aten::unsqueeze_", op::quantizable_op<op::inplace_op<op::translate_1to1_match_2_inputs<opset10::Unsqueeze>>>},
|
||||
{"aten::upsample_bicubic2d", op::translate_upsample_bicubic2d},
|
||||
{"aten::upsample_bilinear2d", op::translate_upsample_bilinear2d},
|
||||
{"aten::_upsample_bicubic2d_aa", op::translate_upsample_bicubic2d_aa},
|
||||
@ -417,7 +418,7 @@ const std::map<std::string, CreatorFunction> get_supported_ops() {
|
||||
{"aten::upsample_trilinear3d", op::translate_upsample_trilinear3d},
|
||||
{"aten::var", op::translate_var},
|
||||
{"aten::var_mean", op::translate_var_mean},
|
||||
{"aten::view", op::translate_reshape},
|
||||
{"aten::view", op::quantizable_op<op::translate_reshape>},
|
||||
{"aten::where", op::translate_where},
|
||||
{"aten::zero_", op::inplace_op<op::translate_zeros_like>},
|
||||
{"aten::zeros", op::translate_zeros},
|
||||
|
@ -20,13 +20,13 @@ namespace pytorch {
|
||||
|
||||
using namespace ov::op;
|
||||
|
||||
ov::Output<ov::Node> quantize(const NodeContext& context,
|
||||
std::shared_ptr<ov::Node> input,
|
||||
std::shared_ptr<ov::Node> scale,
|
||||
std::shared_ptr<ov::Node> zero_point,
|
||||
std::shared_ptr<ov::Node> axis,
|
||||
ov::element::Type dtype,
|
||||
QuantizedPtNodeType quantization_type) {
|
||||
Output<Node> quantize(const NodeContext& context,
|
||||
const Output<Node>& input,
|
||||
const Output<Node>& scale,
|
||||
const Output<Node>& zero_point,
|
||||
const Output<Node>& axis,
|
||||
element::Type dtype,
|
||||
QuantizedPtNodeType quantization_type) {
|
||||
if (quantization_type == QuantizedPtNodeType::QUANTIZE_PER_TENSOR) {
|
||||
const auto input_convert = context.mark_node(std::make_shared<v0::Convert>(input, element::f32));
|
||||
const auto scale_convert = context.mark_node(std::make_shared<v0::Convert>(scale, element::f32));
|
||||
@ -63,7 +63,7 @@ ov::Output<ov::Node> quantize(const NodeContext& context,
|
||||
zero_point_convert,
|
||||
dtype));
|
||||
} else if (quantization_type == QuantizedPtNodeType::QUANTIZE_PER_CHANNEL) {
|
||||
FRONT_END_OP_CONVERSION_CHECK(axis, "Axis cannot be null for quantize_per_channel.");
|
||||
FRONT_END_OP_CONVERSION_CHECK(axis.get_node(), "Axis cannot be null for quantize_per_channel.");
|
||||
const auto input_convert = context.mark_node(std::make_shared<v0::Convert>(input, element::f32));
|
||||
const auto scales_convert = context.mark_node(std::make_shared<v0::Convert>(scale, element::f32));
|
||||
const auto zero_points_convert = context.mark_node(std::make_shared<v0::Convert>(zero_point, element::f32));
|
||||
@ -119,46 +119,19 @@ ov::Output<ov::Node> quantize(const NodeContext& context,
|
||||
FRONT_END_OP_CONVERSION_CHECK(false, "Got unknown quantization method in quantize.");
|
||||
}
|
||||
|
||||
// ========================================================
|
||||
|
||||
ov::Output<ov::Node> quantize(const NodeContext& context,
|
||||
ov::Output<ov::Node> input,
|
||||
ov::Output<ov::Node> scale,
|
||||
ov::Output<ov::Node> zero_point,
|
||||
ov::element::Type dtype,
|
||||
QuantizedPtNodeType quantization_type) {
|
||||
return quantize(context,
|
||||
input.get_node_shared_ptr(),
|
||||
scale.get_node_shared_ptr(),
|
||||
zero_point.get_node_shared_ptr(),
|
||||
nullptr,
|
||||
dtype,
|
||||
quantization_type);
|
||||
Output<Node> quantize(const NodeContext& context,
|
||||
const Output<Node>& input,
|
||||
const Output<Node>& scale,
|
||||
const Output<Node>& zero_point,
|
||||
element::Type dtype,
|
||||
QuantizedPtNodeType quantization_type) {
|
||||
return quantize(context, input, scale, zero_point, Output<Node>(), dtype, quantization_type);
|
||||
}
|
||||
|
||||
ov::Output<ov::Node> quantize(const NodeContext& context,
|
||||
ov::Output<ov::Node> input,
|
||||
ov::Output<ov::Node> scale,
|
||||
ov::Output<ov::Node> zero_point,
|
||||
ov::Output<ov::Node> axis,
|
||||
ov::element::Type dtype,
|
||||
QuantizedPtNodeType quantization_type) {
|
||||
return quantize(context,
|
||||
input.get_node_shared_ptr(),
|
||||
scale.get_node_shared_ptr(),
|
||||
zero_point.get_node_shared_ptr(),
|
||||
axis.get_node_shared_ptr(),
|
||||
dtype,
|
||||
quantization_type);
|
||||
}
|
||||
|
||||
ov::Output<ov::Node> quantize(const NodeContext& context,
|
||||
ov::Output<ov::Node> input,
|
||||
ov::Output<ov::Node> quantized_node) {
|
||||
std::shared_ptr<QuantizedPtNode> quantized_pt_node;
|
||||
if ((quantized_pt_node = cast_quantized_fw_node(quantized_node.get_node_shared_ptr()))) {
|
||||
Output<Node> quantize(const NodeContext& context, const Output<Node>& input, const Output<Node>& quantized_node) {
|
||||
if (const auto quantized_pt_node = cast_quantized_fw_node(quantized_node.get_node_shared_ptr())) {
|
||||
return quantize(context,
|
||||
input.get_node_shared_ptr(),
|
||||
input,
|
||||
quantized_pt_node->get_scale(),
|
||||
quantized_pt_node->get_zero_point(),
|
||||
quantized_pt_node->get_axis(),
|
||||
@ -168,17 +141,16 @@ ov::Output<ov::Node> quantize(const NodeContext& context,
|
||||
FRONT_END_OP_CONVERSION_CHECK(false, "Failed to convert a node to QuantizedPtNode");
|
||||
}
|
||||
|
||||
ov::Output<ov::Node> quantize(const NodeContext& context,
|
||||
ov::Output<ov::Node> input,
|
||||
ov::Output<ov::Node> scale,
|
||||
ov::Output<ov::Node> zero_point,
|
||||
ov::Output<ov::Node> quantized_node) {
|
||||
std::shared_ptr<QuantizedPtNode> quantized_pt_node;
|
||||
if ((quantized_pt_node = cast_quantized_fw_node(quantized_node.get_node_shared_ptr()))) {
|
||||
Output<Node> quantize(const NodeContext& context,
|
||||
const Output<Node>& input,
|
||||
const Output<Node>& scale,
|
||||
const Output<Node>& zero_point,
|
||||
const Output<Node>& quantized_node) {
|
||||
if (const auto quantized_pt_node = cast_quantized_fw_node(quantized_node.get_node_shared_ptr())) {
|
||||
return quantize(context,
|
||||
input.get_node_shared_ptr(),
|
||||
scale.get_node_shared_ptr(),
|
||||
zero_point.get_node_shared_ptr(),
|
||||
input,
|
||||
scale,
|
||||
zero_point,
|
||||
quantized_pt_node->get_axis(),
|
||||
quantized_pt_node->get_dtype(),
|
||||
quantized_pt_node->get_type());
|
||||
@ -186,7 +158,7 @@ ov::Output<ov::Node> quantize(const NodeContext& context,
|
||||
FRONT_END_OP_CONVERSION_CHECK(false, "Failed to convert a node to QuantizedPtNode");
|
||||
}
|
||||
|
||||
std::shared_ptr<QuantizedPtNode> cast_quantized_fw_node(ov::Output<Node> node) {
|
||||
std::shared_ptr<QuantizedPtNode> cast_quantized_fw_node(Output<Node> node) {
|
||||
auto quant_node = std::dynamic_pointer_cast<QuantizedPtNode>(node.get_node_shared_ptr());
|
||||
if (!quant_node) {
|
||||
return nullptr;
|
||||
@ -198,7 +170,7 @@ std::shared_ptr<QuantizedPtNode> cast_quantized_fw_node(ov::Output<Node> node) {
|
||||
return quant_node;
|
||||
}
|
||||
|
||||
std::shared_ptr<QuantizedPtNode> cast_quantized_fw_node(ov::Output<Node> node, const std::string& type) {
|
||||
std::shared_ptr<QuantizedPtNode> cast_quantized_fw_node(Output<Node> node, const std::string& type) {
|
||||
auto quant_node = std::dynamic_pointer_cast<QuantizedPtNode>(node.get_node_shared_ptr());
|
||||
if (!quant_node) {
|
||||
return nullptr;
|
||||
|
@ -15,22 +15,19 @@ enum QuantizedPtNodeType { QUANTIZE_PER_TENSOR, QUANTIZE_PER_CHANNEL };
|
||||
|
||||
class QuantizedPtNode : public PtFrameworkNode {
|
||||
public:
|
||||
OPENVINO_OP("QuantizedPtNode", "util", ::ov::frontend::pytorch::PtFrameworkNode);
|
||||
OPENVINO_OP("QuantizedPtNode", "util", PtFrameworkNode);
|
||||
static constexpr const char* quantized_node_type_key = "QuantizedPtTypeName";
|
||||
static constexpr const char* quantize_per_tensor = "quantize_per_tensor";
|
||||
static constexpr const char* quantize_per_channel = "quantize_per_channel";
|
||||
|
||||
QuantizedPtNode(const QuantizedPtNodeType type,
|
||||
const NodeContext& context,
|
||||
const ov::Output<ov::Node> input,
|
||||
const ov::Output<ov::Node> scale,
|
||||
const ov::Output<ov::Node> zero_point,
|
||||
const Output<Node> input,
|
||||
const Output<Node> scale,
|
||||
const Output<Node> zero_point,
|
||||
element::Type& dtype)
|
||||
: PtFrameworkNode(context.get_decoder(), {input}, 1, false),
|
||||
type(type),
|
||||
scale(scale.get_node_shared_ptr()),
|
||||
zero_point(zero_point.get_node_shared_ptr()),
|
||||
axis(nullptr) {
|
||||
: PtFrameworkNode(context.get_decoder(), {input, scale, zero_point}, 1, false),
|
||||
type(type) {
|
||||
ov::op::util::FrameworkNodeAttrs attrs = get_attrs();
|
||||
if (type == QuantizedPtNodeType::QUANTIZE_PER_TENSOR) {
|
||||
attrs[quantized_node_type_key] = quantize_per_tensor;
|
||||
@ -45,16 +42,13 @@ public:
|
||||
|
||||
QuantizedPtNode(const QuantizedPtNodeType type,
|
||||
const NodeContext& context,
|
||||
const ov::Output<ov::Node> input,
|
||||
const ov::Output<ov::Node> scale,
|
||||
const ov::Output<ov::Node> zero_point,
|
||||
const ov::Output<ov::Node> axis,
|
||||
const Output<Node> input,
|
||||
const Output<Node> scale,
|
||||
const Output<Node> zero_point,
|
||||
const Output<Node> axis,
|
||||
element::Type& dtype)
|
||||
: PtFrameworkNode(context.get_decoder(), {input}, 1, false),
|
||||
type(type),
|
||||
scale(scale.get_node_shared_ptr()),
|
||||
zero_point(zero_point.get_node_shared_ptr()),
|
||||
axis(axis.get_node_shared_ptr()) {
|
||||
: PtFrameworkNode(context.get_decoder(), {input, scale, zero_point, axis}, 1, false),
|
||||
type(type) {
|
||||
ov::op::util::FrameworkNodeAttrs attrs = get_attrs();
|
||||
if (type == QuantizedPtNodeType::QUANTIZE_PER_TENSOR) {
|
||||
attrs[quantized_node_type_key] = quantize_per_tensor;
|
||||
@ -67,66 +61,86 @@ public:
|
||||
this->dtype = dtype;
|
||||
}
|
||||
|
||||
const std::shared_ptr<ov::Node> get_scale() {
|
||||
return scale;
|
||||
const Output<Node> get_scale() const {
|
||||
return input_value(1);
|
||||
}
|
||||
const std::shared_ptr<ov::Node> get_zero_point() {
|
||||
return zero_point;
|
||||
const Output<Node> get_zero_point() const {
|
||||
return input_value(2);
|
||||
}
|
||||
const std::shared_ptr<ov::Node> get_axis() {
|
||||
return axis;
|
||||
const Output<Node> get_axis() const {
|
||||
if (inputs().size() < 4) {
|
||||
return Output<Node>();
|
||||
}
|
||||
return input_value(3);
|
||||
}
|
||||
const QuantizedPtNodeType get_type() {
|
||||
const QuantizedPtNodeType get_type() const {
|
||||
return type;
|
||||
}
|
||||
const element::Type get_dtype() {
|
||||
const element::Type get_dtype() const {
|
||||
return dtype;
|
||||
}
|
||||
|
||||
private:
|
||||
const QuantizedPtNodeType type;
|
||||
std::shared_ptr<ov::Node> scale;
|
||||
std::shared_ptr<ov::Node> zero_point;
|
||||
std::shared_ptr<ov::Node> axis;
|
||||
element::Type dtype;
|
||||
};
|
||||
|
||||
/**
|
||||
* Quantizes input node with the given parameters. Returns a shared pointer to the new QuantizedPtNode.
|
||||
*/
|
||||
ov::Output<ov::Node> quantize(const NodeContext& context,
|
||||
ov::Output<ov::Node> input,
|
||||
ov::Output<ov::Node> scale,
|
||||
ov::Output<ov::Node> zero_point,
|
||||
ov::element::Type dtype,
|
||||
QuantizedPtNodeType quantization_type);
|
||||
ov::Output<ov::Node> quantize(const NodeContext& context,
|
||||
ov::Output<ov::Node> input,
|
||||
ov::Output<ov::Node> scale,
|
||||
ov::Output<ov::Node> zero_point,
|
||||
ov::Output<ov::Node> axis,
|
||||
ov::element::Type dtype,
|
||||
QuantizedPtNodeType quantization_type);
|
||||
Output<Node> quantize(const NodeContext& context,
|
||||
const Output<Node>& input,
|
||||
const Output<Node>& scale,
|
||||
const Output<Node>& zero_point,
|
||||
element::Type dtype,
|
||||
QuantizedPtNodeType quantization_type);
|
||||
Output<Node> quantize(const NodeContext& context,
|
||||
const Output<Node>& input,
|
||||
const Output<Node>& scale,
|
||||
const Output<Node>& zero_point,
|
||||
const Output<Node>& axis,
|
||||
element::Type dtype,
|
||||
QuantizedPtNodeType quantization_type);
|
||||
|
||||
/**
|
||||
* Quantizes input node like the quantized node. Returns a shared pointer to the new QuantizedPtNode.
|
||||
*/
|
||||
ov::Output<ov::Node> quantize(const NodeContext& context,
|
||||
ov::Output<ov::Node> input,
|
||||
ov::Output<ov::Node> quantized_node);
|
||||
Output<Node> quantize(const NodeContext& context, Output<Node> input, Output<Node> quantized_node);
|
||||
|
||||
/**
|
||||
* Quantizes input node like the quantized node, with new scale and zero_point parameters. Returns a shared pointer to
|
||||
* the new QuantizedPtNode.
|
||||
*/
|
||||
ov::Output<ov::Node> quantize(const NodeContext& context,
|
||||
ov::Output<ov::Node> input,
|
||||
ov::Output<ov::Node> scale,
|
||||
ov::Output<ov::Node> zero_point,
|
||||
ov::Output<ov::Node> quantized_node);
|
||||
Output<Node> quantize(const NodeContext& context,
|
||||
const Output<Node>& input,
|
||||
const Output<Node>& scale,
|
||||
const Output<Node>& zero_point,
|
||||
const Output<Node>& quantized_node);
|
||||
|
||||
std::shared_ptr<QuantizedPtNode> cast_quantized_fw_node(Output<Node> node);
|
||||
std::shared_ptr<QuantizedPtNode> cast_quantized_fw_node(Output<Node> node, const std::string& type);
|
||||
|
||||
namespace op {
|
||||
/**
|
||||
* Modifies conversion function to support quantized case. When input is quantized it is processed as quantized op.
|
||||
*/
|
||||
template <OutputVector (*T)(const NodeContext&), size_t in_idx = 0, size_t out_idx = 0>
|
||||
OutputVector quantizable_op(const NodeContext& context) {
|
||||
auto translation_res = T(context);
|
||||
FRONT_END_OP_CONVERSION_CHECK(translation_res.size() > out_idx, "Not enough outputs to apply quantization.");
|
||||
if (const auto quantized_pt_node = cast_quantized_fw_node(context.get_input(in_idx).get_node_shared_ptr())) {
|
||||
return {quantize(context,
|
||||
translation_res[out_idx],
|
||||
quantized_pt_node->get_scale(),
|
||||
quantized_pt_node->get_zero_point(),
|
||||
quantized_pt_node->get_axis(),
|
||||
quantized_pt_node->get_dtype(),
|
||||
quantized_pt_node->get_type())};
|
||||
}
|
||||
return translation_res;
|
||||
}
|
||||
} // namespace op
|
||||
|
||||
std::shared_ptr<QuantizedPtNode> cast_quantized_fw_node(ov::Output<ov::Node> node);
|
||||
std::shared_ptr<QuantizedPtNode> cast_quantized_fw_node(ov::Output<ov::Node> node, const std::string& type);
|
||||
} // namespace pytorch
|
||||
} // namespace frontend
|
||||
} // namespace ov
|
||||
|
@ -6,6 +6,7 @@ import torch
|
||||
import numpy as np
|
||||
from pytorch_layer_test_class import PytorchLayerTest
|
||||
|
||||
|
||||
class TestQuantizedLinear(PytorchLayerTest):
|
||||
def _prepare_input(self, input_shape=(2, 2)):
|
||||
return (np.random.randn(*input_shape).astype(np.float32),)
|
||||
@ -16,10 +17,12 @@ class TestQuantizedLinear(PytorchLayerTest):
|
||||
def __init__(self, weight_shape, is_bias, scale, zero_point):
|
||||
super(aten_quantized_linear, self).__init__()
|
||||
if is_bias:
|
||||
self.linear = torch.ao.nn.quantized.Linear(weight_shape[-1], weight_shape[0], True)
|
||||
self.linear = torch.ao.nn.quantized.Linear(
|
||||
weight_shape[-1], weight_shape[0], True)
|
||||
torch.nn.init.normal_(self.linear.bias())
|
||||
else:
|
||||
self.linear = torch.ao.nn.quantized.Linear(weight_shape[-1], weight_shape[0], False)
|
||||
self.linear = torch.ao.nn.quantized.Linear(
|
||||
weight_shape[-1], weight_shape[0], False)
|
||||
self.linear.scale = float(scale)
|
||||
self.linear.zero_point = int(zero_point)
|
||||
|
||||
@ -31,6 +34,31 @@ class TestQuantizedLinear(PytorchLayerTest):
|
||||
|
||||
return aten_quantized_linear(weight_shape, is_bias, scale, zero_point), ref_net, "quantized::linear"
|
||||
|
||||
def create_hardtanh_model(self, weight_shape, is_bias, scale, zero_point, inplace):
|
||||
|
||||
class aten_quantized_linear(torch.nn.Module):
|
||||
def __init__(self, weight_shape, is_bias, scale, zero_point, inplace):
|
||||
super(aten_quantized_linear, self).__init__()
|
||||
self.hardtanh = torch.nn.Hardtanh(inplace=inplace)
|
||||
if is_bias:
|
||||
self.linear = torch.ao.nn.quantized.Linear(
|
||||
weight_shape[-1], weight_shape[0], True)
|
||||
torch.nn.init.normal_(self.linear.bias())
|
||||
else:
|
||||
self.linear = torch.ao.nn.quantized.Linear(
|
||||
weight_shape[-1], weight_shape[0], False)
|
||||
self.linear.scale = float(scale)
|
||||
self.linear.zero_point = int(zero_point)
|
||||
|
||||
def forward(self, inp):
|
||||
inp_q = torch.quantize_per_tensor(inp, 1., 0, torch.quint8)
|
||||
inp_q = self.hardtanh(inp_q)
|
||||
return torch.dequantize(self.linear(inp_q))
|
||||
|
||||
ref_net = None
|
||||
|
||||
return aten_quantized_linear(weight_shape, is_bias, scale, zero_point, inplace), ref_net, ["quantized::linear", "aten::hardtanh_" if inplace else "aten::hardtanh"]
|
||||
|
||||
@pytest.mark.parametrize("params", [
|
||||
{'input_shape': [3, 9], 'weight_shape': [10, 9]},
|
||||
{'input_shape': [3, 9], 'weight_shape': [9]},
|
||||
@ -51,3 +79,11 @@ class TestQuantizedLinear(PytorchLayerTest):
|
||||
bias = params.get("bias", False)
|
||||
self._test(*self.create_model(weight_shape, bias, scale, zero_point), ie_device, precision, ir_version,
|
||||
kwargs_to_prepare_input={"input_shape": input_shape}, trace_model=trace, freeze_model=False)
|
||||
|
||||
@pytest.mark.parametrize("trace", [True, False])
|
||||
@pytest.mark.parametrize("inplace", [True, False])
|
||||
@pytest.mark.nightly
|
||||
@pytest.mark.precommit
|
||||
def test_quantized_hardtanh_linear(self, trace, inplace, ie_device, precision, ir_version):
|
||||
self._test(*self.create_hardtanh_model([10, 9], True, 1, 0.3, inplace), ie_device, precision, ir_version,
|
||||
kwargs_to_prepare_input={"input_shape": [2, 3, 9]}, trace_model=trace, freeze_model=False)
|
||||
|
Loading…
Reference in New Issue
Block a user