Fixed tests compilation for Android ARM (#3572)

* Fixed tests compilation for Android ARM

* Added check for size_t
This commit is contained in:
Ilya Lavrenov 2020-12-15 11:51:17 +03:00 committed by GitHub
parent 3fb5f63573
commit 5da7e8dab8
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 11 additions and 6 deletions

View File

@ -33,8 +33,13 @@ std::shared_ptr<ngraph::Function> get_initial_function(const ngraph::PartialShap
auto updates = std::make_shared<ngraph::opset3::Parameter>(ngraph::element::f32, updates_shape);
auto axis_const = ngraph::opset3::Constant::create(ngraph::element::i64, {1}, {axis});
uint64_t broadcast_len = broadcast_shape.rank().get_length();
auto broadcast_shape_param = std::make_shared<ngraph::opset3::Parameter>(ngraph::element::i64, ngraph::Shape{broadcast_len});
auto broadcast_len = broadcast_shape.rank().get_length();
if (std::numeric_limits<size_t>::max() < broadcast_len) {
throw ngraph::ngraph_error("broadcast_len cannot be represented in size_t");
}
auto broadcast_shape_param = std::make_shared<ngraph::opset3::Parameter>(ngraph::element::i64,
ngraph::Shape{static_cast<size_t>(broadcast_len)});
auto broadcast = std::make_shared<ngraph::opset3::Broadcast>(indexes, broadcast_shape_param);
auto scatter = std::make_shared<ngraph::opset3::ScatterElementsUpdate>(data, broadcast, updates, axis_const);

View File

@ -27,8 +27,8 @@ public:
static std::string getTestCaseName(testing::TestParamInfo<basicLstmParams> obj);
void Run() override;
static std::shared_ptr<ngraph::Function> GetNetwork(uint64_t thirdDimOut,
uint64_t hiddenSize,
static std::shared_ptr<ngraph::Function> GetNetwork(size_t thirdDimOut,
size_t hiddenSize,
const InferenceEngine::Precision& netPrecission = InferenceEngine::Precision::FP32,
std::vector<float>* hidden_memory_init_out = nullptr,
std::vector<float>* cell_memory_init_out = nullptr);

View File

@ -53,8 +53,8 @@ void Basic_LSTM_S::SetUp() {
function = GetNetwork(49, hidden_size, netPrecision, &hidden_memory_init, &cell_memory_init);
}
std::shared_ptr<ngraph::Function> Basic_LSTM_S::GetNetwork(uint64_t thirdDimOut,
uint64_t hiddenSize,
std::shared_ptr<ngraph::Function> Basic_LSTM_S::GetNetwork(size_t thirdDimOut,
size_t hiddenSize,
const InferenceEngine::Precision& netPrecission,
std::vector<float>* hidden_memory_init_out,
std::vector<float>* cell_memory_init_out) {