Add SoftSign to CPU plugin (#12034)

This commit is contained in:
Pawel Raasz 2022-07-05 13:34:42 +02:00 committed by GitHub
parent 177d977449
commit e1bcfeca9d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
10 changed files with 135 additions and 55 deletions

9
.gitignore vendored
View File

@ -1,5 +1,7 @@
# build/artifact dirs # build/artifact dirs
_* _*
[Bb]uild*/
# but ensure we don't skip __init__.py and __main__.py # but ensure we don't skip __init__.py and __main__.py
!__init__.py !__init__.py
!__main__.py !__main__.py
@ -9,12 +11,10 @@ _*
# developer tools # developer tools
*.idea *.idea
.vscode .vscode
cmake-build-*
.DS_Store .DS_Store
**/tags **/tags
compile_commands.json compile_commands.json
bin/ bin/
build/
.local_vimrc .local_vimrc
.gdb_history .gdb_history
.vimspector.json .vimspector.json
@ -34,14 +34,13 @@ docs/IE_PLUGIN_DG/html/
*.pydevproject *.pydevproject
*.settings *.settings
*/gen/ */gen/
__pycache__
*.swp *.swp
/config.xml /config.xml
# Python-specific # Python-specific
*.env3 *.?env*
*.pyc *.pyc
__pycache__
# Tests-specific # Tests-specific
*.coverage *.coverage
*htmlcov *htmlcov

View File

@ -10,6 +10,8 @@ from openvino.runtime import Shape, Type
from tests.runtime import get_runtime from tests.runtime import get_runtime
from tests.test_ngraph.util import run_op_node from tests.test_ngraph.util import run_op_node
R_TOLERANCE = 1e-6 # global relative tolerance
@pytest.mark.parametrize( @pytest.mark.parametrize(
("ng_api_fn", "numpy_fn", "range_start", "range_end"), ("ng_api_fn", "numpy_fn", "range_start", "range_end"),
@ -79,7 +81,8 @@ def test_unary_op_scalar(ng_api_fn, numpy_fn, input_data):
@pytest.mark.parametrize( @pytest.mark.parametrize(
"input_data", [(np.array([True, False, True, False])), (np.array([True])), (np.array([False]))], "input_data",
[(np.array([True, False, True, False])), (np.array([True])), (np.array([False]))],
) )
def test_logical_not(input_data): def test_logical_not(input_data):
expected = np.logical_not(input_data) expected = np.logical_not(input_data)
@ -245,38 +248,44 @@ def test_gelu_tanh_operator_with_array():
assert np.allclose(result, expected, 1e-6, 1e-6) assert np.allclose(result, expected, 1e-6, 1e-6)
@pytest.mark.parametrize( type_tolerance = [
"data_type", (np.float64, 1e-6),
[ (np.float32, 1e-6),
Type.f64, (np.float16, 1e-3),
Type.f32, ]
Type.f16,
],
) @pytest.mark.parametrize("type_tolerance", type_tolerance)
def test_softsign_with_parameters(data_type): def test_softsign_with_parameters(type_tolerance):
data = np.random.rand(4, 2).astype(data_type.to_dtype()) dtype, atol = type_tolerance
data = np.random.uniform(-1.0, 1.0, (32, 5)).astype(dtype)
expected = np.divide(data, np.abs(data) + 1) expected = np.divide(data, np.abs(data) + 1)
runtime = get_runtime() runtime = get_runtime()
param = ov.parameter(data.shape, data_type, name="Data") param = ov.parameter(data.shape, dtype, name="Data")
result = runtime.computation(ov.softsign(param, "SoftSign"), param)(data) result = runtime.computation(ov.softsign(param, "SoftSign"), param)(data)
assert np.allclose(result, expected, 1e-6, 1e-3) assert np.allclose(result, expected, R_TOLERANCE, atol)
@pytest.mark.parametrize( @pytest.mark.parametrize("type_tolerance", type_tolerance)
"data_type", def test_softsign_with_array(type_tolerance):
[ dtype, atol = type_tolerance
np.float64, data = np.random.uniform(-1.0, 1.0, (32, 5)).astype(dtype)
np.float32,
np.float16,
],
)
def test_softsign_with_array(data_type):
data = np.random.rand(32, 5).astype(data_type)
expected = np.divide(data, np.abs(data) + 1) expected = np.divide(data, np.abs(data) + 1)
runtime = get_runtime() runtime = get_runtime()
result = runtime.computation(ov.softsign(data, "SoftSign"))() result = runtime.computation(ov.softsign(data, "SoftSign"))()
assert np.allclose(result, expected, 1e-6, 1e-6) assert np.allclose(result, expected, R_TOLERANCE, atol)
@pytest.mark.parametrize("type_tolerance", type_tolerance)
def test_softsign(type_tolerance):
dtype, atol = type_tolerance
data = np.random.uniform(-1.0, 1.0, (32, 5)).astype(dtype)
expected = np.divide(data, np.abs(data) + 1)
result = run_op_node([data], ov.softsign)
assert np.allclose(result, expected, R_TOLERANCE, atol)

View File

@ -9,6 +9,8 @@ from ngraph.impl import Shape, Type
from tests_compatibility.runtime import get_runtime from tests_compatibility.runtime import get_runtime
from tests_compatibility.test_ngraph.util import run_op_node from tests_compatibility.test_ngraph.util import run_op_node
R_TOLERANCE = 1e-6 # global relative tolerance
@pytest.mark.parametrize( @pytest.mark.parametrize(
"ng_api_fn, numpy_fn, range_start, range_end", "ng_api_fn, numpy_fn, range_start, range_end",
@ -244,38 +246,34 @@ def test_gelu_tanh_operator_with_array():
assert np.allclose(result, expected, 1e-6, 1e-6) assert np.allclose(result, expected, 1e-6, 1e-6)
@pytest.mark.parametrize( type_tolerance = [
"numpy_type", (np.float64, 1e-6),
[ (np.float32, 1e-6),
np.float64, (np.float16, 1e-3),
np.float32, ]
np.float16,
],
) @pytest.mark.parametrize("type_tolerance", type_tolerance)
def test_softsign_with_parameters(numpy_type): def test_softsign_with_parameters(type_tolerance):
data = np.random.rand(4, 2).astype(numpy_type) dtype, atol = type_tolerance
data = np.random.uniform(-1.0, 1.0, (32, 5)).astype(dtype)
expected = np.divide(data, np.abs(data) + 1) expected = np.divide(data, np.abs(data) + 1)
runtime = get_runtime() runtime = get_runtime()
param = ng.parameter(data.shape, numpy_type, name="Data") param = ng.parameter(data.shape, dtype, name="Data")
result = runtime.computation(ng.softsign(param, "SoftSign"), param)(data) result = runtime.computation(ng.softsign(param, "SoftSign"), param)(data)
assert np.allclose(result, expected, 1e-6, 1e-3) assert np.allclose(result, expected, R_TOLERANCE, atol)
@pytest.mark.parametrize( @pytest.mark.parametrize("type_tolerance", type_tolerance)
"data_type", def test_softsign_with_array(type_tolerance):
[ dtype, atol = type_tolerance
np.float64, data = np.random.uniform(-1.0, 1.0, (32, 5)).astype(dtype)
np.float32,
np.float16,
],
)
def test_softsign_with_array(data_type):
data = np.random.rand(32, 5).astype(data_type)
expected = np.divide(data, np.abs(data) + 1) expected = np.divide(data, np.abs(data) + 1)
runtime = get_runtime() runtime = get_runtime()
result = runtime.computation(ng.softsign(data, "SoftSign"))() result = runtime.computation(ng.softsign(data, "SoftSign"))()
assert np.allclose(result, expected, 1e-6, 1e-6) assert np.allclose(result, expected, R_TOLERANCE, atol)

View File

@ -64,6 +64,7 @@ const InferenceEngine::details::caseless_unordered_map<std::string, Type> type_t
{ "PRelu", Type::Eltwise }, { "PRelu", Type::Eltwise },
{ "Erf", Type::Eltwise }, { "Erf", Type::Eltwise },
{ "SoftPlus", Type::Eltwise }, { "SoftPlus", Type::Eltwise },
{ "SoftSign", Type::Eltwise },
{ "Reshape", Type::Reshape }, { "Reshape", Type::Reshape },
{ "Squeeze", Type::Reshape }, { "Squeeze", Type::Reshape },
{ "Unsqueeze", Type::Reshape }, { "Unsqueeze", Type::Reshape },

View File

@ -169,6 +169,7 @@ enum class Algorithm {
EltwiseRoundHalfToEven, EltwiseRoundHalfToEven,
EltwiseRoundHalfAwayFromZero, EltwiseRoundHalfAwayFromZero,
EltwiseErf, EltwiseErf,
EltwiseSoftSign,
// FakeQuantize algorithms // FakeQuantize algorithms
FQCommon, FQCommon,

View File

@ -1839,5 +1839,47 @@ size_t jit_erf_emitter::aux_vecs_count() const {
return 5ul; return 5ul;
} }
/// SOFT SIGN ///
jit_soft_sign_emitter::jit_soft_sign_emitter(jit_generator *host, cpu_isa_t host_isa, const std::shared_ptr<ngraph::Node>& node, Precision exec_prc)
: jit_emitter(host, host_isa, node, exec_prc) {
prepare_table();
}
jit_soft_sign_emitter::jit_soft_sign_emitter(jit_generator *host, cpu_isa_t host_isa, Precision exec_prc)
: jit_emitter(host, host_isa, exec_prc) {
prepare_table();
}
size_t jit_soft_sign_emitter::get_inputs_num() const { return 1; }
void jit_soft_sign_emitter::emit_impl(const std::vector<size_t> &in_vec_idxs, const std::vector<size_t> &out_vec_idxs,
const std::vector<size_t> &pool_vec_idxs, const std::vector<size_t> &pool_gpr_idxs,
const emitter_context *emit_context) const {
if (host_isa_ == cpu::x64::sse41) {
emit_isa<cpu::x64::sse41>(in_vec_idxs, out_vec_idxs);
} else if (host_isa_ == cpu::x64::avx2) {
emit_isa<cpu::x64::avx2>(in_vec_idxs, out_vec_idxs);
} else if (host_isa_ == cpu::x64::avx512_core) {
emit_isa<cpu::x64::avx512_core>(in_vec_idxs, out_vec_idxs);
} else {
assert(!"unsupported isa");
}
}
template <dnnl::impl::cpu::x64::cpu_isa_t isa>
void jit_soft_sign_emitter::emit_isa(const std::vector<size_t> &in_vec_idxs, const std::vector<size_t> &out_vec_idxs) const {
using Vmm = typename conditional3<isa == cpu::x64::sse41, Xmm, isa == cpu::x64::avx2, Ymm, Zmm>::type;
Vmm vmm_src = Vmm(in_vec_idxs[0]);
Vmm vmm_dst = Vmm(out_vec_idxs[0]);
h->uni_vmovups(vmm_dst, vmm_src); // y = x
h->uni_vandps(vmm_src, vmm_src, table_val("positive_mask")); // x = abs(x)
h->uni_vaddps(vmm_src, vmm_src, table_val("one")); // x++
h->uni_vdivps(vmm_dst, vmm_dst, vmm_src); // y = y/x
}
void jit_soft_sign_emitter::register_table_entries() {
push_arg_entry_of("one", 0x3f800000, true);
push_arg_entry_of("positive_mask", 0x7fffffff, true);
}
} // namespace intel_cpu } // namespace intel_cpu
} // namespace ov } // namespace ov

View File

@ -586,5 +586,25 @@ private:
size_t aux_vecs_count() const override; size_t aux_vecs_count() const override;
}; };
class jit_soft_sign_emitter : public jit_emitter {
public:
jit_soft_sign_emitter(dnnl::impl::cpu::x64::jit_generator *host, dnnl::impl::cpu::x64::cpu_isa_t host_isa,
InferenceEngine::Precision exec_prc = InferenceEngine::Precision::FP32);
jit_soft_sign_emitter(dnnl::impl::cpu::x64::jit_generator *host, dnnl::impl::cpu::x64::cpu_isa_t host_isa, const std::shared_ptr<ngraph::Node>& n,
InferenceEngine::Precision exec_prc = InferenceEngine::Precision::FP32);
size_t get_inputs_num() const override;
private:
void emit_impl(const std::vector<size_t> &in_vec_idxs, const std::vector<size_t> &out_vec_idxs,
const std::vector<size_t> &pool_vec_idxs, const std::vector<size_t> &pool_gpr_idxs,
const emitter_context *emit_context) const override;
template <dnnl::impl::cpu::x64::cpu_isa_t isa>
void emit_isa(const std::vector<size_t> &in_vec_idxs, const std::vector<size_t> &out_vec_idxs) const;
void register_table_entries() override;
};
} // namespace intel_cpu } // namespace intel_cpu
} // namespace ov } // namespace ov

View File

@ -463,7 +463,8 @@ private:
OV_CASE(Algorithm::EltwiseLogicalNot, jit_logical_not_emitter), OV_CASE(Algorithm::EltwiseLogicalNot, jit_logical_not_emitter),
OV_CASE(Algorithm::EltwisePowerStatic, jit_power_static_emitter), OV_CASE(Algorithm::EltwisePowerStatic, jit_power_static_emitter),
OV_CASE(Algorithm::EltwisePrelu, jit_prelu_emitter), OV_CASE(Algorithm::EltwisePrelu, jit_prelu_emitter),
OV_CASE(Algorithm::EltwiseErf, jit_erf_emitter)); OV_CASE(Algorithm::EltwiseErf, jit_erf_emitter),
OV_CASE(Algorithm::EltwiseSoftSign, jit_soft_sign_emitter));
if (precisions.empty()) if (precisions.empty())
IE_THROW() << "Unsupported operation type for Eltwise emitter"; IE_THROW() << "Unsupported operation type for Eltwise emitter";
@ -520,7 +521,8 @@ private:
OV_CASE(Algorithm::EltwiseLogicalNot, jit_logical_not_emitter), OV_CASE(Algorithm::EltwiseLogicalNot, jit_logical_not_emitter),
OV_CASE(Algorithm::EltwisePowerStatic, jit_power_static_emitter), OV_CASE(Algorithm::EltwisePowerStatic, jit_power_static_emitter),
OV_CASE(Algorithm::EltwisePrelu, jit_prelu_emitter), OV_CASE(Algorithm::EltwisePrelu, jit_prelu_emitter),
OV_CASE(Algorithm::EltwiseErf, jit_erf_emitter)); OV_CASE(Algorithm::EltwiseErf, jit_erf_emitter),
OV_CASE(Algorithm::EltwiseSoftSign, jit_soft_sign_emitter));
if (!ctx.emitter) if (!ctx.emitter)
IE_THROW() << "Unsupported operation type for Eltwise emitter"; IE_THROW() << "Unsupported operation type for Eltwise emitter";
@ -1022,6 +1024,9 @@ const std::map<const ngraph::DiscreteTypeInfo, Eltwise::Initializer> Eltwise::in
node.algorithm = Algorithm::EltwiseSoftRelu; node.algorithm = Algorithm::EltwiseSoftRelu;
node.onednnAlgorithm = dnnl::algorithm::eltwise_soft_relu; node.onednnAlgorithm = dnnl::algorithm::eltwise_soft_relu;
}}, }},
{ngraph::op::v9::SoftSign::get_type_info_static(), [](const std::shared_ptr<ngraph::Node>& op, Eltwise& node) {
node.algorithm = Algorithm::EltwiseSoftSign;
}},
}; };
@ -1505,6 +1510,7 @@ public:
case Algorithm::EltwisePowerStatic: *dst_ptr_f = powf(_opData.beta * src_f[0] + _opData.gamma, _opData.alpha); break; case Algorithm::EltwisePowerStatic: *dst_ptr_f = powf(_opData.beta * src_f[0] + _opData.gamma, _opData.alpha); break;
case Algorithm::EltwisePrelu: *dst_ptr_f = src_f[0] > 0 ? src_f[0] : src_f[0] * src_f[1]; break; case Algorithm::EltwisePrelu: *dst_ptr_f = src_f[0] > 0 ? src_f[0] : src_f[0] * src_f[1]; break;
case Algorithm::EltwiseErf: *dst_ptr_f = std::erf(src_f[0]); break; case Algorithm::EltwiseErf: *dst_ptr_f = std::erf(src_f[0]); break;
case Algorithm::EltwiseSoftSign: *dst_ptr_f = src_f[0] / (1 + std::fabs(src_f[0])); break;
default: IE_THROW() << "Unsupported operation type for Eltwise executor"; default: IE_THROW() << "Unsupported operation type for Eltwise executor";
} }
} }
@ -1608,6 +1614,7 @@ size_t Eltwise::getOpInputsNum() const {
case Algorithm::EltwiseHsigmoid: case Algorithm::EltwiseHsigmoid:
case Algorithm::EltwiseRoundHalfToEven: case Algorithm::EltwiseRoundHalfToEven:
case Algorithm::EltwiseRoundHalfAwayFromZero: case Algorithm::EltwiseRoundHalfAwayFromZero:
case Algorithm::EltwiseSoftSign:
return 1; return 1;
case Algorithm::EltwiseAdd: case Algorithm::EltwiseAdd:
case Algorithm::EltwiseSubtract: case Algorithm::EltwiseSubtract:

View File

@ -85,6 +85,7 @@
#include "ngraph_transformations/snippets_mark_skipped.hpp" #include "ngraph_transformations/snippets_mark_skipped.hpp"
#include <transformations/op_conversions/convert_roi_align_v9_to_v3.hpp> #include <transformations/op_conversions/convert_roi_align_v9_to_v3.hpp>
#include <transformations/op_conversions/convert_roi_align_v3_to_v9.hpp> #include <transformations/op_conversions/convert_roi_align_v3_to_v9.hpp>
#include <transformations/op_conversions/softsign_decomposition.hpp>
#include <ngraph/opsets/opset1.hpp> #include <ngraph/opsets/opset1.hpp>
#include <ngraph/opsets/opset2.hpp> #include <ngraph/opsets/opset2.hpp>
@ -483,6 +484,7 @@ static void TransformationUpToCPUSpecificOpSet(std::shared_ptr<ngraph::Function>
pass_config->disable<ngraph::pass::SliceToStridedSlice>(); pass_config->disable<ngraph::pass::SliceToStridedSlice>();
pass_config->disable<ngraph::pass::ConvertDetectionOutput8ToDetectionOutput1>(); pass_config->disable<ngraph::pass::ConvertDetectionOutput8ToDetectionOutput1>();
pass_config->disable<ngraph::pass::ConvertROIAlign9To3>(); pass_config->disable<ngraph::pass::ConvertROIAlign9To3>();
pass_config->disable<ngraph::pass::SoftSignDecomposition>();
pass_config->enable<ngraph::pass::NormalizeL2Decomposition>(); pass_config->enable<ngraph::pass::NormalizeL2Decomposition>();
pass_config->enable<ngraph::pass::ConvertInterpolate1ToInterpolate4>(); pass_config->enable<ngraph::pass::ConvertInterpolate1ToInterpolate4>();

View File

@ -243,7 +243,8 @@ const std::map<ActivationTypes, std::vector<std::vector<float>>> activationTypes
{Tan, {{}}}, {Tan, {{}}},
{HardSigmoid, {{0.2f, 0.5f}}}, {HardSigmoid, {{0.2f, 0.5f}}},
{Selu, {{1.6732f, 1.0507f}}}, {Selu, {{1.6732f, 1.0507f}}},
{Ceiling, {{}}} {Ceiling, {{}}},
{SoftSign, {{}}}
}; };
const std::vector<InferenceEngine::Precision> netPrecisions = { const std::vector<InferenceEngine::Precision> netPrecisions = {