[CPU] FP16 support in EmbeddingBagOffsetSum and EmbeddingBagSum (#21433)
This commit is contained in:
@@ -51,7 +51,7 @@ void EmbeddingBagOffsetSum::initSupportedPrimitiveDescriptors() {
|
||||
{ov::element::f32, ov::element::i8, ov::element::u8, ov::element::i32};
|
||||
|
||||
auto inDataPrecision = getOriginalInputPrecisionAtPort(EMB_TABLE_IDX);
|
||||
if (inDataPrecision == ov::element::bf16)
|
||||
if (one_of(inDataPrecision, ov::element::bf16, ov::element::f16))
|
||||
inDataPrecision = ov::element::f32;
|
||||
if (!supportedPrecisions.empty()) {
|
||||
if (supportedPrecisions.find(inDataPrecision) == supportedPrecisions.end())
|
||||
|
||||
@@ -48,7 +48,7 @@ void EmbeddingBagPackedSum::initSupportedPrimitiveDescriptors() {
|
||||
{ov::element::f32, ov::element::i8, ov::element::u8, ov::element::i32};
|
||||
|
||||
auto inDataPrecision = getOriginalInputPrecisionAtPort(EMB_TABLE_IDX);
|
||||
if (inDataPrecision == ov::element::bf16)
|
||||
if (one_of(inDataPrecision, ov::element::bf16, ov::element::f16))
|
||||
inDataPrecision = ov::element::f32;
|
||||
if (!supportedPrecisions.empty()) {
|
||||
if (supportedPrecisions.find(inDataPrecision) == supportedPrecisions.end())
|
||||
|
||||
@@ -52,7 +52,7 @@ void EmbeddingSegmentsSum::initSupportedPrimitiveDescriptors() {
|
||||
{ov::element::f32, ov::element::i8, ov::element::u8, ov::element::i32};
|
||||
|
||||
auto inDataPrecision = getOriginalInputPrecisionAtPort(EMB_TABLE_IDX);
|
||||
if (inDataPrecision == ov::element::bf16)
|
||||
if (one_of(inDataPrecision, ov::element::bf16, ov::element::f16))
|
||||
inDataPrecision = ov::element::f32;
|
||||
if (!supportedPrecisions.empty()) {
|
||||
if (supportedPrecisions.find(inDataPrecision) == supportedPrecisions.end())
|
||||
|
||||
Reference in New Issue
Block a user