moved TransformationsTestsF method's definitions from .hpp to .cpp (#11359)

* moved

* fix style
This commit is contained in:
Smirnov Grigorii 2022-03-31 14:07:41 +03:00 committed by GitHub
parent 16a5962698
commit a87e8f7880
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 52 additions and 43 deletions

View File

@ -4,6 +4,52 @@
#include "ngraph_test_utils.hpp"
TransformationTestsF::TransformationTestsF() : comparator(FunctionsComparator::no_default()) {
m_unh = std::make_shared<ngraph::pass::UniqueNamesHolder>();
comparator.enable(FunctionsComparator::CmpValues::NODES);
comparator.enable(FunctionsComparator::CmpValues::PRECISIONS);
comparator.enable(FunctionsComparator::CmpValues::RUNTIME_KEYS);
// TODO: enable attributes and constant values comparison by default XXX-68694
// comparator.enable(FunctionsComparator::CmpValues::ATTRIBUTES);
// comparator.enable(FunctionsComparator::CmpValues::CONST_VALUES);
}
void TransformationTestsF::SetUp() {
manager.register_pass<ngraph::pass::InitUniqueNames>(m_unh);
manager.register_pass<ngraph::pass::InitNodeInfo>();
}
void TransformationTestsF::TearDown() {
auto cloned_function = ngraph::clone_function(*function);
if (!function_ref) {
function_ref = cloned_function;
}
manager.register_pass<ngraph::pass::CheckUniqueNames>(m_unh, m_soft_names_comparison);
manager.run_passes(function);
if (!m_disable_rt_info_check) {
ASSERT_NO_THROW(check_rt_info(function));
}
if (comparator.should_compare(FunctionsComparator::ACCURACY)) {
auto acc_comparator = FunctionsComparator::no_default();
acc_comparator.enable(FunctionsComparator::CmpValues::ACCURACY);
auto res = acc_comparator.compare(function, cloned_function);
ASSERT_TRUE(res.valid) << res.message;
comparator.disable(FunctionsComparator::CmpValues::ACCURACY);
}
auto res = comparator.compare(function, function_ref);
ASSERT_TRUE(res.valid) << res.message;
}
void TransformationTestsF::disable_rt_info_check() {
m_disable_rt_info_check = true;
}
void TransformationTestsF::enable_soft_names_comparison() {
m_soft_names_comparison = true;
}
void init_unique_names(std::shared_ptr<ngraph::Function> f, const std::shared_ptr<ngraph::pass::UniqueNamesHolder>& unh) {
ngraph::pass::Manager manager;
manager.register_pass<ngraph::pass::InitUniqueNames>(unh);

View File

@ -30,54 +30,17 @@ using TransformationTests = CommonTestUtils::TestsCommon;
class TransformationTestsF : public CommonTestUtils::TestsCommon {
public:
TransformationTestsF() : comparator(FunctionsComparator::no_default()) {
m_unh = std::make_shared<ngraph::pass::UniqueNamesHolder>();
comparator.enable(FunctionsComparator::CmpValues::NODES);
comparator.enable(FunctionsComparator::CmpValues::PRECISIONS);
comparator.enable(FunctionsComparator::CmpValues::RUNTIME_KEYS);
// TODO: enable attributes and constant values comparison by default XXX-68694
// comparator.enable(FunctionsComparator::CmpValues::ATTRIBUTES);
// comparator.enable(FunctionsComparator::CmpValues::CONST_VALUES);
}
TransformationTestsF();
void SetUp() override {
manager.register_pass<ngraph::pass::InitUniqueNames>(m_unh);
manager.register_pass<ngraph::pass::InitNodeInfo>();
}
void SetUp() override;
void TearDown() override {
auto cloned_function = ngraph::clone_function(*function);
if (!function_ref) {
function_ref = cloned_function;
}
manager.register_pass<ngraph::pass::CheckUniqueNames>(m_unh, m_soft_names_comparison);
manager.run_passes(function);
if (!m_disable_rt_info_check) {
ASSERT_NO_THROW(check_rt_info(function));
}
if (comparator.should_compare(FunctionsComparator::ACCURACY)) {
auto acc_comparator = FunctionsComparator::no_default();
acc_comparator.enable(FunctionsComparator::CmpValues::ACCURACY);
auto res = acc_comparator.compare(function, cloned_function);
ASSERT_TRUE(res.valid) << res.message;
comparator.disable(FunctionsComparator::CmpValues::ACCURACY);
}
auto res = comparator.compare(function, function_ref);
ASSERT_TRUE(res.valid) << res.message;
}
void TearDown() override;
// TODO: this is temporary solution to disable rt info checks that must be applied by default
// first tests must be fixed then this method must be removed XXX-68696
void disable_rt_info_check() {
m_disable_rt_info_check = true;
}
void disable_rt_info_check();
void enable_soft_names_comparison() {
m_soft_names_comparison = true;
}
void enable_soft_names_comparison();
std::shared_ptr<ov::Model> function, function_ref;
ngraph::pass::Manager manager;
@ -103,4 +66,4 @@ size_t count_ops_of_type(const std::shared_ptr<ngraph::Function>& f) {
}
return count;
}
}