diff --git a/src/bindings/python/tests/__init__.py b/src/bindings/python/tests/__init__.py index 2dead80f444..0ebb43f241c 100644 --- a/src/bindings/python/tests/__init__.py +++ b/src/bindings/python/tests/__init__.py @@ -114,7 +114,6 @@ xfail_issue_52463 = xfail_test(reason="test_operator_add_size1_singleton_broadca xfail_issue_58033 = xfail_test(reason="Einsum operation misses support for complex ellipsis equations") xfail_issue_58676 = xfail_test(reason="AssertionError: Not equal to tolerance rtol=0.001, atol=1e-07") xfail_issue_onnx_models_140 = xfail_test(reason="https://github.com/onnx/models/issues/140") -xfail_issue_54630 = xfail_test(reason="Gather with negative indices is not yet implemented on CPU") xfail_issue_63033 = xfail_test(reason="BatchNormalization: Training mode is not supported") xfail_issue_63036 = xfail_test(reason="Changes in ConvTranspose padding") @@ -128,3 +127,5 @@ xfail_issue_63137 = xfail_test(reason="Unsupported operations: OptionalHasElemen xfail_issue_63138 = xfail_test(reason="Missing ONNX Shape-15 support") xfail_issue_63643 = xfail_test(reason="RuntimeError: Unsupported operation of type: Convolution name") xfail_issue_68212 = xfail_test(reason="Unsupported reading model with bytes streams") + +xfail_issue_77668 = xfail_test(reason="Accuracy issue related to Gather-8.") diff --git a/src/bindings/python/tests/test_ngraph/test_gather.py b/src/bindings/python/tests/test_ngraph/test_gather.py index b33eacc9e57..8d9e66bae88 100644 --- a/src/bindings/python/tests/test_ngraph/test_gather.py +++ b/src/bindings/python/tests/test_ngraph/test_gather.py @@ -4,7 +4,6 @@ import openvino.runtime.opset8 as ov import numpy as np -from tests import xfail_issue_54630 from tests.test_ngraph.util import run_op_node @@ -55,7 +54,6 @@ def test_gather_batch_dims_1(): assert np.allclose(result, expected) -@xfail_issue_54630 def test_gather_negative_indices(): input_data = np.array( [1.0, 1.1, 1.2, 2.0, 2.1, 2.2, 3.0, 3.1, 3.2], np.float32 @@ -71,7 +69,6 @@ def test_gather_negative_indices(): assert np.allclose(result, expected) -@xfail_issue_54630 def test_gather_batch_dims_1_negative_indices(): input_data = np.array([[1, 2, 3, 4, 5], diff --git a/src/bindings/python/tests/test_onnx/test_backend.py b/src/bindings/python/tests/test_onnx/test_backend.py index 94c2b9506bd..d1babd0c996 100644 --- a/src/bindings/python/tests/test_onnx/test_backend.py +++ b/src/bindings/python/tests/test_onnx/test_backend.py @@ -114,11 +114,9 @@ tests_expected_to_fail = [ ( xfail_issue_39662, "OnnxBackendNodeModelTest.test_scatter_elements_with_negative_indices_cpu", - "OnnxBackendNodeModelTest.test_gather_negative_indices_cpu", ), ( xfail_issue_38091, - "OnnxBackendNodeModelTest.test_gather_negative_indices_cpu", "OnnxBackendNodeModelTest.test_dynamicquantizelinear_cpu", "OnnxBackendNodeModelTest.test_dynamicquantizelinear_expanded_cpu", ), diff --git a/src/bindings/python/tests/test_onnx/test_zoo_models.py b/src/bindings/python/tests/test_onnx/test_zoo_models.py index 380f732669d..3eb8eae6df9 100644 --- a/src/bindings/python/tests/test_onnx/test_zoo_models.py +++ b/src/bindings/python/tests/test_onnx/test_zoo_models.py @@ -22,7 +22,8 @@ from tests import ( xfail_issue_48190, xfail_issue_58676, xfail_issue_63643, - xfail_issue_onnx_models_140) + xfail_issue_onnx_models_140, + xfail_issue_77668) MODELS_ROOT_DIR = tests.MODEL_ZOO_DIR @@ -179,6 +180,8 @@ if len(zoo_models) > 0: (xfail_issue_48190, "test_onnx_model_zoo_text_machine_comprehension_roberta_model_roberta_base_11_roberta_base_11_roberta_base_11_cpu"), (xfail_issue_onnx_models_140, "test_onnx_model_zoo_vision_object_detection_segmentation_duc_model_ResNet101_DUC_7_ResNet101_DUC_HDC_ResNet101_DUC_HDC_cpu"), (xfail_issue_63643, "test_onnx_model_zoo_vision_object_detection_segmentation_ssd_mobilenetv1_model_ssd_mobilenet_v1_10_ssd_mobilenet_v1_ssd_mobilenet_v1_cpu"), + (xfail_issue_77668, "test_onnx_model_zoo_vision_object_detection_segmentation_faster_rcnn_model_FasterRCNN_10_faster_rcnn_R_50_FPN_1x_cpu"), + (xfail_issue_77668, "test_onnx_model_zoo_vision_object_detection_segmentation_mask_rcnn_model_MaskRCNN_10_mask_rcnn_R_50_FPN_1x_cpu"), # Model MSFT (xfail_issue_37973, "test_MSFT_opset7_tf_inception_v2_model_cpu"), @@ -193,6 +196,9 @@ if len(zoo_models) > 0: (xfail_issue_39669, "test_MSFT_opset9_cgan_cgan_cpu"), (xfail_issue_47495, "test_MSFT_opset10_BERT_Squad_bertsquad10_cpu"), (xfail_issue_63643, "test_MSFT_opset10_mlperf_ssd_mobilenet_300_ssd_mobilenet_v1_coco_2018_01_28_cpu"), + + (xfail_issue_77668, "test_MSFT_opset10_faster_rcnn_faster_rcnn_R_50_FPN_1x_cpu"), + (xfail_issue_77668, "test_MSFT_opset10_mask_rcnn_mask_rcnn_R_50_FPN_1x_cpu"), ] for test_case in import_xfail_list + execution_xfail_list: xfail, test_name = test_case diff --git a/src/bindings/python/tests_compatibility/__init__.py b/src/bindings/python/tests_compatibility/__init__.py index 072362a86dd..4b6162b67e9 100644 --- a/src/bindings/python/tests_compatibility/__init__.py +++ b/src/bindings/python/tests_compatibility/__init__.py @@ -124,7 +124,6 @@ xfail_issue_52463 = xfail_test(reason="test_operator_add_size1_singleton_broadca xfail_issue_58033 = xfail_test(reason="Einsum operation misses support for complex ellipsis equations") xfail_issue_58676 = xfail_test(reason="AssertionError: Not equal to tolerance rtol=0.001, atol=1e-07") xfail_issue_onnx_models_140 = xfail_test(reason="https://github.com/onnx/models/issues/140") -xfail_issue_54630 = xfail_test(reason="Gather with negative indices is not yet implemented on CPU") xfail_issue_63033 = xfail_test(reason="BatchNormalization: Training mode is not supported") xfail_issue_63036 = xfail_test(reason="Changes in ConvTranspose padding") @@ -137,3 +136,5 @@ xfail_issue_63136 = xfail_test(reason="Unsupported operation: CastLike") xfail_issue_63137 = xfail_test(reason="Unsupported operations: OptionalHasElement, OptionalGetElement") xfail_issue_63138 = xfail_test(reason="Missing ONNX Shape-15 support") xfail_issue_63643 = xfail_test(reason="RuntimeError: Unsupported operation of type: Convolution name") + +xfail_issue_77668 = xfail_test(reason="Accuracy issue related to Gather-8.") diff --git a/src/bindings/python/tests_compatibility/test_ngraph/test_gather.py b/src/bindings/python/tests_compatibility/test_ngraph/test_gather.py index 2b3e3687810..4cfbacd0d28 100644 --- a/src/bindings/python/tests_compatibility/test_ngraph/test_gather.py +++ b/src/bindings/python/tests_compatibility/test_ngraph/test_gather.py @@ -4,7 +4,6 @@ import ngraph as ng import numpy as np -from tests_compatibility import xfail_issue_54630 from tests_compatibility.test_ngraph.util import run_op_node @@ -55,7 +54,6 @@ def test_gather_batch_dims_1(): assert np.allclose(result, expected) -@xfail_issue_54630 def test_gather_negative_indices(): input_data = np.array( [1.0, 1.1, 1.2, 2.0, 2.1, 2.2, 3.0, 3.1, 3.2], np.float32 @@ -71,7 +69,6 @@ def test_gather_negative_indices(): assert np.allclose(result, expected) -@xfail_issue_54630 def test_gather_batch_dims_1_negative_indices(): input_data = np.array([[1, 2, 3, 4, 5], diff --git a/src/bindings/python/tests_compatibility/test_onnx/test_zoo_models.py b/src/bindings/python/tests_compatibility/test_onnx/test_zoo_models.py index b6da41403ec..5375bc4e3f3 100644 --- a/src/bindings/python/tests_compatibility/test_onnx/test_zoo_models.py +++ b/src/bindings/python/tests_compatibility/test_onnx/test_zoo_models.py @@ -23,7 +23,8 @@ from tests_compatibility import ( xfail_issue_48190, xfail_issue_58676, xfail_issue_63643, - xfail_issue_onnx_models_140) + xfail_issue_onnx_models_140, + xfail_issue_77668) MODELS_ROOT_DIR = tests_compatibility.MODEL_ZOO_DIR @@ -167,6 +168,7 @@ if len(zoo_models) > 0: (xfail_issue_48190, "test_onnx_model_zoo_text_machine_comprehension_roberta_model_roberta_base_11_roberta_base_11_roberta_base_11_cpu"), (xfail_issue_onnx_models_140, "test_onnx_model_zoo_vision_object_detection_segmentation_duc_model_ResNet101_DUC_7_ResNet101_DUC_HDC_ResNet101_DUC_HDC_cpu"), (xfail_issue_63643, "test_onnx_model_zoo_vision_object_detection_segmentation_ssd_mobilenetv1_model_ssd_mobilenet_v1_10_ssd_mobilenet_v1_ssd_mobilenet_v1_cpu"), + (xfail_issue_77668, "test_onnx_model_zoo_vision_object_detection_segmentation_faster_rcnn_model_FasterRCNN_10_faster_rcnn_R_50_FPN_1x_cpu"), # Model MSFT (xfail_issue_37973, "test_MSFT_opset7_tf_inception_v2_model_cpu"), @@ -183,6 +185,8 @@ if len(zoo_models) > 0: (xfail_issue_39669, "test_MSFT_opset9_cgan_cgan_cpu"), (xfail_issue_47495, "test_MSFT_opset10_BERT_Squad_bertsquad10_cpu"), (xfail_issue_63643, "test_MSFT_opset10_mlperf_ssd_mobilenet_300_ssd_mobilenet_v1_coco_2018_01_28_cpu"), + + (xfail_issue_77668, "test_MSFT_opset10_faster_rcnn_faster_rcnn_R_50_FPN_1x_cpu"), ] for test_case in import_xfail_list + execution_xfail_list: xfail, test_name = test_case diff --git a/src/common/transformations/src/transformations/common_optimizations/convert_nms_gather_path_to_unsigned.cpp b/src/common/transformations/src/transformations/common_optimizations/convert_nms_gather_path_to_unsigned.cpp index c3c11aacc8a..b45c9f185ab 100644 --- a/src/common/transformations/src/transformations/common_optimizations/convert_nms_gather_path_to_unsigned.cpp +++ b/src/common/transformations/src/transformations/common_optimizations/convert_nms_gather_path_to_unsigned.cpp @@ -77,6 +77,7 @@ class UpdateConvertGather: public pass::MatcherPass { auto indices = gather->input_value(1); if (!ov::has_nms_selected_indices(indices.get_node())) return false; + gather->get_rt_info()["dontReverseIndices"] = true; auto out_type = (indices.get_element_type() == element::i64 ? element::u64 : element::u32); auto existing_convert = dynamic_pointer_cast(indices.get_node_shared_ptr()); if (existing_convert && indices.get_target_inputs().size() == 1) { diff --git a/src/plugins/intel_cpu/src/mkldnn_plugin.cpp b/src/plugins/intel_cpu/src/mkldnn_plugin.cpp index 81e80212fe0..9645c02e6ce 100644 --- a/src/plugins/intel_cpu/src/mkldnn_plugin.cpp +++ b/src/plugins/intel_cpu/src/mkldnn_plugin.cpp @@ -377,6 +377,7 @@ static void TransformationUpToCPUSpecificOpSet(std::shared_ptr pass_config->disable(); pass_config->disable(); pass_config->disable(); + pass_config->disable(); pass_config->disable(); pass_config->disable(); pass_config->disable(); @@ -388,7 +389,6 @@ static void TransformationUpToCPUSpecificOpSet(std::shared_ptr pass_config->enable(); pass_config->enable(); pass_config->enable(); - pass_config->enable(); pass_config->enable(); if (useLpt) { diff --git a/src/plugins/intel_cpu/src/nodes/kernels/gather_uni_kernel.cpp b/src/plugins/intel_cpu/src/nodes/kernels/gather_uni_kernel.cpp new file mode 100644 index 00000000000..9847db55a03 --- /dev/null +++ b/src/plugins/intel_cpu/src/nodes/kernels/gather_uni_kernel.cpp @@ -0,0 +1,1028 @@ +// Copyright (C) 2022 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include "gather_uni_kernel.hpp" +#include + +using namespace dnnl::impl::cpu; + +namespace MKLDNNPlugin { + +const unsigned jitGatherKernelBase::shufMask8bitUni[16] = {0x0C080400, 0x80808080, 0x80808080, 0x80808080, 0x0C080400, 0x80808080, 0x80808080, 0x80808080, + 0x0C080400, 0x80808080, 0x80808080, 0x80808080, 0x0C080400, 0x80808080, 0x80808080, 0x80808080}; +const unsigned jitGatherKernelBase::permMask8bitA2[8] = {0, 4, 1, 5, 2, 6, 3, 7}; +const unsigned jitGatherKernelBase::permMask8bitA5[16] = {0, 4, 8, 12, 1, 5, 9, 13, 2, 6, 10, 14, 3, 7, 11, 15}; + +const unsigned jitGatherKernelBase::shufMask16bitUni[16] = {0x05040100, 0x0D0C0908, 0x80808080, 0x80808080, 0x05040100, 0x0D0C0908, 0x80808080, 0x80808080, + 0x05040100, 0x0D0C0908, 0x80808080, 0x80808080, 0x05040100, 0x0D0C0908, 0x80808080, 0x80808080}; +const unsigned jitGatherKernelBase::permMask16bitA2[8] = {0, 1, 4, 5, 2, 3, 6, 7}; +const unsigned jitGatherKernelBase::permMask16bitA5[16] = {0, 1, 4, 5, 8, 9, 12, 13, 2, 3, 6, 7, 10, 11, 14, 15}; + +const unsigned jitGatherKernelBase::incVec[16] = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}; + +#define GET_OFF(field) offsetof(gatherJitExecArgs, field) + +template +jitUniGatherKernel::jitUniGatherKernel(const jGatherConfParams& jcp) : + jitGatherKernelBase(jcp), x64::jit_generator() { + vlen = x64::cpu_isa_traits::vlen; + dataElPerVec = vlen / jcp.dataTypeSize; + idxElPerVec = vlen / indicesTypeSize; + if (jcp.dataTypeSize == 2) + dataTypeShift = 1; + else if (jcp.dataTypeSize == 4) + dataTypeShift = 2; + + if (isa == x64::avx2) { + permMask8bitUni = permMask8bitA2; + permMask16bitUni = permMask16bitA2; + } else if (isa == x64::avx512_common) { + permMask8bitUni = permMask8bitA5; + permMask16bitUni = permMask16bitA5; + } +} + +template +void jitUniGatherKernel::create_ker() { + auto code = x64::jit_generator::create_kernel(); + if (code != dnnl::impl::status::success) + IE_THROW() << "Could not create Gather kernel. Error code: " << std::to_string(code); + ker_ = (decltype(ker_))jit_ker(); +} + +template +void jitUniGatherKernel::generate() { + this->preamble(); + + mov(regSrc, ptr[regParams + GET_OFF(src)]); + mov(regDst, ptr[regParams + GET_OFF(dst)]); + mov(regIndices, ptr[regParams + GET_OFF(indices)]); + + mov(regWorkAmount, ptr[regParams + GET_OFF(workAmount)]); + + auto& vAux0 = vmmAuxContainer[0]; + auto& vAux1 = vmmAuxContainer[1]; + auto& xAux0 = xmmAuxContainer[0]; + auto& xAux1 = xmmAuxContainer[1]; + + uni_vpxor(vmmZeros, vmmZeros, vmmZeros); + mov(regAux1, ptr[regParams + GET_OFF(axisDim)]); + uni_vpbroadcastd(vmmAxisDim, ptr[regAux1]); + + if (!jcp.dynamicShapes) { + mov(regAux1, ptr[regParams + GET_OFF(specIndicesSize)]); + uni_vpbroadcastd(vmmSpecIdxSizeB, ptr[regAux1]); + uni_vpslld(vmmSpecIdxSizeB, vmmSpecIdxSizeB, idxTypeShift); // multiply by indices type size. + if (jcp.afterAxisSize == 1lu) { // Elementwise case. + uni_vmovd(reg32SpecIdxSizeB, xmmSpecIdxSizeB); + mov(regAux1, ptr[regParams + GET_OFF(axisAndAfterAxisSizeB)]); + uni_vpbroadcastd(vmmAxisAndAfterAxisSizeB, ptr[regAux1]); + + mov(regAux1, ptr[regParams + GET_OFF(specIdxB)]); + uni_vmovups(vmmSpecIdxB, ptr[regAux1]); + mov(regAux1, ptr[regParams + GET_OFF(idxBatchSumB)]); + uni_vmovups(vmmIdxBatchSumB, ptr[regAux1]); + mov(regAux1, ptr[regParams + GET_OFF(dataBeforeAxisSumB)]); + uni_vmovups(vmmSrcBeforeAxisSumB, ptr[regAux1]); + + mov(regAux1, ptr[regParams + GET_OFF(betweenBatchAndAxisSize)]); + mov(regBetweenBatchAndAxisSize, ptr[regAux1]); + mov(regBetweenBatchAndAxisIter, ptr[regParams + GET_OFF(betweenBatchAndAxisIter)]); + + if (jcp.specIdxSize < idxElPerVec) { // Short case. + mov(regAux1, ptr[regParams + GET_OFF(permIdxMask)]); + uni_vmovups(vmmPermIdxMask, ptr[regAux1]); + mov(regAux1, ptr[regParams + GET_OFF(beforeAxisDiff)]); + uni_vmovups(vmmBeforeAxDiffB, ptr[regAux1]); + if (jcp.dataTypeSize != 1) + uni_vpslld(vmmBeforeAxDiffB, vmmBeforeAxDiffB, dataTypeShift); // multiply by data type size + if (jcp.batchDims > 0lu) { + mov(regAux1, ptr[regParams + GET_OFF(srcAfterBatchSizeB)]); + uni_vpbroadcastd(vmmSrcAfterBatchSizeB, ptr[regAux1]); + } + + process(true, false); + } else { // Long case. + uni_vmovd(reg32IdxIter, xmmSpecIdxB); + fillVlenVector(); + + process(false, false); + } + } else { // Blocked case. + if (jcp.afterAxisSize <= idxElPerVec) { // Short case. + mov(regAux1, ptr[regParams + GET_OFF(afterAxIdxB)]); + uni_vmovups(vmmAfterAxisIdxB, ptr[regAux1]); + mov(regAux1, ptr[regParams + GET_OFF(afterAxisPermMask)]); + uni_vmovups(vmmAfterAxisPermMask, ptr[regAux1]); + mov(regAux1, ptr[regParams + GET_OFF(specIdxDiff)]); + uni_vmovups(vmmSpecIdxDiff, ptr[regAux1]); + mov(regAux1, ptr[regParams + GET_OFF(specIdxB)]); + uni_vmovups(vmmSpecIdxB, ptr[regAux1]); + mov(regAux1, ptr[regParams + GET_OFF(srcAfterBatchSizeB)]); + uni_vpbroadcastd(vmmSrcAfterBatchSizeB, ptr[regAux1]); + mov(regAux1, ptr[regParams + GET_OFF(afterAxisSize)]); + uni_vpbroadcastd(vmmAfterAxisSize, ptr[regAux1]); + + if (jcp.beforeAxisSize != 1lu) { + mov(regAux1, ptr[regParams + GET_OFF(dataBeforeAxisSumB)]); + uni_vmovups(vmmSrcBeforeAxisSumB, ptr[regAux1]); + mov(rSpecIdxAndAfterAxIterB, ptr[regParams + GET_OFF(specIdxAndAfterAxIterB)]); + mov(rSpecIdxAndAfterAxSizeB, ptr[regParams + GET_OFF(specIdxAndAfterAxSizeB)]); + if (jcp.specIdxSize * jcp.afterAxisSize < idxElPerVec) { + mov(regAux1, ptr[regParams + GET_OFF(beforeAxisDiff)]); + uni_vmovups(vmmBeforeAxDiffB, ptr[regAux1]); + } else { + mov(regAux1, ptr[regParams + GET_OFF(axisAndAfterAxisSizeB)]); + uni_vpbroadcastd(vmmAxisAndAfterAxisSizeB, ptr[regAux1]); + } + mov(regAux1, ptr[regParams + GET_OFF(beforeAxisPermMask)]); + uni_vmovups(vmmBeforeAxPermMask, ptr[regAux1]); + } + + process(true, true); + } else { // Long case. + IE_THROW() << "Gather kernel does not support static shape with after axis size greater than elements in vector."; + } + } + } else { // Dynamic shapes. + mov(regAux1, ptr[regParams + GET_OFF(start)]); + uni_vpbroadcastd(vmmSpecIdxB, ptr[regAux1]); + mov(regAux1, reinterpret_cast(incVec)); + uni_vpaddd(vmmSpecIdxB, vmmSpecIdxB, ptr[regAux1]); + vcvtdq2ps(vmmSpecIdxB, vmmSpecIdxB); + + // Formula: specIndices = (start % specIndicesSize) * idxTypeSize + mov(regAux1, ptr[regParams + GET_OFF(specIndicesSize)]); + uni_vpbroadcastd(vmmSpecIdxSizeB, ptr[regAux1]); + uni_vcvtdq2ps(vAux1, vmmSpecIdxSizeB); + uni_vdivps(vmmSrcBeforeAxisSumB, vmmSpecIdxB, vAux1); + uni_vroundps(vmmSrcBeforeAxisSumB, vmmSrcBeforeAxisSumB, 0x1); + uni_vfnmadd231ps(vmmSpecIdxB, vmmSrcBeforeAxisSumB, vAux1); + uni_vcvtps2dq(vmmSpecIdxB, vmmSpecIdxB); + uni_vpslld(vmmSpecIdxB, vmmSpecIdxB, idxTypeShift); // multiply by indices type size. + uni_vpslld(vmmSpecIdxSizeB, vmmSpecIdxSizeB, idxTypeShift); // multiply by indices type size. + uni_vmovd(reg32SpecIdxSizeB, xmmSpecIdxSizeB); + + mov(regAux1, ptr[regParams + GET_OFF(betweenBatchAndAxisSize)]); + uni_vpbroadcastd(vAux1, ptr[regAux1]); + uni_vmovd(reg32BetweenBatchAndAxisSize, xAux1); + uni_vcvtdq2ps(vAux1, vAux1); + uni_vdivps(vmmIdxBatchSumB, vmmSrcBeforeAxisSumB, vAux1); + uni_vroundps(vmmIdxBatchSumB, vmmIdxBatchSumB, 0x1); + uni_vfnmadd231ps(vmmSrcBeforeAxisSumB, vmmIdxBatchSumB, vAux1); + uni_vcvtps2dq(vmmSrcBeforeAxisSumB, vmmSrcBeforeAxisSumB); + uni_vmovd(reg32BetweenBatchAndAxisIter, xmmSrcBeforeAxisSum); + uni_vcvtps2dq(vmmIdxBatchSumB, vmmIdxBatchSumB); + + mov(regAux1, ptr[regParams + GET_OFF(axisAndAfterAxisSizeB)]); + uni_vpbroadcastd(vmmAxisAndAfterAxisSizeB, ptr[regAux1]); + uni_vpmulld(vmmSrcBeforeAxisSumB, vmmSrcBeforeAxisSumB, vmmAxisAndAfterAxisSizeB); + mov(regAux1, ptr[regParams + GET_OFF(srcAfterBatchSizeB)]); + uni_vpbroadcastd(vAux0, ptr[regAux1]); + uni_vpmulld(vAux0, vAux0, vmmIdxBatchSumB); + // Formula: srcBeforeAxisSum = ((start / specIndicesSize) % betweenBatchAndAxis) * axisAndAfterAxisSize + srcAfterBatchSize * idxBatchSum + uni_vpaddd(vmmSrcBeforeAxisSumB, vmmSrcBeforeAxisSumB, vAux0); + + // Formula: idxBatchSum = specIdxSize * (start / afterBatchSize) + uni_vpmulld(vmmIdxBatchSumB, vmmIdxBatchSumB, vmmSpecIdxSizeB); + + Xbyak::Label lBlock, lEnd; + mov(regAux2, ptr[regParams + GET_OFF(afterAxSize)]); + cmp(regAux2, 1); + jg(lBlock, T_NEAR); + { + Xbyak::Label lLessThanVector1, lTail1, lTail2, lE1; + + cmp(regSpecIdxSizeB, vlen); + jl(lLessThanVector1, T_NEAR); + uni_vmovd(reg32IdxIter, xmmSpecIdxB); + fillVlenVector(); + + process(false, false); + jmp(lE1, T_NEAR); + L(lLessThanVector1); + mov(regAux1, ptr[regParams + GET_OFF(permIdxMask)]); + uni_vmovups(vmmPermIdxMask, ptr[regAux1]); + mov(regAux1, ptr[regParams + GET_OFF(beforeAxisDiff)]); + uni_vmovups(vmmBeforeAxDiffB, ptr[regAux1]); + if (jcp.dataTypeSize != 1) + uni_vpslld(vmmBeforeAxDiffB, vmmBeforeAxDiffB, dataTypeShift); // multiply by data type size + mov(regAux1, ptr[regParams + GET_OFF(srcAfterBatchSizeB)]); + uni_vpbroadcastd(vmmSrcAfterBatchSizeB, ptr[regAux1]); + + process(true, false); + L(lE1); + jmp(lEnd, T_NEAR); + } + L(lBlock); { + mov(regAux1, ptr[regParams + GET_OFF(start)]); + uni_vpbroadcastd(vmmAfterAxisIdxB, ptr[regAux1]); + mov(regAux1, reinterpret_cast(incVec)); + uni_vpaddd(vmmAfterAxisIdxB, vmmAfterAxisIdxB, ptr[regAux1]); + uni_vcvtdq2ps(vmmAfterAxisIdxB, vmmAfterAxisIdxB); + + // afterAxIdxB = (start % afterAxSize) * idxTypeSize + movd(xAux0, reg32Aux1); + uni_vpbroadcastd(vAux1, xAux0); + uni_vcvtdq2ps(vAux1, vAux1); + uni_vdivps(vmmSrcBeforeAxisSumB, vmmAfterAxisIdxB, vAux1); + uni_vroundps(vmmSrcBeforeAxisSumB, vmmSrcBeforeAxisSumB, 0x1); + uni_vfnmadd231ps(vmmAfterAxisIdxB, vmmSrcBeforeAxisSumB, vAux1); + uni_vcvtps2dq(vmmAfterAxisIdxB, vmmAfterAxisIdxB); + uni_vpslld(vmmAfterAxisIdxB, vmmAfterAxisIdxB, idxTypeShift); // multiply by indices type size. + + Xbyak::Label lLessThanVector2, lTail3, lTail4, lE2; + + cmp(regAux2, dataElPerVec); + jl(lLessThanVector2, T_NEAR); + uni_vmovd(reg32IdxIter, xmmSpecIdxB); + fillVlenVector(); + +// process(false, true); + jmp(lE2, T_NEAR); + L(lLessThanVector2); + auto& vAux2 = vmmAuxContainer[2]; + // Calculate permute mask + uni_vmovd(xAux0, reg32Aux2); + uni_vpbroadcastd(vAux1, xAux0); + mov(regAux1, reinterpret_cast(&idxElPerVec)); + uni_vpbroadcastd(vAux0, ptr[regAux1]); + uni_vpsubd(vmmAfterAxisPermMask, vAux0, vAux1); + mov(regAux1, reinterpret_cast(incVec)); + uni_vpaddd(vmmAfterAxisPermMask, vmmAfterAxisPermMask, ptr[regAux1]); + for (int i = 0; i < 6; i++) { + if (isa == x64::avx512_common) { + Xbyak::Opmask kMask2 = Xbyak::Opmask(vAux2.getIdx()); + vpcmpgtd(kMask2, vAux0, vmmAfterAxisPermMask); + uni_vpsubd(vmmAfterAxisPermMask | kMask2, vmmAfterAxisPermMask, vAux1); + } else { + vpcmpgtd(vAux2, vAux0, vmmAfterAxisPermMask); + vpandn(vAux2, vAux2, vAux1); + uni_vpsubd(vmmAfterAxisPermMask, vmmAfterAxisPermMask, vAux2); + } + } + + process(true, true); + L(lE2); + } + L(lEnd); + } + + this->postamble(); +} + +template <> +void jitUniGatherKernel::uniVpGatherDd(Vmm& vDst, const Xbyak::Address& srcAddr, Vmask& kMask) { + vpgatherdd(vDst, srcAddr, kMask); +} +template <> +void jitUniGatherKernel::uniVpGatherDd(Vmm& vDst, const Xbyak::Address& srcAddr, Vmask& kMask) { + vpgatherdd(vDst | kMask, srcAddr); +} + +template <> +void jitUniGatherKernel::normalizeRawIndices(Vmm& vRawIndices, Vmask& kDstMask, Vmask& kAuxMask) { + // Compensate negative indices. + if (jcp.reverseIndexing) { + vpcmpgtd(kAuxMask, vmmZeros, vRawIndices); + vpand(kAuxMask, kAuxMask, vmmAxisDim); + uni_vpaddd(vRawIndices, vRawIndices, kAuxMask); + } + // Check boundaries. + vpcmpgtd(kDstMask, vmmAxisDim, vRawIndices); + vpcmpgtd(kAuxMask, vmmZeros, vRawIndices); + vpandn(kDstMask, kAuxMask, kDstMask); + // Multiply by type size. + if (jcp.dataTypeSize > 1) + uni_vpslld(vRawIndices, vRawIndices, dataTypeShift); +} + +template <> +void jitUniGatherKernel::normalizeRawIndices(Vmm& vRawIndices, Vmask& kDstMask, Vmask& kAuxMask) { + // Compensate negative indices. + if (jcp.reverseIndexing) { + vpcmpgtd(kAuxMask, vmmZeros, vRawIndices); + uni_vpaddd(vRawIndices | kAuxMask, vRawIndices, vmmAxisDim); + } + // Check boundaries. + vpcmpgtd(kAuxMask, vmmAxisDim, vRawIndices); + vpcmpd(kDstMask | kAuxMask, vmmZeros, vRawIndices, 2); // 2 - LE + // Multiply by type size. + if (jcp.dataTypeSize > 1) + uni_vpslld(vRawIndices, vRawIndices, dataTypeShift); +} + +template <> +void jitUniGatherKernel::normWithUpperBound(Vmm& vTarget, Vmm& vMax, Vmask& kAuxMask) { + vpcmpgtd(kAuxMask, vMax, vTarget); + vpandn(kAuxMask, kAuxMask, vMax); + uni_vpsubd(vTarget, vTarget, kAuxMask); +} + +template <> +void jitUniGatherKernel::normWithUpperBound(Vmm& vTarget, Vmm& vMax, Vmask& kAuxMask) { + vpcmpd(kAuxMask, vMax, vTarget, 2); // 2 -> LE + uni_vpsubd(vTarget | kAuxMask, vTarget, vMax); +} + +// Requires vAuxPool length 4. +// Returns calculated shifts in vAuxPool[0] and mask in vAuxPool[1]. +template <> +void jitUniGatherKernel::calcSrcShiftLong(Vmm* vAuxPool, bool shiftFirst) { + auto& vDstShifts = vAuxPool[0]; + auto& kDstMask = masksContainer[vAuxPool[1].getIdx()]; + auto& vAux0 = vAuxPool[2]; + auto& vAux1 = vAuxPool[3]; + auto& kAuxMask0 = masksContainer[vAux1.getIdx()]; + + Xbyak::Label lIdxStride, lExit; + if (shiftFirst) + uni_vpaddd(vmmSpecIdxB, vmmSpecIdxB, vmmVecLenB); + + add(regIdxIter, vlen); + cmp(regIdxIter, regSpecIdxSizeB); + jge(lIdxStride, T_NEAR); + if (jcp.batchDims > 0lu) { + uni_vpaddd(vDstShifts, vmmIdxBatchSumB, vmmSpecIdxB); + uni_vmovd(reg32Aux1, xmmAuxContainer[vDstShifts.getIdx()]); + } else { + uni_vmovd(reg32Aux1, xmmSpecIdxB); + } + vmovdqu(vDstShifts, ptr[regIndices + regAux1]); + normalizeRawIndices(vDstShifts, kDstMask, kAuxMask0); + if (jcp.beforeAxisSize != 1lu) + uni_vpaddd(vDstShifts, vDstShifts, vmmSrcBeforeAxisSumB); + jmp(lExit, T_NEAR); + L(lIdxStride); + sub(regIdxIter, regSpecIdxSizeB); + vpcmpeqd(kDstMask, vAux0, vAux0); + if (shiftFirst) { + vpcmpgtd(vAux0, vmmSpecIdxSizeB, vmmSpecIdxB); + vpandn(vAux1, vAux0, vmmSpecIdxSizeB); + uni_vpsubd(vAux1, vmmSpecIdxB, vAux1); + if (jcp.batchDims > 0lu) + uni_vpaddd(vAux1, vmmIdxBatchSumB, vAux1); + uni_vpsubd(vmmSpecIdxB, vmmSpecIdxB, vmmSpecIdxSizeB); + } else { + if (jcp.batchDims > 0lu) { + uni_vpaddd(vAux0, vmmIdxBatchSumB, vmmSpecIdxB); + uniVpGatherDd(vDstShifts, ptr[regIndices + vAux0], kDstMask); + } else { + uniVpGatherDd(vDstShifts, ptr[regIndices + vmmSpecIdxB], kDstMask); + } + normalizeRawIndices(vDstShifts, kDstMask, kAuxMask0); + + uni_vpbroadcastd(vAux0, xmmSpecIdxB); + vpcmpgtd(vAux1, vAux0, vmmSpecIdxB); + vpandn(vAux0, vAux1, vmmSpecIdxSizeB); + uni_vpsubd(vmmSpecIdxB, vmmSpecIdxB, vAux0); + + if (jcp.beforeAxisSize != 1lu) { + uni_vpaddd(vDstShifts, vDstShifts, vmmSrcBeforeAxisSumB); + vpandn(vAux0, vAux1, vmmAxisAndAfterAxisSizeB); + uni_vpaddd(vmmSrcBeforeAxisSumB, vmmSrcBeforeAxisSumB, vAux0); + } + } + + if (jcp.batchDims > 0lu) { + Xbyak::Label l1; + inc(regBetweenBatchAndAxisIter); + cmp(regBetweenBatchAndAxisIter, regBetweenBatchAndAxisSize); + jl(l1, T_NEAR); + mov(regBetweenBatchAndAxisIter, 0); + if (shiftFirst) { + uni_vpaddd(vmmIdxBatchSumB, vmmIdxBatchSumB, vmmSpecIdxSizeB); + vpandn(vDstShifts, vAux0, vmmSpecIdxSizeB); + uni_vpaddd(vAux1, vAux1, vDstShifts); + } else { + vpandn(vAux0, vAux1, vmmSpecIdxSizeB); + uni_vpaddd(vmmIdxBatchSumB, vmmIdxBatchSumB, vAux0); + } + L(l1); + } + + if (shiftFirst) { + uniVpGatherDd(vDstShifts, ptr[regIndices + vAux1], kDstMask); + normalizeRawIndices(vDstShifts, kDstMask, kAuxMask0); + + if (jcp.beforeAxisSize != 1lu) { + vpandn(vAux0, vAux0, vmmAxisAndAfterAxisSizeB); + uni_vpaddd(vAux0, vAux0, vmmSrcBeforeAxisSumB); + uni_vpaddd(vmmSrcBeforeAxisSumB, vmmSrcBeforeAxisSumB, vmmAxisAndAfterAxisSizeB); + + uni_vpaddd(vDstShifts, vDstShifts, vAux0); + } + } + L(lExit); +} + +// Requires vAuxPool length 4. +// Returns calculated shifts in vAuxPool[0] and mask in vAuxPool[1]. +template <> +void jitUniGatherKernel::calcSrcShiftLong(Vmm* vAuxPool, bool shiftFirst) { + auto& vDstShifts = vAuxPool[0]; + auto& kDstMask = masksContainer[vAuxPool[1].getIdx()]; + auto& vAux0 = vAuxPool[2]; + auto& vAux1 = vAuxPool[3]; + auto& kAuxMask0 = masksContainer[vAux1.getIdx()]; + auto& kAuxMask1 = masksContainer[vAux1.getIdx() + 1]; + + Xbyak::Label lIdxStride, lExit; + if (shiftFirst) + uni_vpaddd(vmmSpecIdxB, vmmSpecIdxB, vmmVecLenB); + + add(regIdxIter, vlen); + cmp(regIdxIter, regSpecIdxSizeB); + jge(lIdxStride, T_NEAR); + if (jcp.batchDims > 0lu) { + uni_vpaddd(vDstShifts, vmmIdxBatchSumB, vmmSpecIdxB); + uni_vmovd(reg32Aux1, xmmAuxContainer[vDstShifts.getIdx()]); + } else { + uni_vmovd(reg32Aux1, xmmSpecIdxB); + } + vmovdqu64(vDstShifts, ptr[regIndices + regAux1]); + normalizeRawIndices(vDstShifts, kDstMask, kAuxMask0); + if (jcp.beforeAxisSize != 1lu) + uni_vpaddd(vDstShifts, vDstShifts, vmmSrcBeforeAxisSumB); + jmp(lExit, T_NEAR); + L(lIdxStride); + sub(regIdxIter, regSpecIdxSizeB); + vpcmpeqd(kDstMask, vDstShifts, vDstShifts); + if (shiftFirst) { + vpcmpd(kAuxMask1, vmmSpecIdxSizeB, vmmSpecIdxB, 2); // 2 -> LE + if (jcp.batchDims > 0lu) { + uni_vpaddd(vAux1, vmmIdxBatchSumB, vmmSpecIdxB); + uni_vpsubd(vAux1 | kAuxMask1, vAux1, vmmSpecIdxSizeB); + } else { + uni_vmovups(vAux1, vmmSpecIdxB); + uni_vpsubd(vAux1 | kAuxMask1, vmmSpecIdxB, vmmSpecIdxSizeB); + } + uni_vpsubd(vmmSpecIdxB, vmmSpecIdxB, vmmSpecIdxSizeB); + } else { + if (jcp.batchDims > 0lu) { + uni_vpaddd(vAux0, vmmIdxBatchSumB, vmmSpecIdxB); + uniVpGatherDd(vDstShifts, ptr[regIndices + vAux0], kDstMask); + } else { + uniVpGatherDd(vDstShifts, ptr[regIndices + vmmSpecIdxB], kDstMask); + } + normalizeRawIndices(vDstShifts, kDstMask, kAuxMask0); + + uni_vpbroadcastd(vAux0, xmmSpecIdxB); + vpcmpd(kAuxMask1, vAux0, vmmSpecIdxB, 2); // 2 -> LE + uni_vpsubd(vmmSpecIdxB | kAuxMask1, vmmSpecIdxB, vmmSpecIdxSizeB); + + if (jcp.beforeAxisSize != 1lu) { + uni_vpaddd(vDstShifts, vDstShifts, vmmSrcBeforeAxisSumB); + uni_vpaddd(vmmSrcBeforeAxisSumB | kAuxMask1, vmmSrcBeforeAxisSumB, vmmAxisAndAfterAxisSizeB); + } + } + + if (jcp.batchDims > 0lu) { + Xbyak::Label l1; + inc(regBetweenBatchAndAxisIter); + cmp(regBetweenBatchAndAxisIter, regBetweenBatchAndAxisSize); + jl(l1, T_NEAR); + mov(regBetweenBatchAndAxisIter, 0); + if (shiftFirst) { + uni_vpaddd(vmmIdxBatchSumB, vmmIdxBatchSumB, vmmSpecIdxSizeB); + uni_vpaddd(vAux1 | kAuxMask1, vAux1, vmmSpecIdxSizeB); + } else { + uni_vpaddd(vmmIdxBatchSumB | kAuxMask1, vmmIdxBatchSumB, vmmSpecIdxSizeB); + } + L(l1); + } + + if (shiftFirst) { + uniVpGatherDd(vDstShifts, ptr[regIndices + vAux1], kDstMask); + normalizeRawIndices(vDstShifts, kDstMask, kAuxMask0); + + if (jcp.beforeAxisSize != 1lu) { + uni_vpaddd(vDstShifts, vDstShifts, vmmSrcBeforeAxisSumB); + uni_vpaddd(vDstShifts | kAuxMask1, vDstShifts, vmmAxisAndAfterAxisSizeB); + uni_vpaddd(vmmSrcBeforeAxisSumB, vmmSrcBeforeAxisSumB, vmmAxisAndAfterAxisSizeB); + } + } + L(lExit); +} + +template +void jitUniGatherKernel::calcSrcShiftLongBlock(Vmm* vAuxPool, bool shiftFirst) { + // Most likely there will no significant performance gain vs memcpy in reference implementation on big blocks after axis, + // therefore no time was invested to this case yet. + IE_THROW() << "Unsupported case."; +} + +// Requires vAuxPool length 3. +// Returns calculated shifts in vAuxPool[0] and mask in vAuxPool[1]. +template +void jitUniGatherKernel::calcSrcShiftShort(Vmm* vAuxPool, bool shiftFirst) { + auto& vDstShifts = vAuxPool[0]; + auto& kDstMask = masksContainer[vAuxPool[1].getIdx()]; + auto& vAux0 = vAuxPool[2]; + + if (shiftFirst) { + if (jcp.beforeAxisSize != 1lu) + uni_vpaddd(vmmSrcBeforeAxisSumB, vmmSrcBeforeAxisSumB, vmmBeforeAxDiffB); + // No sense to permute if specIdxSize is one of {1, 2, 4, 8}. + if (jcp.specIdxSize != 1 && jcp.specIdxSize != 2 && jcp.specIdxSize != 4 && jcp.specIdxSize != 8) { + vpermd(vmmSpecIdxB, vmmPermIdxMask, vmmSpecIdxB); + if (jcp.beforeAxisSize != 1lu) + vpermd(vmmBeforeAxDiffB, vmmPermIdxMask, vmmBeforeAxDiffB); + } + } + + vpcmpeqd(kDstMask, vAux0, vAux0); + if (jcp.batchDims > 0lu) { + // Calculate indices batch sum. + uni_vcvtdq2ps(vAux0, vmmSrcBeforeAxisSumB); + uni_vcvtdq2ps(vDstShifts, vmmSrcAfterBatchSizeB); + uni_vdivps(vAux0, vAux0, vDstShifts); + uni_vroundps(vAux0, vAux0, 0x1); + uni_vcvtps2dq(vAux0, vAux0); + + uni_vpmulld(vAux0, vAux0, vmmSpecIdxSizeB); + uni_vpaddd(vAux0, vAux0, vmmSpecIdxB); + + uniVpGatherDd(vDstShifts, ptr[regIndices + vAux0], kDstMask); + } else { + uniVpGatherDd(vDstShifts, ptr[regIndices + vmmSpecIdxB], kDstMask); + } + + auto& kAuxMask0 = masksContainer[vAux0.getIdx()]; + normalizeRawIndices(vDstShifts, kDstMask, kAuxMask0); + if (jcp.beforeAxisSize != 1lu) + uni_vpaddd(vDstShifts, vDstShifts, vmmSrcBeforeAxisSumB); +} + +// Requires vAuxPool length 4. +// Returns calculated shifts in vAuxPool[0] and mask in vAuxPool[1]. +template +void jitUniGatherKernel::calcSrcShiftShortBlock(Vmm* vAuxPool, bool shiftFirst) { + auto& vDstShifts = vAuxPool[0]; + auto& kDstMask = masksContainer[vAuxPool[1].getIdx()]; + auto& vAux0 = vAuxPool[2]; + auto& vAux1 = vAuxPool[3]; + auto& kAuxMask0 = masksContainer[vAux0.getIdx()]; + const uint64_t specIdxAndAfterAxisSize = jcp.specIdxSize * jcp.afterAxisSize; + + if (shiftFirst) { + if (jcp.specIdxSize != 1) { + uni_vpaddd(vmmSpecIdxB, vmmSpecIdxB, vmmSpecIdxDiff); + normWithUpperBound(vmmSpecIdxB, vmmSpecIdxSizeB, kAuxMask0); + } + // No sense to permute if afterAxisSize is one of {1, 2, 4, 8}. + if (jcp.afterAxisSize != 1 && jcp.afterAxisSize != 2 && jcp.afterAxisSize != 4 && jcp.afterAxisSize % 8 != 0) { + vpermd(vmmAfterAxisIdxB, vmmAfterAxisPermMask, vmmAfterAxisIdxB); + if (jcp.specIdxSize != 1) + vpermd(vmmSpecIdxDiff, vmmAfterAxisPermMask, vmmSpecIdxDiff); + } + + if (jcp.beforeAxisSize != 1lu) { + if (!jcp.dynamicShapes) { + if (specIdxAndAfterAxisSize > 0lu && specIdxAndAfterAxisSize <= idxElPerVec) { + uni_vpaddd(vmmSrcBeforeAxisSumB, vmmSrcBeforeAxisSumB, vmmBeforeAxDiffB); + uni_vmovups(vAux1, vmmSrcBeforeAxisSumB); + if (specIdxAndAfterAxisSize != 1 && specIdxAndAfterAxisSize != 2 && specIdxAndAfterAxisSize != 4 && + specIdxAndAfterAxisSize % 8 != 0) + vpermd(vmmBeforeAxDiffB, vmmBeforeAxPermMask, vmmBeforeAxDiffB); + } else { + Xbyak::Label lBeforeAxStep, lBeforeAxStepEnd; + add(rSpecIdxAndAfterAxIterB, idxElPerVec * jcp.dataTypeSize); + cmp(rSpecIdxAndAfterAxIterB, rSpecIdxAndAfterAxSizeB); + jl(lBeforeAxStep, T_NEAR); + sub(rSpecIdxAndAfterAxIterB, rSpecIdxAndAfterAxSizeB); + + vpmulld(vAux0, vmmSpecIdxB, vmmAfterAxisSize); + uni_vpaddd(vAux0, vAux0, vmmAfterAxisIdxB); + Xbyak::Xmm& xAux0 = xmmAuxContainer[vAux0.getIdx()]; + uni_vpbroadcastd(vAux1, xAux0); + if (isa == x64::avx512_common) { + Xbyak::Opmask kMask0 = Xbyak::Opmask(kAuxMask0.getIdx()); + vpcmpgtd(kMask0, vAux1, vAux0); + uni_vmovups(vAux1, vmmSrcBeforeAxisSumB); + uni_vpaddd(vAux1 | kMask0, vmmSrcBeforeAxisSumB, vmmAxisAndAfterAxisSizeB); + } else { + vpcmpgtd(vAux1, vAux1, vAux0); + vpand(vAux1, vAux1, vmmAxisAndAfterAxisSizeB); + uni_vpaddd(vAux1, vmmSrcBeforeAxisSumB, vAux1); + } + uni_vpaddd(vmmSrcBeforeAxisSumB, vmmSrcBeforeAxisSumB, vmmAxisAndAfterAxisSizeB); + jmp(lBeforeAxStepEnd); + L(lBeforeAxStep); + uni_vmovups(vAux1, vmmSrcBeforeAxisSumB); + L(lBeforeAxStepEnd); + } + } else { + } + } + } else { + if (jcp.beforeAxisSize != 1lu) { + uni_vmovups(vAux1, vmmSrcBeforeAxisSumB); + if (specIdxAndAfterAxisSize > idxElPerVec) { + // Broadcast the last element. + if (isa == x64::avx512_common) { + vshuff64x2(vmmSrcBeforeAxisSumB, vmmSrcBeforeAxisSumB, vmmSrcBeforeAxisSumB, 0xFF); + } else { + vpermq(vmmSrcBeforeAxisSumB, vmmSrcBeforeAxisSumB, 0xFF); + } + vpshufd(vmmSrcBeforeAxisSumB, vmmSrcBeforeAxisSumB, 0xFF); + + Xbyak::Label lBeforeAxStepEnd1; + add(rSpecIdxAndAfterAxIterB, idxElPerVec * jcp.dataTypeSize); + cmp(rSpecIdxAndAfterAxIterB, rSpecIdxAndAfterAxSizeB); + jl(lBeforeAxStepEnd1, T_NEAR); + sub(rSpecIdxAndAfterAxIterB, rSpecIdxAndAfterAxSizeB); + cmp(rSpecIdxAndAfterAxIterB, 0); + jne(lBeforeAxStepEnd1, T_NEAR); + uni_vpaddd(vmmSrcBeforeAxisSumB, vmmSrcBeforeAxisSumB, vmmAxisAndAfterAxisSizeB); + L(lBeforeAxStepEnd1); + } + } + } + + vpcmpeqd(kDstMask, vAux0, vAux0); + if (jcp.batchDims > 0lu) { + // Calculate indices batch sum. + uni_vcvtdq2ps(vAux0, vAux1); + uni_vcvtdq2ps(vDstShifts, vmmSrcAfterBatchSizeB); + uni_vdivps(vAux0, vAux0, vDstShifts); + uni_vroundps(vAux0, vAux0, 0x1); + uni_vcvtps2dq(vAux0, vAux0); + + uni_vpmulld(vAux0, vAux0, vmmSpecIdxSizeB); + uni_vpaddd(vAux0, vAux0, vmmSpecIdxB); + + uniVpGatherDd(vDstShifts, ptr[regIndices + vAux0], kDstMask); + } else { + uniVpGatherDd(vDstShifts, ptr[regIndices + vmmSpecIdxB], kDstMask); + } + + normalizeRawIndices(vDstShifts, kDstMask, kAuxMask0); + + if (jcp.afterAxisSize != 1lu) { + vpmulld(vDstShifts, vDstShifts, vmmAfterAxisSize); + uni_vpaddd(vDstShifts, vDstShifts, vmmAfterAxisIdxB); + } + if (jcp.beforeAxisSize != 1lu) + uni_vpaddd(vDstShifts, vDstShifts, vAux1); +} + +template +void jitUniGatherKernel::process(bool isShortIdx, bool blocked) { + Xbyak::Label lTailProc, lEndProc; + cmp(regWorkAmount, dataElPerVec); + jl(lTailProc, T_NEAR); + if (jcp.dataTypeSize == 4) + process32b(isShortIdx, blocked); + else if (jcp.dataTypeSize == 2) + process16b(isShortIdx, blocked); + else if (jcp.dataTypeSize == 1) + process8b(isShortIdx, blocked); + jmp(lEndProc, T_NEAR); + L(lTailProc); + tail(isShortIdx, false, blocked); + L(lEndProc); +} + +template +void jitUniGatherKernel::process32b(bool isShortIdx, bool blocked) { + Xbyak::Label lDstIdxLoop, lTail; + + // First iteration + shiftIdxAndGather(vmmAuxContainer, isShortIdx, false, blocked); + uni_vmovups(ptr[regDst], vmmAuxContainer[2]); + + // Main loop + L(lDstIdxLoop); + { + add(regDst, vlen); + sub(regWorkAmount, dataElPerVec); + cmp(regWorkAmount, dataElPerVec); + jl(lTail, T_NEAR); + + shiftIdxAndGather(vmmAuxContainer, isShortIdx, true, blocked); + uni_vmovups(ptr[regDst], vmmAuxContainer[2]); + + jmp(lDstIdxLoop, T_NEAR); + } + + L(lTail); + tail(isShortIdx, true, blocked); +} + +template +void jitUniGatherKernel::process16b(bool isShortIdx, bool blocked) { + Xbyak::Label lDstIdxLoop1, lTail; + + Vmm vShufMask, vPermMask, vBuff0; + if (isa == x64::avx512_common) { + vPermMask = vmmAuxContainer[7]; + vShufMask = vmmAuxContainer[8]; + vBuff0 = vmmAuxContainer[9]; + } else { + vPermMask = vmmAuxContainer[1]; + vShufMask = vmmAuxContainer[4]; + vBuff0 = vmmAuxContainer[5]; + } + + mov(regAux1, reinterpret_cast(shufMask16bitUni)); + uni_vmovups(vShufMask, ptr[regAux1]); + mov(regAux1, reinterpret_cast(permMask16bitUni)); + uni_vmovups(vPermMask, ptr[regAux1]); + + // First iteration + shiftIdxAndGather(vmmAuxContainer, isShortIdx, false, blocked); + vpshufb(vBuff0, vmmAuxContainer[2], vShufMask); + + shiftIdxAndGather(vmmAuxContainer, isShortIdx, true, blocked); + vpshufb(vmmAuxContainer[0], vmmAuxContainer[2], vShufMask); + + vshufps(vmmAuxContainer[0], vBuff0, vmmAuxContainer[0], 0x44); + vpermd(vmmAuxContainer[0], vPermMask, vmmAuxContainer[0]); + + uni_vmovups(ptr[regDst], vmmAuxContainer[0]); + + // Main loop. + L(lDstIdxLoop1); + { + add(regDst, vlen); + sub(regWorkAmount, dataElPerVec); + cmp(regWorkAmount, dataElPerVec); + jl(lTail, T_NEAR); + + shiftIdxAndGather(vmmAuxContainer, isShortIdx, true, blocked); + vpshufb(vBuff0, vmmAuxContainer[2], vShufMask); + + shiftIdxAndGather(vmmAuxContainer, isShortIdx, true, blocked); + vpshufb(vmmAuxContainer[0], vmmAuxContainer[2], vShufMask); + + vshufps(vmmAuxContainer[0], vBuff0, vmmAuxContainer[0], 0x44); + vpermd(vmmAuxContainer[0], vPermMask, vmmAuxContainer[0]); + + uni_vmovups(ptr[regDst], vmmAuxContainer[0]); + + jmp(lDstIdxLoop1, T_NEAR); + } + + L(lTail); + tail(isShortIdx, true, blocked); +} + +template +void jitUniGatherKernel::process8b(bool isShortIdx, bool blocked) { + Xbyak::Label lDstIdxLoop1, lTail; + + Vmm vShufMask, vPermMask, vBuff0, vBuff1; + if (isa == x64::avx512_common) { + vPermMask = vmmAuxContainer[7]; + vShufMask = vmmAuxContainer[8]; + vBuff0 = vmmAuxContainer[9]; + vBuff1 = vmmAuxContainer[10]; + } else { + vPermMask = vmmAuxContainer[1]; + vShufMask = vmmAuxContainer[4]; + vBuff0 = vmmAuxContainer[5]; + vBuff1 = vmmAuxContainer[6]; + } + mov(regAux1, reinterpret_cast(shufMask8bitUni)); + uni_vmovups(vShufMask, ptr[regAux1]); + + // First iteration + shiftIdxAndGather(vmmAuxContainer, isShortIdx, false, blocked); + vpshufb(vBuff0, vmmAuxContainer[2], vShufMask); + + shiftIdxAndGather(vmmAuxContainer, isShortIdx, true, blocked); + vpshufb(vmmAuxContainer[0], vmmAuxContainer[2], vShufMask); + + vshufps(vBuff0, vBuff0, vmmAuxContainer[0], 0x0); + + shiftIdxAndGather(vmmAuxContainer, isShortIdx, true, blocked); + vpshufb(vBuff1, vmmAuxContainer[2], vShufMask); + + shiftIdxAndGather(vmmAuxContainer, isShortIdx, true, blocked); + vpshufb(vmmAuxContainer[0], vmmAuxContainer[2], vShufMask); + + vshufps(vBuff1, vBuff1, vmmAuxContainer[0], 0x0); + vshufps(vmmAuxContainer[0], vBuff0, vBuff1, 0x88); + + mov(regAux1, reinterpret_cast(permMask8bitUni)); + uni_vmovups(vPermMask, ptr[regAux1]); + + vpermd(vmmAuxContainer[0], vPermMask, vmmAuxContainer[0]); + + uni_vmovups(ptr[regDst], vmmAuxContainer[0]); + + // Main loop. + L(lDstIdxLoop1); + { + add(regDst, vlen); + sub(regWorkAmount, dataElPerVec); + cmp(regWorkAmount, dataElPerVec); + jl(lTail, T_NEAR); + + shiftIdxAndGather(vmmAuxContainer, isShortIdx, true, blocked); + vpshufb(vBuff0, vmmAuxContainer[2], vShufMask); + + shiftIdxAndGather(vmmAuxContainer, isShortIdx, true, blocked); + vpshufb(vmmAuxContainer[0], vmmAuxContainer[2], vShufMask); + + vshufps(vBuff0, vBuff0, vmmAuxContainer[0], 0x0); + + shiftIdxAndGather(vmmAuxContainer, isShortIdx, true, blocked); + vpshufb(vBuff1, vmmAuxContainer[2], vShufMask); + + shiftIdxAndGather(vmmAuxContainer, isShortIdx, true, blocked); + vpshufb(vmmAuxContainer[0], vmmAuxContainer[2], vShufMask); + + vshufps(vmmAuxContainer[0], vBuff1, vmmAuxContainer[0], 0x0); + vshufps(vmmAuxContainer[0], vBuff0, vmmAuxContainer[0], 0x88); + + if (isa == x64::avx2) { + // Register vPermMask is invalidated by shiftIdxAndGather and must be initialized again. + mov(regAux1, reinterpret_cast(permMask8bitUni)); + uni_vmovups(vPermMask, ptr[regAux1]); + } + vpermd(vmmAuxContainer[0], vPermMask, vmmAuxContainer[0]); + + uni_vmovups(ptr[regDst], vmmAuxContainer[0]); + + jmp(lDstIdxLoop1, T_NEAR); + } + + L(lTail); + tail(isShortIdx, true, blocked); +} + +// Requires vAuxPool length 4. +// Returns gathered data in vAuxPool[2]. +template +void jitUniGatherKernel::shiftIdxAndGather(Vmm* vAuxPool, bool isShortIdx, bool shiftFirst, bool blocked) { + if (blocked) { + if (isShortIdx) { + calcSrcShiftShortBlock(vAuxPool, shiftFirst); + } else { + calcSrcShiftLongBlock(vAuxPool, shiftFirst); + } + } else { + if (isShortIdx) { + calcSrcShiftShort(vAuxPool, shiftFirst); + } else { + calcSrcShiftLong(vAuxPool, shiftFirst); + } + } + auto& kGatherMask = masksContainer[vAuxPool[1].getIdx()]; + uni_vmovups(vAuxPool[2], vmmZeros); + uniVpGatherDd(vAuxPool[2], ptr[regSrc + vAuxPool[0]], kGatherMask); +} + +template +void jitUniGatherKernel::tail(bool isShortIdx, bool shiftFirst, bool blocked) { + auto& vSrcShift = vmmAuxContainer[0]; + auto& kGatherMask = masksContainer[vmmAuxContainer[1].getIdx()]; + auto& vAux0 = vmmAuxContainer[2]; + auto& vAux1 = vmmAuxContainer[3]; + auto& kAuxMask1 = masksContainer[vAux1.getIdx()]; + Xbyak::Label lEnd; + + const int secondStepCycles = 4 / jcp.dataTypeSize; + for (int p = 0; p < secondStepCycles; p++) { + cmp(regWorkAmount, 0); + jle(lEnd, T_NEAR); + + if (isShortIdx) { + if (blocked) { + calcSrcShiftShortBlock(vmmAuxContainer, p > 0 || shiftFirst); + } else { + calcSrcShiftShort(vmmAuxContainer, p > 0 || shiftFirst); + } + } else { + if (blocked) { + calcSrcShiftLongBlock(vmmAuxContainer, p > 0 || shiftFirst); + } else { + calcSrcShiftLong(vmmAuxContainer, p > 0 || shiftFirst); + } + } + + fillRestWorkMask(kAuxMask1, vAux0, regWorkAmount, regAux1, rdx); + + // Combining masks. + if (isa == x64::avx512_common) { + auto kMask1 = Xbyak::Opmask(kAuxMask1.getIdx()); + auto kMaskG = Xbyak::Opmask(kGatherMask.getIdx()); + kandd(kMaskG, kMaskG, kMask1); + } else if (isa == x64::avx2) { + auto& vGatherMask = vmmAuxContainer[kGatherMask.getIdx()]; + vpand(vGatherMask, vGatherMask, vAux1); + } + + uni_vmovups(vAux0, vmmZeros); + uniVpGatherDd(vAux0, ptr[regSrc + vSrcShift], kGatherMask); + if (jcp.dataTypeSize == 4) { + uni_vmovups_tail(ptr[regDst], kAuxMask1, vAux0); + sub(regWorkAmount, dataElPerVec); + } else { + storeVectorPart(regDst, regWorkAmount, vAux0, vAux1); + } + } + L(lEnd); +} + +template <> +void jitUniGatherKernel::fillRestWorkMask(Vmask& kDstMask, Vmm& vmmAux, const Xbyak::Reg64& rWorkRest, + const Xbyak::Reg64& rAux0, const Xbyak::Reg64& rAux1) { + Xbyak::Label lKmov; + Xbyak::Reg32 rOnes(rAux1.getIdx()); + mov(rOnes, 0x0000FFFF); + cmp(rWorkRest, idxElPerVec); + jge(lKmov); + Xbyak::Reg8 rShift(Xbyak::Operand::CL); + mov(rShift, idxElPerVec); + sub(rShift, rWorkRest); + shr(rOnes, rShift); + L(lKmov); + kmovw(kDstMask, rOnes); +} + +template <> +void jitUniGatherKernel::fillRestWorkMask(Vmask& kDstMask, Vmm& vAux, const Xbyak::Reg64& rWorkRest, + const Xbyak::Reg64& rAux0, const Xbyak::Reg64& rAux1) { + Xbyak::Label lEnd; + mov(rAux0, rWorkRest); + Xbyak::Reg32 rOnes(rAux1.getIdx()); + mov(rOnes, 0xFFFFFFFF); + Xbyak::Xmm xmmAux(vAux.getIdx()); + uni_vmovups(kDstMask, vmmZeros); + for (uint8_t i = 0; i < idxElPerVec; i++) { + cmp(rAux0, 0); + je(lEnd, T_NEAR); + + if (i % 4 == 0) + uni_vmovups(xmmAux, xmmZeros); + + vpinsrd(xmmAux, xmmAux, rOnes, i % 4); + vinserti128(kDstMask, kDstMask, xmmAux, i / 4); + sub(rAux0, 1); + } + L(lEnd); +} + +template +void jitUniGatherKernel::storeVectorPart(const Xbyak::Reg64& rDst, const Xbyak::Reg64& rToStoreCounter, Vmm& vmmSrc, Vmm& vAux) { + Xbyak::Label lEnd; + Xbyak::Xmm xAux(vAux.getIdx()); + for (int j = 0; j < vlen / vlenXmm; j++) { + if (isa == x64::avx2) + vextracti128(xAux, vmmSrc, j); + else if (isa == x64::avx512_common) + vextracti64x2(xAux, vmmSrc, j); + + for (int k = 0; k < 4; k++) { + cmp(rToStoreCounter, 0); + jle(lEnd, T_NEAR); + + if (jcp.dataTypeSize == 4) + uni_vpextrd(ptr[rDst], xAux, k); + else if (jcp.dataTypeSize == 2) + uni_vpextrw(ptr[rDst], xAux, k * 2); + else if (jcp.dataTypeSize == 1) + uni_vpextrb(ptr[rDst], xAux, k * 4); + + add(rDst, jcp.dataTypeSize); + sub(rToStoreCounter, 1); + } + } + L(lEnd); +} + +template <> +void jitUniGatherKernel::fillVlenVector() { + mov(reg32Aux1, vlen); + vpbroadcastd(vmmVecLenB, reg32Aux1); +} +template <> +void jitUniGatherKernel::fillVlenVector() { + vpcmpeqd(vmmVecLenB, vmmVecLenB, vmmVecLenB); + vpsrld(vmmVecLenB, vmmVecLenB, 31); // Right shift to 1. + uni_vpslld(vmmVecLenB, vmmVecLenB, 5); // Left shift to 32. +} + +template +bool jitUniGatherKernel::isSupportedConfiguration(uint64_t afterAxisSize) { + if (!jcp.dynamicShapes && afterAxisSize <= idxElPerVec) { + if (afterAxisSize > 1 && isa == x64::avx2 && (jcp.dataTypeSize == 1 || jcp.dataTypeSize == 2)) + // There are no enough registers for these cases. + return false; + + return true; + } + if (jcp.dynamicShapes && afterAxisSize == 1) { + return true; + } + return false; +} + +template struct jitUniGatherKernel; +template struct jitUniGatherKernel; + +} // namespace MKLDNNPlugin diff --git a/src/plugins/intel_cpu/src/nodes/kernels/gather_uni_kernel.hpp b/src/plugins/intel_cpu/src/nodes/kernels/gather_uni_kernel.hpp new file mode 100644 index 00000000000..5cb637906c9 --- /dev/null +++ b/src/plugins/intel_cpu/src/nodes/kernels/gather_uni_kernel.hpp @@ -0,0 +1,209 @@ +// Copyright (C) 2022 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +// Gather kernel implements two approaches for indices calculation: "Short" and "Long". +// 1. Short approach is applicable for cases when the number of elements less or equal to vector register length. +// It just uses permutation of current indices vector to retrieve the next. +// 2. Long approach is applicable for cases when the number of elements is greater than vector register length. +// It increases indices in vector on vector length and normalizes upper bound of indices. +// +// SUPPORTED CASES +//-------------------------------------------------------------- +// After axis | AVX512 | AVX2 | +// (block) size| 32bit | 16bit | 8bit | 32bit | 16bit | 8bit | +// STATIC SHAPES +// 1 | X | X | X | X | X | X | +// >1 & <=vlen | X | X | X | X | | | +// DYNAMIC SHAPES +// 1 | X | X | X | X | X | X | +//-------------------------------------------------------------- + + +#pragma once + +#include "cpu/x64/jit_generator.hpp" +#include + +namespace MKLDNNPlugin { + +struct jGatherConfParams { + uint64_t dataTypeSize = 1lu; + bool reverseIndexing = true; + bool dynamicShapes = false; + uint64_t batchDims = 0lu; + uint64_t beforeAxisSize = 0lu; + uint64_t specIdxSize = 0lu; + uint64_t afterAxisSize = 0lu; +}; + +struct gatherJitExecArgs { + const void* src; + const void* indices; + void* dst; + const int* axisDim; + const uint64_t* start; + const uint64_t* specIndicesSize; + const uint64_t* betweenBatchAndAxisSize; + const uint64_t* axisAndAfterAxisSizeB; + const uint64_t* srcAfterBatchSizeB; + const int* permIdxMask; + const int* beforeAxisDiff; + + const int* beforeAxisPermMask; + const int* afterAxIdxB; + const int* afterAxisPermMask; + const uint64_t* afterAxisSize; + const int* specIdxDiff; + + uint64_t workAmount = 0lu; + uint64_t afterAxSize = 1lu; + // Blocked short. + uint64_t specIdxAndAfterAxIterB; + uint64_t specIdxAndAfterAxSizeB; + // Only static + const int* specIdxB; + const int* idxBatchSumB; + const int* dataBeforeAxisSumB; + uint64_t betweenBatchAndAxisIter; +}; + +struct jitGatherKernelBase { + void (*ker_)(const gatherJitExecArgs *); + void operator()(const gatherJitExecArgs *args) { + assert(ker_); + ker_(args); + } + explicit jitGatherKernelBase(const jGatherConfParams& jcp) : ker_(nullptr), jcp(jcp) {} + virtual ~jitGatherKernelBase() {} + + virtual void create_ker() = 0; + uint64_t getVecLen() { + return vlen; + } + uint64_t getDataElPerVec() { + return dataElPerVec; + } + uint64_t getIdxElPerVec() { + return idxElPerVec; + } + virtual bool isSupportedConfiguration(uint64_t afterAxisSize) = 0; + +protected: + jGatherConfParams jcp; + uint64_t vlen; + uint64_t dataElPerVec; + uint64_t idxElPerVec; + static const unsigned shufMask8bitUni[16]; + static const unsigned permMask8bitA2[8]; + static const unsigned permMask8bitA5[16]; + static const unsigned shufMask16bitUni[16]; + static const unsigned permMask16bitA2[8]; + static const unsigned permMask16bitA5[16]; + static const unsigned incVec[16]; + + int shortPermIdx[16]; + int shortBeforeAxisDiff[16]; +}; + +template +struct jitUniGatherKernel : public jitGatherKernelBase, public dnnl::impl::cpu::x64::jit_generator { + DECLARE_CPU_JIT_AUX_FUNCTIONS(jitUniGatherKernel) + + explicit jitUniGatherKernel(const jGatherConfParams& jcp); + + void create_ker() override; + void generate() override; + + bool isSupportedConfiguration(uint64_t afterAxisSize) override; + +protected: + using Vmm = typename dnnl::impl::utils::conditional::type; + using Vmask = typename dnnl::impl::utils::conditional::type; + static const uint32_t vlenXmm = dnnl::impl::cpu::x64::cpu_isa_traits::vlen; + static const uint32_t indicesTypeSize = sizeof(uint32_t); + static const uint8_t idxTypeShift = 2; + uint8_t dataTypeShift = 0; + + // Suffix B means "In Bytes". + // 64b registers. + const Xbyak::Reg64& regSrc = r8; + const Xbyak::Reg64& regDst = r9; + const Xbyak::Reg64& regIndices = r10; + const Xbyak::Reg64& regIdxIter = r11; + const Xbyak::Reg64& regWorkAmount = r12; + const Xbyak::Reg64& regSpecIdxSizeB = r13; + const Xbyak::Reg64& regAux1 = r14; + const Xbyak::Reg64& regAux2 = rsi; + const Xbyak::Reg64& regBetweenBatchAndAxisIter = r15; + const Xbyak::Reg64& regBetweenBatchAndAxisSize = rbx; + const Xbyak::Reg64& rSpecIdxAndAfterAxIterB = regIdxIter; + const Xbyak::Reg64& rSpecIdxAndAfterAxSizeB = regSpecIdxSizeB; + + const Xbyak::Reg64& regParams = dnnl::impl::cpu::x64::abi_param1; + + // 32b registers. + Xbyak::Reg32 reg32IdxIter = Xbyak::Reg32(regIdxIter.getIdx()); + Xbyak::Reg32 reg32SpecIdxSizeB = Xbyak::Reg32(regSpecIdxSizeB.getIdx()); + Xbyak::Reg32 reg32BetweenBatchAndAxisSize = Xbyak::Reg32(regBetweenBatchAndAxisSize.getIdx()); + Xbyak::Reg32 reg32BetweenBatchAndAxisIter = Xbyak::Reg32(regBetweenBatchAndAxisIter.getIdx()); + Xbyak::Reg32 reg32Aux1 = Xbyak::Reg32(regAux1.getIdx()); + Xbyak::Reg32 reg32Aux2 = Xbyak::Reg32(regAux2.getIdx()); + + // Masks pool. Do not use k0 with gather instruction! + Vmask masksContainer[8] = {Vmask(0), Vmask(1), Vmask(2), Vmask(3), Vmask(4), Vmask(5), Vmask(6), Vmask(7)}; + // Auxiliary pool. + Vmm vmmAuxContainer[12] = {Vmm(0), Vmm(1), Vmm(2), Vmm(3), Vmm(4), Vmm(5), Vmm(6), /*AVX5*/ Vmm(16), Vmm(17), Vmm(18), Vmm(19), Vmm(20)}; + // Common. + Vmm vmmZeros = Vmm(7); + Vmm vmmSrcBeforeAxisSumB = Vmm(8); + Vmm vmmSpecIdxB = Vmm(9); + Vmm vmmSpecIdxSizeB = Vmm(10); + Vmm vmmAxisDim = Vmm(11); + Vmm vmmAxisAndAfterAxisSizeB = Vmm(12); + + // Only short. + Vmm vmmSrcAfterBatchSizeB = Vmm(13); + Vmm vmmPermIdxMask = Vmm(14); + Vmm& vmmBeforeAxDiffB = vmmAxisAndAfterAxisSizeB; + // Blocked short. + Vmm& vmmSpecIdxDiff = vmmAuxContainer[4]; + Vmm& vmmAfterAxisSize = vmmAuxContainer[5]; + Vmm vmmAfterAxisIdxB = Vmm(15); + Vmm& vmmAfterAxisPermMask = vmmPermIdxMask; + Vmm& vmmBeforeAxPermMask = vmmAuxContainer[6]; + // Only long. + Vmm vmmVecLenB = Vmm(13); + Vmm vmmIdxBatchSumB = Vmm(14); + + // XMM + Xbyak::Xmm xmmAuxContainer[6] = {Xbyak::Xmm(0), Xbyak::Xmm(1), Xbyak::Xmm(2), Xbyak::Xmm(3), Xbyak::Xmm(4), Xbyak::Xmm(16)}; + Xbyak::Xmm xmmZeros = Xbyak::Xmm(vmmZeros.getIdx()); + Xbyak::Xmm xmmSrcBeforeAxisSum = Xbyak::Xmm(vmmSrcBeforeAxisSumB.getIdx()); + Xbyak::Xmm xmmSpecIdxSizeB = Xbyak::Xmm(vmmSpecIdxSizeB.getIdx()); + Xbyak::Xmm xmmSpecIdxB = Xbyak::Xmm(vmmSpecIdxB.getIdx()); + + + void calcSrcShiftLong(Vmm* vAuxPool, bool shiftFirst = true); + void calcSrcShiftLongBlock(Vmm* vAuxPool, bool shiftFirst = true); + void calcSrcShiftShort(Vmm* vAuxPool, bool shiftFirst = true); + void calcSrcShiftShortBlock(Vmm* vAuxPool, bool shiftFirst); + void process(bool isShortIdx, bool blocked); + void process32b(bool isShortIdx, bool blocked); + void process16b(bool isShortIdx, bool blocked); + void process8b(bool isShortIdx, bool blocked); + void shiftIdxAndGather(Vmm* vAuxPool, bool isShortIdx, bool shiftFirst, bool blocked); + void tail(bool isShortIdx, bool shiftFirst = true, bool blocked = false); + // Aux functions. + void normalizeRawIndices(Vmm& rawIndices, Vmask& dstMask, Vmask& aux); + void normWithUpperBound(Vmm& vTarget, Vmm& vMax, Vmask& kAuxMask); + void fillRestWorkMask(Vmask& kMask, Vmm& vAux, const Xbyak::Reg64& rWorkRest, const Xbyak::Reg64& rAux0, const Xbyak::Reg64& rAux1); + void storeVectorPart(const Xbyak::Reg64& rDst, const Xbyak::Reg64& rToStoreCounter, Vmm& vmmSrc, Vmm& vAux); + void uniVpGatherDd(Vmm& vDst, const Xbyak::Address& srcAddr, Vmask& vMask); + void fillVlenVector(); + + const unsigned* permMask8bitUni; + const unsigned* permMask16bitUni; +}; + +} // namespace MKLDNNPlugin diff --git a/src/plugins/intel_cpu/src/nodes/mkldnn_gather_node.cpp b/src/plugins/intel_cpu/src/nodes/mkldnn_gather_node.cpp index d7b783fef8a..aa812259832 100644 --- a/src/plugins/intel_cpu/src/nodes/mkldnn_gather_node.cpp +++ b/src/plugins/intel_cpu/src/nodes/mkldnn_gather_node.cpp @@ -9,21 +9,26 @@ #include "mkldnn_gather_node.h" #include #include "common/cpu_memcpy.h" +#include +#include "kernels/gather_uni_kernel.hpp" using namespace MKLDNNPlugin; using namespace InferenceEngine; +using namespace mkldnn::impl::cpu; + +#define THROW_ERROR IE_THROW() << getTypeStr() << " node with name '" << getName() << "' " bool MKLDNNGatherNode::isSupportedOperation(const std::shared_ptr& op, std::string& errorMessage) noexcept { try { if (!one_of(op->get_type_info(), - ov::op::v7::Gather::get_type_info_static())) { - errorMessage = "Not supported Gather operation version. CPU plug-in supports only 7 version."; + ov::op::v7::Gather::get_type_info_static(), + ov::op::v8::Gather::get_type_info_static())) { + errorMessage = "Not supported Gather operation version. CPU plug-in supports only 7 and 8 versions."; return false; } - if (op->get_input_node_shared_ptr(GATHER_AXIS)->get_type_info() != ov::op::v0::Constant::get_type_info_static()) { - // TODO: Support parameterized Axis input for dynamic shapes. - errorMessage = "Only Constant operation on 'axis' input is supported."; + if (!isDynamicNgraphNode(op) && !ov::is_type(op->get_input_node_ptr(GATHER_AXIS))) { + errorMessage = "Only Constant operation on 'axis' input is supported for static node."; return false; } } catch (...) { @@ -34,79 +39,163 @@ bool MKLDNNGatherNode::isSupportedOperation(const std::shared_ptr& op, const mkldnn::engine& eng, - MKLDNNWeightsSharing::Ptr &cache) : MKLDNNNode(op, eng, cache) { + MKLDNNWeightsSharing::Ptr &cache) : MKLDNNNode(op, eng, cache), batchDims(0) { std::string errorMessage; if (!isSupportedOperation(op, errorMessage)) { IE_THROW(NotImplemented) << errorMessage; } - errorPrefix = std::string("Layer Gather with name '") + op->get_friendly_name() + "' "; if (op->get_input_size() != 3 || op->get_output_size() != 1) - IE_THROW() << errorPrefix << "has incorrect number of input/output edges!"; + THROW_ERROR << "has incorrect number of input/output edges!"; - dataSrcRank = inputShapes[GATHER_DATA].getRank(); - const auto idxRank = inputShapes[GATHER_INDEXES].getRank(); - if (dataSrcRank == 0 || idxRank == 0) - IE_THROW() << errorPrefix << "has incorrect input parameters ranks."; + const auto& dataShape = getInputShapeAtPort(GATHER_DATA); + isDataShapeStat = dataShape.isStatic(); + dataSrcRank = dataShape.getRank(); + + const auto& idxShape = getInputShapeAtPort(GATHER_INDICES); + isIdxShapeStat = idxShape.isStatic(); + const auto indicesRank = idxShape.getRank(); + if (dataSrcRank == 0lu || indicesRank == 0lu) + THROW_ERROR << "has incorrect input parameters ranks."; + + if (ov::is_type(op)) { + batchDims = static_cast(ov::as_type_ptr(op)->get_batch_dims()); + // WA for NMS->Gather construction. NMS fills part of the output blob by the -1 if these values + // must not be taken into account. There is appropriate pass that looks for such subgraphs + // and sets the dontReverseIndices flag. + const auto& rti = op->get_rt_info(); + const auto& reverse = rti.find("dontReverseIndices"); + if (reverse == rti.end()) + reverseIndexing = true; + else + reverseIndexing = false; + } else if (ov::is_type(op)) { + batchDims = static_cast(ov::as_type_ptr(op)->get_batch_dims()); + reverseIndexing = false; + } - batchDims = static_cast(ov::as_type_ptr(op)->get_batch_dims()); if (batchDims < 0) - batchDims += idxRank; - if (batchDims < 0 || batchDims >= std::min(static_cast(dataSrcRank), static_cast(idxRank))) - IE_THROW() << errorPrefix << "has incorrect batch_dims " << batchDims << "!"; + batchDims += indicesRank; + if (batchDims < 0 || batchDims >= std::min(static_cast(dataSrcRank), static_cast(indicesRank))) + THROW_ERROR << "has incorrect batch_dims " << batchDims << "!"; - if (op->get_input_node_shared_ptr(GATHER_AXIS)->get_type_info() == ov::op::v0::Constant::get_type_info_static()) { + if (ov::is_type(op->get_input_node_ptr(GATHER_AXIS))) { isAxisInputConst = true; axis = ov::as_type(op->get_input_node_ptr(GATHER_AXIS))->cast_vector()[0]; if (axis < 0) axis += dataSrcRank; if (axis < 0 || axis >= dataSrcRank || batchDims > axis) - IE_THROW() << errorPrefix << "has incorrect input parameter axis value: " << axis; + THROW_ERROR << "has incorrect input parameter axis value: " << axis; } - dataSize = getOriginalInputPrecisionAtPort(GATHER_DATA).size(); } void MKLDNNGatherNode::initSupportedPrimitiveDescriptors() { if (!supportedPrimitiveDescriptors.empty()) return; + dataTypeSize = getOriginalInputPrecisionAtPort(GATHER_DATA).size(); + + const auto& dataDims = getInputShapeAtPort(GATHER_DATA).getDims(); + if (isAxisInputConst && isDataShapeStat) { + axisDim = dataDims[axis]; + beforeAxisSize = std::accumulate(dataDims.begin(), dataDims.begin() + axis, 1lu, std::multiplies()); + betweenBatchAndAxisSize = std::accumulate(dataDims.begin() + batchDims, dataDims.begin() + axis, 1lu, std::multiplies()); + afterAxisSize = std::accumulate(dataDims.begin() + axis + 1, dataDims.end(), 1lu, std::multiplies()); + + afterAxisSizeInBytes = afterAxisSize * dataTypeSize; + axisAndAfterAxisSizeInBytes = axisDim * afterAxisSizeInBytes; + srcAfterBatchSizeInBytes = betweenBatchAndAxisSize * axisAndAfterAxisSizeInBytes; + } + if (isDataShapeStat) { + beforeBatchSize = std::accumulate(dataDims.begin(), dataDims.begin() + batchDims, 1lu, std::multiplies()); + } + if (isIdxShapeStat) { + const auto& idxDims = getInputShapeAtPort(GATHER_INDICES).getDims(); + specIndicesSize = std::accumulate(idxDims.begin() + batchDims, idxDims.end(), 1lu, std::multiplies()); + + if (isDataShapeStat) { + specIdxAndAfterAxSizeB = specIndicesSize * afterAxisSizeInBytes; + totalWork = beforeBatchSize * betweenBatchAndAxisSize * specIndicesSize * afterAxisSize; + } + } + + // Implementation desc type will be redefined in the fn prepareParams if a kernel will be created. Precision dataPrecision = getOriginalInputPrecisionAtPort(GATHER_DATA); addSupportedPrimDesc({{LayoutType::ncsp, dataPrecision}, {LayoutType::ncsp, Precision::I32}, {LayoutType::ncsp, Precision::I32, isAxisInputConst}}, {{LayoutType::ncsp, dataPrecision}}, - impl_desc_type::ref_any); + ref_any, + isDynamicNode()); } -void MKLDNNGatherNode::prepareParams() { - auto& srcMemPtr = getParentEdgeAt(GATHER_DATA)->getMemoryPtr(); - if (!srcMemPtr || !srcMemPtr->GetPrimitivePtr()) - IE_THROW() << errorPrefix << " has not allocated input memory."; - if (getSelectedPrimitiveDescriptor() == nullptr) - IE_THROW() << errorPrefix << " has unidentified preferable primitive descriptor."; +void MKLDNNGatherNode::createPrimitive() { + uint64_t idxElPerVec = 1; + if (!isDynamicNode()) { + idxElPerVec = x64::mayiuse(x64::avx512_common) ? x64::cpu_isa_traits::vlen / idxTypeSize : + x64::mayiuse(x64::avx2) ? x64::cpu_isa_traits::vlen / idxTypeSize : 1; + } + // Gather instruction is not supported by SSE. + if ((x64::mayiuse(x64::avx512_common) || x64::mayiuse(x64::avx2)) && + (isDynamicNode() || afterAxisSize == 1 || (afterAxisSize <= idxElPerVec && + (x64::mayiuse(x64::avx512_common) || (x64::mayiuse(x64::avx2) && dataTypeSize == 4))))) { + jGatherConfParams jcp; + jcp.dataTypeSize = dataTypeSize; + jcp.reverseIndexing = reverseIndexing; + jcp.dynamicShapes = isDynamicNode(); + jcp.batchDims = batchDims; + if (!jcp.dynamicShapes) { + jcp.beforeAxisSize = beforeAxisSize; + jcp.specIdxSize = specIndicesSize; + jcp.afterAxisSize = afterAxisSize; + } else { + if (isDataShapeStat && isAxisInputConst) { + jcp.beforeAxisSize = beforeAxisSize; + jcp.afterAxisSize = afterAxisSize; + } + if (isIdxShapeStat) { + jcp.specIdxSize = specIndicesSize; + } + } - const auto& srcDims = srcMemPtr->getStaticDims(); - const auto& idxDims = getParentEdgeAt(GATHER_INDEXES)->getMemory().getStaticDims(); - const auto& dstDims = getChildEdgesAtPort(0)[0]->getMemory().getStaticDims(); + if (x64::mayiuse(x64::avx512_common)) { + jitKernel.reset(new jitUniGatherKernel(jcp)); + } else if (x64::mayiuse(x64::avx2)) { + jitKernel.reset(new jitUniGatherKernel(jcp)); + } + if (jitKernel) { + jitKernel->create_ker(); - if (!isAxisInputConst) { - axis = (reinterpret_cast(getParentEdgeAt(GATHER_AXIS)->getMemoryPtr()->GetPtr()))[0]; - if (axis < 0) - axis += dataSrcRank; - if (axis < 0 || axis >= dataSrcRank || batchDims > axis) - IE_THROW() << errorPrefix << "has incorrect input parameter axis value: " << axis; + if (!isDynamicNode()) { + const uint64_t dataElPerVec = jitKernel->getDataElPerVec(); + const uint64_t nthr = parallel_get_max_threads(); + const uint64_t wpt = ((totalWork / dataElPerVec) / nthr + 1) * dataElPerVec; + execParamsPerThread.resize(nthr); + + parallel_nt(nthr, [&](const int ithr, const int nthr) { + const uint64_t dstStart = std::min(wpt * ithr, totalWork); + const uint64_t dstEnd = std::min(wpt * (ithr + 1), totalWork); + + auto& p = execParamsPerThread[ithr]; + p.workAmount = dstEnd - dstStart; + p.dstStart = dstStart; + p.specIdxInBytes.resize(dataElPerVec); + p.idxBatchSumInBytes.resize(dataElPerVec); + p.dataBeforeAxisSumInBytes.resize(dataElPerVec); + p.betweenBatchAndAxisIter = (dstStart / specIndicesSize) % betweenBatchAndAxisSize; + for (uint64_t j = 0lu; j < dataElPerVec; j++) { + p.specIdxInBytes[j] = (((dstStart + j) / afterAxisSize) % specIndicesSize) * idxTypeSize; + p.idxBatchSumInBytes[j] = ((dstStart + j) / (betweenBatchAndAxisSize * specIndicesSize * afterAxisSize)) * + specIndicesSize * idxTypeSize; + p.dataBeforeAxisSumInBytes[j] = ((dstStart + j) / (specIndicesSize * afterAxisSize)) * axisAndAfterAxisSizeInBytes; + } + initShortParams(p, dstStart); + }); + } + } } - indexRange = srcDims[axis]; - batchSize = std::accumulate(srcDims.begin(), srcDims.begin() + batchDims, 1, std::multiplies()); - outerSize = std::accumulate(srcDims.begin() + batchDims, srcDims.begin() + axis, 1, std::multiplies()); - dataLength = std::accumulate(srcDims.begin() + axis + 1, srcDims.end(), 1, std::multiplies()); - srcBatchStride = std::accumulate(srcDims.begin() + batchDims, srcDims.end(), 1, std::multiplies()); - idxBatchStride = std::accumulate(idxDims.begin() + batchDims, idxDims.end(), 1, std::multiplies()); - dstBatchStride = std::accumulate(dstDims.begin() + batchDims, dstDims.end(), 1, std::multiplies()); - len = dataLength * dataSize; - if (dataLength == 0) - IE_THROW() << errorPrefix << "had incorrect input parameters dimension!"; + MKLDNNNode::createPrimitive(); } bool MKLDNNGatherNode::needPrepareParams() const { @@ -116,32 +205,275 @@ bool MKLDNNGatherNode::needPrepareParams() const { return result; } +void MKLDNNGatherNode::prepareParams() { + auto& dataMemPtr = getParentEdgeAt(GATHER_DATA)->getMemoryPtr(); + if (!dataMemPtr || !dataMemPtr->GetPrimitivePtr()) + THROW_ERROR << " has not allocated input data memory."; + auto& idxMemPtr = getParentEdgeAt(GATHER_INDICES)->getMemoryPtr(); + if (!idxMemPtr || !idxMemPtr->GetPrimitivePtr()) + THROW_ERROR << " has not allocated input indices memory."; + if (getSelectedPrimitiveDescriptor() == nullptr) + THROW_ERROR << " has unidentified preferable primitive descriptor."; + + if (!isAxisInputConst) { + axis = (reinterpret_cast(getParentEdgeAt(GATHER_AXIS)->getMemoryPtr()->GetPtr()))[0]; + if (axis < 0) + axis += dataSrcRank; + if (axis < 0 || axis >= dataSrcRank || batchDims > axis) + THROW_ERROR << "has incorrect input parameter axis value: " << axis; + } + + if (!isDataShapeStat || !isAxisInputConst) { + const auto& dataDims = dataMemPtr->getStaticDims(); + axisDim = dataDims[axis]; + beforeBatchSize = std::accumulate(dataDims.begin(), dataDims.begin() + batchDims, 1lu, std::multiplies()); + betweenBatchAndAxisSize = std::accumulate(dataDims.begin() + batchDims, dataDims.begin() + axis, 1lu, std::multiplies()); + afterAxisSize = std::accumulate(dataDims.begin() + axis + 1, dataDims.end(), 1lu, std::multiplies()); + + afterAxisSizeInBytes = afterAxisSize * dataTypeSize; + axisAndAfterAxisSizeInBytes = axisDim * afterAxisSizeInBytes; + srcAfterBatchSizeInBytes = betweenBatchAndAxisSize * axisAndAfterAxisSizeInBytes; + + if (isIdxShapeStat) { + specIdxAndAfterAxSizeB = specIndicesSize * afterAxisSizeInBytes; + totalWork = beforeBatchSize * betweenBatchAndAxisSize * specIndicesSize * afterAxisSize; + } + } + + if (!isIdxShapeStat) { + const auto& idxDims = idxMemPtr->getStaticDims(); + specIndicesSize = std::accumulate(idxDims.begin() + batchDims, idxDims.end(), 1lu, std::multiplies()); + + specIdxAndAfterAxSizeB = specIndicesSize * afterAxisSizeInBytes; + totalWork = beforeBatchSize * betweenBatchAndAxisSize * specIndicesSize * afterAxisSize; + } + + const auto& selectedPD = getSelectedPrimitiveDescriptor(); + if (jitKernel && jitKernel->isSupportedConfiguration(afterAxisSize)) { + if (x64::mayiuse(x64::avx512_common)) { + selectedPD->setImplementationType(jit_avx512); + } else if (x64::mayiuse(x64::avx2)) { + selectedPD->setImplementationType(jit_avx2); + } + } else { + selectedPD->setImplementationType(ref_any); + } +} + void MKLDNNGatherNode::execute(mkldnn::stream strm) { - const int32_t* srcIndexes = reinterpret_cast(getParentEdgeAt(GATHER_INDEXES)->getMemoryPtr()->GetPtr()); + if (jitKernel && jitKernel->isSupportedConfiguration(afterAxisSize)) { + const void* srcIndices = getParentEdgeAt(GATHER_INDICES)->getMemoryPtr()->GetPtr(); + const void* srcData = getParentEdgeAt(GATHER_DATA)->getMemoryPtr()->GetPtr(); + uint8_t* dstData = reinterpret_cast(getChildEdgeAt(0)->getMemoryPtr()->GetPtr()); + + const uint64_t dataElPerVec = jitKernel->getDataElPerVec(); + + auto threadBody = [&](const int ithr, const int nthr) { + auto& p = execParamsPerThread[ithr]; + auto arg = gatherJitExecArgs(); + + arg.src = srcData; + arg.dst = dstData + p.dstStart * dataTypeSize; + arg.indices = srcIndices; + arg.start = &p.dstStart; + arg.axisDim = &axisDim; + arg.afterAxSize = afterAxisSize; + arg.axisAndAfterAxisSizeB = &axisAndAfterAxisSizeInBytes; + arg.srcAfterBatchSizeB = &srcAfterBatchSizeInBytes; + arg.betweenBatchAndAxisSize = &betweenBatchAndAxisSize; + arg.specIndicesSize = &specIndicesSize; + arg.workAmount = p.workAmount; + arg.specIdxB = p.specIdxInBytes.data(); + arg.idxBatchSumB = p.idxBatchSumInBytes.data(); + arg.dataBeforeAxisSumB = p.dataBeforeAxisSumInBytes.data(); + arg.betweenBatchAndAxisIter = p.betweenBatchAndAxisIter; + + const uint64_t idxElPerVec = jitKernel->getIdxElPerVec(); + + if (afterAxisSize == 1 && specIndicesSize < idxElPerVec) { // Elementwise short case. + arg.permIdxMask = p.permIdxMask.data(); + arg.beforeAxisDiff = p.srcBeforeAxisDiff.data(); + } else if (afterAxisSize > 1 && afterAxisSize <= dataElPerVec) { // Blocked short case. + arg.afterAxIdxB = p.afterAxIdxInBytes.data(); + arg.specIdxDiff = p.specIdxDiff.data(); + arg.beforeAxisDiff = p.srcBeforeAxisDiff.data(); + arg.beforeAxisPermMask = p.beforeAxPermMask.data(); + arg.afterAxisPermMask = p.afterAxPermMask.data(); + arg.afterAxisSize = &afterAxisSize; + arg.specIdxAndAfterAxIterB = p.specIdxAndAfterAxIterB; + arg.specIdxAndAfterAxSizeB = specIdxAndAfterAxSizeB; + } + + (*jitKernel)(&arg); + }; + + parallel_nt(0, threadBody); + } else { + execReference(); + } +} + +void MKLDNNGatherNode::executeDynamicImpl(mkldnn::stream strm) { + if (jitKernel && jitKernel->isSupportedConfiguration(afterAxisSize)) { + const void* srcIndices = getParentEdgeAt(GATHER_INDICES)->getMemoryPtr()->GetPtr(); + const void* srcData = getParentEdgeAt(GATHER_DATA)->getMemoryPtr()->GetPtr(); + uint8_t* dstData = reinterpret_cast(getChildEdgeAt(0)->getMemoryPtr()->GetPtr()); + + const uint64_t dataElPerVec = jitKernel->getDataElPerVec(); + + auto threadBody = [&](const int ithr, const int nthr) { + const uint64_t wpt = ((totalWork / dataElPerVec) / nthr + 1) * dataElPerVec; + const uint64_t start = std::min(wpt * ithr, totalWork); + const uint64_t end = std::min(wpt * (ithr + 1), totalWork); + const uint64_t workAmount = end - start; + + auto arg = gatherJitExecArgs(); + + arg.src = srcData; + arg.dst = dstData + afterAxisSizeInBytes * start; + arg.indices = srcIndices; + arg.start = &start; + arg.axisDim = &axisDim; + arg.afterAxSize = afterAxisSize; + arg.axisAndAfterAxisSizeB = &axisAndAfterAxisSizeInBytes; + arg.srcAfterBatchSizeB = &srcAfterBatchSizeInBytes; + arg.betweenBatchAndAxisSize = &betweenBatchAndAxisSize; + arg.specIndicesSize = &specIndicesSize; + arg.workAmount = workAmount; + + const uint64_t idxElPerVec = jitKernel->getIdxElPerVec(); + int permIdxMask[16]; + int beforeAxisDiff[16]; + if (afterAxisSize == 1 && specIndicesSize < idxElPerVec) { + permIdxMask[0] = idxElPerVec - specIndicesSize; + int div = idxElPerVec / specIndicesSize; + int remainder = idxElPerVec % specIndicesSize; + for (int i = 1; i < idxElPerVec; i++) { + permIdxMask[i] = permIdxMask[i - 1] + 1; + if (permIdxMask[i] == idxElPerVec) + permIdxMask[i] = idxElPerVec - specIndicesSize; + } + for (int i = 0; i < idxElPerVec; i++) { + if (((start + i) % specIndicesSize) < (specIndicesSize - remainder)) + beforeAxisDiff[i] = axisDim * div; + else + beforeAxisDiff[i] = axisDim * (div + 1); + } + arg.permIdxMask = permIdxMask; + arg.beforeAxisDiff = beforeAxisDiff; + } + + (*jitKernel)(&arg); + }; + + parallel_nt(0, threadBody); + } else { + execReference(); + } +} + +void MKLDNNGatherNode::initShortParams(threadExecParams& p, const uint64_t start) { + if (!jitKernel) + THROW_ERROR << "has uninitialized kernel in function initShortParams."; + const uint64_t idxElPerVec = jitKernel->getIdxElPerVec(); + + if (afterAxisSize == 1) { // Elementwise gather. + if (specIndicesSize >= idxElPerVec) + return; // Is not a short case. + + p.permIdxMask.resize(idxElPerVec); + p.srcBeforeAxisDiff.resize(idxElPerVec); + + p.permIdxMask[0] = idxElPerVec - specIndicesSize; + for (int i = 1; i < idxElPerVec; i++) { + p.permIdxMask[i] = p.permIdxMask[i - 1] + 1; + if (p.permIdxMask[i] == idxElPerVec) + p.permIdxMask[i] = idxElPerVec - specIndicesSize; + } + + const int div = idxElPerVec / specIndicesSize; + const int remainder = idxElPerVec % specIndicesSize; + for (uint64_t i = 0; i < idxElPerVec; i++) { + if (((start + i) % specIndicesSize) < (specIndicesSize - remainder)) { + p.srcBeforeAxisDiff[i] = axisDim * div; + } else { + p.srcBeforeAxisDiff[i] = axisDim * (div + 1); + } + } + } else { // Blocked gather. + if (afterAxisSize > idxElPerVec) + return; // Is not a short case. + + p.afterAxIdxInBytes.resize(idxElPerVec); + p.afterAxPermMask.resize(idxElPerVec); + p.beforeAxPermMask.resize(idxElPerVec); + p.specIdxDiff.resize(idxElPerVec); + p.srcBeforeAxisDiff.resize(idxElPerVec); + + int secondStart = start + idxElPerVec; + for (int i = 0; i < idxElPerVec; i++) { + p.afterAxIdxInBytes[i] = (start + i) % afterAxisSize; + p.specIdxDiff[i] = (((secondStart + i) / afterAxisSize) % specIndicesSize) * idxTypeSize - p.specIdxInBytes[i]; + if (p.specIdxDiff[i] < 0) + p.specIdxDiff[i] += specIndicesSize * idxTypeSize; + p.srcBeforeAxisDiff[i] = ((start + i + idxElPerVec) / (specIndicesSize * afterAxisSize)) * axisAndAfterAxisSizeInBytes - + ((start + i) / (specIndicesSize * afterAxisSize)) * axisAndAfterAxisSizeInBytes; + + p.afterAxIdxInBytes[i] *= dataTypeSize; + p.afterAxPermMask[i] = idxElPerVec - afterAxisSize + i; + for (size_t j = 0lu; j < 6lu; j++) { + if (p.afterAxPermMask[i] >= idxElPerVec) + p.afterAxPermMask[i] -= afterAxisSize; + } + } + if (specIndicesSize * afterAxisSize < idxElPerVec) { + p.beforeAxPermMask[0] = idxElPerVec - specIndicesSize * afterAxisSize; + for (int i = 1; i < idxElPerVec; i++) { + p.beforeAxPermMask[i] = p.beforeAxPermMask[i - 1] + 1; + if (p.beforeAxPermMask[i] == idxElPerVec) + p.beforeAxPermMask[i] = idxElPerVec - specIndicesSize * afterAxisSize; + } + } + + p.specIdxAndAfterAxIterB = (start * dataTypeSize) % specIdxAndAfterAxSizeB; + } +} + +void MKLDNNGatherNode::execReference() { + const int32_t* srcIndices = reinterpret_cast(getParentEdgeAt(GATHER_INDICES)->getMemoryPtr()->GetPtr()); const uint8_t* srcData = reinterpret_cast(getParentEdgeAt(GATHER_DATA)->getMemoryPtr()->GetPtr()); uint8_t* dstData = reinterpret_cast(getChildEdgeAt(0)->getMemoryPtr()->GetPtr()); - parallel_for2d(batchSize, idxBatchStride, [&](const size_t i, const size_t j) { - const unsigned int idx = static_cast(srcIndexes[i * idxBatchStride + j]); + const size_t dstIdxAndAfterAxisSize = afterAxisSizeInBytes * specIndicesSize; + const size_t dstAfterBatchSize = betweenBatchAndAxisSize * dstIdxAndAfterAxisSize; + parallel_for2d(beforeBatchSize, specIndicesSize, [&](const size_t b, const size_t j) { + int ii = srcIndices[b * specIndicesSize + j]; + if (ii < 0) { + if (reverseIndexing) + ii += axisDim; + else + ii = axisDim; + } + const size_t idx = ii; + const size_t c2 = dstAfterBatchSize * b + afterAxisSizeInBytes * j; + if (idx < axisDim) { + size_t c1 = srcAfterBatchSizeInBytes * b + afterAxisSizeInBytes * idx; + for (size_t i = 0; i < betweenBatchAndAxisSize; i++) { + size_t srcIdx = c1 + axisAndAfterAxisSizeInBytes * i; + size_t dstIdx = c2 + dstIdxAndAfterAxisSize * i; - // while negative indices are not supported, should set zero - if (idx < indexRange) { - for (size_t k = 0; k < outerSize; ++k) { - const size_t srcStride = (i * srcBatchStride + k * dataLength * indexRange) * dataSize; - const size_t dstStride = (i * dstBatchStride + k * dataLength * idxBatchStride) * dataSize; - - cpu_memcpy(&dstData[dstStride + j * len], &srcData[srcStride + idx * len], len); + cpu_memcpy(&dstData[dstIdx], &srcData[srcIdx], afterAxisSizeInBytes); } } else { - for (size_t k = 0; k < outerSize; ++k) { - memset(&dstData[(i * dstBatchStride + k * dataLength * idxBatchStride) * dataSize + j * len], 0, len); + for (size_t i = 0; i < betweenBatchAndAxisSize; i++) { + memset(&dstData[c2 + dstIdxAndAfterAxisSize * i], 0, afterAxisSizeInBytes); } } }); } -void MKLDNNGatherNode::executeDynamicImpl(mkldnn::stream strm) { - execute(strm); +std::vector MKLDNNGatherNode::shapeInfer() const { + return MKLDNNNode::shapeInferGeneric(PortMask(1, 2, 3)); } bool MKLDNNGatherNode::created() const { diff --git a/src/plugins/intel_cpu/src/nodes/mkldnn_gather_node.h b/src/plugins/intel_cpu/src/nodes/mkldnn_gather_node.h index 831e679a28e..09b5b686c79 100644 --- a/src/plugins/intel_cpu/src/nodes/mkldnn_gather_node.h +++ b/src/plugins/intel_cpu/src/nodes/mkldnn_gather_node.h @@ -5,6 +5,7 @@ #pragma once #include +#include "kernels/gather_uni_kernel.hpp" #include #include @@ -18,37 +19,71 @@ public: void getSupportedDescriptors() override {}; void initSupportedPrimitiveDescriptors() override; + void createPrimitive() override; void execute(mkldnn::stream strm) override; bool created() const override; static bool isSupportedOperation(const std::shared_ptr& op, std::string& errorMessage) noexcept; + struct threadExecParams { + std::vector specIdxInBytes; + std::vector permIdxMask; + std::vector srcBeforeAxisDiff; + std::vector idxBatchSumInBytes; + std::vector dataBeforeAxisSumInBytes; + + std::vector afterAxIdxInBytes; + std::vector specIdxDiff; + std::vector beforeAxPermMask; + std::vector afterAxPermMask; + int betweenBatchAndAxisIter = 0; + int specIdxAndAfterAxIterB = 0; + + uint64_t workAmount = 0; + uint64_t dstStart = 0; + }; + protected: void executeDynamicImpl(mkldnn::stream strm) override; bool needPrepareParams() const override; void prepareParams() override; + std::vector shapeInfer() const override; private: - int axis = 0; - int batchDims = 0; + void initShortParams(threadExecParams& p, uint64_t start); + void execReference(); - size_t indexRange = 0; - size_t batchSize = 1; - size_t outerSize = 1; - size_t dataLength = 1; - size_t srcBatchStride = 1; - size_t idxBatchStride = 1; - size_t dstBatchStride = 1; - size_t dataSize = 1; - size_t len = 1; - int dataSrcRank = 1; + bool isDataShapeStat = false; + bool isIdxShapeStat = false; bool isAxisInputConst = false; + bool reverseIndexing = false; + + uint64_t dataTypeSize = 1lu; + static constexpr uint64_t idxTypeSize = sizeof(int); + + int axis = 0; + int axisDim; + int batchDims = 0; + int dataSrcRank = 1; + uint64_t specIndicesSize; + uint64_t beforeBatchSize; + uint64_t beforeAxisSize; + uint64_t betweenBatchAndAxisSize; + uint64_t afterAxisSize = 0lu; + uint64_t afterAxisSizeInBytes = 0lu; + uint64_t axisAndAfterAxisSizeInBytes = 0lu; + uint64_t srcAfterBatchSizeInBytes = 0lu; + uint64_t specIdxAndAfterAxSizeB = 0lu; + uint64_t totalWork; + + std::vector execParamsPerThread; + static constexpr size_t GATHER_DATA = 0; - static constexpr size_t GATHER_INDEXES = 1; + static constexpr size_t GATHER_INDICES = 1; static constexpr size_t GATHER_AXIS = 2; - std::string errorPrefix; + std::shared_ptr jitKernel; }; } // namespace MKLDNNPlugin diff --git a/src/tests/functional/plugin/cpu/shared_tests_instances/single_layer_tests/gather.cpp b/src/tests/functional/plugin/cpu/shared_tests_instances/single_layer_tests/gather.cpp index 236a05bce80..e6b971516e3 100644 --- a/src/tests/functional/plugin/cpu/shared_tests_instances/single_layer_tests/gather.cpp +++ b/src/tests/functional/plugin/cpu/shared_tests_instances/single_layer_tests/gather.cpp @@ -100,6 +100,8 @@ const std::vector> inputShapes4D = { const std::vector> indicesShapes_BD0 = { std::vector{4}, std::vector{2, 2}, + std::vector{3, 3}, + std::vector{5, 2}, std::vector{3, 2, 4}, }; @@ -122,7 +124,8 @@ const auto gather7ParamsSubset_BD0 = testing::Combine( testing::Values(CommonTestUtils::DEVICE_CPU) ); -INSTANTIATE_TEST_SUITE_P(smoke_Gather7_BD0, Gather7LayerTest, gather7ParamsSubset_BD0, Gather7LayerTest::getTestCaseName); +INSTANTIATE_TEST_SUITE_P(smoke_BD0, Gather7LayerTest, gather7ParamsSubset_BD0, Gather7LayerTest::getTestCaseName); +INSTANTIATE_TEST_SUITE_P(smoke_BD0, Gather8LayerTest, gather7ParamsSubset_BD0, Gather8LayerTest::getTestCaseName); const std::vector> indicesShapes_BD1 = { std::vector{4, 2}, @@ -205,4 +208,78 @@ const auto gather7ParamsSubset_NegativeBD = testing::Combine( INSTANTIATE_TEST_SUITE_P(smoke_Gather7_NegativeBD, Gather7LayerTest, gather7ParamsSubset_NegativeBD, Gather7LayerTest::getTestCaseName); + +///// GATHER-8 ///// + +const std::vector> dataShapes4DGather8 = { + {10, 3, 1, 2}, + {10, 3, 3, 1}, + {10, 2, 2, 7}, + {10, 2, 2, 2}, + {10, 3, 4, 4}, + {10, 2, 3, 17} +}; +const std::vector> idxShapes4DGather8 = { + {10, 1, 1}, + {10, 1, 2}, + {10, 1, 3}, + {10, 2, 2}, + {10, 1, 7}, + {10, 2, 4}, + {10, 3, 3}, + {10, 3, 5}, + {10, 7, 3}, + {10, 8, 7} +}; +const std::vector> axesBatches4DGather8 = { + {3, 0}, + {-1, -2}, + {2, -3}, + {2, 1}, + {1, 0}, + {1, 1}, + {0, 0} +}; + +INSTANTIATE_TEST_CASE_P(smoke_static_4D, Gather8LayerTest, + testing::Combine( + testing::ValuesIn(dataShapes4DGather8), + testing::ValuesIn(idxShapes4DGather8), + testing::ValuesIn(axesBatches4DGather8), + testing::ValuesIn(netPrecisions), + testing::Values(InferenceEngine::Precision::UNSPECIFIED), + testing::Values(InferenceEngine::Precision::UNSPECIFIED), + testing::Values(InferenceEngine::Layout::ANY), + testing::Values(InferenceEngine::Layout::ANY), + testing::Values(CommonTestUtils::DEVICE_CPU)), + Gather8LayerTest::getTestCaseName); + +const auto gatherParamsVec2 = testing::Combine( + testing::ValuesIn(std::vector>({{5, 4}, {11, 4}, {23, 4}, {35, 4}, {51, 4}, {71, 4}})), + testing::ValuesIn(std::vector>({{1}})), + testing::ValuesIn(std::vector>{std::tuple{1, 0}}), + testing::ValuesIn(netPrecisions), + testing::Values(InferenceEngine::Precision::UNSPECIFIED), + testing::Values(InferenceEngine::Precision::UNSPECIFIED), + testing::Values(InferenceEngine::Layout::ANY), + testing::Values(InferenceEngine::Layout::ANY), + testing::Values(CommonTestUtils::DEVICE_CPU) +); + +INSTANTIATE_TEST_CASE_P(smoke_Vec2, Gather8LayerTest, gatherParamsVec2, Gather8LayerTest::getTestCaseName); + +const auto gatherParamsVec3 = testing::Combine( + testing::ValuesIn(std::vector>({{4, 4}})), + testing::ValuesIn(std::vector>({{5}, {11}, {21}, {35}, {55}, {70}})), + testing::ValuesIn(std::vector>{std::tuple{1, 0}}), + testing::ValuesIn(netPrecisions), + testing::Values(InferenceEngine::Precision::UNSPECIFIED), + testing::Values(InferenceEngine::Precision::UNSPECIFIED), + testing::Values(InferenceEngine::Layout::ANY), + testing::Values(InferenceEngine::Layout::ANY), + testing::Values(CommonTestUtils::DEVICE_CPU) +); + +INSTANTIATE_TEST_CASE_P(smoke_Vec3, Gather8LayerTest, gatherParamsVec3, Gather8LayerTest::getTestCaseName); + } // namespace diff --git a/src/tests/functional/plugin/cpu/shared_tests_instances/skip_tests_config.cpp b/src/tests/functional/plugin/cpu/shared_tests_instances/skip_tests_config.cpp index e4336236b83..d7451c9b656 100644 --- a/src/tests/functional/plugin/cpu/shared_tests_instances/skip_tests_config.cpp +++ b/src/tests/functional/plugin/cpu/shared_tests_instances/skip_tests_config.cpp @@ -69,8 +69,6 @@ std::vector disabledTestPatterns() { // TODO: 57562 No dynamic output shape support R"(.*NonZeroLayerTest.*)", - // TODO: 69084 Not constant Axis input produces dynamic output shape. - R"(.*GatherLayerTestCPU.*constAx=False.*)", // TODO: 74961. Enforce precision via inType and outType does not work properly. R"(.*(RNN|GRU|LSTM).*ENFORCE_BF16=YES.*)", // Not expected behavior diff --git a/src/tests/functional/plugin/cpu/single_layer_tests/gather.cpp b/src/tests/functional/plugin/cpu/single_layer_tests/gather.cpp index 8a6b6f1f064..280e76e3305 100644 --- a/src/tests/functional/plugin/cpu/single_layer_tests/gather.cpp +++ b/src/tests/functional/plugin/cpu/single_layer_tests/gather.cpp @@ -1,345 +1,548 @@ -//// Copyright (C) 2018-2022 Intel Corporation -//// SPDX-License-Identifier: Apache-2.0 -//// +// Copyright (C) 2018-2022 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 // -//#include -//#include "ngraph_functions/builders.hpp" -//#include "test_utils/cpu_test_utils.hpp" -// -//using namespace InferenceEngine; -//using namespace CPUTestUtils; -// -//namespace CPULayerTestsDefinitions { -// -//using inputShapesPair = std::pair, std::vector>>; -// -//typedef std::tuple< -// inputShapesPair, // Input shapes -// int64_t, // Axis -// int64_t, // Batch dims -// InferenceEngine::Precision, // Network precision -// bool, // Is axis input constant -// std::string, // Device name -// CPUSpecificParams // CPU specific params -//> GatherLayerTestCPUParams; -// -//class GatherLayerTestCPU : public testing::WithParamInterface, -// virtual public LayerTestsUtils::LayerTestsCommon, public CPUTestsBase { -//public: -// static std::string getTestCaseName(testing::TestParamInfo obj) { -// inputShapesPair inputShapes; -// int axis, batchDims; -// Precision netPrecision; -// std::string targetDevice; -// bool isAxisConstant; -// CPUSpecificParams cpuParams; -// std::tie(inputShapes, axis, batchDims, netPrecision, isAxisConstant, targetDevice, cpuParams) = obj.param; -// -// std::ostringstream result; -// result << "DynShapes=" << CommonTestUtils::partialShape2str(inputShapes.first) << "_"; -// result << "StatShapes=" << CommonTestUtils::vec2str(inputShapes.second) << "_"; -// result << "axis=" << axis << "_"; -// result << "batchDims=" << batchDims << "_"; -// result << "netPrc=" << netPrecision.name() << "_"; -// result << "constAx=" << (isAxisConstant ? "True" : "False") << "_"; -// result << "trgDev=" << targetDevice; -// result << CPUTestsBase::getTestCaseName(cpuParams); -// -// return result.str(); -// } -// -//protected: -// void SetUp() override { -// inputShapesPair inputShapes; -// int64_t batchDims; -// Precision netPrecision; -// CPUSpecificParams cpuParams; -// bool isAxisConstant = true; -// std::tie(inputShapes, axis, batchDims, netPrecision, isAxisConstant, targetDevice, cpuParams) = this->GetParam(); -// -// std::tie(inFmts, outFmts, priority, selectedType) = cpuParams; -// -// selectedType = std::string("ref_any_") + netPrecision.name(); -// -// targetStaticShapes.reserve(inputShapes.second.size()); -// inputDynamicShapes.reserve(inputShapes.first.size()); -// for (int i = 0; i < (isAxisConstant ? 2 : 3); i++) { -// if (inputShapes.second.size() > i) -// targetStaticShapes.push_back({inputShapes.second[i]}); -// if (inputShapes.first.size() > i) -// inputDynamicShapes.push_back(inputShapes.first[i]); -// } -// const ov::Shape& inputDataShape = targetStaticShapes.front().front(), indicesShape = targetStaticShapes.front()[1]; -// dataSrcRank = inputDataShape.size(); -// -// const auto ngPrc = FuncTestUtils::PrecisionUtils::convertIE2nGraphPrc(netPrecision); -// ov::ParameterVector functionParams { -// ngraph::builder::makeParams(ngPrc, { {"data", inputDataShape} })[0], -// ngraph::builder::makeParams(ov::element::i32, { {"indices", indicesShape} })[0] -// }; -// if (!isAxisConstant) { -// functionParams.push_back(ngraph::builder::makeParams(ov::element::i32, { {"axis", {1}} })[0]); -// } -// auto paramOuts = ngraph::helpers::convert2OutputVector(ngraph::helpers::castOps2Nodes(functionParams)); -// std::shared_ptr gatherNode; -// if (isAxisConstant) { -// gatherNode = std::make_shared(paramOuts[0], paramOuts[1], -// ov::op::v0::Constant::create(ov::element::i64, ov::Shape({}), { axis }), batchDims); -// } else { -// gatherNode = std::make_shared(paramOuts[0], paramOuts[1], paramOuts[2], batchDims); -// } -// -// ov::ResultVector results{ std::make_shared(gatherNode) }; -// function = std::make_shared(results, functionParams, "Gather"); -// } -// -// InferenceEngine::Blob::Ptr GenerateInput(const InferenceEngine::InputInfo &inputInfo) const override { -// if (inputInfo.name() == "indices") { -// const auto& td = inputInfo.getTensorDesc(); -// size_t normAxis = axis < 0 ? axis + dataSrcRank : axis; -// const auto axDim = targetStaticShapes[index][0][normAxis]; -// if (axDim == 1) { -// // Random generator cannot generate values in range [0; 0] -// int values[1] = { 0 }; -// return FuncTestUtils::createAndFillBlobWithFloatArray(td, values, 1); -// } else { -// return FuncTestUtils::createAndFillBlob(td, axDim - 1, 0); -// } -// } else if (inputInfo.name() == "axis") { -// int values[1] = { static_cast(axis) }; -// return FuncTestUtils::createAndFillBlobWithFloatArray(inputInfo.getTensorDesc(), values, 1); -// } else { -// return LayerTestsCommon::GenerateInput(inputInfo); -// } -// } -// -// int64_t axis = 0; -// int64_t dataSrcRank = 0; -//}; -// -//TEST_P(GatherLayerTestCPU, CompareWithRefs) { -// SKIP_IF_CURRENT_TEST_IS_DISABLED() -// -// Run(); -// CheckPluginRelatedResults(executableNetwork, "Gather"); -//} -// -//namespace { -//const std::vector netPrecisions = { -// InferenceEngine::Precision::FP32, -// InferenceEngine::Precision::BF16, -// InferenceEngine::Precision::I8 -//}; -// -//// 1D -//const std::vector staticInputShapes1D = { -// { -// {}, -// { // Static shapes -// {{4}, {2, 3, 4}} -// } -// }, -// { -// {}, -// { // Static shapes -// {{4}, {1}} -// } -// }, -// { -// {}, -// { // Static shapes -// {{4}, {9}} -// } -// }, -// { -// {}, -// { // Static shapes -// {{5}, {5}} -// } -// } -//}; -//const std::vector dynamicInputShapes1D = { -// { -// { // Origin dynamic shapes -// {ov::Dimension(4, 6)}, {ov::Dimension(1, 10)}, {ov::Dimension(1, 2)} -// }, -// { // Dynamic shapes instances -// {{4}, {1}, {1}}, -// {{4}, {9}, {1}}, -// {{5}, {5}, {1}} -// } -// } -//}; -// -//INSTANTIATE_TEST_SUITE_P(smoke_StaticShape1D, GatherLayerTestCPU, -// ::testing::Combine( -// ::testing::ValuesIn(staticInputShapes1D), -// ::testing::Values(0), -// ::testing::Values(0), -// ::testing::ValuesIn(netPrecisions), -// ::testing::Values(true), -// ::testing::Values(CommonTestUtils::DEVICE_CPU), -// ::testing::Values(CPUSpecificParams{})), -// GatherLayerTestCPU::getTestCaseName); -// -//INSTANTIATE_TEST_SUITE_P(smoke_DynamicShape1D, GatherLayerTestCPU, -// ::testing::Combine( -// ::testing::ValuesIn(dynamicInputShapes1D), -// ::testing::Values(0), -// ::testing::Values(0), -// ::testing::ValuesIn(netPrecisions), -// ::testing::Values(true, false), -// ::testing::Values(CommonTestUtils::DEVICE_CPU), -// ::testing::Values(CPUSpecificParams{})), -// GatherLayerTestCPU::getTestCaseName); -// -//// 2D -//const std::vector staticInputShapes2D = { -// { -// {}, -// { // Static shapes -// {{4, 7}, {4, 55}} -// } -// }, -// { -// {}, -// { // Static shapes -// {{4, 17}, {4, 17}} -// } -// }, -// { -// {}, -// { // Static shapes -// {{4, 55}, {4, 7}} -// } -// } -//}; -//const std::vector dynamicInputShapes2D = { -// { -// { // Origin dynamic shapes -// {4, ov::Dimension(3, 99)}, -// {4, ov::Dimension(3, 99)}, -// {1} -// }, -// { // Dynamic shapes instances -// {{4, 7}, {4, 55}, {1}}, -// {{4, 55}, {4, 7}, {1}}, -// {{4, 17}, {4, 17}, {1}} -// } -// } -//}; -//const std::vector dynamicInputShapes2Dv2 = { -// { -// { // Origin dynamic shapes -// {ov::Dimension(3, 99), ov::Dimension(3, 99)}, -// {-1, ov::Dimension(3, 99)}, -// {1} -// }, -// { // Dynamic shapes instances -// {{4, 7}, {4, 55}, {1}}, -// {{8, 55}, {5, 7}, {1}} -// } -// } -//}; -// -//INSTANTIATE_TEST_SUITE_P(smoke_StaticShape2D, GatherLayerTestCPU, -// ::testing::Combine( -// ::testing::ValuesIn(staticInputShapes2D), -// ::testing::Values(1), -// ::testing::ValuesIn(std::vector{0, 1}), -// ::testing::ValuesIn(netPrecisions), -// ::testing::Values(true), -// ::testing::Values(CommonTestUtils::DEVICE_CPU), -// ::testing::Values(CPUSpecificParams{})), -// GatherLayerTestCPU::getTestCaseName); -// -//INSTANTIATE_TEST_SUITE_P(smoke_DynamicShape2D, GatherLayerTestCPU, -// ::testing::Combine( -// ::testing::ValuesIn(dynamicInputShapes2D), -// ::testing::Values(1), -// ::testing::ValuesIn(std::vector{0, 1}), -// ::testing::ValuesIn(netPrecisions), -// ::testing::Values(true, false), -// ::testing::Values(CommonTestUtils::DEVICE_CPU), -// ::testing::Values(CPUSpecificParams{})), -// GatherLayerTestCPU::getTestCaseName); -// -//INSTANTIATE_TEST_SUITE_P(smoke_DynamicShape2Dv2, GatherLayerTestCPU, -// ::testing::Combine( -// ::testing::ValuesIn(dynamicInputShapes2Dv2), -// ::testing::Values(0), -// ::testing::Values(0), -// ::testing::ValuesIn(netPrecisions), -// ::testing::Values(true, false), -// ::testing::Values(CommonTestUtils::DEVICE_CPU), -// ::testing::Values(CPUSpecificParams{})), -// GatherLayerTestCPU::getTestCaseName); -// -//// 4D -//const std::vector staticInputShapes4D = { -// { -// {}, -// { // Static shapes -// {{4, 5, 6, 7}, {2, 5, 1}} -// } -// }, -// { -// {}, -// { // Static shapes -// {{10, 5, 6, 7}, {2, 5, 2}} -// } -// }, -// { -// {}, -// { // Static shapes -// {{16, 5, 6, 7}, {3, 5, 3}} -// } -// } -//}; -//const std::vector dynamicInputShapes4D = { -// { -// { // Origin dynamic shapes -// {ov::Dimension(4, 20), 5, 6, 7}, -// {ov::Dimension(2, 4), 5, ov::Dimension(1, 4)}, -// {1} -// }, -// { // Dynamic shapes instances -// {{4, 5, 6, 7}, {2, 5, 1}, {1}}, -// {{10, 5, 6, 7}, {2, 5, 2}, {1}}, -// {{16, 5, 6, 7}, {3, 5, 3}, {1}} -// } -// }, -// { -// { // Origin dynamic shapes -// {-1, -1, -1, -1}, {-1, -1, -1}, {1} -// }, -// { // Dynamic shapes instances -// {{4, 5, 6, 4}, {2, 5, 16}, {1}}, -// {{10, 5, 6, 8}, {2, 5, 24}, {1}} -// } -// } -//}; -// -//INSTANTIATE_TEST_SUITE_P(smoke_StaticShape4D, GatherLayerTestCPU, -// ::testing::Combine( -// ::testing::ValuesIn(staticInputShapes4D), -// ::testing::ValuesIn(std::vector{0, 1, 2, -1}), -// ::testing::Values(0), -// ::testing::ValuesIn(netPrecisions), -// ::testing::Values(true), -// ::testing::Values(CommonTestUtils::DEVICE_CPU), -// ::testing::Values(CPUSpecificParams{})), -// GatherLayerTestCPU::getTestCaseName); -// -//INSTANTIATE_TEST_SUITE_P(smoke_DynamicShape4D, GatherLayerTestCPU, -// ::testing::Combine( -// ::testing::ValuesIn(dynamicInputShapes4D), -// ::testing::ValuesIn(std::vector{0, 1, 2, -1}), -// ::testing::Values(0), -// ::testing::ValuesIn(netPrecisions), -// ::testing::Values(true, false), -// ::testing::Values(CommonTestUtils::DEVICE_CPU), -// ::testing::Values(CPUSpecificParams{})), -// GatherLayerTestCPU::getTestCaseName); -//} // namespace -//} // namespace CPULayerTestsDefinitions + +#include "shared_test_classes/base/ov_subgraph.hpp" +#include "ngraph_functions/builders.hpp" +#include "test_utils/cpu_test_utils.hpp" +#include "functional_test_utils/ov_tensor_utils.hpp" + +using namespace CPUTestUtils; +using namespace ov::test; + +namespace CPULayerTestsDefinitions { + +typedef std::tuple< + std::vector, // Input shapes + std::tuple, // Axis and Batch dim + ElementType, // Network precision + bool, // Is const Axis + CPUSpecificParams, // CPU specific params + std::map // Additional config +> GatherLayerTestCPUParams; + +class GatherLayerTestCPU : public testing::WithParamInterface, + virtual public ov::test::SubgraphBaseTest, public CPUTestsBase { +public: + static std::string getTestCaseName(testing::TestParamInfo obj) { + std::vector inputShapes; + std::tuple axisAndBatchDims; + ElementType netPrecision; + bool isAxisConstant; + CPUSpecificParams cpuParams; + std::map additionalConfig; + + std::tie(inputShapes, axisAndBatchDims, netPrecision, isAxisConstant, cpuParams, additionalConfig) = obj.param; + + std::ostringstream result; + result << "IS=("; + for (size_t i = 0lu; i < inputShapes.size(); i++) { + result << CommonTestUtils::partialShape2str({inputShapes[i].first}) << (i < inputShapes.size() - 1lu ? "_" : ""); + } + result << ")_TS="; + for (size_t i = 0lu; i < inputShapes.front().second.size(); i++) { + result << "{"; + for (size_t j = 0lu; j < inputShapes.size(); j++) { + result << CommonTestUtils::vec2str(inputShapes[j].second[i]) << (j < inputShapes.size() - 1lu ? "_" : ""); + } + result << "}_"; + } + result << "axis=" << std::get<0>(axisAndBatchDims) << "_"; + result << "batchDims=" << std::get<1>(axisAndBatchDims) << "_"; + result << "netPrc=" << netPrecision << "_"; + result << "constAx=" << (isAxisConstant ? "True" : "False") << "_"; + result << CPUTestsBase::getTestCaseName(cpuParams); + + if (!additionalConfig.empty()) { + result << "_PluginConf"; + for (auto &item : additionalConfig) { + if (item.second == InferenceEngine::PluginConfigParams::YES) + result << "_" << item.first << "=" << item.second; + } + } + + return result.str(); + } + +protected: + void SetUp() override { + std::vector inputShapes; + std::tuple axisAndBatchDims; + ElementType netPrecision; + bool isAxisConstant; + CPUSpecificParams cpuParams; + std::map additionalConfig; + const ElementType intInputsPrecision = ElementType::i64; + + std::tie(inputShapes, axisAndBatchDims, netPrecision, isAxisConstant, cpuParams, additionalConfig) = this->GetParam(); + std::tie(inFmts, outFmts, priority, selectedType) = cpuParams; + axis = std::get<0>(axisAndBatchDims); + const int batchDims = std::get<1>(axisAndBatchDims); + targetDevice = CommonTestUtils::DEVICE_CPU; + init_input_shapes(inputShapes); + configuration.insert(additionalConfig.begin(), additionalConfig.end()); + + if (additionalConfig[InferenceEngine::PluginConfigParams::KEY_ENFORCE_BF16] == InferenceEngine::PluginConfigParams::YES) { + selectedType = makeSelectedTypeStr(selectedType, ElementType::bf16); + } else { + selectedType = makeSelectedTypeStr(selectedType, netPrecision); + } + + if (!isAxisConstant) { + inputDynamicShapes.push_back({1}); + for (size_t i = 0lu; i < targetStaticShapes.size(); i++) { + targetStaticShapes[i].push_back({1}); + } + } + + ngraph::ParameterVector params { + std::make_shared(netPrecision, inputDynamicShapes[0]), + std::make_shared(intInputsPrecision, inputDynamicShapes[1]) + }; + params[0]->set_friendly_name("data"); + params[1]->set_friendly_name("indices"); + if (!isAxisConstant) { + params.push_back(std::make_shared(intInputsPrecision, inputDynamicShapes[2])); + params[2]->set_friendly_name("axis"); + } + auto paramOuts = ngraph::helpers::convert2OutputVector(ngraph::helpers::castOps2Nodes(params)); + std::shared_ptr gatherNode; + if (isAxisConstant) { + gatherNode = std::make_shared(paramOuts[0], paramOuts[1], + ov::op::v0::Constant::create(intInputsPrecision, ov::Shape({1}), { axis }), batchDims); + } else { + gatherNode = std::make_shared(paramOuts[0], paramOuts[1], paramOuts[2], batchDims); + } + + function = makeNgraphFunction(netPrecision, params, gatherNode, "GatherCPU"); + } + + void generate_inputs(const std::vector& targetInputStaticShapes) override { + const auto& funcInputs = function->inputs(); + inputs.clear(); + + const size_t normAxis = axis < 0 ? axis + targetInputStaticShapes[0].size() : axis; + const int32_t axisDim = targetInputStaticShapes[0][normAxis]; + + for (int i = 0; i < funcInputs.size(); ++i) { + const auto& funcInput = funcInputs[i]; + ov::runtime::Tensor tensor; + + if (funcInput.get_node()->get_friendly_name() == "data") { + const auto dataTypeSize = funcInput.get_element_type().size(); + const uint32_t range = dataTypeSize == 4 ? 0x7FFFFFFF : dataTypeSize == 2 ? 0xFFFF : 0xFF; + tensor = ov::test::utils::create_and_fill_tensor( + funcInput.get_element_type(), targetInputStaticShapes[0], range, 0, 1); + } else if (funcInput.get_node()->get_friendly_name() == "indices") { + tensor = ov::test::utils::create_and_fill_tensor( + funcInput.get_element_type(), targetInputStaticShapes[1], axisDim * 2, -axisDim, 1); + } else if (funcInput.get_node()->get_friendly_name() == "axis") { + tensor = ov::test::utils::create_and_fill_tensor(funcInput.get_element_type(), {1}, 1, axis, 1); + } + inputs.insert({funcInput.get_node_shared_ptr(), tensor}); + } + } + + int64_t axis = 0; +}; + +TEST_P(GatherLayerTestCPU, CompareWithRefs) { + SKIP_IF_CURRENT_TEST_IS_DISABLED() + + run(); + CheckPluginRelatedResults(executableNetwork, "Gather"); +} + +namespace { +const std::vector netPrecisions = { + ElementType::f32, + ElementType::bf16, + ElementType::i8 +}; + +std::vector> additionalConfig + = {{{InferenceEngine::PluginConfigParams::KEY_ENFORCE_BF16, InferenceEngine::PluginConfigParams::NO}}, + {{InferenceEngine::PluginConfigParams::KEY_ENFORCE_BF16, InferenceEngine::PluginConfigParams::YES}}}; + +std::vector isAxisConst{true, false}; +const CPUSpecificParams cpuParamsRef{{}, {}, {"ref_any"}, "ref_any"}; + +std::vector getCPUInfo() { + std::vector resCPUParams; + if (InferenceEngine::with_cpu_x86_avx512f()) { + resCPUParams.push_back(CPUSpecificParams{{}, {}, {"jit_avx512"}, "jit_avx512"}); + } else if (InferenceEngine::with_cpu_x86_avx2()) { + resCPUParams.push_back(CPUSpecificParams{{}, {}, {"jit_avx2"}, "jit_avx2"}); + } else { + resCPUParams.push_back(CPUSpecificParams{{}, {}, {"ref"}, "ref"}); + } + return resCPUParams; +} + +///// 1D ///// +const std::vector> staticInputShapes1D = { + { { {}, { {1} } }, { {}, { {1} } } }, + { { {}, { {2} } }, { {}, { {2} } } }, + { { {}, { {3} } }, { {}, { {3} } } }, + { { {}, { {4} } }, { {}, { {4} } } }, + { { {}, { {5} } }, { {}, { {5} } } }, + { { {}, { {6} } }, { {}, { {6} } } }, + { { {}, { {7} } }, { {}, { {7} } } }, + { { {}, { {8} } }, { {}, { {8} } } }, + { { {}, { {9} } }, { {}, { {9} } } }, + { { {}, { {11} } }, { {}, { {11} } } }, + { { {}, { {13} } }, { {}, { {13} } } }, + { { {}, { {15} } }, { {}, { {15} } } }, + { { {}, { {16} } }, { {}, { {16} } } }, + { { {}, { {17} } }, { {}, { {17} } } }, + { { {}, { {19} } }, { {}, { {19} } } }, + { { {}, { {23} } }, { {}, { {23} } } }, + { { {}, { {24} } }, { {}, { {24} } } }, + { { {}, { {32} } }, { {}, { {32} } } }, + { { {}, { {33} } }, { {}, { {33} } } }, + { { {}, { {37} } }, { {}, { {37} } } }, + { { {}, { {41} } }, { {}, { {41} } } }, + { { {}, { {48} } }, { {}, { {48} } } }, + { { {}, { {51} } }, { {}, { {51} } } }, + { { {}, { {63} } }, { {}, { {63} } } }, + { { {}, { {64} } }, { {}, { {64} } } }, + { { {}, { {65} } }, { {}, { {65} } } } +}; + +INSTANTIATE_TEST_SUITE_P(smoke_static_1D, GatherLayerTestCPU, + ::testing::Combine( + ::testing::ValuesIn(staticInputShapes1D), + ::testing::Values(std::tuple{0, 0}), + ::testing::ValuesIn(netPrecisions), + ::testing::Values(true), + ::testing::ValuesIn(getCPUInfo()), + ::testing::Values(additionalConfig[0])), + GatherLayerTestCPU::getTestCaseName); + +const std::vector> dynamicInputShapes1D = { + { { { ov::Dimension{1, 70} }, // Dynamic shape 0 + { {1}, {2}, {3}, {4}, {5}, {6}, {7}, {8}, {9}, {11}, {13}, {15}, {16}, {17}, {19}, {23}, {24}, {32}, {55}, {63}, {64}, {65} } }, // Target shapes + { { -1 }, // Dynamic shape 1 + { {1}, {2}, {3}, {4}, {5}, {6}, {7}, {8}, {9}, {11}, {13}, {15}, {16}, {17}, {19}, {23}, {24}, {32}, {55}, {63}, {64}, {65} } } } // Target shapes +}; + +INSTANTIATE_TEST_SUITE_P(smoke_dynamic_1D, GatherLayerTestCPU, + ::testing::Combine( + ::testing::ValuesIn(dynamicInputShapes1D), + ::testing::Values(std::tuple{0, 0}), + ::testing::ValuesIn(netPrecisions), + ::testing::Values(true, false), + ::testing::ValuesIn(getCPUInfo()), + ::testing::Values(additionalConfig[0])), + GatherLayerTestCPU::getTestCaseName); + +///// 4D JIT ///// +std::vector> get4DShapesJitStat() { + std::vector> result = {}; + if (InferenceEngine::with_cpu_x86_avx2()) { + result = { + { { {}, { {18, 2, 2, 1} } }, // Static shapes + { {}, { {18, 2, 8} } } + }, + { { {}, { {17, 2, 2, 2} } }, // Static shapes + { {}, { {17, 2, 7} } } + }, + { { {}, { {16, 2, 2, 3} } }, // Static shapes + { {}, { {16, 2, 6} } } + }, + { { {}, { {15, 2, 2, 4} } }, // Static shapes + { {}, { {15, 2, 5} } } + }, + { { {}, { {14, 2, 2, 5} } }, // Static shapes + { {}, { {14, 2, 4} } } + }, + { { {}, { {13, 2, 2, 6} } }, // Static shapes + { {}, { {13, 2, 3} } } + }, + { { {}, { {12, 2, 2, 7} } }, // Static shapes + { {}, { {12, 2, 2} } } + }, + { { {}, { {11, 2, 2, 8} } }, // Static shapes + { {}, { {11, 2, 1} } } + } + }; + } + if (InferenceEngine::with_cpu_x86_avx512f()) { + std::vector> tmp = { + { { {}, { {19, 4, 2, 9} } }, // Static shapes + { {}, { {19, 4, 16} } } + }, + { { {}, { {20, 4, 2, 10} } }, // Static shapes + { {}, { {20, 4, 15} } }, + }, + { { {}, { {21, 4, 2, 11} } }, // Static shapes + { {}, { {21, 4, 14} } } + }, + { { {}, { {22, 4, 2, 12} } }, // Static shapes + { {}, { {22, 4, 13} } }, + }, + { { {}, { {23, 4, 2, 13} } }, // Static shapes + { {}, { {23, 4, 12} } }, + }, + { { {}, { {24, 4, 2, 14} } }, // Static shapes + { {}, { {24, 4, 11} } }, + }, + { { {}, { {25, 4, 2, 15} } }, // Static shapes + { {}, { {25, 4, 10} } }, + }, + { { {}, { {26, 4, 2, 16} } }, // Static shapes + { {}, { {26, 4, 9} } }, + } + }; + result.insert(result.end(), tmp.begin(), tmp.end()); + } + + return result; +} + +std::vector> get4DAxisBatchJitStat(ov::element::Type type) { + std::vector> result = {}; + if (InferenceEngine::with_cpu_x86_avx512f()) { + if (type.size() == 4 || type.size() == 2 || type.size() == 1) + return std::vector>{{3, 0}, {3, 1}, {3, 2}, {2, 0}, {2, 1}, {2, 2}}; + } else if (InferenceEngine::with_cpu_x86_avx2()) { + if (type.size() == 4) + return std::vector>{{3, 0}, {3, 1}, {3, 2}, {2, 0}, {2, 1}, {2, 2}}; + else if (type.size() == 2 || type.size() == 1) + return std::vector>{{3, 0}, {3, 1}, {3, 2}}; + } + return {}; +} + +INSTANTIATE_TEST_SUITE_P(smoke_static_4D_jit32, GatherLayerTestCPU, + ::testing::Combine( + ::testing::ValuesIn(get4DShapesJitStat()), + ::testing::ValuesIn(get4DAxisBatchJitStat(ElementType::f32)), + ::testing::Values(ElementType::f32), + ::testing::Values(true), + ::testing::ValuesIn(getCPUInfo()), + ::testing::ValuesIn(additionalConfig)), + GatherLayerTestCPU::getTestCaseName); + +INSTANTIATE_TEST_SUITE_P(smoke_static_4D_jit16, GatherLayerTestCPU, + ::testing::Combine( + ::testing::ValuesIn(get4DShapesJitStat()), + ::testing::ValuesIn(get4DAxisBatchJitStat(ElementType::bf16)), + ::testing::Values(ElementType::bf16), + ::testing::Values(true), + ::testing::ValuesIn(getCPUInfo()), + ::testing::Values(additionalConfig[0])), + GatherLayerTestCPU::getTestCaseName); + +INSTANTIATE_TEST_SUITE_P(smoke_static_4D_jit8, GatherLayerTestCPU, + ::testing::Combine( + ::testing::ValuesIn(get4DShapesJitStat()), + ::testing::ValuesIn(get4DAxisBatchJitStat(ElementType::i8)), + ::testing::Values(ElementType::i8), + ::testing::Values(true), + ::testing::ValuesIn(getCPUInfo()), + ::testing::Values(additionalConfig[0])), + GatherLayerTestCPU::getTestCaseName); + + +std::vector> get4DShapesJitDyn() { + std::vector> result = {}; + if (InferenceEngine::with_cpu_x86_avx2()) { + result = { + { { { ov::Dimension(5, 15), -1, -1, -1 }, // Dynamic shape 0 + { {8, 2, 2, 1}, {10, 2, 2, 2}, {8, 2, 2, 3}, {10, 2, 2, 4}} }, // Target shapes + { { ov::Dimension(4, 16), -1, -1 }, // Dynamic shape 1 + { {8, 2, 8}, {10, 2, 7}, {8, 2, 6}, {10, 2, 5} } } }, // Target shapes + { { { -1, -1, -1, -1 }, // Dynamic shape 0 + { {8, 2, 2, 5}, {10, 2, 2, 6}, {8, 2, 2, 7}, {10, 2, 2, 8}} }, // Target shapes + { { -1, -1, -1 }, // Dynamic shape 1 + { {8, 2, 4}, {10, 2, 3}, {8, 2, 2}, {10, 2, 1} } } }, // Target shapes + { { { ov::Dimension(5, 15), -1, -1, -1 }, // Dynamic shape 0 + { {10, 2, 2, 1}, {10, 2, 2, 2}, {10, 2, 2, 3}, {10, 2, 2, 4}} }, // Target shapes + { { 10, 2, 5 }, // Dynamic shape 1 + { {10, 2, 5}, {10, 2, 5}, {10, 2, 5}, {10, 2, 5} } } }, // Target shapes + { { { 8, 2, 2, 5 }, // Dynamic shape 0 + { {8, 2, 2, 5}, {8, 2, 2, 5}, {8, 2, 2, 5}, {8, 2, 2, 5}} }, // Target shapes + { { -1, -1, -1 }, // Dynamic shape 1 + { {8, 2, 4}, {8, 2, 3}, {8, 2, 2}, {8, 2, 1} } } } // Target shapes + }; + } + if (InferenceEngine::with_cpu_x86_avx512f()) { + std::vector> tmp = { + { { { ov::Dimension(5, 15), -1, -1, -1 }, // Dynamic shape 0 + { {8, 2, 2, 9}, {10, 2, 2, 10}, {8, 2, 2, 11}, {10, 2, 2, 12}} }, // Target shapes + { { ov::Dimension(4, 16), -1, -1 }, // Dynamic shape 1 + { {8, 2, 16}, {10, 2, 15}, {8, 2, 14}, {10, 2, 13} } } }, // Target shapes + { { { -1, -1, -1, -1 }, // Dynamic shape 0 + { {8, 2, 2, 13}, {10, 2, 2, 14}, {8, 2, 2, 15}, {10, 2, 2, 16}} }, // Target shapes + { { -1, -1, -1 }, // Dynamic shape 1 + { {8, 2, 12}, {10, 2, 11}, {8, 2, 10}, {10, 2, 9} } } }, // Target shapes + { { { ov::Dimension(5, 15), -1, -1, -1 }, // Dynamic shape 0 + { {10, 2, 2, 9}, {10, 2, 2, 10}, {10, 2, 2, 11}, {10, 2, 2, 12}} }, // Target shapes + { { 10, 2, 16 }, // Dynamic shape 1 + { {10, 2, 16}, {10, 2, 16}, {10, 2, 16}, {10, 2, 16} } } }, // Target shapes + { { { 8, 2, 2, 15 }, // Dynamic shape 0 + { {8, 2, 2, 15}, {8, 2, 2, 15}, {8, 2, 2, 15}, {8, 2, 2, 15}} }, // Target shapes + { { -1, -1, -1 }, // Dynamic shape 1 + { {8, 2, 12}, {8, 2, 11}, {8, 2, 10}, {8, 2, 9} } } } // Target shapes + }; + result.insert(result.end(), tmp.begin(), tmp.end()); + } + + return result; +} + +std::vector> get4DAxisBatchJitDyn(ov::element::Type type) { + std::vector> result = {}; + if (InferenceEngine::with_cpu_x86_avx512f()) { + if (type.size() == 4 || type.size() == 2 || type.size() == 1) + return std::vector>{{3, 0}, {3, 1}, {3, 2}}; + } else if (InferenceEngine::with_cpu_x86_avx2()) { + if (type.size() == 4 || type.size() == 2 || type.size() == 1) + return std::vector>{{3, 0}, {3, 1}, {3, 2}}; + } + return {}; +} + +INSTANTIATE_TEST_SUITE_P(smoke_dynamic_4D_jit32, GatherLayerTestCPU, + ::testing::Combine( + ::testing::ValuesIn(get4DShapesJitDyn()), + ::testing::ValuesIn(get4DAxisBatchJitDyn(ElementType::f32)), + ::testing::Values(ElementType::f32), + ::testing::ValuesIn(isAxisConst), + ::testing::ValuesIn(getCPUInfo()), + ::testing::ValuesIn(additionalConfig)), + GatherLayerTestCPU::getTestCaseName); + +INSTANTIATE_TEST_SUITE_P(smoke_dynamic_4D_jit16, GatherLayerTestCPU, + ::testing::Combine( + ::testing::ValuesIn(get4DShapesJitDyn()), + ::testing::ValuesIn(get4DAxisBatchJitDyn(ElementType::bf16)), + ::testing::Values(ElementType::bf16), + ::testing::ValuesIn(isAxisConst), + ::testing::ValuesIn(getCPUInfo()), + ::testing::Values(additionalConfig[0])), + GatherLayerTestCPU::getTestCaseName); + +INSTANTIATE_TEST_SUITE_P(smoke_dynamic_4D_jit8, GatherLayerTestCPU, + ::testing::Combine( + ::testing::ValuesIn(get4DShapesJitDyn()), + ::testing::ValuesIn(get4DAxisBatchJitDyn(ElementType::i8)), + ::testing::Values(ElementType::i8), + ::testing::ValuesIn(isAxisConst), + ::testing::ValuesIn(getCPUInfo()), + ::testing::Values(additionalConfig[0])), + GatherLayerTestCPU::getTestCaseName); + + +///// 4D REFERENCE ///// +std::vector> get4DShapesRefStat() { + std::vector> result = {}; + if (InferenceEngine::with_cpu_x86_avx2()) { + result = { + { { {}, { {10, 2, 9, 9} } }, // Static shapes + { {}, { {10, 2, 8} } } + }, + { { {}, { {11, 2, 9, 2} } }, // Static shapes + { {}, { {11, 2, 7} } } + }, + { { {}, { {12, 2, 9, 3} } }, // Static shapes + { {}, { {12, 2, 6} } } + }, + { { {}, { {13, 2, 9, 4} } }, // Static shapes + { {}, { {13, 2, 5} } } + }, + { { {}, { {14, 2, 9, 5} } }, // Static shapes + { {}, { {14, 2, 4} } } + }, + { { {}, { {15, 2, 9, 6} } }, // Static shapes + { {}, { {15, 2, 3} } } + }, + { { {}, { {16, 2, 9, 7} } }, // Static shapes + { {}, { {16, 2, 2} } } + }, + { { {}, { {17, 2, 9, 8} } }, // Static shapes + { {}, { {17, 2, 1} } } + } + }; + } + if (InferenceEngine::with_cpu_x86_avx512f()) { + std::vector> tmp = { + { { {}, { {25, 4, 4, 17} } }, // Static shapes + { {}, { {25, 4, 16} } } + }, + { { {}, { {24, 4, 4, 18} } }, // Static shapes + { {}, { {24, 4, 15} } }, + }, + { { {}, { {23, 4, 4, 19} } }, // Static shapes + { {}, { {23, 4, 14} } } + }, + { { {}, { {22, 4, 4, 20} } }, // Static shapes + { {}, { {22, 4, 13} } }, + }, + { { {}, { {21, 4, 4, 21} } }, // Static shapes + { {}, { {21, 4, 12} } }, + }, + { { {}, { {20, 4, 4, 22} } }, // Static shapes + { {}, { {20, 4, 11} } }, + }, + { { {}, { {19, 4, 4, 23} } }, // Static shapes + { {}, { {19, 4, 10} } }, + }, + { { {}, { {18, 4, 4, 24} } }, // Static shapes + { {}, { {18, 4, 9} } }, + } + }; + result.insert(result.end(), tmp.begin(), tmp.end()); + } + + return result; +} + +std::vector> get4DAxisBatchRefStat(ov::element::Type type) { + std::vector> result = {}; + if (InferenceEngine::with_cpu_x86_avx512f()) { + if (type.size() == 4) + return std::vector>{{1, 0}, {1, 1}, {0, 0}}; + else if (type.size() == 2 || type.size() == 1) + return std::vector>{{0, 0}}; + } else if (InferenceEngine::with_cpu_x86_avx2()) { + if (type.size() == 4) + return std::vector>{{1, 0}, {1, 1}, {0, 0}}; + else if (type.size() == 2 || type.size() == 1) + return std::vector>{{2, 0}, {2, 1}, {2, 2}, {1, 0}, {1, 1}, {0, 0}}; + } + return {}; +} + +INSTANTIATE_TEST_SUITE_P(smoke_static_4D_ref32, GatherLayerTestCPU, + ::testing::Combine( + ::testing::ValuesIn(get4DShapesRefStat()), + ::testing::ValuesIn(get4DAxisBatchRefStat(ElementType::f32)), + ::testing::Values(ElementType::f32), + ::testing::Values(true), + ::testing::Values(cpuParamsRef), + ::testing::ValuesIn(additionalConfig)), + GatherLayerTestCPU::getTestCaseName); + +INSTANTIATE_TEST_SUITE_P(smoke_static_4D_ref16, GatherLayerTestCPU, + ::testing::Combine( + ::testing::ValuesIn(get4DShapesRefStat()), + ::testing::ValuesIn(get4DAxisBatchRefStat(ElementType::bf16)), + ::testing::Values(ElementType::bf16), + ::testing::Values(true), + ::testing::Values(cpuParamsRef), + ::testing::Values(additionalConfig[0])), + GatherLayerTestCPU::getTestCaseName); + +INSTANTIATE_TEST_SUITE_P(smoke_static_4D_ref8, GatherLayerTestCPU, + ::testing::Combine( + ::testing::ValuesIn(get4DShapesRefStat()), + ::testing::ValuesIn(get4DAxisBatchRefStat(ElementType::i8)), + ::testing::Values(ElementType::i8), + ::testing::Values(true), + ::testing::Values(cpuParamsRef), + ::testing::Values(additionalConfig[0])), + GatherLayerTestCPU::getTestCaseName); +} // namespace +} // namespace CPULayerTestsDefinitions diff --git a/src/tests/functional/shared_test_classes/src/single_layer/gather.cpp b/src/tests/functional/shared_test_classes/src/single_layer/gather.cpp index 06331289065..033c900067a 100644 --- a/src/tests/functional/shared_test_classes/src/single_layer/gather.cpp +++ b/src/tests/functional/shared_test_classes/src/single_layer/gather.cpp @@ -104,9 +104,9 @@ std::string Gather8LayerTest::getTestCaseName(const testing::TestParamInfo(functionParams)); auto indicesNode = ngraph::builder::makeConstant(ngraph::element::i64, indicesShape, {}, true, inputShape[axis < 0 ? axis + inputShape.size() : axis] - 1, - 1 - static_cast(inputShape[axis < 0 ? axis + inputShape.size() : axis])); + -static_cast(inputShape[axis < 0 ? axis + inputShape.size() : axis])); auto axisNode = ngraph::opset8::Constant::create(ngraph::element::i64, ngraph::Shape({}), { axis }); auto gather = std::make_shared(paramOuts[0], indicesNode, axisNode, batchIdx); ngraph::ResultVector results{ std::make_shared(gather) }; diff --git a/tests/layer_tests/onnx_tests/test_gather.py b/tests/layer_tests/onnx_tests/test_gather.py index 8f3b90f156c..addab999191 100644 --- a/tests/layer_tests/onnx_tests/test_gather.py +++ b/tests/layer_tests/onnx_tests/test_gather.py @@ -275,7 +275,6 @@ class TestGather(OnnxRuntimeLayerTest): dict(shape=[6, 8, 10, 12], axis=-1, indices=[[[2, -1], [3, 2]], [[5, -1], [3, -2]]], output_shape=[6, 8, 10, 2, 2, 2])] - @pytest.mark.xfail(reason='negative indices are not yet implemented on CPU: xxx-54630') @pytest.mark.parametrize("params", test_data_negative_indices) @pytest.mark.nightly def test_gather_nightly_negative_indices(self, params, ie_device, precision, ir_version,