// Copyright (C) 2018-2020 Intel Corporation // SPDX-License-Identifier: Apache-2.0 // #include #include #include using namespace ::testing; using namespace std; using namespace InferenceEngine; class LocaleTests : public ::testing::Test { std::string originalLocale; std::string _model = R"V0G0N( 2 3 5 5 2 3 5 5 2 3 5 5 2 3 5 5 2 3 5 5 2 3 5 5 )V0G0N"; std::string _model_LSTM = R"V0G0N( 1 30 1 30 1 10 1 10 1 10 1 10 1 10 1 10 1 10 1 10 1 10 1 10 1 10 )V0G0N"; protected: void SetUp() override { originalLocale = setlocale(LC_ALL, nullptr); } void TearDown() override { setlocale(LC_ALL, originalLocale.c_str()); } void testBody(bool isLSTM = false) const { InferenceEngine::Core core; // This model contains layers with float attributes. // Conversion from string may be affected by locale. std::string model = isLSTM ? _model_LSTM : _model; auto net = core.ReadNetwork(model, InferenceEngine::Blob::CPtr()); IE_SUPPRESS_DEPRECATED_START if (!isLSTM) { auto power_layer = dynamic_pointer_cast(net.getLayerByName("power")); ASSERT_EQ(power_layer->scale, 0.75f); ASSERT_EQ(power_layer->offset, 0.35f); ASSERT_EQ(power_layer->power, 0.5f); auto sum_layer = dynamic_pointer_cast(net.getLayerByName("sum")); std::vector ref_coeff{0.77f, 0.33f}; ASSERT_EQ(sum_layer->coeff, ref_coeff); auto info = net.getInputsInfo(); auto preproc = info.begin()->second->getPreProcess(); ASSERT_EQ(preproc[0]->stdScale, 0.1f); ASSERT_EQ(preproc[0]->meanValue, 104.006f); } else { InferenceEngine::NetPass::UnrollRNN_if(net, [] (const RNNCellBase& rnn) -> bool { return true; }); net.serialize("UnrollRNN_if.xml"); EXPECT_EQ(0, std::remove("UnrollRNN_if.xml")); auto lstmcell_layer = dynamic_pointer_cast(net.getLayerByName("LSTMCell:split_clip")); float ref_coeff = 0.2f; ASSERT_EQ(lstmcell_layer->min_value, -ref_coeff); ASSERT_EQ(lstmcell_layer->max_value, ref_coeff); ASSERT_EQ(lstmcell_layer->GetParamAsFloat("min"), -ref_coeff); ASSERT_EQ(lstmcell_layer->GetParamAsFloat("max"), ref_coeff); } IE_SUPPRESS_DEPRECATED_END } }; TEST_F(LocaleTests, WithRULocale) { setlocale(LC_ALL, "ru_RU.UTF-8"); testBody(); } TEST_F(LocaleTests, WithUSLocale) { setlocale(LC_ALL, "en_US.UTF-8"); testBody(); } TEST_F(LocaleTests, WithRULocaleOnLSTM) { setlocale(LC_ALL, "ru_RU.UTF-8"); testBody(true); } TEST_F(LocaleTests, WithUSLocaleOnLSTM) { setlocale(LC_ALL, "en_US.UTF-8"); testBody(true); } TEST_F(LocaleTests, DISABLED_WithRULocaleCPP) { auto prev = std::locale(); std::locale::global(std::locale("ru_RU.UTF-8")); testBody(); std::locale::global(prev); } TEST_F(LocaleTests, DISABLED_WithUSLocaleCPP) { auto prev = std::locale(); std::locale::global(std::locale("en_US.UTF-8")); testBody(); std::locale::global(prev); }