Fixed tests compilation for Android ARM (#3572)
* Fixed tests compilation for Android ARM * Added check for size_t
This commit is contained in:
parent
3fb5f63573
commit
5da7e8dab8
@ -33,8 +33,13 @@ std::shared_ptr<ngraph::Function> get_initial_function(const ngraph::PartialShap
|
||||
auto updates = std::make_shared<ngraph::opset3::Parameter>(ngraph::element::f32, updates_shape);
|
||||
auto axis_const = ngraph::opset3::Constant::create(ngraph::element::i64, {1}, {axis});
|
||||
|
||||
uint64_t broadcast_len = broadcast_shape.rank().get_length();
|
||||
auto broadcast_shape_param = std::make_shared<ngraph::opset3::Parameter>(ngraph::element::i64, ngraph::Shape{broadcast_len});
|
||||
auto broadcast_len = broadcast_shape.rank().get_length();
|
||||
if (std::numeric_limits<size_t>::max() < broadcast_len) {
|
||||
throw ngraph::ngraph_error("broadcast_len cannot be represented in size_t");
|
||||
}
|
||||
|
||||
auto broadcast_shape_param = std::make_shared<ngraph::opset3::Parameter>(ngraph::element::i64,
|
||||
ngraph::Shape{static_cast<size_t>(broadcast_len)});
|
||||
auto broadcast = std::make_shared<ngraph::opset3::Broadcast>(indexes, broadcast_shape_param);
|
||||
|
||||
auto scatter = std::make_shared<ngraph::opset3::ScatterElementsUpdate>(data, broadcast, updates, axis_const);
|
||||
|
@ -27,8 +27,8 @@ public:
|
||||
static std::string getTestCaseName(testing::TestParamInfo<basicLstmParams> obj);
|
||||
|
||||
void Run() override;
|
||||
static std::shared_ptr<ngraph::Function> GetNetwork(uint64_t thirdDimOut,
|
||||
uint64_t hiddenSize,
|
||||
static std::shared_ptr<ngraph::Function> GetNetwork(size_t thirdDimOut,
|
||||
size_t hiddenSize,
|
||||
const InferenceEngine::Precision& netPrecission = InferenceEngine::Precision::FP32,
|
||||
std::vector<float>* hidden_memory_init_out = nullptr,
|
||||
std::vector<float>* cell_memory_init_out = nullptr);
|
||||
|
@ -53,8 +53,8 @@ void Basic_LSTM_S::SetUp() {
|
||||
function = GetNetwork(49, hidden_size, netPrecision, &hidden_memory_init, &cell_memory_init);
|
||||
}
|
||||
|
||||
std::shared_ptr<ngraph::Function> Basic_LSTM_S::GetNetwork(uint64_t thirdDimOut,
|
||||
uint64_t hiddenSize,
|
||||
std::shared_ptr<ngraph::Function> Basic_LSTM_S::GetNetwork(size_t thirdDimOut,
|
||||
size_t hiddenSize,
|
||||
const InferenceEngine::Precision& netPrecission,
|
||||
std::vector<float>* hidden_memory_init_out,
|
||||
std::vector<float>* cell_memory_init_out) {
|
||||
|
Loading…
Reference in New Issue
Block a user