[PT FE] Fix conversion for aten::quantize_per_channel (#18646)
* Support GetAttr with packed params * Apply suggestions from code review * [PT FE] Add quantized types as normal types to decoder * [PT FE] Add decoder dequantize, add dtypes to quantize * [PT FE] Add dequantize example * [PT FE] Implement replacer for quantized nodes * [PT FE] Register replacer for quantize/dequantize * [PT FE] Remove unwanted junk from previous version * [PT FE] Fix building mistakes for frontend * [PT FE] Clang fix * [PT FE] Ease of use upgrade to quantize funcs * [PT FE] Clang format * [PT FE] Introduce new version of quantize/dequantize * [PT FE] Remove unwanted files from new version * [PT FE] Fix style * [PT FE] Add QuantizedPtNode replacer, fix accuracy error * [PT FE] Add improved version of quantize/dequantize with shared_ptrs * [PT FE] Fix utils shared ptr reference error * [PT FE] Quantize now takes correct input for operations * [PT FE] Upgrade quantize method * [PT FE] Add BFS for dequantize, add quantize_per_channel * [PT FE] Add missing replacer to frontend, improve tests * [PT FE] Rename replacer -> remover, remove unwanted header files * [PT FE] Change function declarations to return ov::Output instead of shared ptr * [PT FE] Add missing context mark node * [PT FE] Remove unknown modifications to ie_c_api * [PT FE] Remove fp16 support, turn off int32 tests * [PT FE] Clang format * [PT FE] Fix quantize_per_tensor * [PT FE] Minor fixes from review * [PT FE] Remove dequantize, remove helpers, replacer now removes nodes instead * [PT FE] Rename Replacer to Remover for dequantize nodes * [PT FE] Clang format * [PT FE] Move comments to header files, minor import fixes * [PT FE] Fix clang format * [PT FE] Fix dtype issue * [PT FE] Fix quantize_per_channel tests * Apply suggestions from code review Removing sporadic tests from precommit * Apply suggestions from code review * [PT FE] Fix conversion errors for aten::quantize_per_channel * [PT FE] Mark axis 2 as xfail --------- Co-authored-by: Maxim Vafin <maxim.vafin@intel.com>
This commit is contained in:
parent
60a8c2bc7a
commit
6d8dcb059d
@ -67,9 +67,9 @@ ov::Output<ov::Node> quantize(const NodeContext& context,
|
||||
const auto input_convert = context.mark_node(std::make_shared<v0::Convert>(input, element::f32));
|
||||
const auto scales_convert = context.mark_node(std::make_shared<v0::Convert>(scale, element::f32));
|
||||
const auto zero_points_convert = context.mark_node(std::make_shared<v0::Convert>(zero_point, element::f32));
|
||||
const auto axis_convert = context.mark_node(std::make_shared<v0::Convert>(zero_point, element::i32));
|
||||
const auto axis_convert = context.mark_node(std::make_shared<v0::Convert>(axis, element::i32));
|
||||
|
||||
const auto neg_one = context.mark_node(v0::Constant::create(element::i32, Shape{}, {-1}));
|
||||
const auto neg_one = context.mark_node(v0::Constant::create(element::i32, Shape{1}, {-1}));
|
||||
const auto zero = context.mark_node(v0::Constant::create(element::i32, Shape{}, {0}));
|
||||
const auto one = context.mark_node(v0::Constant::create(element::i32, Shape{}, {1}));
|
||||
|
||||
@ -88,8 +88,9 @@ ov::Output<ov::Node> quantize(const NodeContext& context,
|
||||
const auto out_low = context.mark_node(v0::Constant::create(element::f32, Shape{}, {out_low_i64}));
|
||||
const auto out_high = context.mark_node(v0::Constant::create(element::f32, Shape{}, {out_high_i64}));
|
||||
|
||||
const auto rank = std::get<1>(get_shape_rank(context, input_convert));
|
||||
const auto rank = std::get<1>(get_shape_rank(context, input_convert, false, element::i32));
|
||||
const auto ones = context.mark_node(std::make_shared<v3::Broadcast>(one, rank));
|
||||
|
||||
const auto normalized_axis = normalize_axis(context, axis_convert, input_convert);
|
||||
const auto new_shape =
|
||||
context.mark_node(std::make_shared<v3::ScatterElementsUpdate>(ones, normalized_axis, neg_one, zero));
|
||||
@ -105,13 +106,14 @@ ov::Output<ov::Node> quantize(const NodeContext& context,
|
||||
const auto bound_high = context.mark_node(std::make_shared<v1::Multiply>(scale_bc, out_high_normalized));
|
||||
|
||||
const auto quantized_input = context.mark_node(
|
||||
std::make_shared<v0::FakeQuantize>(input_convert, out_low, out_high, bound_low, bound_high, levels));
|
||||
std::make_shared<v0::FakeQuantize>(input_convert, bound_low, bound_high, bound_low, bound_high, levels));
|
||||
|
||||
return context.mark_node(std::make_shared<QuantizedPtNode>(quantization_type,
|
||||
context,
|
||||
quantized_input,
|
||||
scale_bc,
|
||||
zero_point_bc,
|
||||
axis_convert,
|
||||
dtype));
|
||||
}
|
||||
FRONT_END_OP_CONVERSION_CHECK(false, "Got unknown quantization method in quantize.");
|
||||
|
@ -58,17 +58,27 @@ class TestQuantizePerChannelDequantize(PytorchLayerTest):
|
||||
def _prepare_input(self):
|
||||
return (np.array(5.00 * np.random.rand(5, 6, 7, 8) + 5.00, dtype=np.float32),)
|
||||
|
||||
@pytest.mark.parametrize("scales", [
|
||||
np.array([1.0, 0.21, 0.62, 0.5], dtype=np.float32),
|
||||
np.array([0.21, 0.62, 0.5, 1.0], dtype=np.float32),
|
||||
np.array([0.62, 0.5, 1.0, 0.21], dtype=np.float32),
|
||||
np.array([0.5, 1.0, 0.21, 0.62], dtype=np.float32),
|
||||
])
|
||||
@pytest.mark.parametrize("zero_points", [
|
||||
np.array([0, 4, 2, 1], dtype=np.int32),
|
||||
np.array([0, 1, 2, 3], dtype=np.int32),
|
||||
np.array([0, 0, 0, 0], dtype=np.int32),
|
||||
np.array([-1, 0, -4, 5], dtype=np.int32),
|
||||
@pytest.mark.parametrize("scale, zero_point, axis", [
|
||||
[
|
||||
np.array([1.0, 0.21, 0.62, 0.5, 0.74], dtype=np.float32),
|
||||
np.array([0, -1, 2, -3, 4], dtype=np.int32),
|
||||
0
|
||||
],
|
||||
[
|
||||
np.array([1.0, 0.62, 0.74, 0.11, 0.89, 0.32], dtype=np.float32),
|
||||
np.array([0, 2, 4, -5, 6, -7], dtype=np.int32),
|
||||
1
|
||||
],
|
||||
pytest.param(
|
||||
np.array([1.0, 0.21, 0.62, 0.5, 0.11, 0.89, 0.32], dtype=np.float32),
|
||||
np.array([0, -1, 2, -3, 4, -5, -7], dtype=np.int32),
|
||||
2,
|
||||
marks=pytest.mark.skip(reason="Axis = 2 not supported in FakeQuantize.")),
|
||||
[
|
||||
np.array([1.0, 0.21, 0.62, 0.5, 0.74, 0.11, 0.89, 0.32], dtype=np.float32),
|
||||
np.array([0, -1, 2, -3, 4, -5, 6, -7], dtype=np.int32),
|
||||
3
|
||||
],
|
||||
])
|
||||
@pytest.mark.parametrize("dtype", [
|
||||
torch.quint8,
|
||||
@ -76,12 +86,10 @@ class TestQuantizePerChannelDequantize(PytorchLayerTest):
|
||||
pytest.param(torch.qint32, marks=pytest.mark.skip(
|
||||
reason="Not supported with FakeQuantize."))
|
||||
])
|
||||
@pytest.mark.parametrize("axis", [
|
||||
0, 1, 2, 3
|
||||
])
|
||||
@pytest.mark.nightly
|
||||
# @pytest.mark.precommit - conversion issue
|
||||
def test_quantize_per_channel_dequantize(self, scales, zero_points, dtype, axis, ie_device, precision, ir_version):
|
||||
if dtype == torch.quint8: zero_points = abs(zero_points)
|
||||
self._test(aten_quantize_per_channel_aten_dequantize(scales, zero_points, dtype, axis), None, ["aten::quantize_per_channel", "aten::dequantize"],
|
||||
# @pytest.mark.precommit - sporadic issue
|
||||
def test_quantize_per_channel_dequantize(self, scale, zero_point, axis, dtype, ie_device, precision, ir_version):
|
||||
np.random.shuffle(scale), np.random.shuffle(zero_point)
|
||||
if dtype == torch.quint8: zero_point = abs(zero_point)
|
||||
self._test(aten_quantize_per_channel_aten_dequantize(scale, zero_point, axis, dtype), None, ["aten::quantize_per_channel", "aten::dequantize"],
|
||||
ie_device, precision, ir_version, )
|
||||
|
Loading…
Reference in New Issue
Block a user