[CPU] Select dynamic done (#7818)
This commit is contained in:
parent
11ddd731b7
commit
984fda6305
@ -17,10 +17,6 @@ using namespace InferenceEngine;
|
||||
|
||||
bool MKLDNNSelectNode::isSupportedOperation(const std::shared_ptr<const ngraph::Node>& op, std::string& errorMessage) noexcept {
|
||||
try {
|
||||
if (isDynamicNgraphNode(op)) {
|
||||
errorMessage = "Doesn't support op with dynamic shapes";
|
||||
return false;
|
||||
}
|
||||
const auto select = std::dynamic_pointer_cast<const ngraph::opset1::Select>(op);
|
||||
if (!select) {
|
||||
errorMessage = "Only opset1 Select operation is supported";
|
||||
@ -47,7 +43,7 @@ MKLDNNSelectNode::MKLDNNSelectNode(const std::shared_ptr<ngraph::Node>& op, cons
|
||||
errorPrefix = "Select layer with name '" + op->get_friendly_name() + "'";
|
||||
const auto select = std::dynamic_pointer_cast<const ngraph::opset1::Select>(op);
|
||||
|
||||
if (op->get_input_size() != numOfInputs || op->get_output_size() != 1)
|
||||
if (inputShapes.size() != numOfInputs || outputShapes.size() != 1)
|
||||
IE_THROW() << errorPrefix << " has incorrect number of input/output edges!";
|
||||
|
||||
const auto broadcast = select->get_auto_broadcast();
|
||||
@ -59,56 +55,43 @@ MKLDNNSelectNode::MKLDNNSelectNode(const std::shared_ptr<ngraph::Node>& op, cons
|
||||
IE_THROW() << errorPrefix << " has unsupported broadcast type: " + ngraph::as_string(broadcast.m_type);
|
||||
}
|
||||
|
||||
auto conditionShapes = op->get_input_shape(CONDITION);
|
||||
if (ngraph::is_scalar(conditionShapes))
|
||||
conditionShapes = ngraph::Shape{1};
|
||||
auto thenShapes = op->get_input_shape(THEN);
|
||||
if (ngraph::is_scalar(thenShapes))
|
||||
thenShapes = ngraph::Shape{1};
|
||||
auto elseShapes = op->get_input_shape(ELSE);
|
||||
if (ngraph::is_scalar(elseShapes))
|
||||
elseShapes = ngraph::Shape{1};
|
||||
auto outputShapes = op->get_output_shape(0);
|
||||
if (ngraph::is_scalar(outputShapes))
|
||||
outputShapes = ngraph::Shape{1};
|
||||
const auto &conditionShape = getInputShapeAtPort(CONDITION).getDims();
|
||||
const auto &thenShape = getInputShapeAtPort(THEN).getDims();
|
||||
const auto &elseShape = getInputShapeAtPort(ELSE).getDims();
|
||||
const auto &outputShape = getOutputShapeAtPort(0).getDims();
|
||||
|
||||
if (broadcastType == SelectBroadcastType::NONE && ((conditionShapes != outputShapes) || (thenShapes != outputShapes) ||
|
||||
(elseShapes != outputShapes)))
|
||||
if (broadcastType == SelectBroadcastType::NONE && (!dimsEqualWeak(conditionShape, outputShape) || !dimsEqualWeak(thenShape, outputShape) ||
|
||||
!dimsEqualWeak(elseShape, outputShape))) {
|
||||
IE_THROW() << errorPrefix << " and auto_broadcast='none' has input shapes mismatch";
|
||||
}
|
||||
|
||||
if (broadcastType == SelectBroadcastType::NUMPY) {
|
||||
if (outputShapes.size() < conditionShapes.size() || outputShapes.size() < thenShapes.size() || outputShapes.size() < elseShapes.size())
|
||||
if (outputShape.size() < conditionShape.size() || outputShape.size() < thenShape.size() || outputShape.size() < elseShape.size())
|
||||
IE_THROW() << errorPrefix << " and auto_broadcast='numpy' has incompatible input and output shapes";
|
||||
|
||||
for (int condIt = conditionShapes.size() - 1, outIt = outputShapes.size() - 1; condIt >= 0; condIt--, outIt--)
|
||||
if (conditionShapes[condIt] != outputShapes[outIt] && conditionShapes[condIt] != 1)
|
||||
for (int condIt = conditionShape.size() - 1, outIt = outputShape.size() - 1; condIt >= 0; condIt--, outIt--)
|
||||
if (!dimsEqualWeak(conditionShape[condIt], outputShape[outIt]) && !dimsEqualWeak(conditionShape[condIt], 1))
|
||||
IE_THROW() << errorPrefix << " and auto_broadcast='numpy' has incompatible 'Condition' input and output shapes";
|
||||
|
||||
for (int thenIt = thenShapes.size() - 1, outIt = outputShapes.size() - 1; thenIt >= 0; thenIt--, outIt--)
|
||||
if (thenShapes[thenIt] != outputShapes[outIt] && thenShapes[thenIt] != 1)
|
||||
for (int thenIt = thenShape.size() - 1, outIt = outputShape.size() - 1; thenIt >= 0; thenIt--, outIt--)
|
||||
if (!dimsEqualWeak(thenShape[thenIt], outputShape[outIt]) && !dimsEqualWeak(thenShape[thenIt], 1))
|
||||
IE_THROW() << errorPrefix << " and auto_broadcast='numpy' has incompatible 'Then' input and output shapes";
|
||||
|
||||
for (int elseIt = elseShapes.size() - 1, outIt = outputShapes.size() - 1; elseIt >= 0; elseIt--, outIt--)
|
||||
if (elseShapes[elseIt] != outputShapes[outIt] && elseShapes[elseIt] != 1)
|
||||
for (int elseIt = elseShape.size() - 1, outIt = outputShape.size() - 1; elseIt >= 0; elseIt--, outIt--)
|
||||
if (!dimsEqualWeak(elseShape[elseIt], outputShape[outIt]) && !dimsEqualWeak(elseShape[elseIt], 1))
|
||||
IE_THROW() << errorPrefix << " and auto_broadcast='numpy' has incompatible 'Else' input and output shapes";
|
||||
}
|
||||
|
||||
resDims.resize(numOfDims, 1);
|
||||
std::copy(std::begin(outputShapes), std::end(outputShapes), std::begin(resDims) + (numOfDims - outputShapes.size()));
|
||||
if (broadcastType == SelectBroadcastType::NUMPY) {
|
||||
calcOutOffset(resOffset, resDims);
|
||||
resOffset.resize(numOfDims);
|
||||
condOffset.resize(numOfDims);
|
||||
thenOffset.resize(numOfDims);
|
||||
elseOffset.resize(numOfDims);
|
||||
|
||||
std::vector<size_t> condDims(numOfDims, 1);
|
||||
std::copy(std::begin(conditionShapes), std::end(conditionShapes), std::begin(condDims) + (numOfDims - conditionShapes.size()));
|
||||
calcInOffset(condOffset, condDims, resDims);
|
||||
|
||||
std::vector<size_t> thenDims(numOfDims, 1);
|
||||
std::copy(std::begin(thenShapes), std::end(thenShapes), std::begin(thenDims) + (numOfDims - thenShapes.size()));
|
||||
calcInOffset(thenOffset, thenDims, resDims);
|
||||
|
||||
std::vector<size_t> elseDims(numOfDims, 1);
|
||||
std::copy(std::begin(elseShapes), std::end(elseShapes), std::begin(elseDims) + (numOfDims - elseShapes.size()));
|
||||
calcInOffset(elseOffset, elseDims, resDims);
|
||||
condDims.resize(numOfDims, 1);
|
||||
thenDims.resize(numOfDims, 1);
|
||||
elseDims.resize(numOfDims, 1);
|
||||
}
|
||||
}
|
||||
|
||||
@ -140,8 +123,47 @@ void MKLDNNSelectNode::initSupportedPrimitiveDescriptors() {
|
||||
impl_desc_type::ref_any);
|
||||
}
|
||||
|
||||
void MKLDNNSelectNode::calcOutOffset(std::vector<size_t>& offset, const std::vector<size_t>& dims) {
|
||||
offset.resize(numOfDims);
|
||||
void MKLDNNSelectNode::prepareParams() {
|
||||
if (!inputShapesDefined()) {
|
||||
IE_THROW() << "Can't prepare params for eltwise node with name: " << getName();
|
||||
}
|
||||
|
||||
const auto &_conditionDims = getParentEdgesAtPort(CONDITION)[0]->getMemory().getStaticDims();
|
||||
const auto &_thenDims = getParentEdgesAtPort(THEN)[0]->getMemory().getStaticDims();
|
||||
const auto &_elseDims = getParentEdgesAtPort(ELSE)[0]->getMemory().getStaticDims();
|
||||
const auto &_outputDims = getChildEdgesAtPort(0)[0]->getMemory().getStaticDims();
|
||||
|
||||
std::fill(resDims.begin(), resDims.end(), 1);
|
||||
std::copy(std::begin(_outputDims), std::end(_outputDims), std::begin(resDims) + (numOfDims - _outputDims.size()));
|
||||
if (broadcastType == SelectBroadcastType::NUMPY) {
|
||||
std::fill(resOffset.begin(), resOffset.end(), 1);
|
||||
calcOutOffset(resOffset, resDims);
|
||||
|
||||
std::fill(condDims.begin(), condDims.end(), 1);
|
||||
std::copy(std::begin(_conditionDims), std::end(_conditionDims), std::begin(condDims) + (numOfDims - _conditionDims.size()));
|
||||
std::fill(condOffset.begin(), condOffset.end(), 1);
|
||||
calcInOffset(condOffset, condDims, resDims);
|
||||
|
||||
std::fill(thenDims.begin(), thenDims.end(), 1);
|
||||
std::copy(std::begin(_thenDims), std::end(_thenDims), std::begin(thenDims) + (numOfDims - _thenDims.size()));
|
||||
std::fill(thenOffset.begin(), thenOffset.end(), 1);
|
||||
calcInOffset(thenOffset, thenDims, resDims);
|
||||
|
||||
std::fill(elseDims.begin(), elseDims.end(), 1);
|
||||
std::copy(std::begin(_elseDims), std::end(_elseDims), std::begin(elseDims) + (numOfDims - _elseDims.size()));
|
||||
std::fill(elseOffset.begin(), elseOffset.end(), 1);
|
||||
calcInOffset(elseOffset, elseDims, resDims);
|
||||
}
|
||||
}
|
||||
|
||||
void MKLDNNSelectNode::createPrimitive() {
|
||||
if (inputShapesDefined()) {
|
||||
prepareParams();
|
||||
updateLastInputDims();
|
||||
}
|
||||
}
|
||||
|
||||
void MKLDNNSelectNode::calcOutOffset(VectorDims& offset, const VectorDims& dims) {
|
||||
int k = 1;
|
||||
for (int i = dims.size() - 1; i >= 0; i--) {
|
||||
offset[i] = k;
|
||||
@ -149,8 +171,7 @@ void MKLDNNSelectNode::calcOutOffset(std::vector<size_t>& offset, const std::vec
|
||||
}
|
||||
}
|
||||
|
||||
void MKLDNNSelectNode::calcInOffset(std::vector<size_t>& offset, const std::vector<size_t>& inDims, const std::vector<size_t>& outDims) {
|
||||
offset.resize(numOfDims);
|
||||
void MKLDNNSelectNode::calcInOffset(VectorDims& offset, const VectorDims& inDims, const VectorDims& outDims) {
|
||||
int k = 1;
|
||||
for (int i = inDims.size() - 1; i >= 0; i--) {
|
||||
offset[i] = (inDims[i] == outDims[i]) ? k : 0;
|
||||
|
@ -18,10 +18,13 @@ public:
|
||||
|
||||
void getSupportedDescriptors() override {};
|
||||
void initSupportedPrimitiveDescriptors() override;
|
||||
void createPrimitive() override {};
|
||||
void createPrimitive() override;
|
||||
void execute(mkldnn::stream strm) override;
|
||||
bool created() const override;
|
||||
|
||||
void executeDynamicImpl(mkldnn::stream strm) override { execute(strm); }
|
||||
void prepareParams() override;
|
||||
|
||||
static bool isSupportedOperation(const std::shared_ptr<const ngraph::Node>& op, std::string& errorMessage) noexcept;
|
||||
|
||||
private:
|
||||
@ -33,16 +36,20 @@ private:
|
||||
};
|
||||
|
||||
SelectBroadcastType broadcastType;
|
||||
std::vector<size_t> resDims;
|
||||
std::vector<size_t> resOffset;
|
||||
std::vector<size_t> condOffset;
|
||||
std::vector<size_t> thenOffset;
|
||||
std::vector<size_t> elseOffset;
|
||||
VectorDims resDims;
|
||||
VectorDims resOffset;
|
||||
VectorDims condOffset;
|
||||
VectorDims thenOffset;
|
||||
VectorDims elseOffset;
|
||||
|
||||
VectorDims condDims;
|
||||
VectorDims thenDims;
|
||||
VectorDims elseDims;
|
||||
|
||||
std::string errorPrefix;
|
||||
|
||||
void calcOutOffset(std::vector<size_t>& offset, const std::vector<size_t>& dims);
|
||||
void calcInOffset(std::vector<size_t>& offset, const std::vector<size_t>& inDims, const std::vector<size_t>& outDims);
|
||||
void calcOutOffset(VectorDims& offset, const VectorDims& dims);
|
||||
void calcInOffset(VectorDims& offset, const VectorDims& inDims, const VectorDims& outDims);
|
||||
template <typename COND_T, typename DATA_T>
|
||||
void execute_impl();
|
||||
};
|
||||
|
@ -0,0 +1,174 @@
|
||||
// Copyright (C) 2021 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#include "test_utils/cpu_test_utils.hpp"
|
||||
#include "ngraph_functions/builders.hpp"
|
||||
|
||||
using namespace ngraph;
|
||||
using namespace InferenceEngine;
|
||||
using namespace CPUTestUtils;
|
||||
|
||||
namespace CPULayerTestsDefinitions {
|
||||
|
||||
using selectParams = std::tuple<
|
||||
std::pair<std::vector<ngraph::PartialShape>, std::vector<std::vector<ngraph::Shape>>>, // input shapes
|
||||
ngraph::op::AutoBroadcastSpec>; // broadcast
|
||||
|
||||
class SelectLayerCPUTest : public testing::WithParamInterface<selectParams>, public LayerTestsUtils::LayerTestsCommon, public CPUTestsBase {
|
||||
public:
|
||||
static std::string getTestCaseName(testing::TestParamInfo<selectParams> obj) {
|
||||
std::pair<std::vector<ngraph::PartialShape>, std::vector<std::vector<ngraph::Shape>>> shapes;
|
||||
ngraph::op::AutoBroadcastSpec broadcast;
|
||||
std::tie(shapes, broadcast) = obj.param;
|
||||
|
||||
std::ostringstream result;
|
||||
result << "IS=" << CommonTestUtils::partialShape2str(shapes.first) << "_";
|
||||
result << "TS=";
|
||||
for (const auto& shape : shapes.second) {
|
||||
result << "(";
|
||||
for (const auto& item : shape) {
|
||||
result << CommonTestUtils::vec2str(item) << "_";
|
||||
}
|
||||
result << ")_";
|
||||
}
|
||||
result << "Broadcast=" << broadcast.m_type;
|
||||
|
||||
return result.str();
|
||||
}
|
||||
|
||||
protected:
|
||||
void SetUp() override {
|
||||
targetDevice = CommonTestUtils::DEVICE_CPU;
|
||||
|
||||
std::pair<std::vector<ngraph::PartialShape>, std::vector<std::vector<ngraph::Shape>>> shapes;
|
||||
ngraph::op::AutoBroadcastSpec broadcast;
|
||||
std::tie(shapes, broadcast) = this->GetParam();
|
||||
|
||||
for (size_t i = 0; i < shapes.second.size(); i++) {
|
||||
targetStaticShapes.push_back(shapes.second[i]);
|
||||
}
|
||||
inputDynamicShapes = shapes.first;
|
||||
|
||||
selectedType = std::string("ref_any_") + Precision(Precision::I8).name();
|
||||
|
||||
ngraph::ParameterVector paramNodesVector;
|
||||
auto paramNode = std::make_shared<ngraph::opset1::Parameter>(ngraph::element::Type_t::boolean, ngraph::Shape(targetStaticShapes[0][0]));
|
||||
paramNodesVector.push_back(paramNode);
|
||||
auto inType = ngraph::element::Type_t::f32;
|
||||
for (size_t i = 1; i < targetStaticShapes[0].size(); i++) {
|
||||
paramNode = std::make_shared<ngraph::opset1::Parameter>(inType, ngraph::Shape(targetStaticShapes[0][i]));
|
||||
paramNodesVector.push_back(paramNode);
|
||||
}
|
||||
auto paramOuts = ngraph::helpers::convert2OutputVector(ngraph::helpers::castOps2Nodes<ngraph::op::Parameter>(paramNodesVector));
|
||||
|
||||
auto select = ngraph::builder::makeSelect(paramOuts, broadcast);
|
||||
|
||||
function = std::make_shared<ngraph::Function>(select, paramNodesVector, "SelectLayerCPUTest");
|
||||
functionRefs = ngraph::clone_function(*function);
|
||||
}
|
||||
};
|
||||
|
||||
TEST_P(SelectLayerCPUTest, CompareWithRefs) {
|
||||
SKIP_IF_CURRENT_TEST_IS_DISABLED()
|
||||
|
||||
Run();
|
||||
CheckPluginRelatedResults(executableNetwork, "Select");
|
||||
}
|
||||
|
||||
std::vector<std::pair<std::vector<ngraph::PartialShape>, std::vector<std::vector<ngraph::Shape>>>> inShapesDynamicNumpy = {
|
||||
{
|
||||
// dynamic
|
||||
{
|
||||
{-1, -1, -1, -1},
|
||||
{-1, -1, -1, -1, -1},
|
||||
{-1, -1, -1, -1}
|
||||
},
|
||||
|
||||
// target
|
||||
{
|
||||
{{5, 1, 2, 1}, {8, 1, 9, 1, 1}, {5, 1, 2, 1}},
|
||||
{{1, 1, 1, 1}, {1, 1, 1, 1, 1}, {1, 1, 1, 1}},
|
||||
{{5, 9, 8, 7}, {21, 5, 9, 8, 7}, {1, 1, 1, 1}},
|
||||
}
|
||||
},
|
||||
{
|
||||
// dynamic
|
||||
{
|
||||
{-1, -1},
|
||||
{-1, -1, -1, -1, -1},
|
||||
{-1, -1, -1}
|
||||
},
|
||||
|
||||
// target
|
||||
{
|
||||
{{8, 1}, {2, 1, 1, 8, 1}, {9, 1, 1}},
|
||||
{{10, 5}, {7, 8, 3, 10, 5}, {3, 10, 5}},
|
||||
{{8, 7}, {1, 1, 1, 8, 1}, {1, 1, 7}},
|
||||
}
|
||||
},
|
||||
{
|
||||
// dynamic
|
||||
{
|
||||
{{2, 8}, {3, 7}, {1, 10}, {1, 6}, {1, 10}},
|
||||
{-1, -1, -1, -1, -1},
|
||||
{{1, 5}, {1, 11}, {5, 5}, {1, 8}}
|
||||
},
|
||||
|
||||
// target
|
||||
{
|
||||
{{5, 4, 1, 1, 1}, {5, 1, 8, 1, 1}, {1, 1, 5, 1}},
|
||||
{{8, 5, 5, 5, 1}, {8, 1, 1, 1, 8}, {5, 5, 5, 8}},
|
||||
{{2, 3, 4, 5, 6}, {2, 3, 4, 5, 6}, {3, 4, 5, 6}},
|
||||
}
|
||||
},
|
||||
{
|
||||
// dynamic
|
||||
{
|
||||
{{1, 10}},
|
||||
{{1, 15}, {2, 7}, {1, 6}, {5, 12}, {1, 20}},
|
||||
{{2, 10}, {1, 16}}
|
||||
},
|
||||
|
||||
// target
|
||||
{
|
||||
{{4}, {8, 5, 6, 6, 1}, {6, 4}},
|
||||
{{10}, {15, 7, 6, 10, 10}, {10, 10}},
|
||||
{{1}, {2, 5, 4, 5, 3}, {5, 1}},
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
const auto numpyCases = ::testing::Combine(
|
||||
::testing::ValuesIn(inShapesDynamicNumpy),
|
||||
::testing::Values(ngraph::op::AutoBroadcastSpec::NUMPY)
|
||||
);
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(smoke_CompareWithRefsNumpy_dynamic, SelectLayerCPUTest, numpyCases, SelectLayerCPUTest::getTestCaseName);
|
||||
|
||||
std::vector<std::pair<std::vector<ngraph::PartialShape>, std::vector<std::vector<ngraph::Shape>>>> inShapesDynamicNone = {
|
||||
{
|
||||
// dynamic
|
||||
{
|
||||
{{1, 10}, -1, {10, 20}, {1, 5}},
|
||||
{-1, {16, 16}, -1, -1},
|
||||
{-1, -1, -1, -1}
|
||||
},
|
||||
|
||||
// target
|
||||
{
|
||||
{{3, 16, 15, 5}, {3, 16, 15, 5}, {3, 16, 15, 5}},
|
||||
{{1, 16, 10, 1}, {1, 16, 10, 1}, {1, 16, 10, 1}},
|
||||
{{10, 16, 20, 5}, {10, 16, 20, 5}, {10, 16, 20, 5}}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
const auto noneCases = ::testing::Combine(
|
||||
::testing::ValuesIn(inShapesDynamicNone),
|
||||
::testing::Values(ngraph::op::AutoBroadcastSpec::NONE)
|
||||
);
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(smoke_CompareWithRefsNone_dynamic, SelectLayerCPUTest, noneCases, SelectLayerCPUTest::getTestCaseName);
|
||||
|
||||
} // namespace CPULayerTestsDefinitions
|
Loading…
Reference in New Issue
Block a user