[CPU] Added BF16 support for CumSum as well (#9030)
This commit is contained in:
parent
d5f84ad783
commit
58be795970
@ -12,6 +12,7 @@
|
||||
#include "ie_precision.hpp"
|
||||
#include <ie_ngraph_utils.hpp>
|
||||
#include "mkldnn_cum_sum_node.h"
|
||||
#include "utils/bfloat16.hpp"
|
||||
|
||||
using namespace MKLDNNPlugin;
|
||||
using namespace InferenceEngine;
|
||||
@ -70,8 +71,7 @@ void MKLDNNCumSumNode::initSupportedPrimitiveDescriptors() {
|
||||
return;
|
||||
|
||||
dataPrecision = getOriginalInputPrecisionAtPort(CUM_SUM_DATA);
|
||||
if (dataPrecision != Precision::I8 && dataPrecision != Precision::U8 && dataPrecision != Precision::I16 && dataPrecision != Precision::I32 &&
|
||||
dataPrecision != Precision::FP32 && dataPrecision != Precision::I64 && dataPrecision != Precision::U64 && dataPrecision != Precision::BF16)
|
||||
if (!one_of(dataPrecision, Precision::I8, Precision::U8, Precision::I16, Precision::BF16, Precision::I32, Precision::FP32, Precision::I64, Precision::U64))
|
||||
IE_THROW() << errorPrefix << " has unsupported 'data' input precision: " << dataPrecision.name();
|
||||
|
||||
if (inputShapes.size() == numOfInputs) {
|
||||
@ -95,43 +95,17 @@ void MKLDNNCumSumNode::execute(mkldnn::stream strm) {
|
||||
if (inputShapes.size() == numOfInputs)
|
||||
axis = getAxis(getParentEdgeAt(AXIS)->getMemory(), getParentEdgeAt(CUM_SUM_DATA)->getMemory());
|
||||
|
||||
switch (dataPrecision) {
|
||||
case Precision::I8 : {
|
||||
exec<int8_t>();
|
||||
break;
|
||||
}
|
||||
case Precision::U8 : {
|
||||
exec<uint8_t>();
|
||||
break;
|
||||
}
|
||||
case Precision::I16 : {
|
||||
exec<int16_t>();
|
||||
break;
|
||||
}
|
||||
case Precision::I32 : {
|
||||
exec<int32_t>();
|
||||
break;
|
||||
}
|
||||
case Precision::FP32 : {
|
||||
exec<float>();
|
||||
break;
|
||||
}
|
||||
case Precision::I64 : {
|
||||
exec<int64_t>();
|
||||
break;
|
||||
}
|
||||
case Precision::U64 : {
|
||||
exec<uint64_t>();
|
||||
break;
|
||||
}
|
||||
default : {
|
||||
std::string errorMsg = errorPrefix + " has unsupported 'data' input precision: " + dataPrecision.name();
|
||||
IE_THROW() << errorMsg;
|
||||
}
|
||||
}
|
||||
OV_SWITCH(MKLDNNPlugin, CumSumExecute, this, dataPrecision,
|
||||
OV_CASE(Precision::I8, int8_t),
|
||||
OV_CASE(Precision::U8, uint8_t),
|
||||
OV_CASE(Precision::I16, int16_t),
|
||||
OV_CASE(Precision::BF16, bfloat16_t),
|
||||
OV_CASE(Precision::I32, int32_t),
|
||||
OV_CASE(Precision::FP32, float),
|
||||
OV_CASE(Precision::I64, int64_t),
|
||||
OV_CASE(Precision::U64, uint64_t))
|
||||
}
|
||||
|
||||
|
||||
template <typename dataType>
|
||||
void MKLDNNCumSumNode::exec() {
|
||||
const auto *input = reinterpret_cast<const dataType *>(getParentEdgeAt(CUM_SUM_DATA)->getMemoryPtr()->GetPtr());
|
||||
|
@ -47,6 +47,13 @@ private:
|
||||
|
||||
InferenceEngine::Precision dataPrecision;
|
||||
std::string errorPrefix;
|
||||
|
||||
template<typename T>
|
||||
struct CumSumExecute {
|
||||
void operator()(MKLDNNCumSumNode* node) {
|
||||
node->exec<T>();
|
||||
}
|
||||
};
|
||||
};
|
||||
|
||||
} // namespace MKLDNNPlugin
|
||||
|
@ -9,57 +9,55 @@
|
||||
using namespace ngraph;
|
||||
using namespace InferenceEngine;
|
||||
using namespace CPUTestUtils;
|
||||
using namespace ov;
|
||||
using namespace test;
|
||||
|
||||
namespace CPULayerTestsDefinitions {
|
||||
|
||||
using cumSumShape = std::pair<std::vector<ngraph::PartialShape>, std::vector<std::vector<ngraph::Shape>>>;
|
||||
using cumSumParams = std::tuple<
|
||||
ngraph::element::Type, // data precision
|
||||
cumSumShape, // input shape
|
||||
InputShape, // input shape
|
||||
std::int64_t, // axis
|
||||
bool, // exclusive
|
||||
bool>; // reverse
|
||||
|
||||
class CumSumLayerCPUTest : public testing::WithParamInterface<cumSumParams>, public ov::test::SubgraphBaseTest, public CPUTestsBase {
|
||||
class CumSumLayerCPUTest : public testing::WithParamInterface<cumSumParams>,
|
||||
public SubgraphBaseTest, public CPUTestsBase {
|
||||
public:
|
||||
static std::string getTestCaseName(testing::TestParamInfo<cumSumParams> obj) {
|
||||
ngraph::element::Type inputPrecision;
|
||||
std::pair<std::vector<ngraph::PartialShape>, std::vector<std::vector<ngraph::Shape>>> shapes;
|
||||
InputShape shapes;
|
||||
std::int64_t axis;
|
||||
bool exclusive;
|
||||
bool reverse;
|
||||
std::tie(inputPrecision, shapes, axis, exclusive, reverse) = obj.param;
|
||||
|
||||
std::ostringstream result;
|
||||
result << inputPrecision << "_" << "IS=" << CommonTestUtils::partialShape2str(shapes.first) << "_" << "TS=";
|
||||
for (const auto& shape : shapes.second) {
|
||||
result << "(";
|
||||
for (const auto& item : shape) {
|
||||
result << CommonTestUtils::vec2str(item) << "_";
|
||||
}
|
||||
result << ")_";
|
||||
std::ostringstream results;
|
||||
results << "IS=" << CommonTestUtils::partialShape2str({shapes.first}) << "_";
|
||||
results << "TS=";
|
||||
for (const auto& item : shapes.second) {
|
||||
results << CommonTestUtils::vec2str(item) << "_";
|
||||
}
|
||||
|
||||
result << "Axis=" << axis << "_" << (exclusive ? "exclusive" : "") << "_" << (reverse ? "reverse" : "");
|
||||
return result.str();
|
||||
results << "Prc=" << inputPrecision << "_";
|
||||
results << "Axis=" << axis << "_" << (exclusive ? "exclusive" : "") << "_" << (reverse ? "reverse" : "");
|
||||
return results.str();
|
||||
}
|
||||
|
||||
protected:
|
||||
void SetUp() override {
|
||||
targetDevice = CommonTestUtils::DEVICE_CPU;
|
||||
ngraph::element::Type inputPrecision;
|
||||
std::pair<std::vector<ngraph::PartialShape>, std::vector<std::vector<ngraph::Shape>>> shapes;
|
||||
InputShape shapes;
|
||||
std::int64_t axis;
|
||||
bool exclusive;
|
||||
bool reverse;
|
||||
std::tie(inputPrecision, shapes, axis, exclusive, reverse) = this->GetParam();
|
||||
std::tie(inType, shapes, axis, exclusive, reverse) = this->GetParam();
|
||||
if (inType == ElementType::bf16)
|
||||
rel_threshold = 0.05f;
|
||||
|
||||
for (size_t i = 0; i < shapes.second.size(); i++) {
|
||||
targetStaticShapes.push_back(shapes.second[i]);
|
||||
}
|
||||
inputDynamicShapes = shapes.first;
|
||||
selectedType = makeSelectedTypeStr("ref_any", inType);
|
||||
init_input_shapes({shapes});
|
||||
|
||||
auto params = ngraph::builder::makeDynamicParams(inputPrecision, { inputDynamicShapes.front() });
|
||||
auto params = ngraph::builder::makeDynamicParams(inType, inputDynamicShapes);
|
||||
auto axisNode = ngraph::opset1::Constant::create(ngraph::element::i32, ngraph::Shape{}, std::vector<int64_t>{axis})->output(0);
|
||||
auto cumSum = ngraph::builder::makeCumSum(params[0], axisNode, exclusive, reverse);
|
||||
|
||||
@ -72,15 +70,12 @@ TEST_P(CumSumLayerCPUTest, CompareWithRefs) {
|
||||
SKIP_IF_CURRENT_TEST_IS_DISABLED()
|
||||
|
||||
run();
|
||||
// TODO: Should be uncommented after updating the CheckPluginRelatedResults() method
|
||||
//CheckPluginRelatedResults(executableNetwork, "CumSum");
|
||||
CheckPluginRelatedResults(executableNetwork, "CumSum");
|
||||
}
|
||||
|
||||
const ngraph::element::TypeVector inputPrecision = {
|
||||
ngraph::element::i8,
|
||||
ngraph::element::u8,
|
||||
ngraph::element::i16,
|
||||
ngraph::element::i32,
|
||||
ngraph::element::bf16,
|
||||
ngraph::element::f32
|
||||
};
|
||||
|
||||
@ -90,97 +85,33 @@ const std::vector<int64_t> negativeAxes = { -1, -2, -3, -4, -5, -6 };
|
||||
const std::vector<bool> exclusive = { true, false };
|
||||
const std::vector<bool> reverse = { true, false };
|
||||
|
||||
const std::vector<cumSumShape> inShapes = {
|
||||
{
|
||||
// dynamic
|
||||
{
|
||||
{-1}
|
||||
},
|
||||
// target
|
||||
{
|
||||
{{16}, {18}, {12}}
|
||||
}
|
||||
},
|
||||
{
|
||||
// dynamic
|
||||
{
|
||||
{-1, -1}
|
||||
},
|
||||
// target
|
||||
{
|
||||
{{9, 15}, {18, 12}, {12, 12}}
|
||||
}
|
||||
},
|
||||
{
|
||||
// dynamic
|
||||
{
|
||||
{-1, -1, -1}
|
||||
},
|
||||
// target
|
||||
{
|
||||
{{16, 10, 12}, {18, 12, 10}, {12, 18, 10}}
|
||||
}
|
||||
},
|
||||
{
|
||||
// dynamic
|
||||
{
|
||||
{-1, -1, -1, -1}
|
||||
},
|
||||
// target
|
||||
{
|
||||
{{18, 20, 14, 12}, {19, 20, 14, 12}, {20, 22, 23, 25}}
|
||||
}
|
||||
},
|
||||
{
|
||||
// dynamic
|
||||
{
|
||||
{-1, -1, -1, -1, -1}
|
||||
},
|
||||
// target
|
||||
{
|
||||
{{2, 4, 6, 2, 4}, {3, 5, 6, 3, 5}, {1, 4, 2, 6, 8}}
|
||||
}
|
||||
},
|
||||
{
|
||||
// dynamic
|
||||
{
|
||||
{-1, -1, -1, -1, -1, -1}
|
||||
},
|
||||
// target
|
||||
{
|
||||
{{2, 4, 6, 2, 4, 2}, {3, 5, 6, 3, 5, 3}, {1, 4, 2, 6, 8, 1}}
|
||||
}
|
||||
},
|
||||
{
|
||||
// dynamic
|
||||
{
|
||||
{-1, -1, -1, -1, -1, -1, -1}
|
||||
},
|
||||
// target
|
||||
{
|
||||
{{2, 4, 6, 2, 4, 2, 4}, {3, 5, 6, 3, 5, 3, 5}, {1, 4, 2, 6, 8, 1, 4}}
|
||||
}
|
||||
},
|
||||
{
|
||||
// dynamic
|
||||
{
|
||||
{{2, 5}, {3, 7}, {4, 8}, {5, 7}, {2, 5}, {3, 7}, {1, 2}}
|
||||
},
|
||||
// target
|
||||
{
|
||||
{{2, 4, 6, 5, 4, 3, 1}, {3, 5, 6, 6, 5, 3, 1}, {5, 7, 4, 6, 3, 7, 2}}
|
||||
}
|
||||
},
|
||||
{
|
||||
// dynamic
|
||||
{
|
||||
{{2, 5}, -1, {4, 8}, -1, -1, {3, 7}, -1}
|
||||
},
|
||||
// target
|
||||
{
|
||||
{{2, 4, 6, 5, 4, 3, 1}, {3, 5, 6, 6, 5, 3, 1}, {5, 7, 4, 6, 3, 7, 2}}
|
||||
}
|
||||
},
|
||||
const std::vector<InputShape> inShapes = {
|
||||
{{-1},
|
||||
{{16}, {18}, {12}}},
|
||||
|
||||
{{-1, -1},
|
||||
{{9, 15}, {18, 12}, {12, 12}}},
|
||||
|
||||
{{-1, -1, -1},
|
||||
{{16, 10, 12}, {18, 12, 10}, {12, 18, 10}}},
|
||||
|
||||
{{-1, -1, -1, -1},
|
||||
{{18, 20, 14, 12}, {19, 20, 14, 12}, {20, 22, 23, 25}}},
|
||||
|
||||
{{-1, -1, -1, -1, -1},
|
||||
{{2, 4, 6, 2, 4}, {3, 5, 6, 3, 5}, {1, 4, 2, 6, 8}}},
|
||||
|
||||
{{-1, -1, -1, -1, -1, -1},
|
||||
{{2, 4, 6, 2, 4, 2}, {3, 5, 6, 3, 5, 3}, {1, 4, 2, 6, 8, 1}}},
|
||||
|
||||
{{{-1, -1, -1, -1, -1, -1, -1}},
|
||||
{{2, 4, 6, 2, 4, 2, 4}, {3, 5, 6, 3, 5, 3, 5}, {1, 4, 2, 6, 8, 1, 4}}},
|
||||
|
||||
{{{2, 5}, {3, 7}, {4, 8}, {5, 7}, {2, 5}, {3, 7}, {1, 2}},
|
||||
{{2, 4, 6, 5, 4, 3, 1}, {3, 5, 6, 6, 5, 3, 1}, {5, 7, 4, 6, 3, 7, 2}}},
|
||||
|
||||
{{{2, 5}, -1, {4, 8}, -1, -1, {3, 7}, -1},
|
||||
{{2, 4, 6, 5, 4, 3, 1}, {3, 5, 6, 6, 5, 3, 1}, {5, 7, 4, 6, 3, 7, 2}}}
|
||||
};
|
||||
|
||||
const auto testCasesAxis_0 = ::testing::Combine(
|
||||
@ -193,7 +124,7 @@ const auto testCasesAxis_0 = ::testing::Combine(
|
||||
|
||||
const auto testCasesAxis_1 = ::testing::Combine(
|
||||
::testing::ValuesIn(inputPrecision),
|
||||
::testing::ValuesIn(std::vector<cumSumShape>(inShapes.begin() + 1, inShapes.end())),
|
||||
::testing::ValuesIn(std::vector<InputShape>(inShapes.begin() + 1, inShapes.end())),
|
||||
::testing::Values(axes[1]),
|
||||
::testing::ValuesIn(exclusive),
|
||||
::testing::ValuesIn(reverse)
|
||||
@ -201,7 +132,7 @@ const auto testCasesAxis_1 = ::testing::Combine(
|
||||
|
||||
const auto testCasesAxis_2 = ::testing::Combine(
|
||||
::testing::ValuesIn(inputPrecision),
|
||||
::testing::ValuesIn(std::vector<cumSumShape>(inShapes.begin() + 2, inShapes.end())),
|
||||
::testing::ValuesIn(std::vector<InputShape>(inShapes.begin() + 2, inShapes.end())),
|
||||
::testing::Values(axes[2]),
|
||||
::testing::ValuesIn(exclusive),
|
||||
::testing::ValuesIn(reverse)
|
||||
@ -209,7 +140,7 @@ const auto testCasesAxis_2 = ::testing::Combine(
|
||||
|
||||
const auto testCasesAxis_3 = ::testing::Combine(
|
||||
::testing::ValuesIn(inputPrecision),
|
||||
::testing::ValuesIn(std::vector<cumSumShape>(inShapes.begin() + 3, inShapes.end())),
|
||||
::testing::ValuesIn(std::vector<InputShape>(inShapes.begin() + 3, inShapes.end())),
|
||||
::testing::Values(axes[3]),
|
||||
::testing::ValuesIn(exclusive),
|
||||
::testing::ValuesIn(reverse)
|
||||
@ -217,7 +148,7 @@ const auto testCasesAxis_3 = ::testing::Combine(
|
||||
|
||||
const auto testCasesAxis_4 = ::testing::Combine(
|
||||
::testing::ValuesIn(inputPrecision),
|
||||
::testing::ValuesIn(std::vector<cumSumShape>(inShapes.begin() + 4, inShapes.end())),
|
||||
::testing::ValuesIn(std::vector<InputShape>(inShapes.begin() + 4, inShapes.end())),
|
||||
::testing::Values(axes[4]),
|
||||
::testing::ValuesIn(exclusive),
|
||||
::testing::ValuesIn(reverse)
|
||||
@ -225,7 +156,7 @@ const auto testCasesAxis_4 = ::testing::Combine(
|
||||
|
||||
const auto testCasesAxis_5 = ::testing::Combine(
|
||||
::testing::ValuesIn(inputPrecision),
|
||||
::testing::ValuesIn(std::vector<cumSumShape>(inShapes.begin() + 5, inShapes.end())),
|
||||
::testing::ValuesIn(std::vector<InputShape>(inShapes.begin() + 5, inShapes.end())),
|
||||
::testing::Values(axes[5]),
|
||||
::testing::ValuesIn(exclusive),
|
||||
::testing::ValuesIn(reverse)
|
||||
@ -233,7 +164,7 @@ const auto testCasesAxis_5 = ::testing::Combine(
|
||||
|
||||
const auto testCasesAxis_6 = ::testing::Combine(
|
||||
::testing::ValuesIn(inputPrecision),
|
||||
::testing::ValuesIn(std::vector<cumSumShape>(inShapes.begin() + 6, inShapes.end())),
|
||||
::testing::ValuesIn(std::vector<InputShape>(inShapes.begin() + 6, inShapes.end())),
|
||||
::testing::Values(axes[6]),
|
||||
::testing::ValuesIn(exclusive),
|
||||
::testing::ValuesIn(reverse)
|
||||
@ -241,7 +172,7 @@ const auto testCasesAxis_6 = ::testing::Combine(
|
||||
|
||||
const auto testCasesAxis_negative = ::testing::Combine(
|
||||
::testing::ValuesIn(inputPrecision),
|
||||
::testing::ValuesIn(std::vector<cumSumShape>(inShapes.begin() + 6, inShapes.end())),
|
||||
::testing::ValuesIn(std::vector<InputShape>(inShapes.begin() + 6, inShapes.end())),
|
||||
::testing::ValuesIn(negativeAxes),
|
||||
::testing::ValuesIn(exclusive),
|
||||
::testing::ValuesIn(reverse)
|
||||
|
Loading…
Reference in New Issue
Block a user