[IE][VPU]: Enable new tests for adjust_data_batch pass (#1219)

* New tests for adjust_data_batch pass
This commit is contained in:
Daria Mityagina 2020-07-30 13:19:33 +03:00 committed by GitHub
parent 00630127e7
commit e38106239c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -0,0 +1,352 @@
// Copyright (C) 2020 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include <vpu/stages/mx_stage.hpp>
#include <vpu/utils/numeric.hpp>
#include "graph_transformer_tests.hpp"
using namespace vpu;
class VPU_AdjustDataBatchTest : public GraphTransformerTest {
protected:
const int batchSize = 4;
TestModel testModel;
public:
void SetUp() override {
ASSERT_NO_FATAL_FAILURE(GraphTransformerTest::SetUp());
ASSERT_NO_FATAL_FAILURE(InitCompileEnv());
testModel = CreateTestModel();
}
void RunPass() {
PassSet pipeline;
pipeline.addPass(passManager->dumpModel("initial"));
pipeline.addPass(passManager->adjustDataBatch());
pipeline.addPass(passManager->dumpModel("adjustDataBatch"));
pipeline.run(testModel.getBaseModel());
}
DataVector checkSingleLoopStart(const Data& data) {
EXPECT_EQ(data->desc().dim(Dim::N), 4);
EXPECT_EQ(data->numConsumers(), 2);
DataVector outputs;
for (const auto& consumer : data->consumers()) {
EXPECT_TRUE(consumer->type() == StageType::LoopStart || consumer->type() == StageType::LoopEnd);
if (consumer->type() == StageType::LoopStart) {
for (const auto& output : consumer->outputs()) {
EXPECT_EQ(output->desc().dim(Dim::N), 1);
outputs.push_back(output);
}
}
}
return outputs;
}
DataVector checkBranches(const Data& root, const std::vector<StageType>& consumersTypes) {
auto successors = DataVector{};
const auto& consumers = root->consumers() | asVector();
EXPECT_EQ(consumers.size(), consumersTypes.size());
for (std::size_t i = 0; i < consumers.size(); ++i) {
const auto& consumer = consumers[i];
const auto& expected = consumersTypes[i];
EXPECT_EQ(consumer->type(), expected);
EXPECT_EQ(consumer->numOutputs(), 1);
const auto& output = consumer->output(0);
successors.push_back(output);
if (expected == StageType::LoopStart) {
EXPECT_EQ(consumer->numOutputs(), 1);
EXPECT_EQ(output->desc().dim(Dim::N), 1);
} else if (expected == StageType::LoopEnd) {
EXPECT_EQ(output->desc().dim(Dim::N), 4);
}
}
return successors;
}
DataVector checkSingleLoopEnd(const Data& data) {
EXPECT_EQ(data->numConsumers(), 1);
const auto& consumer = data->singleConsumer();
EXPECT_EQ(consumer->type(), StageType::LoopEnd);
DataVector outputs;
for (const auto& output : consumer->outputs()) {
EXPECT_EQ(output->desc().dim(Dim::N), 4);
outputs.push_back(output);
}
return outputs;
}
static Data CheckSingleConnection(const Data& data, int testInd, int batch = 1) {
EXPECT_EQ(data->numConsumers(), 1);
const auto& consumer = data->singleConsumer();
EXPECT_EQ(consumer->type(), StageType::None);
EXPECT_EQ(consumer->attrs().get<int>("test_ind"), testInd);
EXPECT_EQ(consumer->numOutputs(), 1);
const auto& output = consumer->output(0);
EXPECT_EQ(output->desc().dim(Dim::N), batch);
return output;
}
static Data singleElement(const DataVector& dataObjects) {
EXPECT_EQ(dataObjects.size(), 1);
return dataObjects.front();
}
};
TEST_F(VPU_AdjustDataBatchTest, LinearWithBatchedInTheEnd) {
//
// [Input] -> (Split) -> (Split) -> (Split) -> (Split) -> (Split) -> (Split) -> (Batched) -> [Output]
//
const DataDesc desc{16, 16, 3, batchSize};
testModel.createInputs({desc});
testModel.createOutputs({desc});
for (int i = 0; i < 6; i++) {
if (i > 0)
testModel.addStage({InputInfo::fromPrevStage(i - 1)}, {OutputInfo::intermediate(desc)});
else
testModel.addStage({InputInfo::fromNetwork()}, {OutputInfo::intermediate(desc)});
testModel.setStageBatchInfo(i, {{0, BatchSupport::Split}});
}
testModel.addStage({InputInfo::fromPrevStage(5)}, {OutputInfo::fromNetwork(0)});
RunPass();
const auto& data0 = singleElement(checkSingleLoopStart(testModel.getInputs().at(0)));
const auto& data1 = CheckSingleConnection(data0, 0);
const auto& data2 = CheckSingleConnection(data1, 1);
const auto& data3 = CheckSingleConnection(data2, 2);
const auto& data4 = CheckSingleConnection(data3, 3);
const auto& data5 = CheckSingleConnection(data4, 4);
const auto& data6 = CheckSingleConnection(data5, 5);
const auto& data7 = singleElement(checkSingleLoopEnd(data6));
const auto& data8 = CheckSingleConnection(data7, 6, batchSize);
ASSERT_EQ(data8, testModel.getOutputs().at(0));
}
TEST_F(VPU_AdjustDataBatchTest, BranchedWithBatchSplitItems) {
// -> (Batched) -> [Output]
// [Input] -> (Split) -> (Split) -> (Split) -> (Split) -> (Split) -> (Split) -> (Split)
// -> (Batched) -> [Output]
const DataDesc desc{16, 16, 3, batchSize};
testModel.createInputs({desc});
testModel.createOutputs({desc, desc});
for (int i = 0; i < 7; i++) {
if (i > 0)
testModel.addStage({InputInfo::fromPrevStage(i - 1)}, {OutputInfo::intermediate(desc)});
else
testModel.addStage({InputInfo::fromNetwork()}, {OutputInfo::intermediate(desc)});
testModel.setStageBatchInfo(i, {{0, BatchSupport::Split}});
}
testModel.addStage({InputInfo::fromPrevStage(6)}, {OutputInfo::fromNetwork(0)});
testModel.addStage({InputInfo::fromPrevStage(6)}, {OutputInfo::fromNetwork(1)});
RunPass();
const auto& data0 = singleElement(checkSingleLoopStart(testModel.getInputs().at(0)));
const auto& data1 = CheckSingleConnection(data0, 0);
const auto& data2 = CheckSingleConnection(data1, 1);
const auto& data3 = CheckSingleConnection(data2, 2);
const auto& data4 = CheckSingleConnection(data3, 3);
const auto& data5 = CheckSingleConnection(data4, 4);
const auto& data6 = CheckSingleConnection(data5, 5);
const auto& data7 = CheckSingleConnection(data6, 6);
const auto& data8 = singleElement(checkSingleLoopEnd(data7));
const auto& branches = checkBranches(data8, {StageType::None, StageType::None});
const auto& withBatch = branches[0];
const auto& withBatch_1 = branches[1];
ASSERT_EQ(withBatch->producer()->attrs().get<int>("test_ind"), 7);
ASSERT_EQ(withBatch->desc().dim(Dim::N), batchSize);
ASSERT_EQ(withBatch, testModel.getOutputs().at(0));
ASSERT_EQ(withBatch_1->producer()->attrs().get<int>("test_ind"), 8);
ASSERT_EQ(withBatch_1->desc().dim(Dim::N), batchSize);
ASSERT_EQ(withBatch_1, testModel.getOutputs().at(1));
}
TEST_F(VPU_AdjustDataBatchTest, LinearWithBatchedInTheBeginning) {
//
// [Input] -> (Batched) -> (Split) -> (Split) -> (Split) -> (Split) -> (Split) -> (Split) -> [Output]
//
const DataDesc desc{16, 16, 3, batchSize};
testModel.createInputs({desc});
testModel.createOutputs({desc});
for (int i = 0; i < 6; i++) {
if (i > 0)
testModel.addStage({InputInfo::fromPrevStage(i - 1)}, {OutputInfo::intermediate(desc)});
else
testModel.addStage({InputInfo::fromNetwork()}, {OutputInfo::intermediate(desc)});
if (i > 0)
testModel.setStageBatchInfo(i, {{0, BatchSupport::Split}});
}
testModel.addStage({InputInfo::fromPrevStage(5)}, {OutputInfo::fromNetwork()});
testModel.setStageBatchInfo(6, {{0, BatchSupport::Split}});
RunPass();
const auto& data0 = CheckSingleConnection(testModel.getInputs().at(0), 0, batchSize);
const auto& data7 = singleElement(checkSingleLoopStart(data0));
const auto& data3 = CheckSingleConnection(data7, 1);
const auto& data4 = CheckSingleConnection(data3, 2);
const auto& data5 = CheckSingleConnection(data4, 3);
const auto& data6 = CheckSingleConnection(data5, 4);
const auto& data8 = CheckSingleConnection(data6, 5);
const auto& data10 = CheckSingleConnection(data8, 6);
const auto& data11 = checkSingleLoopEnd(data10);
ASSERT_EQ(data11, testModel.getOutputs());
}
TEST_F(VPU_AdjustDataBatchTest, BranchedWithBatchItemsInTheEnd) {
// -> (Batched) -> [Output]
// [Input] -> (Split) -> (Split) -> (Split) -> (Split) -> (Split) -> (Split) -> (Batch)
// -> (Batched) -> [Output]
const DataDesc desc{16, 16, 3, batchSize};
testModel.createInputs({desc});
testModel.createOutputs({desc, desc});
for (int i = 0; i < 6; i++) {
if (i > 0)
testModel.addStage({InputInfo::fromPrevStage(i - 1)}, {OutputInfo::intermediate(desc)});
else
testModel.addStage({InputInfo::fromNetwork()}, {OutputInfo::intermediate(desc)});
testModel.setStageBatchInfo(i, {{0, BatchSupport::Split}});
}
testModel.addStage({InputInfo::fromPrevStage(5)}, {OutputInfo::intermediate(desc)});
testModel.addStage({InputInfo::fromPrevStage(6)}, {OutputInfo::fromNetwork(0)});
testModel.addStage({InputInfo::fromPrevStage(6)}, {OutputInfo::fromNetwork(1)});
RunPass();
const auto& data0 = singleElement(checkSingleLoopStart(testModel.getInputs().at(0)));
const auto& data1 = CheckSingleConnection(data0, 0);
const auto& data2 = CheckSingleConnection(data1, 1);
const auto& data3 = CheckSingleConnection(data2, 2);
const auto& data4 = CheckSingleConnection(data3, 3);
const auto& data5 = CheckSingleConnection(data4, 4);
const auto& data6 = CheckSingleConnection(data5, 5);
const auto& data7 = singleElement(checkSingleLoopEnd(data6));
const auto& data8 = CheckSingleConnection(data7, 6, batchSize);
const auto& branches = checkBranches(data8, {StageType::None, StageType::None});
const auto& withBatch = branches[0];
const auto& withBatch_1 = branches[1];
ASSERT_EQ(withBatch->producer()->attrs().get<int>("test_ind"), 7);
ASSERT_EQ(withBatch->desc().dim(Dim::N), batchSize);
ASSERT_EQ(withBatch, testModel.getOutputs().at(0));
ASSERT_EQ(withBatch_1->producer()->attrs().get<int>("test_ind"), 8);
ASSERT_EQ(withBatch_1->desc().dim(Dim::N), batchSize);
ASSERT_EQ(withBatch_1, testModel.getOutputs().at(1));
}
TEST_F(VPU_AdjustDataBatchTest, DISABLED_BranchedWithSplitAndBatchItemsInTheEnd) {
//
// -> (Split) -> (Batched) -> [Output]
// [Input] -> (Split) -> (Split) -> (Split)
// -> (Split) -> [Output]
//
const DataDesc desc{16, 16, 3, batchSize};
testModel.createInputs({desc});
testModel.createOutputs({desc, desc});
for (int i = 0; i < 5; i++) {
if (i > 0)
testModel.addStage({InputInfo::fromPrevStage(i - 1)}, {OutputInfo::intermediate(desc)});
else
testModel.addStage({InputInfo::fromNetwork()}, {OutputInfo::intermediate(desc)});
if (i != 3)
testModel.setStageBatchInfo(i, {{0, BatchSupport::Split}});
}
testModel.addStage({InputInfo::fromPrevStage(2)}, {OutputInfo::fromNetwork(1)});
testModel.setStageBatchInfo(5, {{0, BatchSupport::Split}});
RunPass();
const auto& data0 = singleElement(checkSingleLoopStart(testModel.getInputs().at(0)));
const auto& data1 = CheckSingleConnection(data0, 0);
const auto& data2 = CheckSingleConnection(data1, 1);
const auto& data3 = CheckSingleConnection(data2, 2);
const auto& branches = checkBranches(data3, {StageType::None, StageType::LoopEnd});
const auto& branch1 = branches[0];
const auto& branch2 = branches[1];
const auto& data4 = CheckSingleConnection(branch1, 3);
const auto& data7 = singleElement(checkSingleLoopEnd(data4));
const auto& data5 = CheckSingleConnection(data7, 4, batchSize);
ASSERT_EQ(data5, testModel.getOutputs().at(0));
const auto& data6 = CheckSingleConnection(branch2, 5);
ASSERT_EQ(data6, testModel.getOutputs().at(1));
}
TEST_F(VPU_AdjustDataBatchTest, DISABLED_BranchedWithBatchAndSplitItemsInTheEnd) {
//
// -> (Split) -> [Output]
// [Input] -> (Split) -> (Split) -> (Split)
// -> (Split) -> [Output]
//
const DataDesc desc{16, 16, 3, batchSize};
testModel.createInputs({desc});
testModel.createOutputs({desc, desc});
for (int i = 0; i < 3; i++) {
if (i > 0)
testModel.addStage({InputInfo::fromPrevStage(i - 1)}, {OutputInfo::intermediate(desc)});
else
testModel.addStage({InputInfo::fromNetwork()}, {OutputInfo::intermediate(desc)});
testModel.setStageBatchInfo(i, {{0, BatchSupport::Split}});
}
for (int i = 0; i < 2; i++) {
testModel.addStage({InputInfo::fromNetwork(2)}, {OutputInfo::intermediate(desc)});
testModel.setStageBatchInfo(3 + i, {{0, BatchSupport::Split}});
}
testModel.addStage({InputInfo::fromPrevStage(2)}, {OutputInfo::fromNetwork(0)});
testModel.setStageBatchInfo(3, {{0, BatchSupport::Split}});
testModel.addStage({InputInfo::fromPrevStage(2)}, {OutputInfo::fromNetwork(1)});
testModel.setStageBatchInfo(4, {{0, BatchSupport::Split}});
RunPass();
const auto& data1 = singleElement(checkSingleLoopStart(testModel.getInputs().at(0)));
const auto& data2 = CheckSingleConnection(data1, 1);
const auto& data3 = CheckSingleConnection(data2, 2);
const auto& branches = checkBranches(data3, {StageType::None, StageType::LoopEnd});
const auto& branch1 = branches[0];
const auto& branch2 = branches[1];
const auto& data4 = CheckSingleConnection(branch1, 3);
const auto& data5 = CheckSingleConnection(branch2, 4);
const auto& data6 = checkSingleLoopEnd(data5);
}