[CPU] FP16 support in EmbeddingBagOffsetSum and EmbeddingBagSum (#21433)

This commit is contained in:
Aleksandr Voron
2023-12-06 19:21:02 +01:00
committed by GitHub
parent 152b4dfc0d
commit 60109caf63
3 changed files with 3 additions and 3 deletions

View File

@@ -51,7 +51,7 @@ void EmbeddingBagOffsetSum::initSupportedPrimitiveDescriptors() {
{ov::element::f32, ov::element::i8, ov::element::u8, ov::element::i32};
auto inDataPrecision = getOriginalInputPrecisionAtPort(EMB_TABLE_IDX);
if (inDataPrecision == ov::element::bf16)
if (one_of(inDataPrecision, ov::element::bf16, ov::element::f16))
inDataPrecision = ov::element::f32;
if (!supportedPrecisions.empty()) {
if (supportedPrecisions.find(inDataPrecision) == supportedPrecisions.end())

View File

@@ -48,7 +48,7 @@ void EmbeddingBagPackedSum::initSupportedPrimitiveDescriptors() {
{ov::element::f32, ov::element::i8, ov::element::u8, ov::element::i32};
auto inDataPrecision = getOriginalInputPrecisionAtPort(EMB_TABLE_IDX);
if (inDataPrecision == ov::element::bf16)
if (one_of(inDataPrecision, ov::element::bf16, ov::element::f16))
inDataPrecision = ov::element::f32;
if (!supportedPrecisions.empty()) {
if (supportedPrecisions.find(inDataPrecision) == supportedPrecisions.end())

View File

@@ -52,7 +52,7 @@ void EmbeddingSegmentsSum::initSupportedPrimitiveDescriptors() {
{ov::element::f32, ov::element::i8, ov::element::u8, ov::element::i32};
auto inDataPrecision = getOriginalInputPrecisionAtPort(EMB_TABLE_IDX);
if (inDataPrecision == ov::element::bf16)
if (one_of(inDataPrecision, ov::element::bf16, ov::element::f16))
inDataPrecision = ov::element::f32;
if (!supportedPrecisions.empty()) {
if (supportedPrecisions.find(inDataPrecision) == supportedPrecisions.end())