Add SoftSign to CPU plugin (#12034)
This commit is contained in:
parent
177d977449
commit
e1bcfeca9d
9
.gitignore
vendored
9
.gitignore
vendored
@ -1,5 +1,7 @@
|
||||
# build/artifact dirs
|
||||
_*
|
||||
[Bb]uild*/
|
||||
|
||||
# but ensure we don't skip __init__.py and __main__.py
|
||||
!__init__.py
|
||||
!__main__.py
|
||||
@ -9,12 +11,10 @@ _*
|
||||
# developer tools
|
||||
*.idea
|
||||
.vscode
|
||||
cmake-build-*
|
||||
.DS_Store
|
||||
**/tags
|
||||
compile_commands.json
|
||||
bin/
|
||||
build/
|
||||
.local_vimrc
|
||||
.gdb_history
|
||||
.vimspector.json
|
||||
@ -34,14 +34,13 @@ docs/IE_PLUGIN_DG/html/
|
||||
*.pydevproject
|
||||
*.settings
|
||||
*/gen/
|
||||
__pycache__
|
||||
*.swp
|
||||
/config.xml
|
||||
|
||||
# Python-specific
|
||||
*.env3
|
||||
*.?env*
|
||||
*.pyc
|
||||
|
||||
__pycache__
|
||||
# Tests-specific
|
||||
*.coverage
|
||||
*htmlcov
|
||||
|
@ -10,6 +10,8 @@ from openvino.runtime import Shape, Type
|
||||
from tests.runtime import get_runtime
|
||||
from tests.test_ngraph.util import run_op_node
|
||||
|
||||
R_TOLERANCE = 1e-6 # global relative tolerance
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("ng_api_fn", "numpy_fn", "range_start", "range_end"),
|
||||
@ -79,7 +81,8 @@ def test_unary_op_scalar(ng_api_fn, numpy_fn, input_data):
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"input_data", [(np.array([True, False, True, False])), (np.array([True])), (np.array([False]))],
|
||||
"input_data",
|
||||
[(np.array([True, False, True, False])), (np.array([True])), (np.array([False]))],
|
||||
)
|
||||
def test_logical_not(input_data):
|
||||
expected = np.logical_not(input_data)
|
||||
@ -245,38 +248,44 @@ def test_gelu_tanh_operator_with_array():
|
||||
assert np.allclose(result, expected, 1e-6, 1e-6)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"data_type",
|
||||
[
|
||||
Type.f64,
|
||||
Type.f32,
|
||||
Type.f16,
|
||||
],
|
||||
)
|
||||
def test_softsign_with_parameters(data_type):
|
||||
data = np.random.rand(4, 2).astype(data_type.to_dtype())
|
||||
type_tolerance = [
|
||||
(np.float64, 1e-6),
|
||||
(np.float32, 1e-6),
|
||||
(np.float16, 1e-3),
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("type_tolerance", type_tolerance)
|
||||
def test_softsign_with_parameters(type_tolerance):
|
||||
dtype, atol = type_tolerance
|
||||
data = np.random.uniform(-1.0, 1.0, (32, 5)).astype(dtype)
|
||||
expected = np.divide(data, np.abs(data) + 1)
|
||||
|
||||
runtime = get_runtime()
|
||||
param = ov.parameter(data.shape, data_type, name="Data")
|
||||
param = ov.parameter(data.shape, dtype, name="Data")
|
||||
result = runtime.computation(ov.softsign(param, "SoftSign"), param)(data)
|
||||
|
||||
assert np.allclose(result, expected, 1e-6, 1e-3)
|
||||
assert np.allclose(result, expected, R_TOLERANCE, atol)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"data_type",
|
||||
[
|
||||
np.float64,
|
||||
np.float32,
|
||||
np.float16,
|
||||
],
|
||||
)
|
||||
def test_softsign_with_array(data_type):
|
||||
data = np.random.rand(32, 5).astype(data_type)
|
||||
@pytest.mark.parametrize("type_tolerance", type_tolerance)
|
||||
def test_softsign_with_array(type_tolerance):
|
||||
dtype, atol = type_tolerance
|
||||
data = np.random.uniform(-1.0, 1.0, (32, 5)).astype(dtype)
|
||||
expected = np.divide(data, np.abs(data) + 1)
|
||||
|
||||
runtime = get_runtime()
|
||||
result = runtime.computation(ov.softsign(data, "SoftSign"))()
|
||||
|
||||
assert np.allclose(result, expected, 1e-6, 1e-6)
|
||||
assert np.allclose(result, expected, R_TOLERANCE, atol)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("type_tolerance", type_tolerance)
|
||||
def test_softsign(type_tolerance):
|
||||
dtype, atol = type_tolerance
|
||||
data = np.random.uniform(-1.0, 1.0, (32, 5)).astype(dtype)
|
||||
expected = np.divide(data, np.abs(data) + 1)
|
||||
|
||||
result = run_op_node([data], ov.softsign)
|
||||
|
||||
assert np.allclose(result, expected, R_TOLERANCE, atol)
|
||||
|
@ -9,6 +9,8 @@ from ngraph.impl import Shape, Type
|
||||
from tests_compatibility.runtime import get_runtime
|
||||
from tests_compatibility.test_ngraph.util import run_op_node
|
||||
|
||||
R_TOLERANCE = 1e-6 # global relative tolerance
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"ng_api_fn, numpy_fn, range_start, range_end",
|
||||
@ -244,38 +246,34 @@ def test_gelu_tanh_operator_with_array():
|
||||
assert np.allclose(result, expected, 1e-6, 1e-6)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"numpy_type",
|
||||
[
|
||||
np.float64,
|
||||
np.float32,
|
||||
np.float16,
|
||||
],
|
||||
)
|
||||
def test_softsign_with_parameters(numpy_type):
|
||||
data = np.random.rand(4, 2).astype(numpy_type)
|
||||
type_tolerance = [
|
||||
(np.float64, 1e-6),
|
||||
(np.float32, 1e-6),
|
||||
(np.float16, 1e-3),
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("type_tolerance", type_tolerance)
|
||||
def test_softsign_with_parameters(type_tolerance):
|
||||
dtype, atol = type_tolerance
|
||||
data = np.random.uniform(-1.0, 1.0, (32, 5)).astype(dtype)
|
||||
|
||||
expected = np.divide(data, np.abs(data) + 1)
|
||||
|
||||
runtime = get_runtime()
|
||||
param = ng.parameter(data.shape, numpy_type, name="Data")
|
||||
param = ng.parameter(data.shape, dtype, name="Data")
|
||||
result = runtime.computation(ng.softsign(param, "SoftSign"), param)(data)
|
||||
|
||||
assert np.allclose(result, expected, 1e-6, 1e-3)
|
||||
assert np.allclose(result, expected, R_TOLERANCE, atol)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"data_type",
|
||||
[
|
||||
np.float64,
|
||||
np.float32,
|
||||
np.float16,
|
||||
],
|
||||
)
|
||||
def test_softsign_with_array(data_type):
|
||||
data = np.random.rand(32, 5).astype(data_type)
|
||||
@pytest.mark.parametrize("type_tolerance", type_tolerance)
|
||||
def test_softsign_with_array(type_tolerance):
|
||||
dtype, atol = type_tolerance
|
||||
data = np.random.uniform(-1.0, 1.0, (32, 5)).astype(dtype)
|
||||
expected = np.divide(data, np.abs(data) + 1)
|
||||
|
||||
runtime = get_runtime()
|
||||
result = runtime.computation(ng.softsign(data, "SoftSign"))()
|
||||
|
||||
assert np.allclose(result, expected, 1e-6, 1e-6)
|
||||
assert np.allclose(result, expected, R_TOLERANCE, atol)
|
||||
|
@ -64,6 +64,7 @@ const InferenceEngine::details::caseless_unordered_map<std::string, Type> type_t
|
||||
{ "PRelu", Type::Eltwise },
|
||||
{ "Erf", Type::Eltwise },
|
||||
{ "SoftPlus", Type::Eltwise },
|
||||
{ "SoftSign", Type::Eltwise },
|
||||
{ "Reshape", Type::Reshape },
|
||||
{ "Squeeze", Type::Reshape },
|
||||
{ "Unsqueeze", Type::Reshape },
|
||||
|
@ -169,6 +169,7 @@ enum class Algorithm {
|
||||
EltwiseRoundHalfToEven,
|
||||
EltwiseRoundHalfAwayFromZero,
|
||||
EltwiseErf,
|
||||
EltwiseSoftSign,
|
||||
|
||||
// FakeQuantize algorithms
|
||||
FQCommon,
|
||||
|
@ -1839,5 +1839,47 @@ size_t jit_erf_emitter::aux_vecs_count() const {
|
||||
return 5ul;
|
||||
}
|
||||
|
||||
/// SOFT SIGN ///
|
||||
jit_soft_sign_emitter::jit_soft_sign_emitter(jit_generator *host, cpu_isa_t host_isa, const std::shared_ptr<ngraph::Node>& node, Precision exec_prc)
|
||||
: jit_emitter(host, host_isa, node, exec_prc) {
|
||||
prepare_table();
|
||||
}
|
||||
jit_soft_sign_emitter::jit_soft_sign_emitter(jit_generator *host, cpu_isa_t host_isa, Precision exec_prc)
|
||||
: jit_emitter(host, host_isa, exec_prc) {
|
||||
prepare_table();
|
||||
}
|
||||
|
||||
size_t jit_soft_sign_emitter::get_inputs_num() const { return 1; }
|
||||
|
||||
void jit_soft_sign_emitter::emit_impl(const std::vector<size_t> &in_vec_idxs, const std::vector<size_t> &out_vec_idxs,
|
||||
const std::vector<size_t> &pool_vec_idxs, const std::vector<size_t> &pool_gpr_idxs,
|
||||
const emitter_context *emit_context) const {
|
||||
if (host_isa_ == cpu::x64::sse41) {
|
||||
emit_isa<cpu::x64::sse41>(in_vec_idxs, out_vec_idxs);
|
||||
} else if (host_isa_ == cpu::x64::avx2) {
|
||||
emit_isa<cpu::x64::avx2>(in_vec_idxs, out_vec_idxs);
|
||||
} else if (host_isa_ == cpu::x64::avx512_core) {
|
||||
emit_isa<cpu::x64::avx512_core>(in_vec_idxs, out_vec_idxs);
|
||||
} else {
|
||||
assert(!"unsupported isa");
|
||||
}
|
||||
}
|
||||
|
||||
template <dnnl::impl::cpu::x64::cpu_isa_t isa>
|
||||
void jit_soft_sign_emitter::emit_isa(const std::vector<size_t> &in_vec_idxs, const std::vector<size_t> &out_vec_idxs) const {
|
||||
using Vmm = typename conditional3<isa == cpu::x64::sse41, Xmm, isa == cpu::x64::avx2, Ymm, Zmm>::type;
|
||||
Vmm vmm_src = Vmm(in_vec_idxs[0]);
|
||||
Vmm vmm_dst = Vmm(out_vec_idxs[0]);
|
||||
|
||||
h->uni_vmovups(vmm_dst, vmm_src); // y = x
|
||||
h->uni_vandps(vmm_src, vmm_src, table_val("positive_mask")); // x = abs(x)
|
||||
h->uni_vaddps(vmm_src, vmm_src, table_val("one")); // x++
|
||||
h->uni_vdivps(vmm_dst, vmm_dst, vmm_src); // y = y/x
|
||||
}
|
||||
|
||||
void jit_soft_sign_emitter::register_table_entries() {
|
||||
push_arg_entry_of("one", 0x3f800000, true);
|
||||
push_arg_entry_of("positive_mask", 0x7fffffff, true);
|
||||
}
|
||||
} // namespace intel_cpu
|
||||
} // namespace ov
|
||||
|
@ -586,5 +586,25 @@ private:
|
||||
size_t aux_vecs_count() const override;
|
||||
};
|
||||
|
||||
class jit_soft_sign_emitter : public jit_emitter {
|
||||
public:
|
||||
jit_soft_sign_emitter(dnnl::impl::cpu::x64::jit_generator *host, dnnl::impl::cpu::x64::cpu_isa_t host_isa,
|
||||
InferenceEngine::Precision exec_prc = InferenceEngine::Precision::FP32);
|
||||
jit_soft_sign_emitter(dnnl::impl::cpu::x64::jit_generator *host, dnnl::impl::cpu::x64::cpu_isa_t host_isa, const std::shared_ptr<ngraph::Node>& n,
|
||||
InferenceEngine::Precision exec_prc = InferenceEngine::Precision::FP32);
|
||||
|
||||
size_t get_inputs_num() const override;
|
||||
|
||||
private:
|
||||
void emit_impl(const std::vector<size_t> &in_vec_idxs, const std::vector<size_t> &out_vec_idxs,
|
||||
const std::vector<size_t> &pool_vec_idxs, const std::vector<size_t> &pool_gpr_idxs,
|
||||
const emitter_context *emit_context) const override;
|
||||
|
||||
template <dnnl::impl::cpu::x64::cpu_isa_t isa>
|
||||
void emit_isa(const std::vector<size_t> &in_vec_idxs, const std::vector<size_t> &out_vec_idxs) const;
|
||||
|
||||
void register_table_entries() override;
|
||||
};
|
||||
|
||||
} // namespace intel_cpu
|
||||
} // namespace ov
|
||||
|
@ -463,7 +463,8 @@ private:
|
||||
OV_CASE(Algorithm::EltwiseLogicalNot, jit_logical_not_emitter),
|
||||
OV_CASE(Algorithm::EltwisePowerStatic, jit_power_static_emitter),
|
||||
OV_CASE(Algorithm::EltwisePrelu, jit_prelu_emitter),
|
||||
OV_CASE(Algorithm::EltwiseErf, jit_erf_emitter));
|
||||
OV_CASE(Algorithm::EltwiseErf, jit_erf_emitter),
|
||||
OV_CASE(Algorithm::EltwiseSoftSign, jit_soft_sign_emitter));
|
||||
|
||||
if (precisions.empty())
|
||||
IE_THROW() << "Unsupported operation type for Eltwise emitter";
|
||||
@ -520,7 +521,8 @@ private:
|
||||
OV_CASE(Algorithm::EltwiseLogicalNot, jit_logical_not_emitter),
|
||||
OV_CASE(Algorithm::EltwisePowerStatic, jit_power_static_emitter),
|
||||
OV_CASE(Algorithm::EltwisePrelu, jit_prelu_emitter),
|
||||
OV_CASE(Algorithm::EltwiseErf, jit_erf_emitter));
|
||||
OV_CASE(Algorithm::EltwiseErf, jit_erf_emitter),
|
||||
OV_CASE(Algorithm::EltwiseSoftSign, jit_soft_sign_emitter));
|
||||
|
||||
if (!ctx.emitter)
|
||||
IE_THROW() << "Unsupported operation type for Eltwise emitter";
|
||||
@ -1022,6 +1024,9 @@ const std::map<const ngraph::DiscreteTypeInfo, Eltwise::Initializer> Eltwise::in
|
||||
node.algorithm = Algorithm::EltwiseSoftRelu;
|
||||
node.onednnAlgorithm = dnnl::algorithm::eltwise_soft_relu;
|
||||
}},
|
||||
{ngraph::op::v9::SoftSign::get_type_info_static(), [](const std::shared_ptr<ngraph::Node>& op, Eltwise& node) {
|
||||
node.algorithm = Algorithm::EltwiseSoftSign;
|
||||
}},
|
||||
};
|
||||
|
||||
|
||||
@ -1505,6 +1510,7 @@ public:
|
||||
case Algorithm::EltwisePowerStatic: *dst_ptr_f = powf(_opData.beta * src_f[0] + _opData.gamma, _opData.alpha); break;
|
||||
case Algorithm::EltwisePrelu: *dst_ptr_f = src_f[0] > 0 ? src_f[0] : src_f[0] * src_f[1]; break;
|
||||
case Algorithm::EltwiseErf: *dst_ptr_f = std::erf(src_f[0]); break;
|
||||
case Algorithm::EltwiseSoftSign: *dst_ptr_f = src_f[0] / (1 + std::fabs(src_f[0])); break;
|
||||
default: IE_THROW() << "Unsupported operation type for Eltwise executor";
|
||||
}
|
||||
}
|
||||
@ -1608,6 +1614,7 @@ size_t Eltwise::getOpInputsNum() const {
|
||||
case Algorithm::EltwiseHsigmoid:
|
||||
case Algorithm::EltwiseRoundHalfToEven:
|
||||
case Algorithm::EltwiseRoundHalfAwayFromZero:
|
||||
case Algorithm::EltwiseSoftSign:
|
||||
return 1;
|
||||
case Algorithm::EltwiseAdd:
|
||||
case Algorithm::EltwiseSubtract:
|
||||
|
@ -85,6 +85,7 @@
|
||||
#include "ngraph_transformations/snippets_mark_skipped.hpp"
|
||||
#include <transformations/op_conversions/convert_roi_align_v9_to_v3.hpp>
|
||||
#include <transformations/op_conversions/convert_roi_align_v3_to_v9.hpp>
|
||||
#include <transformations/op_conversions/softsign_decomposition.hpp>
|
||||
|
||||
#include <ngraph/opsets/opset1.hpp>
|
||||
#include <ngraph/opsets/opset2.hpp>
|
||||
@ -483,6 +484,7 @@ static void TransformationUpToCPUSpecificOpSet(std::shared_ptr<ngraph::Function>
|
||||
pass_config->disable<ngraph::pass::SliceToStridedSlice>();
|
||||
pass_config->disable<ngraph::pass::ConvertDetectionOutput8ToDetectionOutput1>();
|
||||
pass_config->disable<ngraph::pass::ConvertROIAlign9To3>();
|
||||
pass_config->disable<ngraph::pass::SoftSignDecomposition>();
|
||||
|
||||
pass_config->enable<ngraph::pass::NormalizeL2Decomposition>();
|
||||
pass_config->enable<ngraph::pass::ConvertInterpolate1ToInterpolate4>();
|
||||
|
@ -243,7 +243,8 @@ const std::map<ActivationTypes, std::vector<std::vector<float>>> activationTypes
|
||||
{Tan, {{}}},
|
||||
{HardSigmoid, {{0.2f, 0.5f}}},
|
||||
{Selu, {{1.6732f, 1.0507f}}},
|
||||
{Ceiling, {{}}}
|
||||
{Ceiling, {{}}},
|
||||
{SoftSign, {{}}}
|
||||
};
|
||||
|
||||
const std::vector<InferenceEngine::Precision> netPrecisions = {
|
||||
|
Loading…
Reference in New Issue
Block a user