[CPU] Added BF16 support for CumSum as well (#9030)

This commit is contained in:
Alexandra Sidorova 2021-12-16 10:24:57 +03:00 committed by GitHub
parent d5f84ad783
commit 58be795970
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 75 additions and 163 deletions

View File

@ -12,6 +12,7 @@
#include "ie_precision.hpp"
#include <ie_ngraph_utils.hpp>
#include "mkldnn_cum_sum_node.h"
#include "utils/bfloat16.hpp"
using namespace MKLDNNPlugin;
using namespace InferenceEngine;
@ -70,8 +71,7 @@ void MKLDNNCumSumNode::initSupportedPrimitiveDescriptors() {
return;
dataPrecision = getOriginalInputPrecisionAtPort(CUM_SUM_DATA);
if (dataPrecision != Precision::I8 && dataPrecision != Precision::U8 && dataPrecision != Precision::I16 && dataPrecision != Precision::I32 &&
dataPrecision != Precision::FP32 && dataPrecision != Precision::I64 && dataPrecision != Precision::U64 && dataPrecision != Precision::BF16)
if (!one_of(dataPrecision, Precision::I8, Precision::U8, Precision::I16, Precision::BF16, Precision::I32, Precision::FP32, Precision::I64, Precision::U64))
IE_THROW() << errorPrefix << " has unsupported 'data' input precision: " << dataPrecision.name();
if (inputShapes.size() == numOfInputs) {
@ -95,43 +95,17 @@ void MKLDNNCumSumNode::execute(mkldnn::stream strm) {
if (inputShapes.size() == numOfInputs)
axis = getAxis(getParentEdgeAt(AXIS)->getMemory(), getParentEdgeAt(CUM_SUM_DATA)->getMemory());
switch (dataPrecision) {
case Precision::I8 : {
exec<int8_t>();
break;
}
case Precision::U8 : {
exec<uint8_t>();
break;
}
case Precision::I16 : {
exec<int16_t>();
break;
}
case Precision::I32 : {
exec<int32_t>();
break;
}
case Precision::FP32 : {
exec<float>();
break;
}
case Precision::I64 : {
exec<int64_t>();
break;
}
case Precision::U64 : {
exec<uint64_t>();
break;
}
default : {
std::string errorMsg = errorPrefix + " has unsupported 'data' input precision: " + dataPrecision.name();
IE_THROW() << errorMsg;
}
}
OV_SWITCH(MKLDNNPlugin, CumSumExecute, this, dataPrecision,
OV_CASE(Precision::I8, int8_t),
OV_CASE(Precision::U8, uint8_t),
OV_CASE(Precision::I16, int16_t),
OV_CASE(Precision::BF16, bfloat16_t),
OV_CASE(Precision::I32, int32_t),
OV_CASE(Precision::FP32, float),
OV_CASE(Precision::I64, int64_t),
OV_CASE(Precision::U64, uint64_t))
}
template <typename dataType>
void MKLDNNCumSumNode::exec() {
const auto *input = reinterpret_cast<const dataType *>(getParentEdgeAt(CUM_SUM_DATA)->getMemoryPtr()->GetPtr());

View File

@ -47,6 +47,13 @@ private:
InferenceEngine::Precision dataPrecision;
std::string errorPrefix;
template<typename T>
struct CumSumExecute {
void operator()(MKLDNNCumSumNode* node) {
node->exec<T>();
}
};
};
} // namespace MKLDNNPlugin

View File

@ -9,57 +9,55 @@
using namespace ngraph;
using namespace InferenceEngine;
using namespace CPUTestUtils;
using namespace ov;
using namespace test;
namespace CPULayerTestsDefinitions {
using cumSumShape = std::pair<std::vector<ngraph::PartialShape>, std::vector<std::vector<ngraph::Shape>>>;
using cumSumParams = std::tuple<
ngraph::element::Type, // data precision
cumSumShape, // input shape
InputShape, // input shape
std::int64_t, // axis
bool, // exclusive
bool>; // reverse
class CumSumLayerCPUTest : public testing::WithParamInterface<cumSumParams>, public ov::test::SubgraphBaseTest, public CPUTestsBase {
class CumSumLayerCPUTest : public testing::WithParamInterface<cumSumParams>,
public SubgraphBaseTest, public CPUTestsBase {
public:
static std::string getTestCaseName(testing::TestParamInfo<cumSumParams> obj) {
ngraph::element::Type inputPrecision;
std::pair<std::vector<ngraph::PartialShape>, std::vector<std::vector<ngraph::Shape>>> shapes;
InputShape shapes;
std::int64_t axis;
bool exclusive;
bool reverse;
std::tie(inputPrecision, shapes, axis, exclusive, reverse) = obj.param;
std::ostringstream result;
result << inputPrecision << "_" << "IS=" << CommonTestUtils::partialShape2str(shapes.first) << "_" << "TS=";
for (const auto& shape : shapes.second) {
result << "(";
for (const auto& item : shape) {
result << CommonTestUtils::vec2str(item) << "_";
}
result << ")_";
std::ostringstream results;
results << "IS=" << CommonTestUtils::partialShape2str({shapes.first}) << "_";
results << "TS=";
for (const auto& item : shapes.second) {
results << CommonTestUtils::vec2str(item) << "_";
}
result << "Axis=" << axis << "_" << (exclusive ? "exclusive" : "") << "_" << (reverse ? "reverse" : "");
return result.str();
results << "Prc=" << inputPrecision << "_";
results << "Axis=" << axis << "_" << (exclusive ? "exclusive" : "") << "_" << (reverse ? "reverse" : "");
return results.str();
}
protected:
void SetUp() override {
targetDevice = CommonTestUtils::DEVICE_CPU;
ngraph::element::Type inputPrecision;
std::pair<std::vector<ngraph::PartialShape>, std::vector<std::vector<ngraph::Shape>>> shapes;
InputShape shapes;
std::int64_t axis;
bool exclusive;
bool reverse;
std::tie(inputPrecision, shapes, axis, exclusive, reverse) = this->GetParam();
std::tie(inType, shapes, axis, exclusive, reverse) = this->GetParam();
if (inType == ElementType::bf16)
rel_threshold = 0.05f;
for (size_t i = 0; i < shapes.second.size(); i++) {
targetStaticShapes.push_back(shapes.second[i]);
}
inputDynamicShapes = shapes.first;
selectedType = makeSelectedTypeStr("ref_any", inType);
init_input_shapes({shapes});
auto params = ngraph::builder::makeDynamicParams(inputPrecision, { inputDynamicShapes.front() });
auto params = ngraph::builder::makeDynamicParams(inType, inputDynamicShapes);
auto axisNode = ngraph::opset1::Constant::create(ngraph::element::i32, ngraph::Shape{}, std::vector<int64_t>{axis})->output(0);
auto cumSum = ngraph::builder::makeCumSum(params[0], axisNode, exclusive, reverse);
@ -72,15 +70,12 @@ TEST_P(CumSumLayerCPUTest, CompareWithRefs) {
SKIP_IF_CURRENT_TEST_IS_DISABLED()
run();
// TODO: Should be uncommented after updating the CheckPluginRelatedResults() method
//CheckPluginRelatedResults(executableNetwork, "CumSum");
CheckPluginRelatedResults(executableNetwork, "CumSum");
}
const ngraph::element::TypeVector inputPrecision = {
ngraph::element::i8,
ngraph::element::u8,
ngraph::element::i16,
ngraph::element::i32,
ngraph::element::bf16,
ngraph::element::f32
};
@ -90,97 +85,33 @@ const std::vector<int64_t> negativeAxes = { -1, -2, -3, -4, -5, -6 };
const std::vector<bool> exclusive = { true, false };
const std::vector<bool> reverse = { true, false };
const std::vector<cumSumShape> inShapes = {
{
// dynamic
{
{-1}
},
// target
{
{{16}, {18}, {12}}
}
},
{
// dynamic
{
{-1, -1}
},
// target
{
{{9, 15}, {18, 12}, {12, 12}}
}
},
{
// dynamic
{
{-1, -1, -1}
},
// target
{
{{16, 10, 12}, {18, 12, 10}, {12, 18, 10}}
}
},
{
// dynamic
{
{-1, -1, -1, -1}
},
// target
{
{{18, 20, 14, 12}, {19, 20, 14, 12}, {20, 22, 23, 25}}
}
},
{
// dynamic
{
{-1, -1, -1, -1, -1}
},
// target
{
{{2, 4, 6, 2, 4}, {3, 5, 6, 3, 5}, {1, 4, 2, 6, 8}}
}
},
{
// dynamic
{
{-1, -1, -1, -1, -1, -1}
},
// target
{
{{2, 4, 6, 2, 4, 2}, {3, 5, 6, 3, 5, 3}, {1, 4, 2, 6, 8, 1}}
}
},
{
// dynamic
{
{-1, -1, -1, -1, -1, -1, -1}
},
// target
{
{{2, 4, 6, 2, 4, 2, 4}, {3, 5, 6, 3, 5, 3, 5}, {1, 4, 2, 6, 8, 1, 4}}
}
},
{
// dynamic
{
{{2, 5}, {3, 7}, {4, 8}, {5, 7}, {2, 5}, {3, 7}, {1, 2}}
},
// target
{
{{2, 4, 6, 5, 4, 3, 1}, {3, 5, 6, 6, 5, 3, 1}, {5, 7, 4, 6, 3, 7, 2}}
}
},
{
// dynamic
{
{{2, 5}, -1, {4, 8}, -1, -1, {3, 7}, -1}
},
// target
{
{{2, 4, 6, 5, 4, 3, 1}, {3, 5, 6, 6, 5, 3, 1}, {5, 7, 4, 6, 3, 7, 2}}
}
},
const std::vector<InputShape> inShapes = {
{{-1},
{{16}, {18}, {12}}},
{{-1, -1},
{{9, 15}, {18, 12}, {12, 12}}},
{{-1, -1, -1},
{{16, 10, 12}, {18, 12, 10}, {12, 18, 10}}},
{{-1, -1, -1, -1},
{{18, 20, 14, 12}, {19, 20, 14, 12}, {20, 22, 23, 25}}},
{{-1, -1, -1, -1, -1},
{{2, 4, 6, 2, 4}, {3, 5, 6, 3, 5}, {1, 4, 2, 6, 8}}},
{{-1, -1, -1, -1, -1, -1},
{{2, 4, 6, 2, 4, 2}, {3, 5, 6, 3, 5, 3}, {1, 4, 2, 6, 8, 1}}},
{{{-1, -1, -1, -1, -1, -1, -1}},
{{2, 4, 6, 2, 4, 2, 4}, {3, 5, 6, 3, 5, 3, 5}, {1, 4, 2, 6, 8, 1, 4}}},
{{{2, 5}, {3, 7}, {4, 8}, {5, 7}, {2, 5}, {3, 7}, {1, 2}},
{{2, 4, 6, 5, 4, 3, 1}, {3, 5, 6, 6, 5, 3, 1}, {5, 7, 4, 6, 3, 7, 2}}},
{{{2, 5}, -1, {4, 8}, -1, -1, {3, 7}, -1},
{{2, 4, 6, 5, 4, 3, 1}, {3, 5, 6, 6, 5, 3, 1}, {5, 7, 4, 6, 3, 7, 2}}}
};
const auto testCasesAxis_0 = ::testing::Combine(
@ -193,7 +124,7 @@ const auto testCasesAxis_0 = ::testing::Combine(
const auto testCasesAxis_1 = ::testing::Combine(
::testing::ValuesIn(inputPrecision),
::testing::ValuesIn(std::vector<cumSumShape>(inShapes.begin() + 1, inShapes.end())),
::testing::ValuesIn(std::vector<InputShape>(inShapes.begin() + 1, inShapes.end())),
::testing::Values(axes[1]),
::testing::ValuesIn(exclusive),
::testing::ValuesIn(reverse)
@ -201,7 +132,7 @@ const auto testCasesAxis_1 = ::testing::Combine(
const auto testCasesAxis_2 = ::testing::Combine(
::testing::ValuesIn(inputPrecision),
::testing::ValuesIn(std::vector<cumSumShape>(inShapes.begin() + 2, inShapes.end())),
::testing::ValuesIn(std::vector<InputShape>(inShapes.begin() + 2, inShapes.end())),
::testing::Values(axes[2]),
::testing::ValuesIn(exclusive),
::testing::ValuesIn(reverse)
@ -209,7 +140,7 @@ const auto testCasesAxis_2 = ::testing::Combine(
const auto testCasesAxis_3 = ::testing::Combine(
::testing::ValuesIn(inputPrecision),
::testing::ValuesIn(std::vector<cumSumShape>(inShapes.begin() + 3, inShapes.end())),
::testing::ValuesIn(std::vector<InputShape>(inShapes.begin() + 3, inShapes.end())),
::testing::Values(axes[3]),
::testing::ValuesIn(exclusive),
::testing::ValuesIn(reverse)
@ -217,7 +148,7 @@ const auto testCasesAxis_3 = ::testing::Combine(
const auto testCasesAxis_4 = ::testing::Combine(
::testing::ValuesIn(inputPrecision),
::testing::ValuesIn(std::vector<cumSumShape>(inShapes.begin() + 4, inShapes.end())),
::testing::ValuesIn(std::vector<InputShape>(inShapes.begin() + 4, inShapes.end())),
::testing::Values(axes[4]),
::testing::ValuesIn(exclusive),
::testing::ValuesIn(reverse)
@ -225,7 +156,7 @@ const auto testCasesAxis_4 = ::testing::Combine(
const auto testCasesAxis_5 = ::testing::Combine(
::testing::ValuesIn(inputPrecision),
::testing::ValuesIn(std::vector<cumSumShape>(inShapes.begin() + 5, inShapes.end())),
::testing::ValuesIn(std::vector<InputShape>(inShapes.begin() + 5, inShapes.end())),
::testing::Values(axes[5]),
::testing::ValuesIn(exclusive),
::testing::ValuesIn(reverse)
@ -233,7 +164,7 @@ const auto testCasesAxis_5 = ::testing::Combine(
const auto testCasesAxis_6 = ::testing::Combine(
::testing::ValuesIn(inputPrecision),
::testing::ValuesIn(std::vector<cumSumShape>(inShapes.begin() + 6, inShapes.end())),
::testing::ValuesIn(std::vector<InputShape>(inShapes.begin() + 6, inShapes.end())),
::testing::Values(axes[6]),
::testing::ValuesIn(exclusive),
::testing::ValuesIn(reverse)
@ -241,7 +172,7 @@ const auto testCasesAxis_6 = ::testing::Combine(
const auto testCasesAxis_negative = ::testing::Combine(
::testing::ValuesIn(inputPrecision),
::testing::ValuesIn(std::vector<cumSumShape>(inShapes.begin() + 6, inShapes.end())),
::testing::ValuesIn(std::vector<InputShape>(inShapes.begin() + 6, inShapes.end())),
::testing::ValuesIn(negativeAxes),
::testing::ValuesIn(exclusive),
::testing::ValuesIn(reverse)