[IE][VPU]: Refactoring of SpecialStageProcessor (#2885)
* SpecialStageProcessor refactoring * Fix for Yolo-v3-pytorch and related test
This commit is contained in:
parent
9cb3c2a6be
commit
5bc74aac75
@ -5,16 +5,208 @@
|
|||||||
#include "vpu/middleend/special_stage_processor.hpp"
|
#include "vpu/middleend/special_stage_processor.hpp"
|
||||||
|
|
||||||
#include <vector>
|
#include <vector>
|
||||||
#include <set>
|
|
||||||
#include <utility>
|
#include <utility>
|
||||||
|
|
||||||
namespace vpu {
|
namespace vpu {
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
struct NeedCopyDesc {
|
||||||
|
bool isCopyNeed = false;
|
||||||
|
bool isCopyOptional = false;
|
||||||
|
};
|
||||||
|
|
||||||
|
NeedCopyDesc isOutputCopyRequired(
|
||||||
|
const Stage& stage,
|
||||||
|
const StageOutput& outputEdge,
|
||||||
|
const Data& inputData) {
|
||||||
|
NeedCopyDesc needCopyDesc;
|
||||||
|
auto output = outputEdge->output();
|
||||||
|
if (output->usage() != DataUsage::Intermediate) {
|
||||||
|
needCopyDesc.isCopyNeed = true;
|
||||||
|
} else if (output->parentDataToDataEdge() != nullptr) {
|
||||||
|
needCopyDesc.isCopyNeed = true;
|
||||||
|
} else {
|
||||||
|
//
|
||||||
|
// Check output StridesRequirement
|
||||||
|
//
|
||||||
|
|
||||||
|
IE_ASSERT(output->checkStrides(output->requiredStrides()));
|
||||||
|
if (!checkStrides(output->desc(), inputData->strides(), output->requiredStrides())) {
|
||||||
|
needCopyDesc.isCopyNeed = true;
|
||||||
|
}
|
||||||
|
|
||||||
|
//
|
||||||
|
// Check consumers StridesRequirement.
|
||||||
|
//
|
||||||
|
|
||||||
|
if (!needCopyDesc.isCopyNeed) {
|
||||||
|
for (const auto& consumerEdge : output->consumerEdges()) {
|
||||||
|
const auto& consumerInfo = consumerEdge->consumer()->getDataStridesRequirements();
|
||||||
|
if (consumerInfo.hasInput(consumerEdge)) {
|
||||||
|
const auto& consumerStrideReqs = consumerInfo.getInput(consumerEdge);
|
||||||
|
IE_ASSERT(output->checkStrides(consumerStrideReqs));
|
||||||
|
if (!checkStrides(output->desc(), inputData->strides(), consumerStrideReqs)) {
|
||||||
|
needCopyDesc.isCopyNeed = true;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return needCopyDesc;
|
||||||
|
}
|
||||||
|
|
||||||
|
NeedCopyDesc isInputCopyRequired(
|
||||||
|
const Stage& stage,
|
||||||
|
const StageInput& inputEdge,
|
||||||
|
const Data& outputData) {
|
||||||
|
auto input = inputEdge->input();
|
||||||
|
NeedCopyDesc needCopyDesc;
|
||||||
|
if (input->usage() != DataUsage::Intermediate) {
|
||||||
|
needCopyDesc.isCopyNeed = true;
|
||||||
|
} else if (input->parentDataToDataEdge() != nullptr) {
|
||||||
|
needCopyDesc.isCopyNeed = true;
|
||||||
|
} else {
|
||||||
|
//
|
||||||
|
// Check input StridesRequirement.
|
||||||
|
//
|
||||||
|
|
||||||
|
IE_ASSERT(input->checkStrides(input->requiredStrides()));
|
||||||
|
if (!checkStrides(input->desc(), outputData->strides(), input->requiredStrides())) {
|
||||||
|
needCopyDesc.isCopyNeed = true;
|
||||||
|
}
|
||||||
|
|
||||||
|
//
|
||||||
|
// Check consumers StridesRequirement.
|
||||||
|
//
|
||||||
|
|
||||||
|
if (!needCopyDesc.isCopyNeed) {
|
||||||
|
for (const auto& consumerEdge : input->consumerEdges()) {
|
||||||
|
const auto& consumerInfo = consumerEdge->consumer()->getDataStridesRequirements();
|
||||||
|
|
||||||
|
if (consumerInfo.hasInput(consumerEdge)) {
|
||||||
|
const auto& consumerStrideReqs = consumerInfo.getInput(consumerEdge);
|
||||||
|
IE_ASSERT(input->checkStrides(consumerStrideReqs));
|
||||||
|
|
||||||
|
if (!checkStrides(input->desc(), outputData->strides(), consumerStrideReqs)) {
|
||||||
|
needCopyDesc.isCopyNeed = true;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
//
|
||||||
|
// Check producer StridesRequirement.
|
||||||
|
//
|
||||||
|
|
||||||
|
if (!needCopyDesc.isCopyNeed) {
|
||||||
|
if (auto producerEdge = input->producerEdge()) {
|
||||||
|
const auto& producerInfo = producerEdge->producer()->getDataStridesRequirements();
|
||||||
|
|
||||||
|
if (producerInfo.hasOutput(producerEdge)) {
|
||||||
|
const auto& producerStrideReqs = producerInfo.getOutput(producerEdge);
|
||||||
|
IE_ASSERT(input->checkStrides(producerStrideReqs));
|
||||||
|
|
||||||
|
if (!checkStrides(input->desc(), outputData->strides(), producerStrideReqs)) {
|
||||||
|
needCopyDesc.isCopyNeed = true;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!needCopyDesc.isCopyNeed) {
|
||||||
|
//
|
||||||
|
// To reduce the size of HW output (still can be optimized).
|
||||||
|
//
|
||||||
|
|
||||||
|
if (producerEdge->producer()->category() == StageCategory::HW) {
|
||||||
|
needCopyDesc.isCopyNeed = true;
|
||||||
|
needCopyDesc.isCopyOptional = true;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return needCopyDesc;
|
||||||
|
}
|
||||||
|
|
||||||
|
Data insertCopyOfInput(const Model& model,
|
||||||
|
const Stage& stage,
|
||||||
|
const StageInput& edge,
|
||||||
|
const StageBuilder::Ptr& _stageBuilder,
|
||||||
|
const NeedCopyDesc& desc) {
|
||||||
|
auto data = edge->input();
|
||||||
|
|
||||||
|
Data copy;
|
||||||
|
if (data->usage() == DataUsage::Const) {
|
||||||
|
copy = model->addNewData(data->name() + "@copy", data->desc());
|
||||||
|
} else {
|
||||||
|
copy = model->duplicateData(data, "@copy");
|
||||||
|
copy->resetRequiredStrides();
|
||||||
|
}
|
||||||
|
if (stage->type() == StageType::Reshape)
|
||||||
|
copy->updateRequiredStrides(StridesRequirement::compact());
|
||||||
|
|
||||||
|
bool hasMultipleInputs = stage->numInputs() > 1;
|
||||||
|
auto inputNumStr = hasMultipleInputs ? formatString("@input=%d", edge->portInd()) : "";
|
||||||
|
std::stringstream typeAsString;
|
||||||
|
typeAsString << stage->type();
|
||||||
|
|
||||||
|
auto copyStage = _stageBuilder->addCopyStage(
|
||||||
|
model,
|
||||||
|
formatString("%s%s@copy-for-%s", stage->name(), inputNumStr, typeAsString),
|
||||||
|
stage->origLayer(),
|
||||||
|
data,
|
||||||
|
copy,
|
||||||
|
formatString("special::%s", typeAsString));
|
||||||
|
if (stage->type() != StageType::Reshape) {
|
||||||
|
copyStage->attrs().set<bool>("optional", desc.isCopyOptional);
|
||||||
|
}
|
||||||
|
if (stage->attrs().has("batchInd")) {
|
||||||
|
copyStage->attrs().set("batchInd", stage->attrs().get<int>("batchInd"));
|
||||||
|
}
|
||||||
|
|
||||||
|
model->replaceStageInput(edge, copy);
|
||||||
|
|
||||||
|
return copy;
|
||||||
|
}
|
||||||
|
|
||||||
|
Data insertCopyOfOutput(const Model& model,
|
||||||
|
const Stage& stage,
|
||||||
|
const StageOutput& edge,
|
||||||
|
const StageBuilder::Ptr& _stageBuilder) {
|
||||||
|
auto data = edge->output();
|
||||||
|
auto copy = model->duplicateData(data, "@copy");
|
||||||
|
copy->resetRequiredStrides();
|
||||||
|
|
||||||
|
model->replaceStageOutput(edge, copy);
|
||||||
|
|
||||||
|
bool hasMultipleOutputs = stage->numOutputs() > 1;
|
||||||
|
auto outputNumStr = hasMultipleOutputs ? formatString("@output=%d", edge->portInd()) : "";
|
||||||
|
std::stringstream typeAsString;
|
||||||
|
typeAsString << stage->type();
|
||||||
|
|
||||||
|
auto copyStage = _stageBuilder->addCopyStage(
|
||||||
|
model,
|
||||||
|
formatString("%s%s@copy-for-%s", stage->name(), outputNumStr, typeAsString),
|
||||||
|
stage->origLayer(),
|
||||||
|
copy,
|
||||||
|
data,
|
||||||
|
formatString("special::%s", typeAsString));
|
||||||
|
if (stage->attrs().has("batchInd")) {
|
||||||
|
copyStage->attrs().set("batchInd", stage->attrs().get<int>("batchInd"));
|
||||||
|
}
|
||||||
|
|
||||||
|
return copy;
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
|
||||||
void SpecialStageProcessor::processSplit(
|
void SpecialStageProcessor::processSplit(
|
||||||
const Model& model,
|
const Model& model,
|
||||||
const Stage& stage) {
|
const Stage& stage) {
|
||||||
IE_ASSERT(stage->type() == StageType::Split);
|
IE_ASSERT(stage->type() == StageType::Split);
|
||||||
|
|
||||||
auto input = stage->input(0);
|
auto input = stage->input(0);
|
||||||
|
|
||||||
const auto& offsets = stage->attrs().get<std::vector<DimValues>>("offsets");
|
const auto& offsets = stage->attrs().get<std::vector<DimValues>>("offsets");
|
||||||
@ -34,70 +226,9 @@ void SpecialStageProcessor::processSplit(
|
|||||||
IE_ASSERT(p.second + output->desc().dim(p.first) <= input->desc().dim(p.first));
|
IE_ASSERT(p.second + output->desc().dim(p.first) <= input->desc().dim(p.first));
|
||||||
}
|
}
|
||||||
|
|
||||||
//
|
auto desc = isOutputCopyRequired(stage, outEdge, input);
|
||||||
// Check if we need to insert Copy stage
|
if (desc.isCopyNeed) {
|
||||||
//
|
output = insertCopyOfOutput(model, stage, outEdge, _stageBuilder);
|
||||||
|
|
||||||
bool needCopy = false;
|
|
||||||
if (output->usage() != DataUsage::Intermediate) {
|
|
||||||
needCopy = true;
|
|
||||||
} else if (output->parentDataToDataEdge() != nullptr) {
|
|
||||||
needCopy = true;
|
|
||||||
} else {
|
|
||||||
//
|
|
||||||
// Check output StridesRequirement.
|
|
||||||
//
|
|
||||||
|
|
||||||
IE_ASSERT(output->checkStrides(output->requiredStrides()));
|
|
||||||
if (!checkStrides(output->desc(), input->strides(), output->requiredStrides())) {
|
|
||||||
needCopy = true;
|
|
||||||
}
|
|
||||||
|
|
||||||
//
|
|
||||||
// Check consumers StridesRequirement.
|
|
||||||
//
|
|
||||||
|
|
||||||
if (!needCopy) {
|
|
||||||
for (const auto& consumerEdge : output->consumerEdges()) {
|
|
||||||
const auto& consumerInfo = consumerEdge->consumer()->getDataStridesRequirements();
|
|
||||||
|
|
||||||
if (consumerInfo.hasInput(consumerEdge)) {
|
|
||||||
const auto& consumerStrideReqs = consumerInfo.getInput(consumerEdge);
|
|
||||||
IE_ASSERT(output->checkStrides(consumerStrideReqs));
|
|
||||||
|
|
||||||
if (!checkStrides(output->desc(), input->strides(), consumerStrideReqs)) {
|
|
||||||
needCopy = true;
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
//
|
|
||||||
// Insert Copy if needed
|
|
||||||
//
|
|
||||||
|
|
||||||
if (needCopy) {
|
|
||||||
auto outputCopy = model->duplicateData(output, "@copy");
|
|
||||||
outputCopy->resetRequiredStrides();
|
|
||||||
|
|
||||||
auto outPortInd = outEdge->portInd();
|
|
||||||
|
|
||||||
model->replaceStageOutput(outEdge, outputCopy);
|
|
||||||
|
|
||||||
auto copyStage = _stageBuilder->addCopyStage(
|
|
||||||
model,
|
|
||||||
formatString("%s@output=%d@copy-for-split", stage->name(), outPortInd),
|
|
||||||
stage->origLayer(),
|
|
||||||
outputCopy,
|
|
||||||
output,
|
|
||||||
"special::split");
|
|
||||||
if (stage->attrs().has("batchInd")) {
|
|
||||||
copyStage->attrs().set("batchInd", stage->attrs().get<int>("batchInd"));
|
|
||||||
}
|
|
||||||
|
|
||||||
output = outputCopy;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
//
|
//
|
||||||
@ -136,113 +267,9 @@ void SpecialStageProcessor::processConcat(
|
|||||||
IE_ASSERT(p.second + input->desc().dim(p.first) <= output->desc().dim(p.first));
|
IE_ASSERT(p.second + input->desc().dim(p.first) <= output->desc().dim(p.first));
|
||||||
}
|
}
|
||||||
|
|
||||||
//
|
NeedCopyDesc desc = isInputCopyRequired(stage, inEdge, output);
|
||||||
// Check if we need to insert Copy stage
|
if (desc.isCopyNeed) {
|
||||||
//
|
input = insertCopyOfInput(model, stage, inEdge, _stageBuilder, desc);
|
||||||
|
|
||||||
bool needCopy = false;
|
|
||||||
bool optionalCopy = false;
|
|
||||||
if (input->usage() != DataUsage::Intermediate) {
|
|
||||||
needCopy = true;
|
|
||||||
optionalCopy = false;
|
|
||||||
} else if (input->parentDataToDataEdge() != nullptr) {
|
|
||||||
needCopy = true;
|
|
||||||
optionalCopy = false;
|
|
||||||
} else {
|
|
||||||
//
|
|
||||||
// Check input StridesRequirement.
|
|
||||||
//
|
|
||||||
|
|
||||||
IE_ASSERT(input->checkStrides(input->requiredStrides()));
|
|
||||||
if (!checkStrides(input->desc(), output->strides(), input->requiredStrides())) {
|
|
||||||
needCopy = true;
|
|
||||||
optionalCopy = false;
|
|
||||||
}
|
|
||||||
|
|
||||||
//
|
|
||||||
// Check consumers StridesRequirement.
|
|
||||||
//
|
|
||||||
|
|
||||||
if (!needCopy) {
|
|
||||||
for (const auto& consumerEdge : input->consumerEdges()) {
|
|
||||||
const auto& consumerInfo = consumerEdge->consumer()->getDataStridesRequirements();
|
|
||||||
|
|
||||||
if (consumerInfo.hasInput(consumerEdge)) {
|
|
||||||
const auto& consumerStrideReqs = consumerInfo.getInput(consumerEdge);
|
|
||||||
IE_ASSERT(input->checkStrides(consumerStrideReqs));
|
|
||||||
|
|
||||||
if (!checkStrides(input->desc(), output->strides(), consumerStrideReqs)) {
|
|
||||||
needCopy = true;
|
|
||||||
optionalCopy = false;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
//
|
|
||||||
// Check producer StridesRequirement.
|
|
||||||
//
|
|
||||||
|
|
||||||
if (!needCopy) {
|
|
||||||
if (auto producerEdge = input->producerEdge()) {
|
|
||||||
const auto& producerInfo = producerEdge->producer()->getDataStridesRequirements();
|
|
||||||
|
|
||||||
if (producerInfo.hasOutput(producerEdge)) {
|
|
||||||
const auto& producerStrideReqs = producerInfo.getOutput(producerEdge);
|
|
||||||
IE_ASSERT(input->checkStrides(producerStrideReqs));
|
|
||||||
|
|
||||||
if (!checkStrides(input->desc(), output->strides(), producerStrideReqs)) {
|
|
||||||
needCopy = true;
|
|
||||||
optionalCopy = false;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if (!needCopy) {
|
|
||||||
//
|
|
||||||
// To reduce the size of HW output (still can be optimized).
|
|
||||||
//
|
|
||||||
|
|
||||||
if (producerEdge->producer()->category() == StageCategory::HW) {
|
|
||||||
needCopy = true;
|
|
||||||
optionalCopy = true;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
//
|
|
||||||
// Insert Copy if needed
|
|
||||||
//
|
|
||||||
|
|
||||||
if (needCopy) {
|
|
||||||
Data inputCopy;
|
|
||||||
if (input->usage() == DataUsage::Const) {
|
|
||||||
inputCopy = model->addNewData(
|
|
||||||
input->name() + "@copy",
|
|
||||||
input->desc());
|
|
||||||
} else {
|
|
||||||
inputCopy = model->duplicateData(
|
|
||||||
input,
|
|
||||||
"@copy");
|
|
||||||
inputCopy->resetRequiredStrides();
|
|
||||||
}
|
|
||||||
|
|
||||||
auto copyStage = _stageBuilder->addCopyStage(
|
|
||||||
model,
|
|
||||||
formatString("%s@input=%d@copy-for-concat", stage->name(), inEdge->portInd()),
|
|
||||||
stage->origLayer(),
|
|
||||||
input,
|
|
||||||
inputCopy,
|
|
||||||
"special::concat");
|
|
||||||
copyStage->attrs().set<bool>("optional", optionalCopy);
|
|
||||||
if (stage->attrs().has("batchInd")) {
|
|
||||||
copyStage->attrs().set("batchInd", stage->attrs().get<int>("batchInd"));
|
|
||||||
}
|
|
||||||
|
|
||||||
model->replaceStageInput(inEdge, inputCopy);
|
|
||||||
|
|
||||||
input = inputCopy;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
//
|
//
|
||||||
@ -272,50 +299,12 @@ void SpecialStageProcessor::processReshape(
|
|||||||
IE_ASSERT(output->desc().dimsOrder() == DimsOrder::fromNumDims(output->desc().numDims()));
|
IE_ASSERT(output->desc().dimsOrder() == DimsOrder::fromNumDims(output->desc().numDims()));
|
||||||
IE_ASSERT(output->checkStrides(StridesRequirement::compact()));
|
IE_ASSERT(output->checkStrides(StridesRequirement::compact()));
|
||||||
|
|
||||||
//
|
NeedCopyDesc desc;
|
||||||
// Check if we need to insert Copy stage
|
if ((input->usage() != DataUsage::Intermediate || input->parentDataToDataEdge() != nullptr) &&
|
||||||
//
|
(output->usage() != DataUsage::Intermediate || output->parentDataToDataEdge() != nullptr))
|
||||||
|
desc.isCopyNeed = true;
|
||||||
bool needCopy = false;
|
if (desc.isCopyNeed) {
|
||||||
if (input->usage() != DataUsage::Intermediate &&
|
input = insertCopyOfInput(model, stage, stage->inputEdge(0), _stageBuilder, desc);
|
||||||
output->usage() != DataUsage::Intermediate) {
|
|
||||||
needCopy = true;
|
|
||||||
} else if (input->parentDataToDataEdge() != nullptr &&
|
|
||||||
output->parentDataToDataEdge() != nullptr) {
|
|
||||||
needCopy = true;
|
|
||||||
}
|
|
||||||
|
|
||||||
//
|
|
||||||
// Insert Copy if needed
|
|
||||||
//
|
|
||||||
|
|
||||||
if (needCopy) {
|
|
||||||
Data inputCopy;
|
|
||||||
if (input->usage() == DataUsage::Const) {
|
|
||||||
inputCopy = model->addNewData(
|
|
||||||
input->name() + "@copy",
|
|
||||||
input->desc());
|
|
||||||
} else {
|
|
||||||
inputCopy = model->duplicateData(
|
|
||||||
input,
|
|
||||||
"@copy");
|
|
||||||
}
|
|
||||||
inputCopy->updateRequiredStrides(StridesRequirement::compact());
|
|
||||||
|
|
||||||
auto copyStage = _stageBuilder->addCopyStage(
|
|
||||||
model,
|
|
||||||
formatString("%s@copy-for-reshape", stage->name()),
|
|
||||||
stage->origLayer(),
|
|
||||||
input,
|
|
||||||
inputCopy,
|
|
||||||
"special::reshape");
|
|
||||||
if (stage->attrs().has("batchInd")) {
|
|
||||||
copyStage->attrs().set("batchInd", stage->attrs().get<int>("batchInd"));
|
|
||||||
}
|
|
||||||
|
|
||||||
model->replaceStageInput(stage->inputEdge(0), inputCopy);
|
|
||||||
|
|
||||||
input = inputCopy;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
//
|
//
|
||||||
@ -330,16 +319,19 @@ void SpecialStageProcessor::processReshape(
|
|||||||
.mode(SharedDataMode::Reshape)
|
.mode(SharedDataMode::Reshape)
|
||||||
.order(SharedDataOrder::ChildWritesToParent)
|
.order(SharedDataOrder::ChildWritesToParent)
|
||||||
.done();
|
.done();
|
||||||
} else {
|
} else if (output->usage() == DataUsage::Intermediate &&
|
||||||
IE_ASSERT(output->usage() == DataUsage::Intermediate);
|
output->parentDataToDataEdge() == nullptr) {
|
||||||
IE_ASSERT(output->parentDataToDataEdge() == nullptr);
|
|
||||||
|
|
||||||
model->connectDataWithData()
|
model->connectDataWithData()
|
||||||
.parent(input)
|
.parent(input)
|
||||||
.child(output)
|
.child(output)
|
||||||
.mode(SharedDataMode::Reshape)
|
.mode(SharedDataMode::Reshape)
|
||||||
.order(SharedDataOrder::ParentWritesToChild)
|
.order(SharedDataOrder::ParentWritesToChild)
|
||||||
.done();
|
.done();
|
||||||
|
} else {
|
||||||
|
IE_ASSERT(input->usage() == DataUsage::Intermediate &&
|
||||||
|
input->parentDataToDataEdge() == nullptr);
|
||||||
|
IE_ASSERT(output->usage() == DataUsage::Intermediate &&
|
||||||
|
output->parentDataToDataEdge() == nullptr);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -359,113 +351,9 @@ void SpecialStageProcessor::processExpand(
|
|||||||
IE_ASSERT(p.second + input->desc().dim(p.first) <= output->desc().dim(p.first));
|
IE_ASSERT(p.second + input->desc().dim(p.first) <= output->desc().dim(p.first));
|
||||||
}
|
}
|
||||||
|
|
||||||
//
|
auto desc = isInputCopyRequired(stage, stage->inputEdge(0), output);
|
||||||
// Check if we need to insert Copy stage
|
if (desc.isCopyNeed) {
|
||||||
//
|
input = insertCopyOfInput(model, stage, stage->inputEdge(0), _stageBuilder, desc);
|
||||||
|
|
||||||
bool needCopy = false;
|
|
||||||
bool optionalCopy = false;
|
|
||||||
if (input->usage() != DataUsage::Intermediate) {
|
|
||||||
needCopy = true;
|
|
||||||
optionalCopy = false;
|
|
||||||
} else if (input->parentDataToDataEdge() != nullptr) {
|
|
||||||
needCopy = true;
|
|
||||||
optionalCopy = false;
|
|
||||||
} else {
|
|
||||||
//
|
|
||||||
// Check input StridesRequirement.
|
|
||||||
//
|
|
||||||
|
|
||||||
IE_ASSERT(input->checkStrides(input->requiredStrides()));
|
|
||||||
if (!checkStrides(input->desc(), output->strides(), input->requiredStrides())) {
|
|
||||||
needCopy = true;
|
|
||||||
optionalCopy = false;
|
|
||||||
}
|
|
||||||
|
|
||||||
//
|
|
||||||
// Check consumers StridesRequirement.
|
|
||||||
//
|
|
||||||
|
|
||||||
if (!needCopy) {
|
|
||||||
for (const auto& consumerEdge : input->consumerEdges()) {
|
|
||||||
const auto& consumerInfo = consumerEdge->consumer()->getDataStridesRequirements();
|
|
||||||
|
|
||||||
if (consumerInfo.hasInput(consumerEdge)) {
|
|
||||||
const auto& consumerStrideReqs = consumerInfo.getInput(consumerEdge);
|
|
||||||
IE_ASSERT(input->checkStrides(consumerStrideReqs));
|
|
||||||
|
|
||||||
if (!checkStrides(input->desc(), output->strides(), consumerStrideReqs)) {
|
|
||||||
needCopy = true;
|
|
||||||
optionalCopy = false;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
//
|
|
||||||
// Check producer StridesRequirement.
|
|
||||||
//
|
|
||||||
|
|
||||||
if (!needCopy) {
|
|
||||||
if (auto producerEdge = input->producerEdge()) {
|
|
||||||
const auto& producerInfo = producerEdge->producer()->getDataStridesRequirements();
|
|
||||||
|
|
||||||
if (producerInfo.hasOutput(producerEdge)) {
|
|
||||||
const auto& producerStrideReqs = producerInfo.getOutput(producerEdge);
|
|
||||||
IE_ASSERT(input->checkStrides(producerStrideReqs));
|
|
||||||
|
|
||||||
if (!checkStrides(input->desc(), output->strides(), producerStrideReqs)) {
|
|
||||||
needCopy = true;
|
|
||||||
optionalCopy = false;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if (!needCopy) {
|
|
||||||
//
|
|
||||||
// To reduce the size of HW output (still can be optimized).
|
|
||||||
//
|
|
||||||
|
|
||||||
if (producerEdge->producer()->category() == StageCategory::HW) {
|
|
||||||
needCopy = true;
|
|
||||||
optionalCopy = true;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
//
|
|
||||||
// Insert Copy if needed
|
|
||||||
//
|
|
||||||
|
|
||||||
if (needCopy) {
|
|
||||||
Data inputCopy;
|
|
||||||
if (input->usage() == DataUsage::Const) {
|
|
||||||
inputCopy = model->addNewData(
|
|
||||||
input->name() + "@copy",
|
|
||||||
input->desc());
|
|
||||||
} else {
|
|
||||||
inputCopy = model->duplicateData(
|
|
||||||
input,
|
|
||||||
"@copy");
|
|
||||||
inputCopy->resetRequiredStrides();
|
|
||||||
}
|
|
||||||
|
|
||||||
auto copyStage = _stageBuilder->addCopyStage(
|
|
||||||
model,
|
|
||||||
formatString("%s@copy-for-expand", stage->name()),
|
|
||||||
stage->origLayer(),
|
|
||||||
input,
|
|
||||||
inputCopy,
|
|
||||||
"special::expand");
|
|
||||||
copyStage->attrs().set<bool>("optional", optionalCopy);
|
|
||||||
if (stage->attrs().has("batchInd")) {
|
|
||||||
copyStage->attrs().set("batchInd", stage->attrs().get<int>("batchInd"));
|
|
||||||
}
|
|
||||||
|
|
||||||
model->replaceStageInput(stage->inputEdge(0), inputCopy);
|
|
||||||
|
|
||||||
input = inputCopy;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
//
|
//
|
||||||
@ -497,76 +385,11 @@ void SpecialStageProcessor::processCrop(
|
|||||||
IE_ASSERT(p.second + output->desc().dim(p.first) <= input->desc().dim(p.first));
|
IE_ASSERT(p.second + output->desc().dim(p.first) <= input->desc().dim(p.first));
|
||||||
}
|
}
|
||||||
|
|
||||||
//
|
auto desc = isOutputCopyRequired(stage, stage->outputEdge(0), input);
|
||||||
// Check if we need to insert Copy for output
|
if (desc.isCopyNeed) {
|
||||||
//
|
output = insertCopyOfOutput(model, stage, stage->outputEdge(0), _stageBuilder);
|
||||||
|
|
||||||
bool needCopy = false;
|
|
||||||
if (output->usage() != DataUsage::Intermediate) {
|
|
||||||
needCopy = true;
|
|
||||||
} else if (output->parentDataToDataEdge() != nullptr) {
|
|
||||||
needCopy = true;
|
|
||||||
} else {
|
|
||||||
//
|
|
||||||
// Check output StridesRequirement.
|
|
||||||
//
|
|
||||||
|
|
||||||
IE_ASSERT(output->checkStrides(output->requiredStrides()));
|
|
||||||
if (!checkStrides(output->desc(), input->strides(), output->requiredStrides())) {
|
|
||||||
needCopy = true;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
//
|
|
||||||
// Check consumers StridesRequirement.
|
|
||||||
//
|
|
||||||
|
|
||||||
if (!needCopy) {
|
|
||||||
for (const auto& consumerEdge : output->consumerEdges()) {
|
|
||||||
const auto& consumerInfo = consumerEdge->consumer()->getDataStridesRequirements();
|
|
||||||
|
|
||||||
if (consumerInfo.hasInput(consumerEdge)) {
|
|
||||||
const auto& consumerStrideReqs = consumerInfo.getInput(consumerEdge);
|
|
||||||
IE_ASSERT(output->checkStrides(consumerStrideReqs));
|
|
||||||
|
|
||||||
if (!checkStrides(output->desc(), input->strides(), consumerStrideReqs)) {
|
|
||||||
needCopy = true;
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
//
|
|
||||||
// Insert output Copy if needed
|
|
||||||
//
|
|
||||||
|
|
||||||
if (needCopy) {
|
|
||||||
auto outputCopy = model->duplicateData(
|
|
||||||
output,
|
|
||||||
"@copy");
|
|
||||||
outputCopy->resetRequiredStrides();
|
|
||||||
|
|
||||||
model->replaceStageOutput(stage->outputEdge(0), outputCopy);
|
|
||||||
|
|
||||||
auto copyStage = _stageBuilder->addCopyStage(
|
|
||||||
model,
|
|
||||||
formatString("%s@copy-output-for-crop", stage->name()),
|
|
||||||
stage->origLayer(),
|
|
||||||
outputCopy,
|
|
||||||
output,
|
|
||||||
"special::crop");
|
|
||||||
if (stage->attrs().has("batchInd")) {
|
|
||||||
copyStage->attrs().set("batchInd", stage->attrs().get<int>("batchInd"));
|
|
||||||
}
|
|
||||||
|
|
||||||
output = outputCopy;
|
|
||||||
}
|
|
||||||
|
|
||||||
//
|
|
||||||
// Add Data<->Data edge
|
|
||||||
//
|
|
||||||
|
|
||||||
model->connectDataWithData()
|
model->connectDataWithData()
|
||||||
.parent(input)
|
.parent(input)
|
||||||
.child(output)
|
.child(output)
|
||||||
|
@ -0,0 +1,86 @@
|
|||||||
|
// Copyright (C) 2020 Intel Corporation
|
||||||
|
// SPDX-License-Identifier: Apache-2.0
|
||||||
|
//
|
||||||
|
|
||||||
|
#include <functional_test_utils/layer_test_utils.hpp>
|
||||||
|
#include "vpu/private_plugin_config.hpp"
|
||||||
|
|
||||||
|
#include <ngraph_functions/builders.hpp>
|
||||||
|
#include <vpu/ngraph/operations/dynamic_shape_resolver.hpp>
|
||||||
|
#include <vpu/myriad_plugin_config.hpp>
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
using DataType = ngraph::element::Type_t;
|
||||||
|
using DataDims = std::vector<std::vector<std::size_t>>;
|
||||||
|
|
||||||
|
using Parameters = std::tuple<
|
||||||
|
DataType,
|
||||||
|
DataDims,
|
||||||
|
std::int64_t,
|
||||||
|
std::vector<std::size_t>,
|
||||||
|
LayerTestsUtils::TargetDevice>;
|
||||||
|
|
||||||
|
class Concat_Split_Transpose : public testing::WithParamInterface<Parameters>, virtual public LayerTestsUtils::LayerTestsCommon {
|
||||||
|
protected:
|
||||||
|
void SetUp() override {
|
||||||
|
SetRefMode(LayerTestsUtils::RefMode::CONSTANT_FOLDING);
|
||||||
|
configuration[InferenceEngine::MYRIAD_DISABLE_CONVERT_STAGES] = CONFIG_VALUE(YES);
|
||||||
|
configuration[InferenceEngine::MYRIAD_DETECT_NETWORK_BATCH] = CONFIG_VALUE(NO);
|
||||||
|
|
||||||
|
const auto& dataType = std::get<0>(GetParam());
|
||||||
|
const auto& dataDims = std::get<1>(GetParam());
|
||||||
|
const auto& axis = std::get<2>(GetParam());
|
||||||
|
const auto& length = std::get<3>(GetParam());
|
||||||
|
targetDevice = std::get<4>(GetParam());
|
||||||
|
|
||||||
|
auto params = ngraph::builder::makeParams(dataType, dataDims);
|
||||||
|
auto paramOuts = ngraph::helpers::convert2OutputVector(
|
||||||
|
ngraph::helpers::castOps2Nodes<ngraph::op::Parameter>(params));
|
||||||
|
|
||||||
|
auto concat = std::make_shared<ngraph::opset1::Concat>(paramOuts, axis);
|
||||||
|
|
||||||
|
const auto lengthData = std::make_shared<ngraph::opset3::Constant>(ngraph::element::i64,
|
||||||
|
ngraph::Shape{length.size()},
|
||||||
|
length);
|
||||||
|
const auto axisData = std::make_shared<ngraph::opset3::Constant>(ngraph::element::i64,
|
||||||
|
ngraph::Shape{1},
|
||||||
|
axis);
|
||||||
|
auto split = std::make_shared<ngraph::opset3::VariadicSplit>(concat, axisData, lengthData);
|
||||||
|
|
||||||
|
auto permutation = std::vector<std::int64_t>(split->get_output_shape(0).size());
|
||||||
|
std::iota(permutation.rbegin(), permutation.rend(), 0);
|
||||||
|
const auto transposition = std::make_shared<ngraph::opset3::Constant>(ngraph::element::i64,
|
||||||
|
ngraph::Shape{split->get_output_shape(0).size()},
|
||||||
|
permutation);
|
||||||
|
|
||||||
|
ngraph::ResultVector results;
|
||||||
|
for (int i = 0; i < 2; i++) {
|
||||||
|
const auto transpose = std::make_shared<ngraph::opset3::Transpose>(split->output(i), transposition);
|
||||||
|
results.push_back(std::make_shared<ngraph::opset1::Result>(transpose));
|
||||||
|
}
|
||||||
|
function = std::make_shared<ngraph::Function>(results, params, "concat-split-transpose");
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
TEST_P(Concat_Split_Transpose, CompareWithRefs) {
|
||||||
|
Run();
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<DataDims> dims = {
|
||||||
|
{{400, 1}, {600, 1}}
|
||||||
|
};
|
||||||
|
|
||||||
|
std::vector<std::vector<std::size_t>> length = {
|
||||||
|
{500, 500}
|
||||||
|
};
|
||||||
|
|
||||||
|
INSTANTIATE_TEST_CASE_P(SpecialStages, Concat_Split_Transpose,
|
||||||
|
::testing::Combine(
|
||||||
|
::testing::Values(ngraph::element::i32),
|
||||||
|
::testing::ValuesIn(dims),
|
||||||
|
::testing::Values(0),
|
||||||
|
::testing::ValuesIn(length),
|
||||||
|
::testing::Values(CommonTestUtils::DEVICE_MYRIAD)));
|
||||||
|
|
||||||
|
} // namespace
|
Loading…
Reference in New Issue
Block a user