diff --git a/src/frontends/pytorch/src/utils_quantize.cpp b/src/frontends/pytorch/src/utils_quantize.cpp index 025a0b9f1ea..014388d6562 100644 --- a/src/frontends/pytorch/src/utils_quantize.cpp +++ b/src/frontends/pytorch/src/utils_quantize.cpp @@ -67,9 +67,9 @@ ov::Output quantize(const NodeContext& context, const auto input_convert = context.mark_node(std::make_shared(input, element::f32)); const auto scales_convert = context.mark_node(std::make_shared(scale, element::f32)); const auto zero_points_convert = context.mark_node(std::make_shared(zero_point, element::f32)); - const auto axis_convert = context.mark_node(std::make_shared(zero_point, element::i32)); + const auto axis_convert = context.mark_node(std::make_shared(axis, element::i32)); - const auto neg_one = context.mark_node(v0::Constant::create(element::i32, Shape{}, {-1})); + const auto neg_one = context.mark_node(v0::Constant::create(element::i32, Shape{1}, {-1})); const auto zero = context.mark_node(v0::Constant::create(element::i32, Shape{}, {0})); const auto one = context.mark_node(v0::Constant::create(element::i32, Shape{}, {1})); @@ -88,8 +88,9 @@ ov::Output quantize(const NodeContext& context, const auto out_low = context.mark_node(v0::Constant::create(element::f32, Shape{}, {out_low_i64})); const auto out_high = context.mark_node(v0::Constant::create(element::f32, Shape{}, {out_high_i64})); - const auto rank = std::get<1>(get_shape_rank(context, input_convert)); + const auto rank = std::get<1>(get_shape_rank(context, input_convert, false, element::i32)); const auto ones = context.mark_node(std::make_shared(one, rank)); + const auto normalized_axis = normalize_axis(context, axis_convert, input_convert); const auto new_shape = context.mark_node(std::make_shared(ones, normalized_axis, neg_one, zero)); @@ -105,13 +106,14 @@ ov::Output quantize(const NodeContext& context, const auto bound_high = context.mark_node(std::make_shared(scale_bc, out_high_normalized)); const auto quantized_input = context.mark_node( - std::make_shared(input_convert, out_low, out_high, bound_low, bound_high, levels)); + std::make_shared(input_convert, bound_low, bound_high, bound_low, bound_high, levels)); return context.mark_node(std::make_shared(quantization_type, context, quantized_input, scale_bc, zero_point_bc, + axis_convert, dtype)); } FRONT_END_OP_CONVERSION_CHECK(false, "Got unknown quantization method in quantize."); diff --git a/tests/layer_tests/pytorch_tests/test_quantize.py b/tests/layer_tests/pytorch_tests/test_quantize.py index e9ab802280d..ecd06792925 100644 --- a/tests/layer_tests/pytorch_tests/test_quantize.py +++ b/tests/layer_tests/pytorch_tests/test_quantize.py @@ -58,17 +58,27 @@ class TestQuantizePerChannelDequantize(PytorchLayerTest): def _prepare_input(self): return (np.array(5.00 * np.random.rand(5, 6, 7, 8) + 5.00, dtype=np.float32),) - @pytest.mark.parametrize("scales", [ - np.array([1.0, 0.21, 0.62, 0.5], dtype=np.float32), - np.array([0.21, 0.62, 0.5, 1.0], dtype=np.float32), - np.array([0.62, 0.5, 1.0, 0.21], dtype=np.float32), - np.array([0.5, 1.0, 0.21, 0.62], dtype=np.float32), - ]) - @pytest.mark.parametrize("zero_points", [ - np.array([0, 4, 2, 1], dtype=np.int32), - np.array([0, 1, 2, 3], dtype=np.int32), - np.array([0, 0, 0, 0], dtype=np.int32), - np.array([-1, 0, -4, 5], dtype=np.int32), + @pytest.mark.parametrize("scale, zero_point, axis", [ + [ + np.array([1.0, 0.21, 0.62, 0.5, 0.74], dtype=np.float32), + np.array([0, -1, 2, -3, 4], dtype=np.int32), + 0 + ], + [ + np.array([1.0, 0.62, 0.74, 0.11, 0.89, 0.32], dtype=np.float32), + np.array([0, 2, 4, -5, 6, -7], dtype=np.int32), + 1 + ], + pytest.param( + np.array([1.0, 0.21, 0.62, 0.5, 0.11, 0.89, 0.32], dtype=np.float32), + np.array([0, -1, 2, -3, 4, -5, -7], dtype=np.int32), + 2, + marks=pytest.mark.skip(reason="Axis = 2 not supported in FakeQuantize.")), + [ + np.array([1.0, 0.21, 0.62, 0.5, 0.74, 0.11, 0.89, 0.32], dtype=np.float32), + np.array([0, -1, 2, -3, 4, -5, 6, -7], dtype=np.int32), + 3 + ], ]) @pytest.mark.parametrize("dtype", [ torch.quint8, @@ -76,12 +86,10 @@ class TestQuantizePerChannelDequantize(PytorchLayerTest): pytest.param(torch.qint32, marks=pytest.mark.skip( reason="Not supported with FakeQuantize.")) ]) - @pytest.mark.parametrize("axis", [ - 0, 1, 2, 3 - ]) @pytest.mark.nightly - # @pytest.mark.precommit - conversion issue - def test_quantize_per_channel_dequantize(self, scales, zero_points, dtype, axis, ie_device, precision, ir_version): - if dtype == torch.quint8: zero_points = abs(zero_points) - self._test(aten_quantize_per_channel_aten_dequantize(scales, zero_points, dtype, axis), None, ["aten::quantize_per_channel", "aten::dequantize"], + # @pytest.mark.precommit - sporadic issue + def test_quantize_per_channel_dequantize(self, scale, zero_point, axis, dtype, ie_device, precision, ir_version): + np.random.shuffle(scale), np.random.shuffle(zero_point) + if dtype == torch.quint8: zero_point = abs(zero_point) + self._test(aten_quantize_per_channel_aten_dequantize(scale, zero_point, axis, dtype), None, ["aten::quantize_per_channel", "aten::dequantize"], ie_device, precision, ir_version, )