Fixed tests compilation for Android ARM (#3572)
* Fixed tests compilation for Android ARM * Added check for size_t
This commit is contained in:
parent
3fb5f63573
commit
5da7e8dab8
@ -33,8 +33,13 @@ std::shared_ptr<ngraph::Function> get_initial_function(const ngraph::PartialShap
|
|||||||
auto updates = std::make_shared<ngraph::opset3::Parameter>(ngraph::element::f32, updates_shape);
|
auto updates = std::make_shared<ngraph::opset3::Parameter>(ngraph::element::f32, updates_shape);
|
||||||
auto axis_const = ngraph::opset3::Constant::create(ngraph::element::i64, {1}, {axis});
|
auto axis_const = ngraph::opset3::Constant::create(ngraph::element::i64, {1}, {axis});
|
||||||
|
|
||||||
uint64_t broadcast_len = broadcast_shape.rank().get_length();
|
auto broadcast_len = broadcast_shape.rank().get_length();
|
||||||
auto broadcast_shape_param = std::make_shared<ngraph::opset3::Parameter>(ngraph::element::i64, ngraph::Shape{broadcast_len});
|
if (std::numeric_limits<size_t>::max() < broadcast_len) {
|
||||||
|
throw ngraph::ngraph_error("broadcast_len cannot be represented in size_t");
|
||||||
|
}
|
||||||
|
|
||||||
|
auto broadcast_shape_param = std::make_shared<ngraph::opset3::Parameter>(ngraph::element::i64,
|
||||||
|
ngraph::Shape{static_cast<size_t>(broadcast_len)});
|
||||||
auto broadcast = std::make_shared<ngraph::opset3::Broadcast>(indexes, broadcast_shape_param);
|
auto broadcast = std::make_shared<ngraph::opset3::Broadcast>(indexes, broadcast_shape_param);
|
||||||
|
|
||||||
auto scatter = std::make_shared<ngraph::opset3::ScatterElementsUpdate>(data, broadcast, updates, axis_const);
|
auto scatter = std::make_shared<ngraph::opset3::ScatterElementsUpdate>(data, broadcast, updates, axis_const);
|
||||||
|
@ -27,8 +27,8 @@ public:
|
|||||||
static std::string getTestCaseName(testing::TestParamInfo<basicLstmParams> obj);
|
static std::string getTestCaseName(testing::TestParamInfo<basicLstmParams> obj);
|
||||||
|
|
||||||
void Run() override;
|
void Run() override;
|
||||||
static std::shared_ptr<ngraph::Function> GetNetwork(uint64_t thirdDimOut,
|
static std::shared_ptr<ngraph::Function> GetNetwork(size_t thirdDimOut,
|
||||||
uint64_t hiddenSize,
|
size_t hiddenSize,
|
||||||
const InferenceEngine::Precision& netPrecission = InferenceEngine::Precision::FP32,
|
const InferenceEngine::Precision& netPrecission = InferenceEngine::Precision::FP32,
|
||||||
std::vector<float>* hidden_memory_init_out = nullptr,
|
std::vector<float>* hidden_memory_init_out = nullptr,
|
||||||
std::vector<float>* cell_memory_init_out = nullptr);
|
std::vector<float>* cell_memory_init_out = nullptr);
|
||||||
|
@ -53,8 +53,8 @@ void Basic_LSTM_S::SetUp() {
|
|||||||
function = GetNetwork(49, hidden_size, netPrecision, &hidden_memory_init, &cell_memory_init);
|
function = GetNetwork(49, hidden_size, netPrecision, &hidden_memory_init, &cell_memory_init);
|
||||||
}
|
}
|
||||||
|
|
||||||
std::shared_ptr<ngraph::Function> Basic_LSTM_S::GetNetwork(uint64_t thirdDimOut,
|
std::shared_ptr<ngraph::Function> Basic_LSTM_S::GetNetwork(size_t thirdDimOut,
|
||||||
uint64_t hiddenSize,
|
size_t hiddenSize,
|
||||||
const InferenceEngine::Precision& netPrecission,
|
const InferenceEngine::Precision& netPrecission,
|
||||||
std::vector<float>* hidden_memory_init_out,
|
std::vector<float>* hidden_memory_init_out,
|
||||||
std::vector<float>* cell_memory_init_out) {
|
std::vector<float>* cell_memory_init_out) {
|
||||||
|
Loading…
Reference in New Issue
Block a user