[GNA] Added params to LSTMCell (#8714)
* [GNA] Added params to LSTMCell - parametrized num of Cells * Fixed review comments
This commit is contained in:
parent
d09bbb498e
commit
e06d9fdbb9
@ -27,6 +27,10 @@ const std::vector<std::pair<size_t, size_t>> size_params = {
|
||||
{300, 38},
|
||||
};
|
||||
|
||||
size_t small_num_cells = 10;
|
||||
|
||||
size_t big_num_cells = 49;
|
||||
|
||||
const std::vector<bool> decompose = { false, true };
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(smoke_BasicLSTM, Basic_LSTM_S,
|
||||
@ -35,6 +39,17 @@ INSTANTIATE_TEST_SUITE_P(smoke_BasicLSTM, Basic_LSTM_S,
|
||||
::testing::Values(CommonTestUtils::DEVICE_GNA),
|
||||
::testing::ValuesIn(configs),
|
||||
::testing::ValuesIn(size_params),
|
||||
::testing::Values(small_num_cells),
|
||||
::testing::ValuesIn(decompose)),
|
||||
Basic_LSTM_S::getTestCaseName);
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(smoke_BasicLSTM_big_cells_num, Basic_LSTM_S,
|
||||
::testing::Combine(
|
||||
::testing::ValuesIn(netPrecisions),
|
||||
::testing::Values(CommonTestUtils::DEVICE_GNA),
|
||||
::testing::ValuesIn(configs),
|
||||
::testing::Values(size_params[0]),
|
||||
::testing::Values(big_num_cells),
|
||||
::testing::ValuesIn(decompose)),
|
||||
Basic_LSTM_S::getTestCaseName);
|
||||
} // namespace
|
||||
|
@ -20,6 +20,7 @@ typedef std::tuple<
|
||||
std::string, // Target Device
|
||||
std::map<std::string, std::string>, // Configuration
|
||||
std::pair<size_t, size_t>, // Third dimenstion and hidden size
|
||||
size_t, // Number of Cells
|
||||
bool // Decompose LSTMCell
|
||||
> basicLstmParams;
|
||||
|
||||
@ -31,6 +32,7 @@ public:
|
||||
void Run() override;
|
||||
static std::shared_ptr<ngraph::Function> GetNetwork(size_t thirdDimOut,
|
||||
size_t hiddenSize,
|
||||
size_t num_cells = 10,
|
||||
const InferenceEngine::Precision& netPrecission = InferenceEngine::Precision::FP32,
|
||||
std::vector<float>* hidden_memory_init_out = nullptr,
|
||||
std::vector<float>* cell_memory_init_out = nullptr);
|
||||
|
@ -15,8 +15,9 @@ std::string Basic_LSTM_S::getTestCaseName(const testing::TestParamInfo<basicLstm
|
||||
std::string targetDevice;
|
||||
std::map<std::string, std::string> configuration;
|
||||
std::pair<size_t, size_t> size_params;
|
||||
size_t num_cells;
|
||||
bool decompose;
|
||||
std::tie(netPrecision, targetDevice, configuration, size_params, decompose) = obj.param;
|
||||
std::tie(netPrecision, targetDevice, configuration, size_params, num_cells, decompose) = obj.param;
|
||||
|
||||
std::ostringstream result;
|
||||
result << "IS=" << CommonTestUtils::vec2str(inputShapes) << "_";
|
||||
@ -36,13 +37,14 @@ void Basic_LSTM_S::SetUp() {
|
||||
|
||||
InferenceEngine::Precision netPrecision;
|
||||
std::pair<size_t, size_t> size_params;
|
||||
size_t num_cells;
|
||||
bool decompose;
|
||||
std::tie(netPrecision, targetDevice, configuration, size_params, decompose) = this->GetParam();
|
||||
std::tie(netPrecision, targetDevice, configuration, size_params, num_cells, decompose) = this->GetParam();
|
||||
third_dim = size_params.first;
|
||||
hidden_size = size_params.second;
|
||||
outPrc = InferenceEngine::Precision::FP32;
|
||||
|
||||
function = GetNetwork(size_params.first, size_params.second, netPrecision, &hidden_memory_init, &cell_memory_init);
|
||||
function = GetNetwork(size_params.first, size_params.second, num_cells, netPrecision, &hidden_memory_init, &cell_memory_init);
|
||||
if (decompose) {
|
||||
ngraph::pass::Manager manager;
|
||||
manager.register_pass<ngraph::pass::LSTMCellDecomposition>();
|
||||
@ -52,17 +54,18 @@ void Basic_LSTM_S::SetUp() {
|
||||
|
||||
std::shared_ptr<ngraph::Function> Basic_LSTM_S::GetNetwork(size_t thirdDimOut,
|
||||
size_t hiddenSize,
|
||||
size_t num_cells,
|
||||
const InferenceEngine::Precision& netPrecission,
|
||||
std::vector<float>* hidden_memory_init_out,
|
||||
std::vector<float>* cell_memory_init_out) {
|
||||
auto ngPrc = FuncTestUtils::PrecisionUtils::convertIE2nGraphPrc(netPrecission);
|
||||
|
||||
auto params = ngraph::builder::makeParams(ngPrc, { {1, 10 * thirdDimOut} });
|
||||
auto params = ngraph::builder::makeParams(ngPrc, { {1, num_cells * thirdDimOut} });
|
||||
|
||||
const size_t batch_size = 1;
|
||||
|
||||
//Reshape_1 [1,thirdDimOut*10] -> [1, 10, thirdDimOut]
|
||||
std::vector<uint64_t> outFormShapes1 = { batch_size, 10, thirdDimOut };
|
||||
//Reshape_1 [1,thirdDimOut*num_cells] -> [1, num_cells, thirdDimOut]
|
||||
std::vector<uint64_t> outFormShapes1 = { batch_size, num_cells, thirdDimOut };
|
||||
auto pattern1 = std::make_shared<ngraph::opset1::Constant>(ngraph::element::Type_t::i64, ngraph::Shape{ 3 }, outFormShapes1);
|
||||
auto reshape1 = std::make_shared<ngraph::opset1::Reshape>(params[0], pattern1, false);
|
||||
|
||||
@ -94,14 +97,14 @@ std::shared_ptr<ngraph::Function> Basic_LSTM_S::GetNetwork(size_t thirdDimOut,
|
||||
auto H_o = lstm1->output(0);
|
||||
auto C_o = lstm1->output(1);
|
||||
|
||||
//TensorIterator [1, 10, thirdDimOut] [1, 118], [1, 118] -> [1, 118]
|
||||
//TensorIterator [1, num_cells, thirdDimOut] [1, 118], [1, 118] -> [1, 118]
|
||||
auto body = std::make_shared<ngraph::Function>(
|
||||
ngraph::OutputVector{ H_o, C_o }, ngraph::ParameterVector{ X, H_t, C_t });
|
||||
|
||||
auto tensor_iterator = std::make_shared<ngraph::opset1::TensorIterator>();
|
||||
tensor_iterator->set_body(body);
|
||||
|
||||
//input tensor shape: [1, 10, thirdDimOut] chunk shape: [1, 1, thirdDimOut]
|
||||
//input tensor shape: [1, num_cells, thirdDimOut] chunk shape: [1, 1, thirdDimOut]
|
||||
tensor_iterator->set_sliced_input(X, reshape1, 0, 1, 1, -1, 1);
|
||||
tensor_iterator->set_merged_input(H_t, H_init, H_o);
|
||||
tensor_iterator->set_merged_input(C_t, C_init, C_o);
|
||||
|
Loading…
Reference in New Issue
Block a user