From a87e8f7880c9aea054f5bb4f7c2371de92242731 Mon Sep 17 00:00:00 2001 From: Smirnov Grigorii Date: Thu, 31 Mar 2022 14:07:41 +0300 Subject: [PATCH] moved TransformationsTestsF method's definitions from .hpp to .cpp (#11359) * moved * fix style --- .../common_test_utils/ngraph_test_utils.cpp | 46 +++++++++++++++++ .../common_test_utils/ngraph_test_utils.hpp | 49 +++---------------- 2 files changed, 52 insertions(+), 43 deletions(-) diff --git a/src/tests/ie_test_utils/common_test_utils/ngraph_test_utils.cpp b/src/tests/ie_test_utils/common_test_utils/ngraph_test_utils.cpp index b2b2fea8f08..b25dfbe771e 100644 --- a/src/tests/ie_test_utils/common_test_utils/ngraph_test_utils.cpp +++ b/src/tests/ie_test_utils/common_test_utils/ngraph_test_utils.cpp @@ -4,6 +4,52 @@ #include "ngraph_test_utils.hpp" +TransformationTestsF::TransformationTestsF() : comparator(FunctionsComparator::no_default()) { + m_unh = std::make_shared(); + comparator.enable(FunctionsComparator::CmpValues::NODES); + comparator.enable(FunctionsComparator::CmpValues::PRECISIONS); + comparator.enable(FunctionsComparator::CmpValues::RUNTIME_KEYS); + // TODO: enable attributes and constant values comparison by default XXX-68694 + // comparator.enable(FunctionsComparator::CmpValues::ATTRIBUTES); + // comparator.enable(FunctionsComparator::CmpValues::CONST_VALUES); +} + +void TransformationTestsF::SetUp() { + manager.register_pass(m_unh); + manager.register_pass(); +} + +void TransformationTestsF::TearDown() { + auto cloned_function = ngraph::clone_function(*function); + if (!function_ref) { + function_ref = cloned_function; + } + + manager.register_pass(m_unh, m_soft_names_comparison); + manager.run_passes(function); + if (!m_disable_rt_info_check) { + ASSERT_NO_THROW(check_rt_info(function)); + } + + if (comparator.should_compare(FunctionsComparator::ACCURACY)) { + auto acc_comparator = FunctionsComparator::no_default(); + acc_comparator.enable(FunctionsComparator::CmpValues::ACCURACY); + auto res = acc_comparator.compare(function, cloned_function); + ASSERT_TRUE(res.valid) << res.message; + comparator.disable(FunctionsComparator::CmpValues::ACCURACY); + } + auto res = comparator.compare(function, function_ref); + ASSERT_TRUE(res.valid) << res.message; +} + +void TransformationTestsF::disable_rt_info_check() { + m_disable_rt_info_check = true; +} + +void TransformationTestsF::enable_soft_names_comparison() { + m_soft_names_comparison = true; +} + void init_unique_names(std::shared_ptr f, const std::shared_ptr& unh) { ngraph::pass::Manager manager; manager.register_pass(unh); diff --git a/src/tests/ie_test_utils/common_test_utils/ngraph_test_utils.hpp b/src/tests/ie_test_utils/common_test_utils/ngraph_test_utils.hpp index c79898e1930..126c256a846 100644 --- a/src/tests/ie_test_utils/common_test_utils/ngraph_test_utils.hpp +++ b/src/tests/ie_test_utils/common_test_utils/ngraph_test_utils.hpp @@ -30,54 +30,17 @@ using TransformationTests = CommonTestUtils::TestsCommon; class TransformationTestsF : public CommonTestUtils::TestsCommon { public: - TransformationTestsF() : comparator(FunctionsComparator::no_default()) { - m_unh = std::make_shared(); - comparator.enable(FunctionsComparator::CmpValues::NODES); - comparator.enable(FunctionsComparator::CmpValues::PRECISIONS); - comparator.enable(FunctionsComparator::CmpValues::RUNTIME_KEYS); - // TODO: enable attributes and constant values comparison by default XXX-68694 - // comparator.enable(FunctionsComparator::CmpValues::ATTRIBUTES); - // comparator.enable(FunctionsComparator::CmpValues::CONST_VALUES); - } + TransformationTestsF(); - void SetUp() override { - manager.register_pass(m_unh); - manager.register_pass(); - } + void SetUp() override; - void TearDown() override { - auto cloned_function = ngraph::clone_function(*function); - if (!function_ref) { - function_ref = cloned_function; - } - - manager.register_pass(m_unh, m_soft_names_comparison); - manager.run_passes(function); - if (!m_disable_rt_info_check) { - ASSERT_NO_THROW(check_rt_info(function)); - } - - if (comparator.should_compare(FunctionsComparator::ACCURACY)) { - auto acc_comparator = FunctionsComparator::no_default(); - acc_comparator.enable(FunctionsComparator::CmpValues::ACCURACY); - auto res = acc_comparator.compare(function, cloned_function); - ASSERT_TRUE(res.valid) << res.message; - comparator.disable(FunctionsComparator::CmpValues::ACCURACY); - } - - auto res = comparator.compare(function, function_ref); - ASSERT_TRUE(res.valid) << res.message; - } + void TearDown() override; // TODO: this is temporary solution to disable rt info checks that must be applied by default // first tests must be fixed then this method must be removed XXX-68696 - void disable_rt_info_check() { - m_disable_rt_info_check = true; - } + void disable_rt_info_check(); - void enable_soft_names_comparison() { - m_soft_names_comparison = true; - } + void enable_soft_names_comparison(); std::shared_ptr function, function_ref; ngraph::pass::Manager manager; @@ -103,4 +66,4 @@ size_t count_ops_of_type(const std::shared_ptr& f) { } return count; -} \ No newline at end of file +}