From 580a0f30ebff39ab2bf6b6fdfdbfc6acf768c609 Mon Sep 17 00:00:00 2001 From: Yuan Hu Date: Mon, 17 Jan 2022 19:50:01 +0800 Subject: [PATCH] [AUTOPLUGIN] Add format string checker for log information (#9592) * add limit format on snprintf Signed-off-by: Hu, Yuan2 * add limit on format Signed-off-by: Hu, Yuan2 * add test case Signed-off-by: Hu, Yuan2 * fix a bug for LOG_TRACE Signed-off-by: Hu, Yuan2 * remove debug info --- src/plugins/auto/utils/log.cpp | 1 + src/plugins/auto/utils/log.hpp | 46 ++++- src/plugins/auto/utils/log_util.hpp | 2 +- src/tests/unit/auto/log_utils_format_test.cpp | 193 ++++++++++++++++++ 4 files changed, 240 insertions(+), 2 deletions(-) create mode 100644 src/tests/unit/auto/log_utils_format_test.cpp diff --git a/src/plugins/auto/utils/log.cpp b/src/plugins/auto/utils/log.cpp index 8b2a094878b..8ed857cbd8e 100644 --- a/src/plugins/auto/utils/log.cpp +++ b/src/plugins/auto/utils/log.cpp @@ -8,4 +8,5 @@ namespace MultiDevicePlugin { uint32_t Log::defaultLogLevel = static_cast(LogLevel::LOG_NONE); +std::vector Log::validFormat = {"u", "d", "s", "ld", "lu"}; } // namespace MultiDevicePlugin diff --git a/src/plugins/auto/utils/log.hpp b/src/plugins/auto/utils/log.hpp index 76d388950f0..7297d2ed240 100644 --- a/src/plugins/auto/utils/log.hpp +++ b/src/plugins/auto/utils/log.hpp @@ -11,6 +11,9 @@ #include #include #include +#include +#include +#include #include "singleton.hpp" #include "time_utils.hpp" @@ -65,7 +68,6 @@ inline int getDebugLevel() { return parseInteger(std::getenv("OPENVINO_LOG_LEVEL")); } const int debug_level = getDebugLevel(); - enum class LogLevel : uint32_t { FREQUENT = 0x01, PROCESS = 0x02, @@ -103,6 +105,7 @@ private: friend Singleton; static std::string colorBegin(LogLevel logLevel); static std::string colorEnd(LogLevel logLevel); + void checkFormat(const char* fmt); MOCKTESTMACRO void print(std::stringstream& stream); private: @@ -113,6 +116,7 @@ private: std::string suffix; uint32_t logLevel; static uint32_t defaultLogLevel; + static std::vector validFormat; }; inline Log::Log() @@ -173,6 +177,45 @@ inline void Log::setLogLevel(LogLevel logLevel_) { inline void Log::print(std::stringstream& stream) { std::cout << stream.str() << std::endl; } + +inline void Log::checkFormat(const char* fmt) { + const char* charIter = fmt; + std::string fmtStr = ""; + bool bCollectFmtStr = false; + while (*charIter != '\0') { + if (bCollectFmtStr) { + fmtStr += *charIter; + switch (fmtStr.size()) { + case 1: + case 2: + { + auto iter = std::find(validFormat.begin(), validFormat.end(), fmtStr); + if (iter != validFormat.end()) { + bCollectFmtStr = false; + fmtStr = ""; + } + break; + } + default: + { + throw std::runtime_error("format %" + fmtStr + " is not valid in log"); + break; + } + } + charIter++; + continue; + } + + if (*charIter == '%') { + bCollectFmtStr = true; + } + charIter++; + } + if (bCollectFmtStr) { + throw std::runtime_error("format %" + fmtStr + " is not valid in log"); + } +} + template inline void Log::doLog(bool on, bool isTraceCallStack, LogLevel level, const char* levelStr, const char* file, const char* func, const long line, const char* tag, const char* fmt, Args... args) { @@ -198,6 +241,7 @@ inline void Log::doLog(bool on, bool isTraceCallStack, LogLevel level, const cha stream << '[' << tag << ']'; } char buffer[255]; + checkFormat(fmt); std::string compatibleString = "%s" + std::string(fmt); std::snprintf(&buffer[0], sizeof(buffer), compatibleString.c_str(), "", args...); stream << ' ' << buffer << suffix << colorEnd(level); diff --git a/src/plugins/auto/utils/log_util.hpp b/src/plugins/auto/utils/log_util.hpp index 42a8b717213..da7739e5134 100644 --- a/src/plugins/auto/utils/log_util.hpp +++ b/src/plugins/auto/utils/log_util.hpp @@ -27,7 +27,7 @@ // #define HFrequent(isOn, tag, ...) HLogPrint(isOn, MultiDevicePlugin::LogLevel::FREQUENT, "FREQ", tag, __VA_ARGS__) // #define HFatal(...) HLogPrint(true, false, MultiDevicePlugin::LogLevel::FATAL, "FATAL", nullptr, __VA_ARGS__) -#define LOG_TRACE(isOn, tag, ...) HLogPrint(isOn, false, MultiDevicePlugin::LogLevel::PROCESS, "PROC", tag, __VA_ARGS__) +#define LOG_TRACE(isOn, tag, ...) HLogPrint(isOn, false, MultiDevicePlugin::LogLevel::PROCESS, "TRACE", tag, __VA_ARGS__) #define LOG_DEBUG(...) HLogPrint(true, false, MultiDevicePlugin::LogLevel::DEBUG, "DEBUG", nullptr, __VA_ARGS__) #define LOG_INFO(...) HLogPrint(true, false, MultiDevicePlugin::LogLevel::INFO, "INFO", nullptr, __VA_ARGS__) #define LOG_WARNING(...) HLogPrint(true, false, MultiDevicePlugin::LogLevel::WARN, "WARN", nullptr, __VA_ARGS__) diff --git a/src/tests/unit/auto/log_utils_format_test.cpp b/src/tests/unit/auto/log_utils_format_test.cpp new file mode 100644 index 00000000000..e5fb7f1d612 --- /dev/null +++ b/src/tests/unit/auto/log_utils_format_test.cpp @@ -0,0 +1,193 @@ +// Copyright (C) 2018-2021 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + + +#include +#include +#include "utils/log_util.hpp" +#include +using namespace MockMultiDevice; +using ::testing::_; +class LogUtilsFormatTest : public ::testing::Test { +public: + void SetUp() override { + setLogLevel("LOG_DEBUG"); + } + + void TearDown() override { + MockLog::Release(); + } +}; + +TEST_F(LogUtilsFormatTest, format_s) { + EXPECT_CALL(*(HLogger), print(_)).Times(1); + ASSERT_NO_THROW(LOG_DEBUG("%s", "DEBUG")); +} +TEST_F(LogUtilsFormatTest, format_d) { + EXPECT_CALL(*(HLogger), print(_)).Times(1); + ASSERT_NO_THROW(LOG_DEBUG("%d", -1)); +} + +TEST_F(LogUtilsFormatTest, format_ld) { + EXPECT_CALL(*(HLogger), print(_)).Times(1); + ASSERT_NO_THROW(LOG_DEBUG("%ld", -3)); +} + +TEST_F(LogUtilsFormatTest, format_u) { + EXPECT_CALL(*(HLogger), print(_)).Times(1); + ASSERT_NO_THROW(LOG_DEBUG("%u", 1)); +} + +TEST_F(LogUtilsFormatTest, format_lu) { + EXPECT_CALL(*(HLogger), print(_)).Times(1); + ASSERT_NO_THROW(LOG_DEBUG("%lu", 3)); +} + +TEST_F(LogUtilsFormatTest, format_s_d_ld_u_lu) { + EXPECT_CALL(*(HLogger), print(_)).Times(1); + ASSERT_NO_THROW(LOG_DEBUG("%s,%d,%ld,%u,%lu", "DEBUG", -1, -3, 1, 3)); +} + +TEST_F(LogUtilsFormatTest, format_s_d_ld_u_lu2) { + EXPECT_CALL(*(HLogger), print(_)).Times(1); + ASSERT_NO_THROW(LOG_DEBUG("%s%d%ld%u%lu", "DEBUG", -1, -3, 1, 3)); +} + +TEST_F(LogUtilsFormatTest, format_p) { + ASSERT_THROW(LOG_DEBUG("%p", MockLog::_mockLog), std::exception); +} + +TEST_F(LogUtilsFormatTest, format_x) { + ASSERT_THROW(LOG_DEBUG("%x", 3), std::exception); +} + +TEST_F(LogUtilsFormatTest, format_X) { + ASSERT_THROW(LOG_DEBUG("%X", 3), std::exception); +} + +TEST_F(LogUtilsFormatTest, format_o) { + ASSERT_THROW(LOG_DEBUG("%o", 3), std::exception); +} + +TEST_F(LogUtilsFormatTest, format_e) { + ASSERT_THROW(LOG_DEBUG("%e", 3), std::exception); +} + +TEST_F(LogUtilsFormatTest, format_E) { + ASSERT_THROW(LOG_DEBUG("%E", 3), std::exception); +} + +TEST_F(LogUtilsFormatTest, format_f) { + ASSERT_THROW(LOG_DEBUG("%f", 3), std::exception); +} + +TEST_F(LogUtilsFormatTest, format_F) { + ASSERT_THROW(LOG_DEBUG("%F", 3), std::exception); +} + +TEST_F(LogUtilsFormatTest, format_g) { + ASSERT_THROW(LOG_DEBUG("%g", 3), std::exception); +} + +TEST_F(LogUtilsFormatTest, format_G) { + ASSERT_THROW(LOG_DEBUG("%G", 3), std::exception); +} + + +TEST_F(LogUtilsFormatTest, format_a) { + ASSERT_THROW(LOG_DEBUG("%a", 3), std::exception); +} + +TEST_F(LogUtilsFormatTest, format_A) { + ASSERT_THROW(LOG_DEBUG("%A", 3), std::exception); +} + +TEST_F(LogUtilsFormatTest, format_c) { + ASSERT_THROW(LOG_DEBUG("%c", 3), std::exception); +} + +TEST_F(LogUtilsFormatTest, format_n) { + int num = 0; + ASSERT_THROW(LOG_DEBUG("%n", &num), std::exception); +} + +TEST_F(LogUtilsFormatTest, format__) { + ASSERT_THROW(LOG_DEBUG("%%"), std::exception); +} + +TEST_F(LogUtilsFormatTest, format_s__) { + ASSERT_THROW(LOG_DEBUG("%s%%", "DEBUG"), std::exception); +} + +TEST_F(LogUtilsFormatTest, format_dn) { + int num = 0; + ASSERT_THROW(LOG_DEBUG("%d%n", num, &num), std::exception); +} + +TEST_F(LogUtilsFormatTest, format_ccccdn) { + int num = 0; + ASSERT_THROW(LOG_DEBUG("cccc%d%n", num, &num), std::exception); +} + +TEST_F(LogUtilsFormatTest, logPrintFormat_error) { + std::string printResult = ""; + std::string pattern{"\\[[0-9]+:[0-9]+:[0-9]+\\.[0-9]+\\]ERROR\\[.+:[0-9]+\\].*"}; + std::regex regex(pattern); + ON_CALL(*(HLogger), print(_)).WillByDefault([&](std::stringstream& stream) { + printResult = stream.str(); + }); + EXPECT_CALL(*(HLogger), print(_)).Times(1); + LOG_ERROR("test"); + EXPECT_TRUE(std::regex_search(printResult, regex)); +} + +TEST_F(LogUtilsFormatTest, logPrintFormat_warning) { + std::string printResult = ""; + std::string pattern{"\\[[0-9]+:[0-9]+:[0-9]+\\.[0-9]+\\]W\\[.+:[0-9]+\\].*"}; + std::regex regex(pattern); + ON_CALL(*(HLogger), print(_)).WillByDefault([&](std::stringstream& stream) { + printResult = stream.str(); + }); + EXPECT_CALL(*(HLogger), print(_)).Times(1); + LOG_WARNING("test"); + EXPECT_TRUE(std::regex_search(printResult, regex)); +} + +TEST_F(LogUtilsFormatTest, logPrintFormat_info) { + std::string printResult = ""; + std::string pattern{"\\[[0-9]+:[0-9]+:[0-9]+\\.[0-9]+\\]I\\[.+:[0-9]+\\].*"}; + std::regex regex(pattern); + ON_CALL(*(HLogger), print(_)).WillByDefault([&](std::stringstream& stream) { + printResult = stream.str(); + }); + EXPECT_CALL(*(HLogger), print(_)).Times(1); + LOG_INFO("test"); + EXPECT_TRUE(std::regex_search(printResult, regex)); +} + +TEST_F(LogUtilsFormatTest, logPrintFormat_debug) { + std::string printResult = ""; + std::string pattern{"\\[[0-9]+:[0-9]+:[0-9]+\\.[0-9]+\\]D\\[.+:[0-9]+\\].*"}; + std::regex regex(pattern); + ON_CALL(*(HLogger), print(_)).WillByDefault([&](std::stringstream& stream) { + printResult = stream.str(); + }); + EXPECT_CALL(*(HLogger), print(_)).Times(1); + LOG_DEBUG("test"); + EXPECT_TRUE(std::regex_search(printResult, regex)); +} + +TEST_F(LogUtilsFormatTest, logPrintFormat_trace) { + setLogLevel("LOG_TRACE"); + std::string printResult = ""; + std::string pattern{"\\[[0-9]+:[0-9]+:[0-9]+\\.[0-9]+\\]T\\[.+:[0-9]+\\].*"}; + std::regex regex(pattern); + ON_CALL(*(HLogger), print(_)).WillByDefault([&](std::stringstream& stream) { + printResult = stream.str(); + }); + EXPECT_CALL(*(HLogger), print(_)).Times(1); + LOG_TRACE(true, "test", "TRACE"); + EXPECT_TRUE(std::regex_search(printResult, regex)); +} +