[GNA] Added params to LSTMCell (#8714)

* [GNA] Added params to LSTMCell
   - parametrized num of Cells

* Fixed review comments
This commit is contained in:
Andrey Noskov 2021-12-14 11:03:26 +03:00 committed by GitHub
parent d09bbb498e
commit e06d9fdbb9
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 28 additions and 8 deletions

View File

@ -27,6 +27,10 @@ const std::vector<std::pair<size_t, size_t>> size_params = {
{300, 38},
};
size_t small_num_cells = 10;
size_t big_num_cells = 49;
const std::vector<bool> decompose = { false, true };
INSTANTIATE_TEST_SUITE_P(smoke_BasicLSTM, Basic_LSTM_S,
@ -35,6 +39,17 @@ INSTANTIATE_TEST_SUITE_P(smoke_BasicLSTM, Basic_LSTM_S,
::testing::Values(CommonTestUtils::DEVICE_GNA),
::testing::ValuesIn(configs),
::testing::ValuesIn(size_params),
::testing::Values(small_num_cells),
::testing::ValuesIn(decompose)),
Basic_LSTM_S::getTestCaseName);
INSTANTIATE_TEST_SUITE_P(smoke_BasicLSTM_big_cells_num, Basic_LSTM_S,
::testing::Combine(
::testing::ValuesIn(netPrecisions),
::testing::Values(CommonTestUtils::DEVICE_GNA),
::testing::ValuesIn(configs),
::testing::Values(size_params[0]),
::testing::Values(big_num_cells),
::testing::ValuesIn(decompose)),
Basic_LSTM_S::getTestCaseName);
} // namespace

View File

@ -20,6 +20,7 @@ typedef std::tuple<
std::string, // Target Device
std::map<std::string, std::string>, // Configuration
std::pair<size_t, size_t>, // Third dimenstion and hidden size
size_t, // Number of Cells
bool // Decompose LSTMCell
> basicLstmParams;
@ -31,6 +32,7 @@ public:
void Run() override;
static std::shared_ptr<ngraph::Function> GetNetwork(size_t thirdDimOut,
size_t hiddenSize,
size_t num_cells = 10,
const InferenceEngine::Precision& netPrecission = InferenceEngine::Precision::FP32,
std::vector<float>* hidden_memory_init_out = nullptr,
std::vector<float>* cell_memory_init_out = nullptr);

View File

@ -15,8 +15,9 @@ std::string Basic_LSTM_S::getTestCaseName(const testing::TestParamInfo<basicLstm
std::string targetDevice;
std::map<std::string, std::string> configuration;
std::pair<size_t, size_t> size_params;
size_t num_cells;
bool decompose;
std::tie(netPrecision, targetDevice, configuration, size_params, decompose) = obj.param;
std::tie(netPrecision, targetDevice, configuration, size_params, num_cells, decompose) = obj.param;
std::ostringstream result;
result << "IS=" << CommonTestUtils::vec2str(inputShapes) << "_";
@ -36,13 +37,14 @@ void Basic_LSTM_S::SetUp() {
InferenceEngine::Precision netPrecision;
std::pair<size_t, size_t> size_params;
size_t num_cells;
bool decompose;
std::tie(netPrecision, targetDevice, configuration, size_params, decompose) = this->GetParam();
std::tie(netPrecision, targetDevice, configuration, size_params, num_cells, decompose) = this->GetParam();
third_dim = size_params.first;
hidden_size = size_params.second;
outPrc = InferenceEngine::Precision::FP32;
function = GetNetwork(size_params.first, size_params.second, netPrecision, &hidden_memory_init, &cell_memory_init);
function = GetNetwork(size_params.first, size_params.second, num_cells, netPrecision, &hidden_memory_init, &cell_memory_init);
if (decompose) {
ngraph::pass::Manager manager;
manager.register_pass<ngraph::pass::LSTMCellDecomposition>();
@ -52,17 +54,18 @@ void Basic_LSTM_S::SetUp() {
std::shared_ptr<ngraph::Function> Basic_LSTM_S::GetNetwork(size_t thirdDimOut,
size_t hiddenSize,
size_t num_cells,
const InferenceEngine::Precision& netPrecission,
std::vector<float>* hidden_memory_init_out,
std::vector<float>* cell_memory_init_out) {
auto ngPrc = FuncTestUtils::PrecisionUtils::convertIE2nGraphPrc(netPrecission);
auto params = ngraph::builder::makeParams(ngPrc, { {1, 10 * thirdDimOut} });
auto params = ngraph::builder::makeParams(ngPrc, { {1, num_cells * thirdDimOut} });
const size_t batch_size = 1;
//Reshape_1 [1,thirdDimOut*10] -> [1, 10, thirdDimOut]
std::vector<uint64_t> outFormShapes1 = { batch_size, 10, thirdDimOut };
//Reshape_1 [1,thirdDimOut*num_cells] -> [1, num_cells, thirdDimOut]
std::vector<uint64_t> outFormShapes1 = { batch_size, num_cells, thirdDimOut };
auto pattern1 = std::make_shared<ngraph::opset1::Constant>(ngraph::element::Type_t::i64, ngraph::Shape{ 3 }, outFormShapes1);
auto reshape1 = std::make_shared<ngraph::opset1::Reshape>(params[0], pattern1, false);
@ -94,14 +97,14 @@ std::shared_ptr<ngraph::Function> Basic_LSTM_S::GetNetwork(size_t thirdDimOut,
auto H_o = lstm1->output(0);
auto C_o = lstm1->output(1);
//TensorIterator [1, 10, thirdDimOut] [1, 118], [1, 118] -> [1, 118]
//TensorIterator [1, num_cells, thirdDimOut] [1, 118], [1, 118] -> [1, 118]
auto body = std::make_shared<ngraph::Function>(
ngraph::OutputVector{ H_o, C_o }, ngraph::ParameterVector{ X, H_t, C_t });
auto tensor_iterator = std::make_shared<ngraph::opset1::TensorIterator>();
tensor_iterator->set_body(body);
//input tensor shape: [1, 10, thirdDimOut] chunk shape: [1, 1, thirdDimOut]
//input tensor shape: [1, num_cells, thirdDimOut] chunk shape: [1, 1, thirdDimOut]
tensor_iterator->set_sliced_input(X, reshape1, 0, 1, 1, -1, 1);
tensor_iterator->set_merged_input(H_t, H_init, H_o);
tensor_iterator->set_merged_input(C_t, C_init, C_o);