[IE][VPU]: Refactoring of SpecialStageProcessor (#2885)
* SpecialStageProcessor refactoring * Fix for Yolo-v3-pytorch and related test
This commit is contained in:
parent
9cb3c2a6be
commit
5bc74aac75
@ -5,16 +5,208 @@
|
||||
#include "vpu/middleend/special_stage_processor.hpp"
|
||||
|
||||
#include <vector>
|
||||
#include <set>
|
||||
#include <utility>
|
||||
|
||||
namespace vpu {
|
||||
|
||||
namespace {
|
||||
|
||||
struct NeedCopyDesc {
|
||||
bool isCopyNeed = false;
|
||||
bool isCopyOptional = false;
|
||||
};
|
||||
|
||||
NeedCopyDesc isOutputCopyRequired(
|
||||
const Stage& stage,
|
||||
const StageOutput& outputEdge,
|
||||
const Data& inputData) {
|
||||
NeedCopyDesc needCopyDesc;
|
||||
auto output = outputEdge->output();
|
||||
if (output->usage() != DataUsage::Intermediate) {
|
||||
needCopyDesc.isCopyNeed = true;
|
||||
} else if (output->parentDataToDataEdge() != nullptr) {
|
||||
needCopyDesc.isCopyNeed = true;
|
||||
} else {
|
||||
//
|
||||
// Check output StridesRequirement
|
||||
//
|
||||
|
||||
IE_ASSERT(output->checkStrides(output->requiredStrides()));
|
||||
if (!checkStrides(output->desc(), inputData->strides(), output->requiredStrides())) {
|
||||
needCopyDesc.isCopyNeed = true;
|
||||
}
|
||||
|
||||
//
|
||||
// Check consumers StridesRequirement.
|
||||
//
|
||||
|
||||
if (!needCopyDesc.isCopyNeed) {
|
||||
for (const auto& consumerEdge : output->consumerEdges()) {
|
||||
const auto& consumerInfo = consumerEdge->consumer()->getDataStridesRequirements();
|
||||
if (consumerInfo.hasInput(consumerEdge)) {
|
||||
const auto& consumerStrideReqs = consumerInfo.getInput(consumerEdge);
|
||||
IE_ASSERT(output->checkStrides(consumerStrideReqs));
|
||||
if (!checkStrides(output->desc(), inputData->strides(), consumerStrideReqs)) {
|
||||
needCopyDesc.isCopyNeed = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return needCopyDesc;
|
||||
}
|
||||
|
||||
NeedCopyDesc isInputCopyRequired(
|
||||
const Stage& stage,
|
||||
const StageInput& inputEdge,
|
||||
const Data& outputData) {
|
||||
auto input = inputEdge->input();
|
||||
NeedCopyDesc needCopyDesc;
|
||||
if (input->usage() != DataUsage::Intermediate) {
|
||||
needCopyDesc.isCopyNeed = true;
|
||||
} else if (input->parentDataToDataEdge() != nullptr) {
|
||||
needCopyDesc.isCopyNeed = true;
|
||||
} else {
|
||||
//
|
||||
// Check input StridesRequirement.
|
||||
//
|
||||
|
||||
IE_ASSERT(input->checkStrides(input->requiredStrides()));
|
||||
if (!checkStrides(input->desc(), outputData->strides(), input->requiredStrides())) {
|
||||
needCopyDesc.isCopyNeed = true;
|
||||
}
|
||||
|
||||
//
|
||||
// Check consumers StridesRequirement.
|
||||
//
|
||||
|
||||
if (!needCopyDesc.isCopyNeed) {
|
||||
for (const auto& consumerEdge : input->consumerEdges()) {
|
||||
const auto& consumerInfo = consumerEdge->consumer()->getDataStridesRequirements();
|
||||
|
||||
if (consumerInfo.hasInput(consumerEdge)) {
|
||||
const auto& consumerStrideReqs = consumerInfo.getInput(consumerEdge);
|
||||
IE_ASSERT(input->checkStrides(consumerStrideReqs));
|
||||
|
||||
if (!checkStrides(input->desc(), outputData->strides(), consumerStrideReqs)) {
|
||||
needCopyDesc.isCopyNeed = true;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
//
|
||||
// Check producer StridesRequirement.
|
||||
//
|
||||
|
||||
if (!needCopyDesc.isCopyNeed) {
|
||||
if (auto producerEdge = input->producerEdge()) {
|
||||
const auto& producerInfo = producerEdge->producer()->getDataStridesRequirements();
|
||||
|
||||
if (producerInfo.hasOutput(producerEdge)) {
|
||||
const auto& producerStrideReqs = producerInfo.getOutput(producerEdge);
|
||||
IE_ASSERT(input->checkStrides(producerStrideReqs));
|
||||
|
||||
if (!checkStrides(input->desc(), outputData->strides(), producerStrideReqs)) {
|
||||
needCopyDesc.isCopyNeed = true;
|
||||
}
|
||||
}
|
||||
|
||||
if (!needCopyDesc.isCopyNeed) {
|
||||
//
|
||||
// To reduce the size of HW output (still can be optimized).
|
||||
//
|
||||
|
||||
if (producerEdge->producer()->category() == StageCategory::HW) {
|
||||
needCopyDesc.isCopyNeed = true;
|
||||
needCopyDesc.isCopyOptional = true;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return needCopyDesc;
|
||||
}
|
||||
|
||||
Data insertCopyOfInput(const Model& model,
|
||||
const Stage& stage,
|
||||
const StageInput& edge,
|
||||
const StageBuilder::Ptr& _stageBuilder,
|
||||
const NeedCopyDesc& desc) {
|
||||
auto data = edge->input();
|
||||
|
||||
Data copy;
|
||||
if (data->usage() == DataUsage::Const) {
|
||||
copy = model->addNewData(data->name() + "@copy", data->desc());
|
||||
} else {
|
||||
copy = model->duplicateData(data, "@copy");
|
||||
copy->resetRequiredStrides();
|
||||
}
|
||||
if (stage->type() == StageType::Reshape)
|
||||
copy->updateRequiredStrides(StridesRequirement::compact());
|
||||
|
||||
bool hasMultipleInputs = stage->numInputs() > 1;
|
||||
auto inputNumStr = hasMultipleInputs ? formatString("@input=%d", edge->portInd()) : "";
|
||||
std::stringstream typeAsString;
|
||||
typeAsString << stage->type();
|
||||
|
||||
auto copyStage = _stageBuilder->addCopyStage(
|
||||
model,
|
||||
formatString("%s%s@copy-for-%s", stage->name(), inputNumStr, typeAsString),
|
||||
stage->origLayer(),
|
||||
data,
|
||||
copy,
|
||||
formatString("special::%s", typeAsString));
|
||||
if (stage->type() != StageType::Reshape) {
|
||||
copyStage->attrs().set<bool>("optional", desc.isCopyOptional);
|
||||
}
|
||||
if (stage->attrs().has("batchInd")) {
|
||||
copyStage->attrs().set("batchInd", stage->attrs().get<int>("batchInd"));
|
||||
}
|
||||
|
||||
model->replaceStageInput(edge, copy);
|
||||
|
||||
return copy;
|
||||
}
|
||||
|
||||
Data insertCopyOfOutput(const Model& model,
|
||||
const Stage& stage,
|
||||
const StageOutput& edge,
|
||||
const StageBuilder::Ptr& _stageBuilder) {
|
||||
auto data = edge->output();
|
||||
auto copy = model->duplicateData(data, "@copy");
|
||||
copy->resetRequiredStrides();
|
||||
|
||||
model->replaceStageOutput(edge, copy);
|
||||
|
||||
bool hasMultipleOutputs = stage->numOutputs() > 1;
|
||||
auto outputNumStr = hasMultipleOutputs ? formatString("@output=%d", edge->portInd()) : "";
|
||||
std::stringstream typeAsString;
|
||||
typeAsString << stage->type();
|
||||
|
||||
auto copyStage = _stageBuilder->addCopyStage(
|
||||
model,
|
||||
formatString("%s%s@copy-for-%s", stage->name(), outputNumStr, typeAsString),
|
||||
stage->origLayer(),
|
||||
copy,
|
||||
data,
|
||||
formatString("special::%s", typeAsString));
|
||||
if (stage->attrs().has("batchInd")) {
|
||||
copyStage->attrs().set("batchInd", stage->attrs().get<int>("batchInd"));
|
||||
}
|
||||
|
||||
return copy;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
|
||||
void SpecialStageProcessor::processSplit(
|
||||
const Model& model,
|
||||
const Stage& stage) {
|
||||
IE_ASSERT(stage->type() == StageType::Split);
|
||||
|
||||
auto input = stage->input(0);
|
||||
|
||||
const auto& offsets = stage->attrs().get<std::vector<DimValues>>("offsets");
|
||||
@ -34,70 +226,9 @@ void SpecialStageProcessor::processSplit(
|
||||
IE_ASSERT(p.second + output->desc().dim(p.first) <= input->desc().dim(p.first));
|
||||
}
|
||||
|
||||
//
|
||||
// Check if we need to insert Copy stage
|
||||
//
|
||||
|
||||
bool needCopy = false;
|
||||
if (output->usage() != DataUsage::Intermediate) {
|
||||
needCopy = true;
|
||||
} else if (output->parentDataToDataEdge() != nullptr) {
|
||||
needCopy = true;
|
||||
} else {
|
||||
//
|
||||
// Check output StridesRequirement.
|
||||
//
|
||||
|
||||
IE_ASSERT(output->checkStrides(output->requiredStrides()));
|
||||
if (!checkStrides(output->desc(), input->strides(), output->requiredStrides())) {
|
||||
needCopy = true;
|
||||
}
|
||||
|
||||
//
|
||||
// Check consumers StridesRequirement.
|
||||
//
|
||||
|
||||
if (!needCopy) {
|
||||
for (const auto& consumerEdge : output->consumerEdges()) {
|
||||
const auto& consumerInfo = consumerEdge->consumer()->getDataStridesRequirements();
|
||||
|
||||
if (consumerInfo.hasInput(consumerEdge)) {
|
||||
const auto& consumerStrideReqs = consumerInfo.getInput(consumerEdge);
|
||||
IE_ASSERT(output->checkStrides(consumerStrideReqs));
|
||||
|
||||
if (!checkStrides(output->desc(), input->strides(), consumerStrideReqs)) {
|
||||
needCopy = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
//
|
||||
// Insert Copy if needed
|
||||
//
|
||||
|
||||
if (needCopy) {
|
||||
auto outputCopy = model->duplicateData(output, "@copy");
|
||||
outputCopy->resetRequiredStrides();
|
||||
|
||||
auto outPortInd = outEdge->portInd();
|
||||
|
||||
model->replaceStageOutput(outEdge, outputCopy);
|
||||
|
||||
auto copyStage = _stageBuilder->addCopyStage(
|
||||
model,
|
||||
formatString("%s@output=%d@copy-for-split", stage->name(), outPortInd),
|
||||
stage->origLayer(),
|
||||
outputCopy,
|
||||
output,
|
||||
"special::split");
|
||||
if (stage->attrs().has("batchInd")) {
|
||||
copyStage->attrs().set("batchInd", stage->attrs().get<int>("batchInd"));
|
||||
}
|
||||
|
||||
output = outputCopy;
|
||||
auto desc = isOutputCopyRequired(stage, outEdge, input);
|
||||
if (desc.isCopyNeed) {
|
||||
output = insertCopyOfOutput(model, stage, outEdge, _stageBuilder);
|
||||
}
|
||||
|
||||
//
|
||||
@ -136,113 +267,9 @@ void SpecialStageProcessor::processConcat(
|
||||
IE_ASSERT(p.second + input->desc().dim(p.first) <= output->desc().dim(p.first));
|
||||
}
|
||||
|
||||
//
|
||||
// Check if we need to insert Copy stage
|
||||
//
|
||||
|
||||
bool needCopy = false;
|
||||
bool optionalCopy = false;
|
||||
if (input->usage() != DataUsage::Intermediate) {
|
||||
needCopy = true;
|
||||
optionalCopy = false;
|
||||
} else if (input->parentDataToDataEdge() != nullptr) {
|
||||
needCopy = true;
|
||||
optionalCopy = false;
|
||||
} else {
|
||||
//
|
||||
// Check input StridesRequirement.
|
||||
//
|
||||
|
||||
IE_ASSERT(input->checkStrides(input->requiredStrides()));
|
||||
if (!checkStrides(input->desc(), output->strides(), input->requiredStrides())) {
|
||||
needCopy = true;
|
||||
optionalCopy = false;
|
||||
}
|
||||
|
||||
//
|
||||
// Check consumers StridesRequirement.
|
||||
//
|
||||
|
||||
if (!needCopy) {
|
||||
for (const auto& consumerEdge : input->consumerEdges()) {
|
||||
const auto& consumerInfo = consumerEdge->consumer()->getDataStridesRequirements();
|
||||
|
||||
if (consumerInfo.hasInput(consumerEdge)) {
|
||||
const auto& consumerStrideReqs = consumerInfo.getInput(consumerEdge);
|
||||
IE_ASSERT(input->checkStrides(consumerStrideReqs));
|
||||
|
||||
if (!checkStrides(input->desc(), output->strides(), consumerStrideReqs)) {
|
||||
needCopy = true;
|
||||
optionalCopy = false;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
//
|
||||
// Check producer StridesRequirement.
|
||||
//
|
||||
|
||||
if (!needCopy) {
|
||||
if (auto producerEdge = input->producerEdge()) {
|
||||
const auto& producerInfo = producerEdge->producer()->getDataStridesRequirements();
|
||||
|
||||
if (producerInfo.hasOutput(producerEdge)) {
|
||||
const auto& producerStrideReqs = producerInfo.getOutput(producerEdge);
|
||||
IE_ASSERT(input->checkStrides(producerStrideReqs));
|
||||
|
||||
if (!checkStrides(input->desc(), output->strides(), producerStrideReqs)) {
|
||||
needCopy = true;
|
||||
optionalCopy = false;
|
||||
}
|
||||
}
|
||||
|
||||
if (!needCopy) {
|
||||
//
|
||||
// To reduce the size of HW output (still can be optimized).
|
||||
//
|
||||
|
||||
if (producerEdge->producer()->category() == StageCategory::HW) {
|
||||
needCopy = true;
|
||||
optionalCopy = true;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
//
|
||||
// Insert Copy if needed
|
||||
//
|
||||
|
||||
if (needCopy) {
|
||||
Data inputCopy;
|
||||
if (input->usage() == DataUsage::Const) {
|
||||
inputCopy = model->addNewData(
|
||||
input->name() + "@copy",
|
||||
input->desc());
|
||||
} else {
|
||||
inputCopy = model->duplicateData(
|
||||
input,
|
||||
"@copy");
|
||||
inputCopy->resetRequiredStrides();
|
||||
}
|
||||
|
||||
auto copyStage = _stageBuilder->addCopyStage(
|
||||
model,
|
||||
formatString("%s@input=%d@copy-for-concat", stage->name(), inEdge->portInd()),
|
||||
stage->origLayer(),
|
||||
input,
|
||||
inputCopy,
|
||||
"special::concat");
|
||||
copyStage->attrs().set<bool>("optional", optionalCopy);
|
||||
if (stage->attrs().has("batchInd")) {
|
||||
copyStage->attrs().set("batchInd", stage->attrs().get<int>("batchInd"));
|
||||
}
|
||||
|
||||
model->replaceStageInput(inEdge, inputCopy);
|
||||
|
||||
input = inputCopy;
|
||||
NeedCopyDesc desc = isInputCopyRequired(stage, inEdge, output);
|
||||
if (desc.isCopyNeed) {
|
||||
input = insertCopyOfInput(model, stage, inEdge, _stageBuilder, desc);
|
||||
}
|
||||
|
||||
//
|
||||
@ -272,50 +299,12 @@ void SpecialStageProcessor::processReshape(
|
||||
IE_ASSERT(output->desc().dimsOrder() == DimsOrder::fromNumDims(output->desc().numDims()));
|
||||
IE_ASSERT(output->checkStrides(StridesRequirement::compact()));
|
||||
|
||||
//
|
||||
// Check if we need to insert Copy stage
|
||||
//
|
||||
|
||||
bool needCopy = false;
|
||||
if (input->usage() != DataUsage::Intermediate &&
|
||||
output->usage() != DataUsage::Intermediate) {
|
||||
needCopy = true;
|
||||
} else if (input->parentDataToDataEdge() != nullptr &&
|
||||
output->parentDataToDataEdge() != nullptr) {
|
||||
needCopy = true;
|
||||
}
|
||||
|
||||
//
|
||||
// Insert Copy if needed
|
||||
//
|
||||
|
||||
if (needCopy) {
|
||||
Data inputCopy;
|
||||
if (input->usage() == DataUsage::Const) {
|
||||
inputCopy = model->addNewData(
|
||||
input->name() + "@copy",
|
||||
input->desc());
|
||||
} else {
|
||||
inputCopy = model->duplicateData(
|
||||
input,
|
||||
"@copy");
|
||||
}
|
||||
inputCopy->updateRequiredStrides(StridesRequirement::compact());
|
||||
|
||||
auto copyStage = _stageBuilder->addCopyStage(
|
||||
model,
|
||||
formatString("%s@copy-for-reshape", stage->name()),
|
||||
stage->origLayer(),
|
||||
input,
|
||||
inputCopy,
|
||||
"special::reshape");
|
||||
if (stage->attrs().has("batchInd")) {
|
||||
copyStage->attrs().set("batchInd", stage->attrs().get<int>("batchInd"));
|
||||
}
|
||||
|
||||
model->replaceStageInput(stage->inputEdge(0), inputCopy);
|
||||
|
||||
input = inputCopy;
|
||||
NeedCopyDesc desc;
|
||||
if ((input->usage() != DataUsage::Intermediate || input->parentDataToDataEdge() != nullptr) &&
|
||||
(output->usage() != DataUsage::Intermediate || output->parentDataToDataEdge() != nullptr))
|
||||
desc.isCopyNeed = true;
|
||||
if (desc.isCopyNeed) {
|
||||
input = insertCopyOfInput(model, stage, stage->inputEdge(0), _stageBuilder, desc);
|
||||
}
|
||||
|
||||
//
|
||||
@ -330,16 +319,19 @@ void SpecialStageProcessor::processReshape(
|
||||
.mode(SharedDataMode::Reshape)
|
||||
.order(SharedDataOrder::ChildWritesToParent)
|
||||
.done();
|
||||
} else {
|
||||
IE_ASSERT(output->usage() == DataUsage::Intermediate);
|
||||
IE_ASSERT(output->parentDataToDataEdge() == nullptr);
|
||||
|
||||
} else if (output->usage() == DataUsage::Intermediate &&
|
||||
output->parentDataToDataEdge() == nullptr) {
|
||||
model->connectDataWithData()
|
||||
.parent(input)
|
||||
.child(output)
|
||||
.mode(SharedDataMode::Reshape)
|
||||
.order(SharedDataOrder::ParentWritesToChild)
|
||||
.done();
|
||||
} else {
|
||||
IE_ASSERT(input->usage() == DataUsage::Intermediate &&
|
||||
input->parentDataToDataEdge() == nullptr);
|
||||
IE_ASSERT(output->usage() == DataUsage::Intermediate &&
|
||||
output->parentDataToDataEdge() == nullptr);
|
||||
}
|
||||
}
|
||||
|
||||
@ -359,113 +351,9 @@ void SpecialStageProcessor::processExpand(
|
||||
IE_ASSERT(p.second + input->desc().dim(p.first) <= output->desc().dim(p.first));
|
||||
}
|
||||
|
||||
//
|
||||
// Check if we need to insert Copy stage
|
||||
//
|
||||
|
||||
bool needCopy = false;
|
||||
bool optionalCopy = false;
|
||||
if (input->usage() != DataUsage::Intermediate) {
|
||||
needCopy = true;
|
||||
optionalCopy = false;
|
||||
} else if (input->parentDataToDataEdge() != nullptr) {
|
||||
needCopy = true;
|
||||
optionalCopy = false;
|
||||
} else {
|
||||
//
|
||||
// Check input StridesRequirement.
|
||||
//
|
||||
|
||||
IE_ASSERT(input->checkStrides(input->requiredStrides()));
|
||||
if (!checkStrides(input->desc(), output->strides(), input->requiredStrides())) {
|
||||
needCopy = true;
|
||||
optionalCopy = false;
|
||||
}
|
||||
|
||||
//
|
||||
// Check consumers StridesRequirement.
|
||||
//
|
||||
|
||||
if (!needCopy) {
|
||||
for (const auto& consumerEdge : input->consumerEdges()) {
|
||||
const auto& consumerInfo = consumerEdge->consumer()->getDataStridesRequirements();
|
||||
|
||||
if (consumerInfo.hasInput(consumerEdge)) {
|
||||
const auto& consumerStrideReqs = consumerInfo.getInput(consumerEdge);
|
||||
IE_ASSERT(input->checkStrides(consumerStrideReqs));
|
||||
|
||||
if (!checkStrides(input->desc(), output->strides(), consumerStrideReqs)) {
|
||||
needCopy = true;
|
||||
optionalCopy = false;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
//
|
||||
// Check producer StridesRequirement.
|
||||
//
|
||||
|
||||
if (!needCopy) {
|
||||
if (auto producerEdge = input->producerEdge()) {
|
||||
const auto& producerInfo = producerEdge->producer()->getDataStridesRequirements();
|
||||
|
||||
if (producerInfo.hasOutput(producerEdge)) {
|
||||
const auto& producerStrideReqs = producerInfo.getOutput(producerEdge);
|
||||
IE_ASSERT(input->checkStrides(producerStrideReqs));
|
||||
|
||||
if (!checkStrides(input->desc(), output->strides(), producerStrideReqs)) {
|
||||
needCopy = true;
|
||||
optionalCopy = false;
|
||||
}
|
||||
}
|
||||
|
||||
if (!needCopy) {
|
||||
//
|
||||
// To reduce the size of HW output (still can be optimized).
|
||||
//
|
||||
|
||||
if (producerEdge->producer()->category() == StageCategory::HW) {
|
||||
needCopy = true;
|
||||
optionalCopy = true;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
//
|
||||
// Insert Copy if needed
|
||||
//
|
||||
|
||||
if (needCopy) {
|
||||
Data inputCopy;
|
||||
if (input->usage() == DataUsage::Const) {
|
||||
inputCopy = model->addNewData(
|
||||
input->name() + "@copy",
|
||||
input->desc());
|
||||
} else {
|
||||
inputCopy = model->duplicateData(
|
||||
input,
|
||||
"@copy");
|
||||
inputCopy->resetRequiredStrides();
|
||||
}
|
||||
|
||||
auto copyStage = _stageBuilder->addCopyStage(
|
||||
model,
|
||||
formatString("%s@copy-for-expand", stage->name()),
|
||||
stage->origLayer(),
|
||||
input,
|
||||
inputCopy,
|
||||
"special::expand");
|
||||
copyStage->attrs().set<bool>("optional", optionalCopy);
|
||||
if (stage->attrs().has("batchInd")) {
|
||||
copyStage->attrs().set("batchInd", stage->attrs().get<int>("batchInd"));
|
||||
}
|
||||
|
||||
model->replaceStageInput(stage->inputEdge(0), inputCopy);
|
||||
|
||||
input = inputCopy;
|
||||
auto desc = isInputCopyRequired(stage, stage->inputEdge(0), output);
|
||||
if (desc.isCopyNeed) {
|
||||
input = insertCopyOfInput(model, stage, stage->inputEdge(0), _stageBuilder, desc);
|
||||
}
|
||||
|
||||
//
|
||||
@ -497,76 +385,11 @@ void SpecialStageProcessor::processCrop(
|
||||
IE_ASSERT(p.second + output->desc().dim(p.first) <= input->desc().dim(p.first));
|
||||
}
|
||||
|
||||
//
|
||||
// Check if we need to insert Copy for output
|
||||
//
|
||||
|
||||
bool needCopy = false;
|
||||
if (output->usage() != DataUsage::Intermediate) {
|
||||
needCopy = true;
|
||||
} else if (output->parentDataToDataEdge() != nullptr) {
|
||||
needCopy = true;
|
||||
} else {
|
||||
//
|
||||
// Check output StridesRequirement.
|
||||
//
|
||||
|
||||
IE_ASSERT(output->checkStrides(output->requiredStrides()));
|
||||
if (!checkStrides(output->desc(), input->strides(), output->requiredStrides())) {
|
||||
needCopy = true;
|
||||
}
|
||||
|
||||
//
|
||||
// Check consumers StridesRequirement.
|
||||
//
|
||||
|
||||
if (!needCopy) {
|
||||
for (const auto& consumerEdge : output->consumerEdges()) {
|
||||
const auto& consumerInfo = consumerEdge->consumer()->getDataStridesRequirements();
|
||||
|
||||
if (consumerInfo.hasInput(consumerEdge)) {
|
||||
const auto& consumerStrideReqs = consumerInfo.getInput(consumerEdge);
|
||||
IE_ASSERT(output->checkStrides(consumerStrideReqs));
|
||||
|
||||
if (!checkStrides(output->desc(), input->strides(), consumerStrideReqs)) {
|
||||
needCopy = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
auto desc = isOutputCopyRequired(stage, stage->outputEdge(0), input);
|
||||
if (desc.isCopyNeed) {
|
||||
output = insertCopyOfOutput(model, stage, stage->outputEdge(0), _stageBuilder);
|
||||
}
|
||||
|
||||
//
|
||||
// Insert output Copy if needed
|
||||
//
|
||||
|
||||
if (needCopy) {
|
||||
auto outputCopy = model->duplicateData(
|
||||
output,
|
||||
"@copy");
|
||||
outputCopy->resetRequiredStrides();
|
||||
|
||||
model->replaceStageOutput(stage->outputEdge(0), outputCopy);
|
||||
|
||||
auto copyStage = _stageBuilder->addCopyStage(
|
||||
model,
|
||||
formatString("%s@copy-output-for-crop", stage->name()),
|
||||
stage->origLayer(),
|
||||
outputCopy,
|
||||
output,
|
||||
"special::crop");
|
||||
if (stage->attrs().has("batchInd")) {
|
||||
copyStage->attrs().set("batchInd", stage->attrs().get<int>("batchInd"));
|
||||
}
|
||||
|
||||
output = outputCopy;
|
||||
}
|
||||
|
||||
//
|
||||
// Add Data<->Data edge
|
||||
//
|
||||
|
||||
model->connectDataWithData()
|
||||
.parent(input)
|
||||
.child(output)
|
||||
|
@ -0,0 +1,86 @@
|
||||
// Copyright (C) 2020 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#include <functional_test_utils/layer_test_utils.hpp>
|
||||
#include "vpu/private_plugin_config.hpp"
|
||||
|
||||
#include <ngraph_functions/builders.hpp>
|
||||
#include <vpu/ngraph/operations/dynamic_shape_resolver.hpp>
|
||||
#include <vpu/myriad_plugin_config.hpp>
|
||||
|
||||
namespace {
|
||||
|
||||
using DataType = ngraph::element::Type_t;
|
||||
using DataDims = std::vector<std::vector<std::size_t>>;
|
||||
|
||||
using Parameters = std::tuple<
|
||||
DataType,
|
||||
DataDims,
|
||||
std::int64_t,
|
||||
std::vector<std::size_t>,
|
||||
LayerTestsUtils::TargetDevice>;
|
||||
|
||||
class Concat_Split_Transpose : public testing::WithParamInterface<Parameters>, virtual public LayerTestsUtils::LayerTestsCommon {
|
||||
protected:
|
||||
void SetUp() override {
|
||||
SetRefMode(LayerTestsUtils::RefMode::CONSTANT_FOLDING);
|
||||
configuration[InferenceEngine::MYRIAD_DISABLE_CONVERT_STAGES] = CONFIG_VALUE(YES);
|
||||
configuration[InferenceEngine::MYRIAD_DETECT_NETWORK_BATCH] = CONFIG_VALUE(NO);
|
||||
|
||||
const auto& dataType = std::get<0>(GetParam());
|
||||
const auto& dataDims = std::get<1>(GetParam());
|
||||
const auto& axis = std::get<2>(GetParam());
|
||||
const auto& length = std::get<3>(GetParam());
|
||||
targetDevice = std::get<4>(GetParam());
|
||||
|
||||
auto params = ngraph::builder::makeParams(dataType, dataDims);
|
||||
auto paramOuts = ngraph::helpers::convert2OutputVector(
|
||||
ngraph::helpers::castOps2Nodes<ngraph::op::Parameter>(params));
|
||||
|
||||
auto concat = std::make_shared<ngraph::opset1::Concat>(paramOuts, axis);
|
||||
|
||||
const auto lengthData = std::make_shared<ngraph::opset3::Constant>(ngraph::element::i64,
|
||||
ngraph::Shape{length.size()},
|
||||
length);
|
||||
const auto axisData = std::make_shared<ngraph::opset3::Constant>(ngraph::element::i64,
|
||||
ngraph::Shape{1},
|
||||
axis);
|
||||
auto split = std::make_shared<ngraph::opset3::VariadicSplit>(concat, axisData, lengthData);
|
||||
|
||||
auto permutation = std::vector<std::int64_t>(split->get_output_shape(0).size());
|
||||
std::iota(permutation.rbegin(), permutation.rend(), 0);
|
||||
const auto transposition = std::make_shared<ngraph::opset3::Constant>(ngraph::element::i64,
|
||||
ngraph::Shape{split->get_output_shape(0).size()},
|
||||
permutation);
|
||||
|
||||
ngraph::ResultVector results;
|
||||
for (int i = 0; i < 2; i++) {
|
||||
const auto transpose = std::make_shared<ngraph::opset3::Transpose>(split->output(i), transposition);
|
||||
results.push_back(std::make_shared<ngraph::opset1::Result>(transpose));
|
||||
}
|
||||
function = std::make_shared<ngraph::Function>(results, params, "concat-split-transpose");
|
||||
}
|
||||
};
|
||||
|
||||
TEST_P(Concat_Split_Transpose, CompareWithRefs) {
|
||||
Run();
|
||||
}
|
||||
|
||||
std::vector<DataDims> dims = {
|
||||
{{400, 1}, {600, 1}}
|
||||
};
|
||||
|
||||
std::vector<std::vector<std::size_t>> length = {
|
||||
{500, 500}
|
||||
};
|
||||
|
||||
INSTANTIATE_TEST_CASE_P(SpecialStages, Concat_Split_Transpose,
|
||||
::testing::Combine(
|
||||
::testing::Values(ngraph::element::i32),
|
||||
::testing::ValuesIn(dims),
|
||||
::testing::Values(0),
|
||||
::testing::ValuesIn(length),
|
||||
::testing::Values(CommonTestUtils::DEVICE_MYRIAD)));
|
||||
|
||||
} // namespace
|
Loading…
Reference in New Issue
Block a user