moved TransformationsTestsF method's definitions from .hpp to .cpp (#11359)
* moved * fix style
This commit is contained in:
parent
16a5962698
commit
a87e8f7880
@ -4,6 +4,52 @@
|
||||
|
||||
#include "ngraph_test_utils.hpp"
|
||||
|
||||
TransformationTestsF::TransformationTestsF() : comparator(FunctionsComparator::no_default()) {
|
||||
m_unh = std::make_shared<ngraph::pass::UniqueNamesHolder>();
|
||||
comparator.enable(FunctionsComparator::CmpValues::NODES);
|
||||
comparator.enable(FunctionsComparator::CmpValues::PRECISIONS);
|
||||
comparator.enable(FunctionsComparator::CmpValues::RUNTIME_KEYS);
|
||||
// TODO: enable attributes and constant values comparison by default XXX-68694
|
||||
// comparator.enable(FunctionsComparator::CmpValues::ATTRIBUTES);
|
||||
// comparator.enable(FunctionsComparator::CmpValues::CONST_VALUES);
|
||||
}
|
||||
|
||||
void TransformationTestsF::SetUp() {
|
||||
manager.register_pass<ngraph::pass::InitUniqueNames>(m_unh);
|
||||
manager.register_pass<ngraph::pass::InitNodeInfo>();
|
||||
}
|
||||
|
||||
void TransformationTestsF::TearDown() {
|
||||
auto cloned_function = ngraph::clone_function(*function);
|
||||
if (!function_ref) {
|
||||
function_ref = cloned_function;
|
||||
}
|
||||
|
||||
manager.register_pass<ngraph::pass::CheckUniqueNames>(m_unh, m_soft_names_comparison);
|
||||
manager.run_passes(function);
|
||||
if (!m_disable_rt_info_check) {
|
||||
ASSERT_NO_THROW(check_rt_info(function));
|
||||
}
|
||||
|
||||
if (comparator.should_compare(FunctionsComparator::ACCURACY)) {
|
||||
auto acc_comparator = FunctionsComparator::no_default();
|
||||
acc_comparator.enable(FunctionsComparator::CmpValues::ACCURACY);
|
||||
auto res = acc_comparator.compare(function, cloned_function);
|
||||
ASSERT_TRUE(res.valid) << res.message;
|
||||
comparator.disable(FunctionsComparator::CmpValues::ACCURACY);
|
||||
}
|
||||
auto res = comparator.compare(function, function_ref);
|
||||
ASSERT_TRUE(res.valid) << res.message;
|
||||
}
|
||||
|
||||
void TransformationTestsF::disable_rt_info_check() {
|
||||
m_disable_rt_info_check = true;
|
||||
}
|
||||
|
||||
void TransformationTestsF::enable_soft_names_comparison() {
|
||||
m_soft_names_comparison = true;
|
||||
}
|
||||
|
||||
void init_unique_names(std::shared_ptr<ngraph::Function> f, const std::shared_ptr<ngraph::pass::UniqueNamesHolder>& unh) {
|
||||
ngraph::pass::Manager manager;
|
||||
manager.register_pass<ngraph::pass::InitUniqueNames>(unh);
|
||||
|
@ -30,54 +30,17 @@ using TransformationTests = CommonTestUtils::TestsCommon;
|
||||
|
||||
class TransformationTestsF : public CommonTestUtils::TestsCommon {
|
||||
public:
|
||||
TransformationTestsF() : comparator(FunctionsComparator::no_default()) {
|
||||
m_unh = std::make_shared<ngraph::pass::UniqueNamesHolder>();
|
||||
comparator.enable(FunctionsComparator::CmpValues::NODES);
|
||||
comparator.enable(FunctionsComparator::CmpValues::PRECISIONS);
|
||||
comparator.enable(FunctionsComparator::CmpValues::RUNTIME_KEYS);
|
||||
// TODO: enable attributes and constant values comparison by default XXX-68694
|
||||
// comparator.enable(FunctionsComparator::CmpValues::ATTRIBUTES);
|
||||
// comparator.enable(FunctionsComparator::CmpValues::CONST_VALUES);
|
||||
}
|
||||
TransformationTestsF();
|
||||
|
||||
void SetUp() override {
|
||||
manager.register_pass<ngraph::pass::InitUniqueNames>(m_unh);
|
||||
manager.register_pass<ngraph::pass::InitNodeInfo>();
|
||||
}
|
||||
void SetUp() override;
|
||||
|
||||
void TearDown() override {
|
||||
auto cloned_function = ngraph::clone_function(*function);
|
||||
if (!function_ref) {
|
||||
function_ref = cloned_function;
|
||||
}
|
||||
|
||||
manager.register_pass<ngraph::pass::CheckUniqueNames>(m_unh, m_soft_names_comparison);
|
||||
manager.run_passes(function);
|
||||
if (!m_disable_rt_info_check) {
|
||||
ASSERT_NO_THROW(check_rt_info(function));
|
||||
}
|
||||
|
||||
if (comparator.should_compare(FunctionsComparator::ACCURACY)) {
|
||||
auto acc_comparator = FunctionsComparator::no_default();
|
||||
acc_comparator.enable(FunctionsComparator::CmpValues::ACCURACY);
|
||||
auto res = acc_comparator.compare(function, cloned_function);
|
||||
ASSERT_TRUE(res.valid) << res.message;
|
||||
comparator.disable(FunctionsComparator::CmpValues::ACCURACY);
|
||||
}
|
||||
|
||||
auto res = comparator.compare(function, function_ref);
|
||||
ASSERT_TRUE(res.valid) << res.message;
|
||||
}
|
||||
void TearDown() override;
|
||||
|
||||
// TODO: this is temporary solution to disable rt info checks that must be applied by default
|
||||
// first tests must be fixed then this method must be removed XXX-68696
|
||||
void disable_rt_info_check() {
|
||||
m_disable_rt_info_check = true;
|
||||
}
|
||||
void disable_rt_info_check();
|
||||
|
||||
void enable_soft_names_comparison() {
|
||||
m_soft_names_comparison = true;
|
||||
}
|
||||
void enable_soft_names_comparison();
|
||||
|
||||
std::shared_ptr<ov::Model> function, function_ref;
|
||||
ngraph::pass::Manager manager;
|
||||
@ -103,4 +66,4 @@ size_t count_ops_of_type(const std::shared_ptr<ngraph::Function>& f) {
|
||||
}
|
||||
|
||||
return count;
|
||||
}
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user