diff --git a/.gitignore b/.gitignore index c208937e682..e6ab6132920 100644 --- a/.gitignore +++ b/.gitignore @@ -1,5 +1,7 @@ # build/artifact dirs _* +[Bb]uild*/ + # but ensure we don't skip __init__.py and __main__.py !__init__.py !__main__.py @@ -9,12 +11,10 @@ _* # developer tools *.idea .vscode -cmake-build-* .DS_Store **/tags compile_commands.json bin/ -build/ .local_vimrc .gdb_history .vimspector.json @@ -34,14 +34,13 @@ docs/IE_PLUGIN_DG/html/ *.pydevproject *.settings */gen/ -__pycache__ *.swp /config.xml # Python-specific -*.env3 +*.?env* *.pyc - +__pycache__ # Tests-specific *.coverage *htmlcov diff --git a/src/bindings/python/tests/test_ngraph/test_ops_unary.py b/src/bindings/python/tests/test_ngraph/test_ops_unary.py index 4f076286a45..3131a6e2ff7 100644 --- a/src/bindings/python/tests/test_ngraph/test_ops_unary.py +++ b/src/bindings/python/tests/test_ngraph/test_ops_unary.py @@ -10,6 +10,8 @@ from openvino.runtime import Shape, Type from tests.runtime import get_runtime from tests.test_ngraph.util import run_op_node +R_TOLERANCE = 1e-6 # global relative tolerance + @pytest.mark.parametrize( ("ng_api_fn", "numpy_fn", "range_start", "range_end"), @@ -79,7 +81,8 @@ def test_unary_op_scalar(ng_api_fn, numpy_fn, input_data): @pytest.mark.parametrize( - "input_data", [(np.array([True, False, True, False])), (np.array([True])), (np.array([False]))], + "input_data", + [(np.array([True, False, True, False])), (np.array([True])), (np.array([False]))], ) def test_logical_not(input_data): expected = np.logical_not(input_data) @@ -245,38 +248,44 @@ def test_gelu_tanh_operator_with_array(): assert np.allclose(result, expected, 1e-6, 1e-6) -@pytest.mark.parametrize( - "data_type", - [ - Type.f64, - Type.f32, - Type.f16, - ], -) -def test_softsign_with_parameters(data_type): - data = np.random.rand(4, 2).astype(data_type.to_dtype()) +type_tolerance = [ + (np.float64, 1e-6), + (np.float32, 1e-6), + (np.float16, 1e-3), +] + + +@pytest.mark.parametrize("type_tolerance", type_tolerance) +def test_softsign_with_parameters(type_tolerance): + dtype, atol = type_tolerance + data = np.random.uniform(-1.0, 1.0, (32, 5)).astype(dtype) expected = np.divide(data, np.abs(data) + 1) runtime = get_runtime() - param = ov.parameter(data.shape, data_type, name="Data") + param = ov.parameter(data.shape, dtype, name="Data") result = runtime.computation(ov.softsign(param, "SoftSign"), param)(data) - assert np.allclose(result, expected, 1e-6, 1e-3) + assert np.allclose(result, expected, R_TOLERANCE, atol) -@pytest.mark.parametrize( - "data_type", - [ - np.float64, - np.float32, - np.float16, - ], -) -def test_softsign_with_array(data_type): - data = np.random.rand(32, 5).astype(data_type) +@pytest.mark.parametrize("type_tolerance", type_tolerance) +def test_softsign_with_array(type_tolerance): + dtype, atol = type_tolerance + data = np.random.uniform(-1.0, 1.0, (32, 5)).astype(dtype) expected = np.divide(data, np.abs(data) + 1) runtime = get_runtime() result = runtime.computation(ov.softsign(data, "SoftSign"))() - assert np.allclose(result, expected, 1e-6, 1e-6) + assert np.allclose(result, expected, R_TOLERANCE, atol) + + +@pytest.mark.parametrize("type_tolerance", type_tolerance) +def test_softsign(type_tolerance): + dtype, atol = type_tolerance + data = np.random.uniform(-1.0, 1.0, (32, 5)).astype(dtype) + expected = np.divide(data, np.abs(data) + 1) + + result = run_op_node([data], ov.softsign) + + assert np.allclose(result, expected, R_TOLERANCE, atol) diff --git a/src/bindings/python/tests_compatibility/test_ngraph/test_ops_unary.py b/src/bindings/python/tests_compatibility/test_ngraph/test_ops_unary.py index faa47b5028a..a051e018bc4 100644 --- a/src/bindings/python/tests_compatibility/test_ngraph/test_ops_unary.py +++ b/src/bindings/python/tests_compatibility/test_ngraph/test_ops_unary.py @@ -9,6 +9,8 @@ from ngraph.impl import Shape, Type from tests_compatibility.runtime import get_runtime from tests_compatibility.test_ngraph.util import run_op_node +R_TOLERANCE = 1e-6 # global relative tolerance + @pytest.mark.parametrize( "ng_api_fn, numpy_fn, range_start, range_end", @@ -244,38 +246,34 @@ def test_gelu_tanh_operator_with_array(): assert np.allclose(result, expected, 1e-6, 1e-6) -@pytest.mark.parametrize( - "numpy_type", - [ - np.float64, - np.float32, - np.float16, - ], -) -def test_softsign_with_parameters(numpy_type): - data = np.random.rand(4, 2).astype(numpy_type) +type_tolerance = [ + (np.float64, 1e-6), + (np.float32, 1e-6), + (np.float16, 1e-3), +] + + +@pytest.mark.parametrize("type_tolerance", type_tolerance) +def test_softsign_with_parameters(type_tolerance): + dtype, atol = type_tolerance + data = np.random.uniform(-1.0, 1.0, (32, 5)).astype(dtype) + expected = np.divide(data, np.abs(data) + 1) runtime = get_runtime() - param = ng.parameter(data.shape, numpy_type, name="Data") + param = ng.parameter(data.shape, dtype, name="Data") result = runtime.computation(ng.softsign(param, "SoftSign"), param)(data) - assert np.allclose(result, expected, 1e-6, 1e-3) + assert np.allclose(result, expected, R_TOLERANCE, atol) -@pytest.mark.parametrize( - "data_type", - [ - np.float64, - np.float32, - np.float16, - ], -) -def test_softsign_with_array(data_type): - data = np.random.rand(32, 5).astype(data_type) +@pytest.mark.parametrize("type_tolerance", type_tolerance) +def test_softsign_with_array(type_tolerance): + dtype, atol = type_tolerance + data = np.random.uniform(-1.0, 1.0, (32, 5)).astype(dtype) expected = np.divide(data, np.abs(data) + 1) runtime = get_runtime() result = runtime.computation(ng.softsign(data, "SoftSign"))() - assert np.allclose(result, expected, 1e-6, 1e-6) + assert np.allclose(result, expected, R_TOLERANCE, atol) diff --git a/src/plugins/intel_cpu/src/cpu_types.cpp b/src/plugins/intel_cpu/src/cpu_types.cpp index ab2b1c1f0b2..f824285e4d5 100644 --- a/src/plugins/intel_cpu/src/cpu_types.cpp +++ b/src/plugins/intel_cpu/src/cpu_types.cpp @@ -64,6 +64,7 @@ const InferenceEngine::details::caseless_unordered_map type_t { "PRelu", Type::Eltwise }, { "Erf", Type::Eltwise }, { "SoftPlus", Type::Eltwise }, + { "SoftSign", Type::Eltwise }, { "Reshape", Type::Reshape }, { "Squeeze", Type::Reshape }, { "Unsqueeze", Type::Reshape }, diff --git a/src/plugins/intel_cpu/src/cpu_types.h b/src/plugins/intel_cpu/src/cpu_types.h index 5554eb7521a..88b64e9ad9e 100644 --- a/src/plugins/intel_cpu/src/cpu_types.h +++ b/src/plugins/intel_cpu/src/cpu_types.h @@ -169,6 +169,7 @@ enum class Algorithm { EltwiseRoundHalfToEven, EltwiseRoundHalfAwayFromZero, EltwiseErf, + EltwiseSoftSign, // FakeQuantize algorithms FQCommon, diff --git a/src/plugins/intel_cpu/src/emitters/jit_eltwise_emitters.cpp b/src/plugins/intel_cpu/src/emitters/jit_eltwise_emitters.cpp index decfbd6c772..1592c814bb8 100644 --- a/src/plugins/intel_cpu/src/emitters/jit_eltwise_emitters.cpp +++ b/src/plugins/intel_cpu/src/emitters/jit_eltwise_emitters.cpp @@ -1839,5 +1839,47 @@ size_t jit_erf_emitter::aux_vecs_count() const { return 5ul; } +/// SOFT SIGN /// +jit_soft_sign_emitter::jit_soft_sign_emitter(jit_generator *host, cpu_isa_t host_isa, const std::shared_ptr& node, Precision exec_prc) +: jit_emitter(host, host_isa, node, exec_prc) { + prepare_table(); +} +jit_soft_sign_emitter::jit_soft_sign_emitter(jit_generator *host, cpu_isa_t host_isa, Precision exec_prc) +: jit_emitter(host, host_isa, exec_prc) { + prepare_table(); +} + +size_t jit_soft_sign_emitter::get_inputs_num() const { return 1; } + +void jit_soft_sign_emitter::emit_impl(const std::vector &in_vec_idxs, const std::vector &out_vec_idxs, + const std::vector &pool_vec_idxs, const std::vector &pool_gpr_idxs, + const emitter_context *emit_context) const { + if (host_isa_ == cpu::x64::sse41) { + emit_isa(in_vec_idxs, out_vec_idxs); + } else if (host_isa_ == cpu::x64::avx2) { + emit_isa(in_vec_idxs, out_vec_idxs); + } else if (host_isa_ == cpu::x64::avx512_core) { + emit_isa(in_vec_idxs, out_vec_idxs); + } else { + assert(!"unsupported isa"); + } +} + +template +void jit_soft_sign_emitter::emit_isa(const std::vector &in_vec_idxs, const std::vector &out_vec_idxs) const { + using Vmm = typename conditional3::type; + Vmm vmm_src = Vmm(in_vec_idxs[0]); + Vmm vmm_dst = Vmm(out_vec_idxs[0]); + + h->uni_vmovups(vmm_dst, vmm_src); // y = x + h->uni_vandps(vmm_src, vmm_src, table_val("positive_mask")); // x = abs(x) + h->uni_vaddps(vmm_src, vmm_src, table_val("one")); // x++ + h->uni_vdivps(vmm_dst, vmm_dst, vmm_src); // y = y/x +} + +void jit_soft_sign_emitter::register_table_entries() { + push_arg_entry_of("one", 0x3f800000, true); + push_arg_entry_of("positive_mask", 0x7fffffff, true); +} } // namespace intel_cpu } // namespace ov diff --git a/src/plugins/intel_cpu/src/emitters/jit_eltwise_emitters.hpp b/src/plugins/intel_cpu/src/emitters/jit_eltwise_emitters.hpp index 46a82d6c033..349bafe5d43 100644 --- a/src/plugins/intel_cpu/src/emitters/jit_eltwise_emitters.hpp +++ b/src/plugins/intel_cpu/src/emitters/jit_eltwise_emitters.hpp @@ -586,5 +586,25 @@ private: size_t aux_vecs_count() const override; }; +class jit_soft_sign_emitter : public jit_emitter { +public: + jit_soft_sign_emitter(dnnl::impl::cpu::x64::jit_generator *host, dnnl::impl::cpu::x64::cpu_isa_t host_isa, + InferenceEngine::Precision exec_prc = InferenceEngine::Precision::FP32); + jit_soft_sign_emitter(dnnl::impl::cpu::x64::jit_generator *host, dnnl::impl::cpu::x64::cpu_isa_t host_isa, const std::shared_ptr& n, + InferenceEngine::Precision exec_prc = InferenceEngine::Precision::FP32); + + size_t get_inputs_num() const override; + +private: + void emit_impl(const std::vector &in_vec_idxs, const std::vector &out_vec_idxs, + const std::vector &pool_vec_idxs, const std::vector &pool_gpr_idxs, + const emitter_context *emit_context) const override; + + template + void emit_isa(const std::vector &in_vec_idxs, const std::vector &out_vec_idxs) const; + + void register_table_entries() override; +}; + } // namespace intel_cpu } // namespace ov diff --git a/src/plugins/intel_cpu/src/nodes/eltwise.cpp b/src/plugins/intel_cpu/src/nodes/eltwise.cpp index 7e1a2dbb820..619ce9f76e9 100644 --- a/src/plugins/intel_cpu/src/nodes/eltwise.cpp +++ b/src/plugins/intel_cpu/src/nodes/eltwise.cpp @@ -463,7 +463,8 @@ private: OV_CASE(Algorithm::EltwiseLogicalNot, jit_logical_not_emitter), OV_CASE(Algorithm::EltwisePowerStatic, jit_power_static_emitter), OV_CASE(Algorithm::EltwisePrelu, jit_prelu_emitter), - OV_CASE(Algorithm::EltwiseErf, jit_erf_emitter)); + OV_CASE(Algorithm::EltwiseErf, jit_erf_emitter), + OV_CASE(Algorithm::EltwiseSoftSign, jit_soft_sign_emitter)); if (precisions.empty()) IE_THROW() << "Unsupported operation type for Eltwise emitter"; @@ -520,7 +521,8 @@ private: OV_CASE(Algorithm::EltwiseLogicalNot, jit_logical_not_emitter), OV_CASE(Algorithm::EltwisePowerStatic, jit_power_static_emitter), OV_CASE(Algorithm::EltwisePrelu, jit_prelu_emitter), - OV_CASE(Algorithm::EltwiseErf, jit_erf_emitter)); + OV_CASE(Algorithm::EltwiseErf, jit_erf_emitter), + OV_CASE(Algorithm::EltwiseSoftSign, jit_soft_sign_emitter)); if (!ctx.emitter) IE_THROW() << "Unsupported operation type for Eltwise emitter"; @@ -1022,6 +1024,9 @@ const std::map Eltwise::in node.algorithm = Algorithm::EltwiseSoftRelu; node.onednnAlgorithm = dnnl::algorithm::eltwise_soft_relu; }}, + {ngraph::op::v9::SoftSign::get_type_info_static(), [](const std::shared_ptr& op, Eltwise& node) { + node.algorithm = Algorithm::EltwiseSoftSign; + }}, }; @@ -1505,6 +1510,7 @@ public: case Algorithm::EltwisePowerStatic: *dst_ptr_f = powf(_opData.beta * src_f[0] + _opData.gamma, _opData.alpha); break; case Algorithm::EltwisePrelu: *dst_ptr_f = src_f[0] > 0 ? src_f[0] : src_f[0] * src_f[1]; break; case Algorithm::EltwiseErf: *dst_ptr_f = std::erf(src_f[0]); break; + case Algorithm::EltwiseSoftSign: *dst_ptr_f = src_f[0] / (1 + std::fabs(src_f[0])); break; default: IE_THROW() << "Unsupported operation type for Eltwise executor"; } } @@ -1608,6 +1614,7 @@ size_t Eltwise::getOpInputsNum() const { case Algorithm::EltwiseHsigmoid: case Algorithm::EltwiseRoundHalfToEven: case Algorithm::EltwiseRoundHalfAwayFromZero: + case Algorithm::EltwiseSoftSign: return 1; case Algorithm::EltwiseAdd: case Algorithm::EltwiseSubtract: diff --git a/src/plugins/intel_cpu/src/plugin.cpp b/src/plugins/intel_cpu/src/plugin.cpp index b6c8eb6cc46..e6a5fb9712f 100644 --- a/src/plugins/intel_cpu/src/plugin.cpp +++ b/src/plugins/intel_cpu/src/plugin.cpp @@ -85,6 +85,7 @@ #include "ngraph_transformations/snippets_mark_skipped.hpp" #include #include +#include #include #include @@ -483,6 +484,7 @@ static void TransformationUpToCPUSpecificOpSet(std::shared_ptr pass_config->disable(); pass_config->disable(); pass_config->disable(); + pass_config->disable(); pass_config->enable(); pass_config->enable(); diff --git a/src/tests/functional/plugin/cpu/single_layer_tests/activation.cpp b/src/tests/functional/plugin/cpu/single_layer_tests/activation.cpp index 8b44d10dbfa..ec472b9111e 100644 --- a/src/tests/functional/plugin/cpu/single_layer_tests/activation.cpp +++ b/src/tests/functional/plugin/cpu/single_layer_tests/activation.cpp @@ -243,7 +243,8 @@ const std::map>> activationTypes {Tan, {{}}}, {HardSigmoid, {{0.2f, 0.5f}}}, {Selu, {{1.6732f, 1.0507f}}}, - {Ceiling, {{}}} + {Ceiling, {{}}}, + {SoftSign, {{}}} }; const std::vector netPrecisions = {