[IE][VPU]: Enable new tests for adjust_data_batch pass (#1219)
* New tests for adjust_data_batch pass
This commit is contained in:
parent
00630127e7
commit
e38106239c
@ -0,0 +1,352 @@
|
||||
// Copyright (C) 2020 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#include <vpu/stages/mx_stage.hpp>
|
||||
#include <vpu/utils/numeric.hpp>
|
||||
|
||||
#include "graph_transformer_tests.hpp"
|
||||
|
||||
using namespace vpu;
|
||||
|
||||
class VPU_AdjustDataBatchTest : public GraphTransformerTest {
|
||||
protected:
|
||||
const int batchSize = 4;
|
||||
TestModel testModel;
|
||||
|
||||
public:
|
||||
void SetUp() override {
|
||||
ASSERT_NO_FATAL_FAILURE(GraphTransformerTest::SetUp());
|
||||
|
||||
ASSERT_NO_FATAL_FAILURE(InitCompileEnv());
|
||||
|
||||
testModel = CreateTestModel();
|
||||
}
|
||||
|
||||
void RunPass() {
|
||||
PassSet pipeline;
|
||||
pipeline.addPass(passManager->dumpModel("initial"));
|
||||
pipeline.addPass(passManager->adjustDataBatch());
|
||||
pipeline.addPass(passManager->dumpModel("adjustDataBatch"));
|
||||
pipeline.run(testModel.getBaseModel());
|
||||
}
|
||||
|
||||
DataVector checkSingleLoopStart(const Data& data) {
|
||||
EXPECT_EQ(data->desc().dim(Dim::N), 4);
|
||||
EXPECT_EQ(data->numConsumers(), 2);
|
||||
|
||||
DataVector outputs;
|
||||
for (const auto& consumer : data->consumers()) {
|
||||
EXPECT_TRUE(consumer->type() == StageType::LoopStart || consumer->type() == StageType::LoopEnd);
|
||||
if (consumer->type() == StageType::LoopStart) {
|
||||
for (const auto& output : consumer->outputs()) {
|
||||
EXPECT_EQ(output->desc().dim(Dim::N), 1);
|
||||
outputs.push_back(output);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return outputs;
|
||||
}
|
||||
|
||||
DataVector checkBranches(const Data& root, const std::vector<StageType>& consumersTypes) {
|
||||
auto successors = DataVector{};
|
||||
|
||||
const auto& consumers = root->consumers() | asVector();
|
||||
EXPECT_EQ(consumers.size(), consumersTypes.size());
|
||||
for (std::size_t i = 0; i < consumers.size(); ++i) {
|
||||
const auto& consumer = consumers[i];
|
||||
const auto& expected = consumersTypes[i];
|
||||
EXPECT_EQ(consumer->type(), expected);
|
||||
|
||||
EXPECT_EQ(consumer->numOutputs(), 1);
|
||||
const auto& output = consumer->output(0);
|
||||
successors.push_back(output);
|
||||
|
||||
if (expected == StageType::LoopStart) {
|
||||
EXPECT_EQ(consumer->numOutputs(), 1);
|
||||
EXPECT_EQ(output->desc().dim(Dim::N), 1);
|
||||
} else if (expected == StageType::LoopEnd) {
|
||||
EXPECT_EQ(output->desc().dim(Dim::N), 4);
|
||||
}
|
||||
}
|
||||
|
||||
return successors;
|
||||
}
|
||||
|
||||
DataVector checkSingleLoopEnd(const Data& data) {
|
||||
EXPECT_EQ(data->numConsumers(), 1);
|
||||
|
||||
const auto& consumer = data->singleConsumer();
|
||||
EXPECT_EQ(consumer->type(), StageType::LoopEnd);
|
||||
DataVector outputs;
|
||||
for (const auto& output : consumer->outputs()) {
|
||||
EXPECT_EQ(output->desc().dim(Dim::N), 4);
|
||||
outputs.push_back(output);
|
||||
}
|
||||
|
||||
return outputs;
|
||||
}
|
||||
|
||||
static Data CheckSingleConnection(const Data& data, int testInd, int batch = 1) {
|
||||
EXPECT_EQ(data->numConsumers(), 1);
|
||||
|
||||
const auto& consumer = data->singleConsumer();
|
||||
EXPECT_EQ(consumer->type(), StageType::None);
|
||||
EXPECT_EQ(consumer->attrs().get<int>("test_ind"), testInd);
|
||||
EXPECT_EQ(consumer->numOutputs(), 1);
|
||||
const auto& output = consumer->output(0);
|
||||
EXPECT_EQ(output->desc().dim(Dim::N), batch);
|
||||
return output;
|
||||
}
|
||||
|
||||
static Data singleElement(const DataVector& dataObjects) {
|
||||
EXPECT_EQ(dataObjects.size(), 1);
|
||||
return dataObjects.front();
|
||||
}
|
||||
};
|
||||
|
||||
TEST_F(VPU_AdjustDataBatchTest, LinearWithBatchedInTheEnd) {
|
||||
//
|
||||
// [Input] -> (Split) -> (Split) -> (Split) -> (Split) -> (Split) -> (Split) -> (Batched) -> [Output]
|
||||
//
|
||||
const DataDesc desc{16, 16, 3, batchSize};
|
||||
|
||||
testModel.createInputs({desc});
|
||||
testModel.createOutputs({desc});
|
||||
|
||||
for (int i = 0; i < 6; i++) {
|
||||
if (i > 0)
|
||||
testModel.addStage({InputInfo::fromPrevStage(i - 1)}, {OutputInfo::intermediate(desc)});
|
||||
else
|
||||
testModel.addStage({InputInfo::fromNetwork()}, {OutputInfo::intermediate(desc)});
|
||||
testModel.setStageBatchInfo(i, {{0, BatchSupport::Split}});
|
||||
}
|
||||
testModel.addStage({InputInfo::fromPrevStage(5)}, {OutputInfo::fromNetwork(0)});
|
||||
|
||||
RunPass();
|
||||
|
||||
const auto& data0 = singleElement(checkSingleLoopStart(testModel.getInputs().at(0)));
|
||||
const auto& data1 = CheckSingleConnection(data0, 0);
|
||||
const auto& data2 = CheckSingleConnection(data1, 1);
|
||||
const auto& data3 = CheckSingleConnection(data2, 2);
|
||||
const auto& data4 = CheckSingleConnection(data3, 3);
|
||||
const auto& data5 = CheckSingleConnection(data4, 4);
|
||||
const auto& data6 = CheckSingleConnection(data5, 5);
|
||||
const auto& data7 = singleElement(checkSingleLoopEnd(data6));
|
||||
|
||||
const auto& data8 = CheckSingleConnection(data7, 6, batchSize);
|
||||
|
||||
ASSERT_EQ(data8, testModel.getOutputs().at(0));
|
||||
}
|
||||
|
||||
TEST_F(VPU_AdjustDataBatchTest, BranchedWithBatchSplitItems) {
|
||||
// -> (Batched) -> [Output]
|
||||
// [Input] -> (Split) -> (Split) -> (Split) -> (Split) -> (Split) -> (Split) -> (Split)
|
||||
// -> (Batched) -> [Output]
|
||||
const DataDesc desc{16, 16, 3, batchSize};
|
||||
|
||||
testModel.createInputs({desc});
|
||||
testModel.createOutputs({desc, desc});
|
||||
|
||||
for (int i = 0; i < 7; i++) {
|
||||
if (i > 0)
|
||||
testModel.addStage({InputInfo::fromPrevStage(i - 1)}, {OutputInfo::intermediate(desc)});
|
||||
else
|
||||
testModel.addStage({InputInfo::fromNetwork()}, {OutputInfo::intermediate(desc)});
|
||||
testModel.setStageBatchInfo(i, {{0, BatchSupport::Split}});
|
||||
}
|
||||
|
||||
testModel.addStage({InputInfo::fromPrevStage(6)}, {OutputInfo::fromNetwork(0)});
|
||||
testModel.addStage({InputInfo::fromPrevStage(6)}, {OutputInfo::fromNetwork(1)});
|
||||
|
||||
RunPass();
|
||||
|
||||
const auto& data0 = singleElement(checkSingleLoopStart(testModel.getInputs().at(0)));
|
||||
const auto& data1 = CheckSingleConnection(data0, 0);
|
||||
const auto& data2 = CheckSingleConnection(data1, 1);
|
||||
const auto& data3 = CheckSingleConnection(data2, 2);
|
||||
const auto& data4 = CheckSingleConnection(data3, 3);
|
||||
const auto& data5 = CheckSingleConnection(data4, 4);
|
||||
const auto& data6 = CheckSingleConnection(data5, 5);
|
||||
const auto& data7 = CheckSingleConnection(data6, 6);
|
||||
const auto& data8 = singleElement(checkSingleLoopEnd(data7));
|
||||
|
||||
const auto& branches = checkBranches(data8, {StageType::None, StageType::None});
|
||||
const auto& withBatch = branches[0];
|
||||
const auto& withBatch_1 = branches[1];
|
||||
|
||||
ASSERT_EQ(withBatch->producer()->attrs().get<int>("test_ind"), 7);
|
||||
ASSERT_EQ(withBatch->desc().dim(Dim::N), batchSize);
|
||||
ASSERT_EQ(withBatch, testModel.getOutputs().at(0));
|
||||
|
||||
ASSERT_EQ(withBatch_1->producer()->attrs().get<int>("test_ind"), 8);
|
||||
ASSERT_EQ(withBatch_1->desc().dim(Dim::N), batchSize);
|
||||
ASSERT_EQ(withBatch_1, testModel.getOutputs().at(1));
|
||||
}
|
||||
|
||||
TEST_F(VPU_AdjustDataBatchTest, LinearWithBatchedInTheBeginning) {
|
||||
//
|
||||
// [Input] -> (Batched) -> (Split) -> (Split) -> (Split) -> (Split) -> (Split) -> (Split) -> [Output]
|
||||
//
|
||||
const DataDesc desc{16, 16, 3, batchSize};
|
||||
|
||||
testModel.createInputs({desc});
|
||||
testModel.createOutputs({desc});
|
||||
|
||||
for (int i = 0; i < 6; i++) {
|
||||
if (i > 0)
|
||||
testModel.addStage({InputInfo::fromPrevStage(i - 1)}, {OutputInfo::intermediate(desc)});
|
||||
else
|
||||
testModel.addStage({InputInfo::fromNetwork()}, {OutputInfo::intermediate(desc)});
|
||||
if (i > 0)
|
||||
testModel.setStageBatchInfo(i, {{0, BatchSupport::Split}});
|
||||
}
|
||||
|
||||
testModel.addStage({InputInfo::fromPrevStage(5)}, {OutputInfo::fromNetwork()});
|
||||
testModel.setStageBatchInfo(6, {{0, BatchSupport::Split}});
|
||||
|
||||
RunPass();
|
||||
|
||||
const auto& data0 = CheckSingleConnection(testModel.getInputs().at(0), 0, batchSize);
|
||||
const auto& data7 = singleElement(checkSingleLoopStart(data0));
|
||||
const auto& data3 = CheckSingleConnection(data7, 1);
|
||||
const auto& data4 = CheckSingleConnection(data3, 2);
|
||||
const auto& data5 = CheckSingleConnection(data4, 3);
|
||||
const auto& data6 = CheckSingleConnection(data5, 4);
|
||||
const auto& data8 = CheckSingleConnection(data6, 5);
|
||||
const auto& data10 = CheckSingleConnection(data8, 6);
|
||||
const auto& data11 = checkSingleLoopEnd(data10);
|
||||
|
||||
ASSERT_EQ(data11, testModel.getOutputs());
|
||||
}
|
||||
|
||||
TEST_F(VPU_AdjustDataBatchTest, BranchedWithBatchItemsInTheEnd) {
|
||||
// -> (Batched) -> [Output]
|
||||
// [Input] -> (Split) -> (Split) -> (Split) -> (Split) -> (Split) -> (Split) -> (Batch)
|
||||
// -> (Batched) -> [Output]
|
||||
const DataDesc desc{16, 16, 3, batchSize};
|
||||
|
||||
testModel.createInputs({desc});
|
||||
testModel.createOutputs({desc, desc});
|
||||
|
||||
for (int i = 0; i < 6; i++) {
|
||||
if (i > 0)
|
||||
testModel.addStage({InputInfo::fromPrevStage(i - 1)}, {OutputInfo::intermediate(desc)});
|
||||
else
|
||||
testModel.addStage({InputInfo::fromNetwork()}, {OutputInfo::intermediate(desc)});
|
||||
testModel.setStageBatchInfo(i, {{0, BatchSupport::Split}});
|
||||
}
|
||||
|
||||
testModel.addStage({InputInfo::fromPrevStage(5)}, {OutputInfo::intermediate(desc)});
|
||||
testModel.addStage({InputInfo::fromPrevStage(6)}, {OutputInfo::fromNetwork(0)});
|
||||
testModel.addStage({InputInfo::fromPrevStage(6)}, {OutputInfo::fromNetwork(1)});
|
||||
|
||||
RunPass();
|
||||
|
||||
const auto& data0 = singleElement(checkSingleLoopStart(testModel.getInputs().at(0)));
|
||||
const auto& data1 = CheckSingleConnection(data0, 0);
|
||||
const auto& data2 = CheckSingleConnection(data1, 1);
|
||||
const auto& data3 = CheckSingleConnection(data2, 2);
|
||||
const auto& data4 = CheckSingleConnection(data3, 3);
|
||||
const auto& data5 = CheckSingleConnection(data4, 4);
|
||||
const auto& data6 = CheckSingleConnection(data5, 5);
|
||||
const auto& data7 = singleElement(checkSingleLoopEnd(data6));
|
||||
|
||||
const auto& data8 = CheckSingleConnection(data7, 6, batchSize);
|
||||
|
||||
const auto& branches = checkBranches(data8, {StageType::None, StageType::None});
|
||||
const auto& withBatch = branches[0];
|
||||
const auto& withBatch_1 = branches[1];
|
||||
|
||||
ASSERT_EQ(withBatch->producer()->attrs().get<int>("test_ind"), 7);
|
||||
ASSERT_EQ(withBatch->desc().dim(Dim::N), batchSize);
|
||||
ASSERT_EQ(withBatch, testModel.getOutputs().at(0));
|
||||
|
||||
ASSERT_EQ(withBatch_1->producer()->attrs().get<int>("test_ind"), 8);
|
||||
ASSERT_EQ(withBatch_1->desc().dim(Dim::N), batchSize);
|
||||
ASSERT_EQ(withBatch_1, testModel.getOutputs().at(1));
|
||||
}
|
||||
|
||||
TEST_F(VPU_AdjustDataBatchTest, DISABLED_BranchedWithSplitAndBatchItemsInTheEnd) {
|
||||
//
|
||||
// -> (Split) -> (Batched) -> [Output]
|
||||
// [Input] -> (Split) -> (Split) -> (Split)
|
||||
// -> (Split) -> [Output]
|
||||
//
|
||||
const DataDesc desc{16, 16, 3, batchSize};
|
||||
|
||||
testModel.createInputs({desc});
|
||||
testModel.createOutputs({desc, desc});
|
||||
|
||||
for (int i = 0; i < 5; i++) {
|
||||
if (i > 0)
|
||||
testModel.addStage({InputInfo::fromPrevStage(i - 1)}, {OutputInfo::intermediate(desc)});
|
||||
else
|
||||
testModel.addStage({InputInfo::fromNetwork()}, {OutputInfo::intermediate(desc)});
|
||||
if (i != 3)
|
||||
testModel.setStageBatchInfo(i, {{0, BatchSupport::Split}});
|
||||
}
|
||||
|
||||
testModel.addStage({InputInfo::fromPrevStage(2)}, {OutputInfo::fromNetwork(1)});
|
||||
testModel.setStageBatchInfo(5, {{0, BatchSupport::Split}});
|
||||
|
||||
RunPass();
|
||||
|
||||
const auto& data0 = singleElement(checkSingleLoopStart(testModel.getInputs().at(0)));
|
||||
const auto& data1 = CheckSingleConnection(data0, 0);
|
||||
const auto& data2 = CheckSingleConnection(data1, 1);
|
||||
const auto& data3 = CheckSingleConnection(data2, 2);
|
||||
|
||||
const auto& branches = checkBranches(data3, {StageType::None, StageType::LoopEnd});
|
||||
const auto& branch1 = branches[0];
|
||||
const auto& branch2 = branches[1];
|
||||
|
||||
const auto& data4 = CheckSingleConnection(branch1, 3);
|
||||
const auto& data7 = singleElement(checkSingleLoopEnd(data4));
|
||||
const auto& data5 = CheckSingleConnection(data7, 4, batchSize);
|
||||
ASSERT_EQ(data5, testModel.getOutputs().at(0));
|
||||
const auto& data6 = CheckSingleConnection(branch2, 5);
|
||||
ASSERT_EQ(data6, testModel.getOutputs().at(1));
|
||||
}
|
||||
|
||||
TEST_F(VPU_AdjustDataBatchTest, DISABLED_BranchedWithBatchAndSplitItemsInTheEnd) {
|
||||
//
|
||||
// -> (Split) -> [Output]
|
||||
// [Input] -> (Split) -> (Split) -> (Split)
|
||||
// -> (Split) -> [Output]
|
||||
//
|
||||
const DataDesc desc{16, 16, 3, batchSize};
|
||||
|
||||
testModel.createInputs({desc});
|
||||
testModel.createOutputs({desc, desc});
|
||||
|
||||
for (int i = 0; i < 3; i++) {
|
||||
if (i > 0)
|
||||
testModel.addStage({InputInfo::fromPrevStage(i - 1)}, {OutputInfo::intermediate(desc)});
|
||||
else
|
||||
testModel.addStage({InputInfo::fromNetwork()}, {OutputInfo::intermediate(desc)});
|
||||
testModel.setStageBatchInfo(i, {{0, BatchSupport::Split}});
|
||||
}
|
||||
for (int i = 0; i < 2; i++) {
|
||||
testModel.addStage({InputInfo::fromNetwork(2)}, {OutputInfo::intermediate(desc)});
|
||||
testModel.setStageBatchInfo(3 + i, {{0, BatchSupport::Split}});
|
||||
}
|
||||
|
||||
testModel.addStage({InputInfo::fromPrevStage(2)}, {OutputInfo::fromNetwork(0)});
|
||||
testModel.setStageBatchInfo(3, {{0, BatchSupport::Split}});
|
||||
testModel.addStage({InputInfo::fromPrevStage(2)}, {OutputInfo::fromNetwork(1)});
|
||||
testModel.setStageBatchInfo(4, {{0, BatchSupport::Split}});
|
||||
|
||||
RunPass();
|
||||
|
||||
const auto& data1 = singleElement(checkSingleLoopStart(testModel.getInputs().at(0)));
|
||||
const auto& data2 = CheckSingleConnection(data1, 1);
|
||||
const auto& data3 = CheckSingleConnection(data2, 2);
|
||||
const auto& branches = checkBranches(data3, {StageType::None, StageType::LoopEnd});
|
||||
const auto& branch1 = branches[0];
|
||||
const auto& branch2 = branches[1];
|
||||
const auto& data4 = CheckSingleConnection(branch1, 3);
|
||||
const auto& data5 = CheckSingleConnection(branch2, 4);
|
||||
const auto& data6 = checkSingleLoopEnd(data5);
|
||||
}
|
Loading…
Reference in New Issue
Block a user