diff --git a/inference-engine/src/mkldnn_plugin/nodes/mkldnn_cum_sum_node.cpp b/inference-engine/src/mkldnn_plugin/nodes/mkldnn_cum_sum_node.cpp index 68fd40d9a56..a99e30a1db0 100644 --- a/inference-engine/src/mkldnn_plugin/nodes/mkldnn_cum_sum_node.cpp +++ b/inference-engine/src/mkldnn_plugin/nodes/mkldnn_cum_sum_node.cpp @@ -12,6 +12,7 @@ #include "ie_precision.hpp" #include #include "mkldnn_cum_sum_node.h" +#include "utils/bfloat16.hpp" using namespace MKLDNNPlugin; using namespace InferenceEngine; @@ -70,8 +71,7 @@ void MKLDNNCumSumNode::initSupportedPrimitiveDescriptors() { return; dataPrecision = getOriginalInputPrecisionAtPort(CUM_SUM_DATA); - if (dataPrecision != Precision::I8 && dataPrecision != Precision::U8 && dataPrecision != Precision::I16 && dataPrecision != Precision::I32 && - dataPrecision != Precision::FP32 && dataPrecision != Precision::I64 && dataPrecision != Precision::U64 && dataPrecision != Precision::BF16) + if (!one_of(dataPrecision, Precision::I8, Precision::U8, Precision::I16, Precision::BF16, Precision::I32, Precision::FP32, Precision::I64, Precision::U64)) IE_THROW() << errorPrefix << " has unsupported 'data' input precision: " << dataPrecision.name(); if (inputShapes.size() == numOfInputs) { @@ -95,43 +95,17 @@ void MKLDNNCumSumNode::execute(mkldnn::stream strm) { if (inputShapes.size() == numOfInputs) axis = getAxis(getParentEdgeAt(AXIS)->getMemory(), getParentEdgeAt(CUM_SUM_DATA)->getMemory()); - switch (dataPrecision) { - case Precision::I8 : { - exec(); - break; - } - case Precision::U8 : { - exec(); - break; - } - case Precision::I16 : { - exec(); - break; - } - case Precision::I32 : { - exec(); - break; - } - case Precision::FP32 : { - exec(); - break; - } - case Precision::I64 : { - exec(); - break; - } - case Precision::U64 : { - exec(); - break; - } - default : { - std::string errorMsg = errorPrefix + " has unsupported 'data' input precision: " + dataPrecision.name(); - IE_THROW() << errorMsg; - } - } + OV_SWITCH(MKLDNNPlugin, CumSumExecute, this, dataPrecision, + OV_CASE(Precision::I8, int8_t), + OV_CASE(Precision::U8, uint8_t), + OV_CASE(Precision::I16, int16_t), + OV_CASE(Precision::BF16, bfloat16_t), + OV_CASE(Precision::I32, int32_t), + OV_CASE(Precision::FP32, float), + OV_CASE(Precision::I64, int64_t), + OV_CASE(Precision::U64, uint64_t)) } - template void MKLDNNCumSumNode::exec() { const auto *input = reinterpret_cast(getParentEdgeAt(CUM_SUM_DATA)->getMemoryPtr()->GetPtr()); diff --git a/inference-engine/src/mkldnn_plugin/nodes/mkldnn_cum_sum_node.h b/inference-engine/src/mkldnn_plugin/nodes/mkldnn_cum_sum_node.h index 2e5ebfaf7d8..f917a53ef34 100644 --- a/inference-engine/src/mkldnn_plugin/nodes/mkldnn_cum_sum_node.h +++ b/inference-engine/src/mkldnn_plugin/nodes/mkldnn_cum_sum_node.h @@ -47,6 +47,13 @@ private: InferenceEngine::Precision dataPrecision; std::string errorPrefix; + + template + struct CumSumExecute { + void operator()(MKLDNNCumSumNode* node) { + node->exec(); + } + }; }; } // namespace MKLDNNPlugin diff --git a/src/tests/functional/plugin/cpu/single_layer_tests/cum_sum.cpp b/src/tests/functional/plugin/cpu/single_layer_tests/cum_sum.cpp index 90a1d12b5df..bd84c142971 100644 --- a/src/tests/functional/plugin/cpu/single_layer_tests/cum_sum.cpp +++ b/src/tests/functional/plugin/cpu/single_layer_tests/cum_sum.cpp @@ -9,57 +9,55 @@ using namespace ngraph; using namespace InferenceEngine; using namespace CPUTestUtils; +using namespace ov; +using namespace test; namespace CPULayerTestsDefinitions { -using cumSumShape = std::pair, std::vector>>; using cumSumParams = std::tuple< ngraph::element::Type, // data precision - cumSumShape, // input shape + InputShape, // input shape std::int64_t, // axis bool, // exclusive bool>; // reverse -class CumSumLayerCPUTest : public testing::WithParamInterface, public ov::test::SubgraphBaseTest, public CPUTestsBase { +class CumSumLayerCPUTest : public testing::WithParamInterface, + public SubgraphBaseTest, public CPUTestsBase { public: static std::string getTestCaseName(testing::TestParamInfo obj) { ngraph::element::Type inputPrecision; - std::pair, std::vector>> shapes; + InputShape shapes; std::int64_t axis; bool exclusive; bool reverse; std::tie(inputPrecision, shapes, axis, exclusive, reverse) = obj.param; - std::ostringstream result; - result << inputPrecision << "_" << "IS=" << CommonTestUtils::partialShape2str(shapes.first) << "_" << "TS="; - for (const auto& shape : shapes.second) { - result << "("; - for (const auto& item : shape) { - result << CommonTestUtils::vec2str(item) << "_"; - } - result << ")_"; + std::ostringstream results; + results << "IS=" << CommonTestUtils::partialShape2str({shapes.first}) << "_"; + results << "TS="; + for (const auto& item : shapes.second) { + results << CommonTestUtils::vec2str(item) << "_"; } - - result << "Axis=" << axis << "_" << (exclusive ? "exclusive" : "") << "_" << (reverse ? "reverse" : ""); - return result.str(); + results << "Prc=" << inputPrecision << "_"; + results << "Axis=" << axis << "_" << (exclusive ? "exclusive" : "") << "_" << (reverse ? "reverse" : ""); + return results.str(); } protected: void SetUp() override { targetDevice = CommonTestUtils::DEVICE_CPU; - ngraph::element::Type inputPrecision; - std::pair, std::vector>> shapes; + InputShape shapes; std::int64_t axis; bool exclusive; bool reverse; - std::tie(inputPrecision, shapes, axis, exclusive, reverse) = this->GetParam(); + std::tie(inType, shapes, axis, exclusive, reverse) = this->GetParam(); + if (inType == ElementType::bf16) + rel_threshold = 0.05f; - for (size_t i = 0; i < shapes.second.size(); i++) { - targetStaticShapes.push_back(shapes.second[i]); - } - inputDynamicShapes = shapes.first; + selectedType = makeSelectedTypeStr("ref_any", inType); + init_input_shapes({shapes}); - auto params = ngraph::builder::makeDynamicParams(inputPrecision, { inputDynamicShapes.front() }); + auto params = ngraph::builder::makeDynamicParams(inType, inputDynamicShapes); auto axisNode = ngraph::opset1::Constant::create(ngraph::element::i32, ngraph::Shape{}, std::vector{axis})->output(0); auto cumSum = ngraph::builder::makeCumSum(params[0], axisNode, exclusive, reverse); @@ -72,15 +70,12 @@ TEST_P(CumSumLayerCPUTest, CompareWithRefs) { SKIP_IF_CURRENT_TEST_IS_DISABLED() run(); - // TODO: Should be uncommented after updating the CheckPluginRelatedResults() method - //CheckPluginRelatedResults(executableNetwork, "CumSum"); + CheckPluginRelatedResults(executableNetwork, "CumSum"); } const ngraph::element::TypeVector inputPrecision = { ngraph::element::i8, - ngraph::element::u8, - ngraph::element::i16, - ngraph::element::i32, + ngraph::element::bf16, ngraph::element::f32 }; @@ -90,97 +85,33 @@ const std::vector negativeAxes = { -1, -2, -3, -4, -5, -6 }; const std::vector exclusive = { true, false }; const std::vector reverse = { true, false }; -const std::vector inShapes = { - { - // dynamic - { - {-1} - }, - // target - { - {{16}, {18}, {12}} - } - }, - { - // dynamic - { - {-1, -1} - }, - // target - { - {{9, 15}, {18, 12}, {12, 12}} - } - }, - { - // dynamic - { - {-1, -1, -1} - }, - // target - { - {{16, 10, 12}, {18, 12, 10}, {12, 18, 10}} - } - }, - { - // dynamic - { - {-1, -1, -1, -1} - }, - // target - { - {{18, 20, 14, 12}, {19, 20, 14, 12}, {20, 22, 23, 25}} - } - }, - { - // dynamic - { - {-1, -1, -1, -1, -1} - }, - // target - { - {{2, 4, 6, 2, 4}, {3, 5, 6, 3, 5}, {1, 4, 2, 6, 8}} - } - }, - { - // dynamic - { - {-1, -1, -1, -1, -1, -1} - }, - // target - { - {{2, 4, 6, 2, 4, 2}, {3, 5, 6, 3, 5, 3}, {1, 4, 2, 6, 8, 1}} - } - }, - { - // dynamic - { - {-1, -1, -1, -1, -1, -1, -1} - }, - // target - { - {{2, 4, 6, 2, 4, 2, 4}, {3, 5, 6, 3, 5, 3, 5}, {1, 4, 2, 6, 8, 1, 4}} - } - }, - { - // dynamic - { - {{2, 5}, {3, 7}, {4, 8}, {5, 7}, {2, 5}, {3, 7}, {1, 2}} - }, - // target - { - {{2, 4, 6, 5, 4, 3, 1}, {3, 5, 6, 6, 5, 3, 1}, {5, 7, 4, 6, 3, 7, 2}} - } - }, - { - // dynamic - { - {{2, 5}, -1, {4, 8}, -1, -1, {3, 7}, -1} - }, - // target - { - {{2, 4, 6, 5, 4, 3, 1}, {3, 5, 6, 6, 5, 3, 1}, {5, 7, 4, 6, 3, 7, 2}} - } - }, +const std::vector inShapes = { + {{-1}, + {{16}, {18}, {12}}}, + + {{-1, -1}, + {{9, 15}, {18, 12}, {12, 12}}}, + + {{-1, -1, -1}, + {{16, 10, 12}, {18, 12, 10}, {12, 18, 10}}}, + + {{-1, -1, -1, -1}, + {{18, 20, 14, 12}, {19, 20, 14, 12}, {20, 22, 23, 25}}}, + + {{-1, -1, -1, -1, -1}, + {{2, 4, 6, 2, 4}, {3, 5, 6, 3, 5}, {1, 4, 2, 6, 8}}}, + + {{-1, -1, -1, -1, -1, -1}, + {{2, 4, 6, 2, 4, 2}, {3, 5, 6, 3, 5, 3}, {1, 4, 2, 6, 8, 1}}}, + + {{{-1, -1, -1, -1, -1, -1, -1}}, + {{2, 4, 6, 2, 4, 2, 4}, {3, 5, 6, 3, 5, 3, 5}, {1, 4, 2, 6, 8, 1, 4}}}, + + {{{2, 5}, {3, 7}, {4, 8}, {5, 7}, {2, 5}, {3, 7}, {1, 2}}, + {{2, 4, 6, 5, 4, 3, 1}, {3, 5, 6, 6, 5, 3, 1}, {5, 7, 4, 6, 3, 7, 2}}}, + + {{{2, 5}, -1, {4, 8}, -1, -1, {3, 7}, -1}, + {{2, 4, 6, 5, 4, 3, 1}, {3, 5, 6, 6, 5, 3, 1}, {5, 7, 4, 6, 3, 7, 2}}} }; const auto testCasesAxis_0 = ::testing::Combine( @@ -193,7 +124,7 @@ const auto testCasesAxis_0 = ::testing::Combine( const auto testCasesAxis_1 = ::testing::Combine( ::testing::ValuesIn(inputPrecision), - ::testing::ValuesIn(std::vector(inShapes.begin() + 1, inShapes.end())), + ::testing::ValuesIn(std::vector(inShapes.begin() + 1, inShapes.end())), ::testing::Values(axes[1]), ::testing::ValuesIn(exclusive), ::testing::ValuesIn(reverse) @@ -201,7 +132,7 @@ const auto testCasesAxis_1 = ::testing::Combine( const auto testCasesAxis_2 = ::testing::Combine( ::testing::ValuesIn(inputPrecision), - ::testing::ValuesIn(std::vector(inShapes.begin() + 2, inShapes.end())), + ::testing::ValuesIn(std::vector(inShapes.begin() + 2, inShapes.end())), ::testing::Values(axes[2]), ::testing::ValuesIn(exclusive), ::testing::ValuesIn(reverse) @@ -209,7 +140,7 @@ const auto testCasesAxis_2 = ::testing::Combine( const auto testCasesAxis_3 = ::testing::Combine( ::testing::ValuesIn(inputPrecision), - ::testing::ValuesIn(std::vector(inShapes.begin() + 3, inShapes.end())), + ::testing::ValuesIn(std::vector(inShapes.begin() + 3, inShapes.end())), ::testing::Values(axes[3]), ::testing::ValuesIn(exclusive), ::testing::ValuesIn(reverse) @@ -217,7 +148,7 @@ const auto testCasesAxis_3 = ::testing::Combine( const auto testCasesAxis_4 = ::testing::Combine( ::testing::ValuesIn(inputPrecision), - ::testing::ValuesIn(std::vector(inShapes.begin() + 4, inShapes.end())), + ::testing::ValuesIn(std::vector(inShapes.begin() + 4, inShapes.end())), ::testing::Values(axes[4]), ::testing::ValuesIn(exclusive), ::testing::ValuesIn(reverse) @@ -225,7 +156,7 @@ const auto testCasesAxis_4 = ::testing::Combine( const auto testCasesAxis_5 = ::testing::Combine( ::testing::ValuesIn(inputPrecision), - ::testing::ValuesIn(std::vector(inShapes.begin() + 5, inShapes.end())), + ::testing::ValuesIn(std::vector(inShapes.begin() + 5, inShapes.end())), ::testing::Values(axes[5]), ::testing::ValuesIn(exclusive), ::testing::ValuesIn(reverse) @@ -233,7 +164,7 @@ const auto testCasesAxis_5 = ::testing::Combine( const auto testCasesAxis_6 = ::testing::Combine( ::testing::ValuesIn(inputPrecision), - ::testing::ValuesIn(std::vector(inShapes.begin() + 6, inShapes.end())), + ::testing::ValuesIn(std::vector(inShapes.begin() + 6, inShapes.end())), ::testing::Values(axes[6]), ::testing::ValuesIn(exclusive), ::testing::ValuesIn(reverse) @@ -241,7 +172,7 @@ const auto testCasesAxis_6 = ::testing::Combine( const auto testCasesAxis_negative = ::testing::Combine( ::testing::ValuesIn(inputPrecision), - ::testing::ValuesIn(std::vector(inShapes.begin() + 6, inShapes.end())), + ::testing::ValuesIn(std::vector(inShapes.begin() + 6, inShapes.end())), ::testing::ValuesIn(negativeAxes), ::testing::ValuesIn(exclusive), ::testing::ValuesIn(reverse)