From e38106239cfdcce715db54379cffbec93cfc6c0a Mon Sep 17 00:00:00 2001 From: Daria Mityagina Date: Thu, 30 Jul 2020 13:19:33 +0300 Subject: [PATCH] [IE][VPU]: Enable new tests for adjust_data_batch pass (#1219) * New tests for adjust_data_batch pass --- .../passes_tests/adjust_data_batch_tests.cpp | 352 ++++++++++++++++++ 1 file changed, 352 insertions(+) create mode 100644 inference-engine/tests/unit/vpu/middleend_tests/passes_tests/adjust_data_batch_tests.cpp diff --git a/inference-engine/tests/unit/vpu/middleend_tests/passes_tests/adjust_data_batch_tests.cpp b/inference-engine/tests/unit/vpu/middleend_tests/passes_tests/adjust_data_batch_tests.cpp new file mode 100644 index 00000000000..67e62174fac --- /dev/null +++ b/inference-engine/tests/unit/vpu/middleend_tests/passes_tests/adjust_data_batch_tests.cpp @@ -0,0 +1,352 @@ +// Copyright (C) 2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include +#include + +#include "graph_transformer_tests.hpp" + +using namespace vpu; + +class VPU_AdjustDataBatchTest : public GraphTransformerTest { +protected: + const int batchSize = 4; + TestModel testModel; + +public: + void SetUp() override { + ASSERT_NO_FATAL_FAILURE(GraphTransformerTest::SetUp()); + + ASSERT_NO_FATAL_FAILURE(InitCompileEnv()); + + testModel = CreateTestModel(); + } + + void RunPass() { + PassSet pipeline; + pipeline.addPass(passManager->dumpModel("initial")); + pipeline.addPass(passManager->adjustDataBatch()); + pipeline.addPass(passManager->dumpModel("adjustDataBatch")); + pipeline.run(testModel.getBaseModel()); + } + + DataVector checkSingleLoopStart(const Data& data) { + EXPECT_EQ(data->desc().dim(Dim::N), 4); + EXPECT_EQ(data->numConsumers(), 2); + + DataVector outputs; + for (const auto& consumer : data->consumers()) { + EXPECT_TRUE(consumer->type() == StageType::LoopStart || consumer->type() == StageType::LoopEnd); + if (consumer->type() == StageType::LoopStart) { + for (const auto& output : consumer->outputs()) { + EXPECT_EQ(output->desc().dim(Dim::N), 1); + outputs.push_back(output); + } + } + } + + return outputs; + } + + DataVector checkBranches(const Data& root, const std::vector& consumersTypes) { + auto successors = DataVector{}; + + const auto& consumers = root->consumers() | asVector(); + EXPECT_EQ(consumers.size(), consumersTypes.size()); + for (std::size_t i = 0; i < consumers.size(); ++i) { + const auto& consumer = consumers[i]; + const auto& expected = consumersTypes[i]; + EXPECT_EQ(consumer->type(), expected); + + EXPECT_EQ(consumer->numOutputs(), 1); + const auto& output = consumer->output(0); + successors.push_back(output); + + if (expected == StageType::LoopStart) { + EXPECT_EQ(consumer->numOutputs(), 1); + EXPECT_EQ(output->desc().dim(Dim::N), 1); + } else if (expected == StageType::LoopEnd) { + EXPECT_EQ(output->desc().dim(Dim::N), 4); + } + } + + return successors; + } + + DataVector checkSingleLoopEnd(const Data& data) { + EXPECT_EQ(data->numConsumers(), 1); + + const auto& consumer = data->singleConsumer(); + EXPECT_EQ(consumer->type(), StageType::LoopEnd); + DataVector outputs; + for (const auto& output : consumer->outputs()) { + EXPECT_EQ(output->desc().dim(Dim::N), 4); + outputs.push_back(output); + } + + return outputs; + } + + static Data CheckSingleConnection(const Data& data, int testInd, int batch = 1) { + EXPECT_EQ(data->numConsumers(), 1); + + const auto& consumer = data->singleConsumer(); + EXPECT_EQ(consumer->type(), StageType::None); + EXPECT_EQ(consumer->attrs().get("test_ind"), testInd); + EXPECT_EQ(consumer->numOutputs(), 1); + const auto& output = consumer->output(0); + EXPECT_EQ(output->desc().dim(Dim::N), batch); + return output; + } + + static Data singleElement(const DataVector& dataObjects) { + EXPECT_EQ(dataObjects.size(), 1); + return dataObjects.front(); + } +}; + +TEST_F(VPU_AdjustDataBatchTest, LinearWithBatchedInTheEnd) { + // + // [Input] -> (Split) -> (Split) -> (Split) -> (Split) -> (Split) -> (Split) -> (Batched) -> [Output] + // + const DataDesc desc{16, 16, 3, batchSize}; + + testModel.createInputs({desc}); + testModel.createOutputs({desc}); + + for (int i = 0; i < 6; i++) { + if (i > 0) + testModel.addStage({InputInfo::fromPrevStage(i - 1)}, {OutputInfo::intermediate(desc)}); + else + testModel.addStage({InputInfo::fromNetwork()}, {OutputInfo::intermediate(desc)}); + testModel.setStageBatchInfo(i, {{0, BatchSupport::Split}}); + } + testModel.addStage({InputInfo::fromPrevStage(5)}, {OutputInfo::fromNetwork(0)}); + + RunPass(); + + const auto& data0 = singleElement(checkSingleLoopStart(testModel.getInputs().at(0))); + const auto& data1 = CheckSingleConnection(data0, 0); + const auto& data2 = CheckSingleConnection(data1, 1); + const auto& data3 = CheckSingleConnection(data2, 2); + const auto& data4 = CheckSingleConnection(data3, 3); + const auto& data5 = CheckSingleConnection(data4, 4); + const auto& data6 = CheckSingleConnection(data5, 5); + const auto& data7 = singleElement(checkSingleLoopEnd(data6)); + + const auto& data8 = CheckSingleConnection(data7, 6, batchSize); + + ASSERT_EQ(data8, testModel.getOutputs().at(0)); +} + +TEST_F(VPU_AdjustDataBatchTest, BranchedWithBatchSplitItems) { + // -> (Batched) -> [Output] + // [Input] -> (Split) -> (Split) -> (Split) -> (Split) -> (Split) -> (Split) -> (Split) + // -> (Batched) -> [Output] + const DataDesc desc{16, 16, 3, batchSize}; + + testModel.createInputs({desc}); + testModel.createOutputs({desc, desc}); + + for (int i = 0; i < 7; i++) { + if (i > 0) + testModel.addStage({InputInfo::fromPrevStage(i - 1)}, {OutputInfo::intermediate(desc)}); + else + testModel.addStage({InputInfo::fromNetwork()}, {OutputInfo::intermediate(desc)}); + testModel.setStageBatchInfo(i, {{0, BatchSupport::Split}}); + } + + testModel.addStage({InputInfo::fromPrevStage(6)}, {OutputInfo::fromNetwork(0)}); + testModel.addStage({InputInfo::fromPrevStage(6)}, {OutputInfo::fromNetwork(1)}); + + RunPass(); + + const auto& data0 = singleElement(checkSingleLoopStart(testModel.getInputs().at(0))); + const auto& data1 = CheckSingleConnection(data0, 0); + const auto& data2 = CheckSingleConnection(data1, 1); + const auto& data3 = CheckSingleConnection(data2, 2); + const auto& data4 = CheckSingleConnection(data3, 3); + const auto& data5 = CheckSingleConnection(data4, 4); + const auto& data6 = CheckSingleConnection(data5, 5); + const auto& data7 = CheckSingleConnection(data6, 6); + const auto& data8 = singleElement(checkSingleLoopEnd(data7)); + + const auto& branches = checkBranches(data8, {StageType::None, StageType::None}); + const auto& withBatch = branches[0]; + const auto& withBatch_1 = branches[1]; + + ASSERT_EQ(withBatch->producer()->attrs().get("test_ind"), 7); + ASSERT_EQ(withBatch->desc().dim(Dim::N), batchSize); + ASSERT_EQ(withBatch, testModel.getOutputs().at(0)); + + ASSERT_EQ(withBatch_1->producer()->attrs().get("test_ind"), 8); + ASSERT_EQ(withBatch_1->desc().dim(Dim::N), batchSize); + ASSERT_EQ(withBatch_1, testModel.getOutputs().at(1)); +} + +TEST_F(VPU_AdjustDataBatchTest, LinearWithBatchedInTheBeginning) { + // + // [Input] -> (Batched) -> (Split) -> (Split) -> (Split) -> (Split) -> (Split) -> (Split) -> [Output] + // + const DataDesc desc{16, 16, 3, batchSize}; + + testModel.createInputs({desc}); + testModel.createOutputs({desc}); + + for (int i = 0; i < 6; i++) { + if (i > 0) + testModel.addStage({InputInfo::fromPrevStage(i - 1)}, {OutputInfo::intermediate(desc)}); + else + testModel.addStage({InputInfo::fromNetwork()}, {OutputInfo::intermediate(desc)}); + if (i > 0) + testModel.setStageBatchInfo(i, {{0, BatchSupport::Split}}); + } + + testModel.addStage({InputInfo::fromPrevStage(5)}, {OutputInfo::fromNetwork()}); + testModel.setStageBatchInfo(6, {{0, BatchSupport::Split}}); + + RunPass(); + + const auto& data0 = CheckSingleConnection(testModel.getInputs().at(0), 0, batchSize); + const auto& data7 = singleElement(checkSingleLoopStart(data0)); + const auto& data3 = CheckSingleConnection(data7, 1); + const auto& data4 = CheckSingleConnection(data3, 2); + const auto& data5 = CheckSingleConnection(data4, 3); + const auto& data6 = CheckSingleConnection(data5, 4); + const auto& data8 = CheckSingleConnection(data6, 5); + const auto& data10 = CheckSingleConnection(data8, 6); + const auto& data11 = checkSingleLoopEnd(data10); + + ASSERT_EQ(data11, testModel.getOutputs()); +} + +TEST_F(VPU_AdjustDataBatchTest, BranchedWithBatchItemsInTheEnd) { + // -> (Batched) -> [Output] + // [Input] -> (Split) -> (Split) -> (Split) -> (Split) -> (Split) -> (Split) -> (Batch) + // -> (Batched) -> [Output] + const DataDesc desc{16, 16, 3, batchSize}; + + testModel.createInputs({desc}); + testModel.createOutputs({desc, desc}); + + for (int i = 0; i < 6; i++) { + if (i > 0) + testModel.addStage({InputInfo::fromPrevStage(i - 1)}, {OutputInfo::intermediate(desc)}); + else + testModel.addStage({InputInfo::fromNetwork()}, {OutputInfo::intermediate(desc)}); + testModel.setStageBatchInfo(i, {{0, BatchSupport::Split}}); + } + + testModel.addStage({InputInfo::fromPrevStage(5)}, {OutputInfo::intermediate(desc)}); + testModel.addStage({InputInfo::fromPrevStage(6)}, {OutputInfo::fromNetwork(0)}); + testModel.addStage({InputInfo::fromPrevStage(6)}, {OutputInfo::fromNetwork(1)}); + + RunPass(); + + const auto& data0 = singleElement(checkSingleLoopStart(testModel.getInputs().at(0))); + const auto& data1 = CheckSingleConnection(data0, 0); + const auto& data2 = CheckSingleConnection(data1, 1); + const auto& data3 = CheckSingleConnection(data2, 2); + const auto& data4 = CheckSingleConnection(data3, 3); + const auto& data5 = CheckSingleConnection(data4, 4); + const auto& data6 = CheckSingleConnection(data5, 5); + const auto& data7 = singleElement(checkSingleLoopEnd(data6)); + + const auto& data8 = CheckSingleConnection(data7, 6, batchSize); + + const auto& branches = checkBranches(data8, {StageType::None, StageType::None}); + const auto& withBatch = branches[0]; + const auto& withBatch_1 = branches[1]; + + ASSERT_EQ(withBatch->producer()->attrs().get("test_ind"), 7); + ASSERT_EQ(withBatch->desc().dim(Dim::N), batchSize); + ASSERT_EQ(withBatch, testModel.getOutputs().at(0)); + + ASSERT_EQ(withBatch_1->producer()->attrs().get("test_ind"), 8); + ASSERT_EQ(withBatch_1->desc().dim(Dim::N), batchSize); + ASSERT_EQ(withBatch_1, testModel.getOutputs().at(1)); +} + +TEST_F(VPU_AdjustDataBatchTest, DISABLED_BranchedWithSplitAndBatchItemsInTheEnd) { + // + // -> (Split) -> (Batched) -> [Output] + // [Input] -> (Split) -> (Split) -> (Split) + // -> (Split) -> [Output] + // + const DataDesc desc{16, 16, 3, batchSize}; + + testModel.createInputs({desc}); + testModel.createOutputs({desc, desc}); + + for (int i = 0; i < 5; i++) { + if (i > 0) + testModel.addStage({InputInfo::fromPrevStage(i - 1)}, {OutputInfo::intermediate(desc)}); + else + testModel.addStage({InputInfo::fromNetwork()}, {OutputInfo::intermediate(desc)}); + if (i != 3) + testModel.setStageBatchInfo(i, {{0, BatchSupport::Split}}); + } + + testModel.addStage({InputInfo::fromPrevStage(2)}, {OutputInfo::fromNetwork(1)}); + testModel.setStageBatchInfo(5, {{0, BatchSupport::Split}}); + + RunPass(); + + const auto& data0 = singleElement(checkSingleLoopStart(testModel.getInputs().at(0))); + const auto& data1 = CheckSingleConnection(data0, 0); + const auto& data2 = CheckSingleConnection(data1, 1); + const auto& data3 = CheckSingleConnection(data2, 2); + + const auto& branches = checkBranches(data3, {StageType::None, StageType::LoopEnd}); + const auto& branch1 = branches[0]; + const auto& branch2 = branches[1]; + + const auto& data4 = CheckSingleConnection(branch1, 3); + const auto& data7 = singleElement(checkSingleLoopEnd(data4)); + const auto& data5 = CheckSingleConnection(data7, 4, batchSize); + ASSERT_EQ(data5, testModel.getOutputs().at(0)); + const auto& data6 = CheckSingleConnection(branch2, 5); + ASSERT_EQ(data6, testModel.getOutputs().at(1)); +} + +TEST_F(VPU_AdjustDataBatchTest, DISABLED_BranchedWithBatchAndSplitItemsInTheEnd) { + // + // -> (Split) -> [Output] + // [Input] -> (Split) -> (Split) -> (Split) + // -> (Split) -> [Output] + // + const DataDesc desc{16, 16, 3, batchSize}; + + testModel.createInputs({desc}); + testModel.createOutputs({desc, desc}); + + for (int i = 0; i < 3; i++) { + if (i > 0) + testModel.addStage({InputInfo::fromPrevStage(i - 1)}, {OutputInfo::intermediate(desc)}); + else + testModel.addStage({InputInfo::fromNetwork()}, {OutputInfo::intermediate(desc)}); + testModel.setStageBatchInfo(i, {{0, BatchSupport::Split}}); + } + for (int i = 0; i < 2; i++) { + testModel.addStage({InputInfo::fromNetwork(2)}, {OutputInfo::intermediate(desc)}); + testModel.setStageBatchInfo(3 + i, {{0, BatchSupport::Split}}); + } + + testModel.addStage({InputInfo::fromPrevStage(2)}, {OutputInfo::fromNetwork(0)}); + testModel.setStageBatchInfo(3, {{0, BatchSupport::Split}}); + testModel.addStage({InputInfo::fromPrevStage(2)}, {OutputInfo::fromNetwork(1)}); + testModel.setStageBatchInfo(4, {{0, BatchSupport::Split}}); + + RunPass(); + + const auto& data1 = singleElement(checkSingleLoopStart(testModel.getInputs().at(0))); + const auto& data2 = CheckSingleConnection(data1, 1); + const auto& data3 = CheckSingleConnection(data2, 2); + const auto& branches = checkBranches(data3, {StageType::None, StageType::LoopEnd}); + const auto& branch1 = branches[0]; + const auto& branch2 = branches[1]; + const auto& data4 = CheckSingleConnection(branch1, 3); + const auto& data5 = CheckSingleConnection(branch2, 4); + const auto& data6 = checkSingleLoopEnd(data5); +}