Update Accuracy Check inside Func tests (#11314)

This commit is contained in:
Gleb Kazantaev 2022-03-30 17:44:03 +03:00 committed by GitHub
parent ee31b648d1
commit 8317493e65
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
22 changed files with 422 additions and 577 deletions

View File

@ -92,7 +92,9 @@ TEST_P(FQMulFusion, ExpectFusion) {
manager.run_passes(m_function);
ASSERT_NO_THROW(check_rt_info(m_function));
auto fc = FunctionsComparator::no_default().enable(FunctionsComparator::PRECISIONS);
auto fc = FunctionsComparator::no_default()
.enable(FunctionsComparator::PRECISIONS)
.enable(FunctionsComparator::NODES);
auto res = fc.compare(m_function, m_expected_function);
ASSERT_TRUE(res.valid) << res.message;
};

View File

@ -113,7 +113,9 @@ TEST_P(nGraphFQReshapeFusionTests, ReshapeMatMul) {
manager.run_passes(f);
ASSERT_NO_THROW(check_rt_info(f));
auto fc = FunctionsComparator::no_default().enable(FunctionsComparator::PRECISIONS);
auto fc = FunctionsComparator::no_default()
.enable(FunctionsComparator::PRECISIONS)
.enable(FunctionsComparator::NODES);
auto res = fc.compare(f, ref_f);
ASSERT_TRUE(res.valid) << res.message;
}

View File

@ -30,7 +30,7 @@ TEST_F(TransformationTestsF, MatMulConstTransposesExtractionConstantWeights) {
function_ref = std::make_shared<Function>(matmul, ParameterVector{data});
}
comparator.enable(FunctionsComparator::CmpValues::CONST_VALUES);
enable_accuracy_check();
comparator.enable(FunctionsComparator::CmpValues::ACCURACY);
}
TEST_F(TransformationTestsF, MatMulConstTransposesExtractionFQOnWeights) {
@ -57,7 +57,7 @@ TEST_F(TransformationTestsF, MatMulConstTransposesExtractionFQOnWeights) {
function_ref = std::make_shared<Function>(matmul, ParameterVector{data});
}
comparator.enable(FunctionsComparator::CmpValues::CONST_VALUES);
enable_accuracy_check();
comparator.enable(FunctionsComparator::CmpValues::ACCURACY);
}
TEST_F(TransformationTestsF, NegativeMatMulConstTransposesExtractionInvalidRank) {

View File

@ -998,7 +998,7 @@ TEST_P(EliminateEltwiseTests, eliminate_eltwise) {
comparator.enable(FunctionsComparator::CmpValues::CONST_VALUES);
if (type == element::f32) {
enable_accuracy_check();
comparator.enable(FunctionsComparator::CmpValues::ACCURACY);
}
}
@ -1088,5 +1088,5 @@ TEST_F(TransformationTestsF, eliminate_eltwise_dequantization_subgraph) {
manager.register_pass<pass::NopElimination>();
comparator.enable(FunctionsComparator::CmpValues::CONST_VALUES);
enable_accuracy_check();
comparator.enable(FunctionsComparator::CmpValues::ACCURACY);
}

View File

@ -120,11 +120,11 @@ TEST_F(TransformationTestsF, PadFusionConvolutionBackpropData) {
auto pads_begin = opset5::Constant::create(element::i32, Shape{4}, {0, 0, 1, 1});
auto pads_end = opset5::Constant::create(element::i32, Shape{4}, {0, 0, 2, 2});
auto pad = std::make_shared<opset5::Pad>(data, pads_begin, pads_end, op::PadMode::CONSTANT);
set_tensor_name(pad, "pad");
auto filters = std::make_shared<opset5::Parameter>(element::f32, Shape{3, 2, 5, 5});
auto conv = std::make_shared<opset5::ConvolutionBackpropData>(pad, filters, Strides{1, 1},
CoordinateDiff{4, 4}, CoordinateDiff{3, 3}, Shape{1, 1});
set_tensor_name(conv, "conv");
function = std::make_shared<Function>(NodeVector{conv}, ParameterVector{data, filters});
manager.register_pass<pass::PadFusion>();
@ -134,7 +134,7 @@ TEST_F(TransformationTestsF, PadFusionConvolutionBackpropData) {
auto filters = std::make_shared<opset5::Parameter>(element::f32, Shape{3, 2, 5, 5});
auto conv = std::make_shared<opset5::ConvolutionBackpropData>(data, filters, Strides{1, 1},
CoordinateDiff{3, 3}, CoordinateDiff{1, 1}, Shape{1, 1});
set_tensor_name(conv, "conv");
function_ref = std::make_shared<Function>(NodeVector{conv}, ParameterVector{data, filters});
}
@ -244,11 +244,11 @@ TEST_F(TransformationTestsF, PadFusionConvolutionBackpropDataNonConstPadValue) {
std::shared_ptr<Node> pad_value = opset5::Constant::create(element::f16, Shape{}, {0});
pad_value = std::make_shared<opset5::Convert>(pad_value, element::f32);
auto pad = std::make_shared<opset5::Pad>(data, pads_begin, pads_end, pad_value, op::PadMode::CONSTANT);
set_tensor_name(pad, "pad");
auto filters = std::make_shared<opset5::Parameter>(element::f32, Shape{3, 2, 5, 5});
auto conv = std::make_shared<opset5::ConvolutionBackpropData>(pad, filters, Strides{1, 1},
CoordinateDiff{4, 4}, CoordinateDiff{3, 3}, Shape{1, 1});
set_tensor_name(conv, "conv");
function = std::make_shared<Function>(NodeVector{conv}, ParameterVector{data, filters});
manager.register_pass<pass::PadFusion>();
@ -258,7 +258,7 @@ TEST_F(TransformationTestsF, PadFusionConvolutionBackpropDataNonConstPadValue) {
auto filters = std::make_shared<opset5::Parameter>(element::f32, Shape{3, 2, 5, 5});
auto conv = std::make_shared<opset5::ConvolutionBackpropData>(data, filters, Strides{1, 1},
CoordinateDiff{3, 3}, CoordinateDiff{1, 1}, Shape{1, 1});
set_tensor_name(conv, "conv");
function_ref = std::make_shared<Function>(NodeVector{conv}, ParameterVector{data, filters});
}
@ -422,30 +422,29 @@ TEST_F(TransformationTestsF, NegativePadFusionConvolutionBackpropDataTooSmallPad
Shape data_shape{1, 3, 14, 14};
{
auto data = std::make_shared<opset5::Parameter>(element::f32, data_shape);
set_tensor_name(data, "data");
auto pads_begin = opset5::Constant::create(element::i32, Shape{4}, {0, 0, 2, 2});
auto pads_end = opset5::Constant::create(element::i32, Shape{4}, {0, 0, 2, 2});
auto pad = std::make_shared<opset5::Pad>(data, pads_begin, pads_end, op::PadMode::CONSTANT);
set_tensor_name(pad, "pad");
auto filters = std::make_shared<opset5::Parameter>(element::f32, Shape{3, 2, 5, 5});
auto conv = std::make_shared<opset5::ConvolutionBackpropData>(pad, filters, Strides{1, 1},
CoordinateDiff{1, 1}, CoordinateDiff{1, 1}, Shape{1, 1});
set_tensor_name(conv, "conv");
function = std::make_shared<Function>(NodeVector{conv}, ParameterVector{data, filters});
manager.register_pass<pass::PadFusion>();
}
{
auto data = std::make_shared<opset5::Parameter>(element::f32, data_shape);
set_tensor_name(data, "data");
auto pads_begin = opset5::Constant::create(element::i32, Shape{4}, {0, 0, 2, 2});
auto pads_end = opset5::Constant::create(element::i32, Shape{4}, {0, 0, 2, 2});
auto pad = std::make_shared<opset5::Pad>(data, pads_begin, pads_end, op::PadMode::CONSTANT);
set_tensor_name(pad, "pad");
auto filters = std::make_shared<opset5::Parameter>(element::f32, Shape{3, 2, 5, 5});
auto conv = std::make_shared<opset5::ConvolutionBackpropData>(pad, filters, Strides{1, 1},
CoordinateDiff{1, 1}, CoordinateDiff{1, 1}, Shape{1, 1});
set_tensor_name(conv, "conv");
function_ref = std::make_shared<Function>(NodeVector{conv}, ParameterVector{data, filters});
}

View File

@ -141,7 +141,7 @@ TEST_F(TransformationTestsF, RICFusionSimple) {
comparator.enable(FunctionsComparator::CmpValues::CONST_VALUES);
comparator.enable(FunctionsComparator::CmpValues::ATTRIBUTES);
disable_rt_info_check();
enable_accuracy_check();
comparator.enable(FunctionsComparator::CmpValues::ACCURACY);
}
TEST_F(TransformationTestsF, RICFusionHard) {
@ -198,7 +198,7 @@ TEST_F(TransformationTestsF, RICFusionHard) {
comparator.enable(FunctionsComparator::CmpValues::CONST_VALUES);
comparator.enable(FunctionsComparator::CmpValues::ATTRIBUTES);
disable_rt_info_check();
enable_accuracy_check();
comparator.enable(FunctionsComparator::CmpValues::ACCURACY);
}
TEST_F(TransformationTestsF, RICFusionDynamic) {
@ -248,7 +248,7 @@ TEST_F(TransformationTestsF, RICFusionEltwise1) {
comparator.enable(FunctionsComparator::CmpValues::CONST_VALUES);
comparator.enable(FunctionsComparator::CmpValues::ATTRIBUTES);
disable_rt_info_check();
enable_accuracy_check();
comparator.enable(FunctionsComparator::CmpValues::ACCURACY);
}
TEST_F(TransformationTestsF, RICFusionEltwise2) {
@ -273,7 +273,7 @@ TEST_F(TransformationTestsF, RICFusionEltwise2) {
comparator.enable(FunctionsComparator::CmpValues::CONST_VALUES);
comparator.enable(FunctionsComparator::CmpValues::ATTRIBUTES);
disable_rt_info_check();
enable_accuracy_check();
comparator.enable(FunctionsComparator::CmpValues::ACCURACY);
}
TEST_F(TransformationTestsF, RICFusionEltwise3) {
@ -298,7 +298,7 @@ TEST_F(TransformationTestsF, RICFusionEltwise3) {
comparator.enable(FunctionsComparator::CmpValues::CONST_VALUES);
comparator.enable(FunctionsComparator::CmpValues::ATTRIBUTES);
disable_rt_info_check();
enable_accuracy_check();
comparator.enable(FunctionsComparator::CmpValues::ACCURACY);
}
TEST_F(TransformationTestsF, RICFusionEltwise4) {
@ -324,7 +324,7 @@ TEST_F(TransformationTestsF, RICFusionEltwise4) {
comparator.enable(FunctionsComparator::CmpValues::CONST_VALUES);
comparator.enable(FunctionsComparator::CmpValues::ATTRIBUTES);
disable_rt_info_check();
enable_accuracy_check();
comparator.enable(FunctionsComparator::CmpValues::ACCURACY);
}
TEST_F(TransformationTestsF, RICFusionEltwise5) {
@ -350,7 +350,7 @@ TEST_F(TransformationTestsF, RICFusionEltwise5) {
comparator.enable(FunctionsComparator::CmpValues::CONST_VALUES);
comparator.enable(FunctionsComparator::CmpValues::ATTRIBUTES);
disable_rt_info_check();
enable_accuracy_check();
comparator.enable(FunctionsComparator::CmpValues::ACCURACY);
}
TEST_F(TransformationTestsF, RICFusionEltwiseNegative) {
@ -390,7 +390,7 @@ TEST_F(TransformationTestsF, RICFusionEltwiseTwoRIC) {
comparator.enable(FunctionsComparator::CmpValues::CONST_VALUES);
comparator.enable(FunctionsComparator::CmpValues::ATTRIBUTES);
disable_rt_info_check();
enable_accuracy_check();
comparator.enable(FunctionsComparator::CmpValues::ACCURACY);
}
TEST_F(TransformationTestsF, RICFusionEltwiseNegative3) {
@ -431,7 +431,7 @@ TEST_F(TransformationTestsF, RICFusionGroupConv) {
comparator.enable(FunctionsComparator::CmpValues::CONST_VALUES);
comparator.enable(FunctionsComparator::CmpValues::ATTRIBUTES);
disable_rt_info_check();
enable_accuracy_check();
comparator.enable(FunctionsComparator::CmpValues::ACCURACY);
}
TEST_F(TransformationTestsF, RICFusionGroupConvNegative) {
@ -473,7 +473,7 @@ TEST_F(TransformationTestsF, RICFusionTranspose) {
comparator.enable(FunctionsComparator::CmpValues::CONST_VALUES);
comparator.enable(FunctionsComparator::CmpValues::ATTRIBUTES);
disable_rt_info_check();
enable_accuracy_check();
comparator.enable(FunctionsComparator::CmpValues::ACCURACY);
}
TEST_F(TransformationTestsF, RICFusionFQOnTheWay) {
@ -499,7 +499,7 @@ TEST_F(TransformationTestsF, RICFusionFQOnTheWay) {
comparator.enable(FunctionsComparator::CmpValues::CONST_VALUES);
comparator.enable(FunctionsComparator::CmpValues::ATTRIBUTES);
disable_rt_info_check();
enable_accuracy_check();
comparator.enable(FunctionsComparator::CmpValues::ACCURACY);
}
TEST_F(TransformationTestsF, RICFusionFQOnTheWay2) {
@ -538,7 +538,7 @@ TEST_F(TransformationTestsF, RICFusionFQOnTheWay2) {
comparator.enable(FunctionsComparator::CmpValues::CONST_VALUES);
comparator.enable(FunctionsComparator::CmpValues::ATTRIBUTES);
disable_rt_info_check();
enable_accuracy_check();
comparator.enable(FunctionsComparator::CmpValues::ACCURACY);
}
TEST_F(TransformationTestsF, RICFusionFQOnTheWay3) {
@ -578,7 +578,7 @@ TEST_F(TransformationTestsF, RICFusionFQOnTheWay3) {
comparator.enable(FunctionsComparator::CmpValues::CONST_VALUES);
comparator.enable(FunctionsComparator::CmpValues::ATTRIBUTES);
disable_rt_info_check();
enable_accuracy_check();
comparator.enable(FunctionsComparator::CmpValues::ACCURACY);
}
TEST_F(TransformationTestsF, RICFusionShapeOf) {
@ -749,7 +749,7 @@ TEST_F(TransformationTestsF, FuseScaleValue) {
}
disable_rt_info_check();
enable_accuracy_check();
comparator.enable(FunctionsComparator::CmpValues::ACCURACY);
}
TEST_F(TransformationTestsF, FuseScaleValues) {
@ -779,5 +779,5 @@ TEST_F(TransformationTestsF, FuseScaleValues) {
}
disable_rt_info_check();
enable_accuracy_check();
comparator.enable(FunctionsComparator::CmpValues::ACCURACY);
}

View File

@ -90,7 +90,9 @@ protected:
};
TEST_P(ShuffleChannelsFusion, CompareFunctions) {
auto fc = FunctionsComparator::no_default().enable(FunctionsComparator::PRECISIONS);
auto fc = FunctionsComparator::no_default()
.enable(FunctionsComparator::PRECISIONS)
.enable(FunctionsComparator::NODES);
auto res = fc.compare(f, f_ref);
ASSERT_TRUE(res.valid) << res.message;
}

View File

@ -55,7 +55,9 @@ TEST_P(SoftmaxFusionFixture, SoftmaxFusion) {
f_ref = std::make_shared<Function>(NodeVector{softmax}, ParameterVector{data});
}
auto fc = FunctionsComparator::no_default().enable(FunctionsComparator::PRECISIONS);
auto fc = FunctionsComparator::no_default()
.enable(FunctionsComparator::PRECISIONS)
.enable(FunctionsComparator::NODES);
auto res = fc.compare(f, f_ref);
ASSERT_TRUE(res.valid) << res.message;
}
@ -113,7 +115,9 @@ TEST_P(NegativeSoftmaxFusionFixture, NegativeSoftmaxFusion) {
f_ref = std::make_shared<Function>(NodeVector{div}, ParameterVector{data});
}
auto fc = FunctionsComparator::no_default().enable(FunctionsComparator::PRECISIONS);
auto fc = FunctionsComparator::no_default()
.enable(FunctionsComparator::PRECISIONS)
.enable(FunctionsComparator::NODES);
auto res = fc.compare(f, f_ref);
ASSERT_TRUE(res.valid) << res.message;
}

View File

@ -94,6 +94,7 @@ TEST_P(TransposeSinkingFQ, TransposeFQReduce) {
ASSERT_NO_THROW(check_rt_info(f));
auto fc = FunctionsComparator::no_default()
.enable(FunctionsComparator::NODES)
.enable(FunctionsComparator::PRECISIONS)
.enable(FunctionsComparator::CONST_VALUES);
auto res = fc.compare(f, f_ref);
@ -180,6 +181,7 @@ TEST_P(TransposeSinking, TransposeReduction) {
ASSERT_NO_THROW(check_rt_info(f));
auto fc = FunctionsComparator::no_default()
.enable(FunctionsComparator::NODES)
.enable(FunctionsComparator::PRECISIONS)
.enable(FunctionsComparator::CONST_VALUES);

View File

@ -109,7 +109,9 @@ TEST_P(TransposeToReshapeTests, CompareFunctions) {
f->validate_nodes_and_infer_types();
ASSERT_NO_THROW(check_rt_info(f));
auto fc = FunctionsComparator::no_default().enable(FunctionsComparator::PRECISIONS);
auto fc = FunctionsComparator::no_default()
.enable(FunctionsComparator::NODES)
.enable(FunctionsComparator::PRECISIONS);
auto res = fc.compare(f, f_ref);
ASSERT_TRUE(res.valid) << res.message;
}

View File

@ -84,6 +84,7 @@ TEST_P(TranslateNewWeightFormatToOldOne, ReshapeMatMul) {
ASSERT_NO_THROW(check_rt_info(f));
auto fc = FunctionsComparator::no_default()
.enable(FunctionsComparator::NODES)
.enable(FunctionsComparator::PRECISIONS)
.enable(FunctionsComparator::CONST_VALUES);
auto res = fc.compare(f, f_ref);

View File

@ -127,7 +127,9 @@ TEST_P(ConvFusionTests, CompareFunctions) {
manager.run_passes(f);
ASSERT_NO_THROW(check_rt_info(f));
auto fc = FunctionsComparator::no_default().enable(FunctionsComparator::PRECISIONS);
auto fc = FunctionsComparator::no_default()
.enable(FunctionsComparator::NODES)
.enable(FunctionsComparator::PRECISIONS);
auto res = fc.compare(f, f_ref);
ASSERT_TRUE(res.valid) << res.message;
}

View File

@ -157,7 +157,9 @@ TEST_P(MulAddConversionTests, CompareFunctions) {
ASSERT_NO_THROW(manager.run_passes(f));
f->validate_nodes_and_infer_types();
auto fc = FunctionsComparator::no_default().enable(FunctionsComparator::PRECISIONS);
auto fc = FunctionsComparator::no_default()
.enable(FunctionsComparator::NODES)
.enable(FunctionsComparator::PRECISIONS);
auto res = fc.compare(f, f_ref);
ASSERT_TRUE(res.valid) << res.message;
}
@ -178,7 +180,9 @@ TEST_P(MulOrAddConversionTests, CompareFunctions) {
f->validate_nodes_and_infer_types();
auto fc = FunctionsComparator::no_default().enable(FunctionsComparator::PRECISIONS);
auto fc = FunctionsComparator::no_default()
.enable(FunctionsComparator::NODES)
.enable(FunctionsComparator::PRECISIONS);
auto res = fc.compare(f, f_ref);
ASSERT_TRUE(res.valid) << res.message;
}

View File

@ -246,7 +246,7 @@ TEST_F(TransformationTestsF, PropagateMasksBasic) {
m.run_passes(function);
}
disable_rt_info_check();
enable_accuracy_check();
comparator.enable(FunctionsComparator::CmpValues::ACCURACY);
}
@ -311,7 +311,7 @@ TEST_F(TransformationTestsF, PropagateMasksDynamicConvolution) {
m.run_passes(function);
}
disable_rt_info_check();
enable_accuracy_check();
comparator.enable(FunctionsComparator::CmpValues::ACCURACY);
}
@ -491,7 +491,7 @@ TEST_F(TransformationTestsF, PropagateMaskPassThrough) {
m.run_passes(function);
}
disable_rt_info_check();
enable_accuracy_check();
comparator.enable(FunctionsComparator::CmpValues::ACCURACY);
}
@ -643,7 +643,7 @@ TEST_F(TransformationTestsF, PropagateMasksHardDependencies) {
m.run_passes(function);
}
disable_rt_info_check();
enable_accuracy_check();
comparator.enable(FunctionsComparator::CmpValues::ACCURACY);
}
@ -761,7 +761,7 @@ TEST_F(TransformationTestsF, PropagateMasksQuantizedGroupConvolution) {
m.run_passes(function);
}
disable_rt_info_check();
enable_accuracy_check();
comparator.enable(FunctionsComparator::CmpValues::ACCURACY);
}
@ -899,7 +899,7 @@ TEST_F(TransformationTestsF, PropagateMasksQuantizedGroupConvolutionWithShapeOf)
m.run_passes(function);
}
disable_rt_info_check();
enable_accuracy_check();
comparator.enable(FunctionsComparator::CmpValues::ACCURACY);
}
@ -1017,7 +1017,7 @@ TEST_F(TransformationTestsF, PropagateMasksFakeQuantizePerTensor) {
m.run_passes(function);
}
disable_rt_info_check();
enable_accuracy_check();
comparator.enable(FunctionsComparator::CmpValues::ACCURACY);
}
@ -1202,7 +1202,7 @@ TEST_F(TransformationTestsF, PropagateMasksFakeQuantizePerChannel) {
m.run_passes(function);
}
disable_rt_info_check();
enable_accuracy_check();
comparator.enable(FunctionsComparator::CmpValues::ACCURACY);
}
@ -1299,7 +1299,7 @@ TEST_F(TransformationTestsF, TestConcatMaskPropagation) {
m.run_passes(function);
}
disable_rt_info_check();
enable_accuracy_check();
comparator.enable(FunctionsComparator::CmpValues::ACCURACY);
}
@ -1406,7 +1406,7 @@ TEST_F(TransformationTestsF, TestConcatMaskPropagationUp) {
m.run_passes(function);
}
disable_rt_info_check();
enable_accuracy_check();
comparator.enable(FunctionsComparator::CmpValues::ACCURACY);
}
@ -1554,7 +1554,7 @@ TEST_F(TransformationTestsF, PruneConvIsClosingAndInGroup) {
m.run_passes(function);
}
disable_rt_info_check();
enable_accuracy_check();
comparator.enable(FunctionsComparator::CmpValues::ACCURACY);
}
@ -1814,7 +1814,7 @@ TEST_F(TransformationTestsF, PruneReduceLayerDown) {
m.run_passes(function);
}
disable_rt_info_check();
enable_accuracy_check();
comparator.enable(FunctionsComparator::CmpValues::ACCURACY);
}
@ -1975,7 +1975,7 @@ TEST_F(TransformationTestsF, MaskPropagationReshapeUp) {
m.run_passes(function);
}
disable_rt_info_check();
enable_accuracy_check();
comparator.enable(FunctionsComparator::CmpValues::ACCURACY);
}
@ -2050,7 +2050,7 @@ TEST_F(TransformationTestsF, MaskPropagationReshapeUpWithShapeOf) {
m.run_passes(function);
}
disable_rt_info_check();
enable_accuracy_check();
comparator.enable(FunctionsComparator::CmpValues::ACCURACY);
}
@ -2149,7 +2149,7 @@ TEST_F(TransformationTestsF, MaskPropagationReshapeDown) {
m.run_passes(function);
}
disable_rt_info_check();
enable_accuracy_check();
comparator.enable(FunctionsComparator::CmpValues::ACCURACY);
}
@ -2403,7 +2403,7 @@ TEST_F(TransformationTestsF, PruneSEBlock) {
m.run_passes(function);
}
disable_rt_info_check();
enable_accuracy_check();
comparator.enable(FunctionsComparator::CmpValues::ACCURACY);
}
@ -2495,7 +2495,7 @@ TEST_F(TransformationTestsF, PropagateMasksLinear) {
m.run_passes(function);
}
disable_rt_info_check();
enable_accuracy_check();
comparator.enable(FunctionsComparator::CmpValues::ACCURACY);
}
@ -2679,7 +2679,7 @@ TEST_F(TransformationTestsF, PruneMasksMatMulColsStopRowsUp) {
m.run_passes(function);
}
disable_rt_info_check();
enable_accuracy_check();
comparator.enable(FunctionsComparator::CmpValues::ACCURACY);
}
@ -2760,7 +2760,7 @@ TEST_F(TransformationTestsF, PruneMasksMatMulRowsStopColsUp) {
m.run_passes(function);
}
disable_rt_info_check();
enable_accuracy_check();
comparator.enable(FunctionsComparator::CmpValues::ACCURACY);
}
@ -2855,5 +2855,5 @@ TEST_F(TransformationTestsF, PropagateFlattenUp) {
m.run_passes(function);
}
disable_rt_info_check();
enable_accuracy_check();
comparator.enable(FunctionsComparator::CmpValues::ACCURACY);
}

View File

@ -126,7 +126,9 @@ TEST_P(ConvertReduceToPoolingTests, CompareFunctions) {
m.run_passes(f);
ASSERT_NO_THROW(check_rt_info(f));
auto fc = FunctionsComparator::no_default().enable(FunctionsComparator::PRECISIONS);
auto fc = FunctionsComparator::no_default()
.enable(FunctionsComparator::NODES)
.enable(FunctionsComparator::PRECISIONS);
auto res = fc.compare(f, f_ref);
ASSERT_TRUE(res.valid) << res.message;
}

View File

@ -90,7 +90,9 @@ void test(std::shared_ptr<ngraph::Function> f, std::shared_ptr<ngraph::Function>
manager.register_pass<ngraph::pass::ConstantFolding>();
ASSERT_NO_THROW(manager.run_passes(f));
auto fc = FunctionsComparator::no_default().enable(FunctionsComparator::PRECISIONS);
auto fc = FunctionsComparator::no_default()
.enable(FunctionsComparator::NODES)
.enable(FunctionsComparator::PRECISIONS);
auto res = fc.compare(f, f_ref);
ASSERT_TRUE(res.valid) << res.message;
}

View File

@ -650,7 +650,9 @@ TEST(TransformationTests, DifferentPrecisionVersusAttributes) {
///
{ // check precision only
const auto fc = FunctionsComparator::no_default().enable(FunctionsComparator::PRECISIONS);
const auto fc = FunctionsComparator::no_default()
.enable(FunctionsComparator::NODES)
.enable(FunctionsComparator::PRECISIONS);
const auto res = fc.compare(f1, f2);
EXPECT_FALSE(res.valid);
EXPECT_THAT(res.message, HasSubstr("Different element type detected"));
@ -660,6 +662,7 @@ TEST(TransformationTests, DifferentPrecisionVersusAttributes) {
{ // check precision and attributes
const auto fc = FunctionsComparator::no_default()
.enable(FunctionsComparator::NODES)
.enable(FunctionsComparator::PRECISIONS)
.enable(FunctionsComparator::ATTRIBUTES);
const auto res = fc.compare(f1, f2);

View File

@ -62,7 +62,7 @@ class CompressQuantizeWeightsTests
function_ref = std::make_shared<Function>(mul, ParameterVector{});
}
comparator.enable(FunctionsComparator::CmpValues::CONST_VALUES);
enable_accuracy_check();
comparator.enable(FunctionsComparator::CmpValues::ACCURACY);
}
};
@ -108,7 +108,7 @@ TEST_F(TransformationTestsF, CompressQuantizeWeightsWithDequantizationSubgraph)
function_ref = std::make_shared<Function>(NodeVector{mul}, ParameterVector{});
}
comparator.enable(FunctionsComparator::CmpValues::CONST_VALUES);
enable_accuracy_check();
comparator.enable(FunctionsComparator::CmpValues::ACCURACY);
}
TEST_F(TransformationTestsF, CompressQuantizeWeightsWithZeroPointOptimizer) {
@ -133,7 +133,7 @@ TEST_F(TransformationTestsF, CompressQuantizeWeightsWithZeroPointOptimizer) {
function_ref = std::make_shared<Function>(NodeVector{mul}, ParameterVector{});
}
comparator.enable(FunctionsComparator::CmpValues::CONST_VALUES);
enable_accuracy_check();
comparator.enable(FunctionsComparator::CmpValues::ACCURACY);
}
TEST_F(TransformationTestsF, NegativeCompressQuantizeWeightsWithZeroPointOptimizer) {
@ -159,7 +159,7 @@ TEST_F(TransformationTestsF, NegativeCompressQuantizeWeightsWithZeroPointOptimiz
function_ref = std::make_shared<Function>(NodeVector{mul}, ParameterVector{});
}
comparator.enable(FunctionsComparator::CmpValues::CONST_VALUES);
enable_accuracy_check();
comparator.enable(FunctionsComparator::CmpValues::ACCURACY);
}
TEST_F(TransformationTestsF, NegativeCompressQuantizeWeightsNonConstantInput) {
@ -175,5 +175,5 @@ TEST_F(TransformationTestsF, NegativeCompressQuantizeWeightsNonConstantInput) {
manager.register_pass<pass::ZeroPointOptimizer>();
comparator.enable(FunctionsComparator::CmpValues::CONST_VALUES);
enable_accuracy_check();
comparator.enable(FunctionsComparator::CmpValues::ACCURACY);
}

View File

@ -2,7 +2,10 @@
// SPDX-License-Identifier: Apache-2.0
//
#include "graph_comparator.hpp"
#include "common_test_utils/graph_comparator.hpp"
#include <gtest/gtest.h>
#include "ie_common.h"
#include <algorithm>
#include <cassert>
@ -20,6 +23,10 @@
#include <typeinfo>
#include <vector>
#include "common_test_utils/ov_tensor_utils.hpp"
#include "ngraph_functions/utils/ngraph_helpers.hpp"
namespace {
inline namespace tools {
bool is_type_relaxed(const std::string& type) {
@ -31,12 +38,12 @@ bool compare_type_info(const ngraph::DiscreteTypeInfo& info1, const ngraph::Disc
if (!is_type_relaxed(info1.name) && !is_type_relaxed(info2.name) && (info1.version != info2.version)) {
return false;
}
OPENVINO_SUPPRESS_DEPRECATED_END
OPENVINO_SUPPRESS_DEPRECATED_END
const std::string info1Name =
is_type_relaxed(info1.name) && (info1.parent != nullptr) ? info1.parent->name : info1.name;
is_type_relaxed(info1.name) && (info1.parent != nullptr) ? info1.parent->name : info1.name;
const std::string info2Name =
is_type_relaxed(info2.name) && (info2.parent != nullptr) ? info2.parent->name : info2.name;
is_type_relaxed(info2.name) && (info2.parent != nullptr) ? info2.parent->name : info2.name;
return info1Name == info2Name;
}
@ -121,9 +128,9 @@ public:
explicit NodeAndInputDescription(const InputNode& input,
const Parameter* parameter,
const InputDescripton* description)
: m_input(input),
m_parameter(not_null(parameter)),
m_description(not_null(description)) {}
: m_input(input),
m_parameter(not_null(parameter)),
m_description(not_null(description)) {}
static bool equal_descriptions(const InputDescripton* lhs, const InputDescripton* rhs) {
if (!lhs || !rhs || lhs->get_type_info() != rhs->get_type_info()) {
@ -172,7 +179,7 @@ public:
}
for (size_t i = 0; i != param_shape.size(); ++i) {
const auto expected_axis_size =
i == slice_description->m_axis ? slice_description->m_part_size * num_iterations : param_shape[i];
i == slice_description->m_axis ? slice_description->m_part_size * num_iterations : param_shape[i];
if (input_shape[i] != expected_axis_size) {
return false;
}
@ -215,9 +222,9 @@ public:
explicit NodeAndOutputDescription(const OutputNode& output,
const Result* result,
const OutputDescription* description)
: m_output(output),
m_result(not_null(result)),
m_description(not_null(description)) {}
: m_output(output),
m_result(not_null(result)),
m_description(not_null(description)) {}
static bool equal_descriptions(const OutputDescription* lhs, const OutputDescription* rhs) {
if (!lhs || !rhs || lhs->get_type_info() != rhs->get_type_info()) {
@ -302,8 +309,8 @@ public:
using Id = uint64_t;
explicit BackEdge(const Parameter* parameter, const Result* result)
: m_parameter(not_null(parameter)),
m_result(not_null(result)) {}
: m_parameter(not_null(parameter)),
m_result(not_null(result)) {}
bool result_and_parameter_match() const {
return equal_type_and_partial_shape(m_result->output(0), *m_parameter);
@ -539,103 +546,113 @@ Comparator::Result Comparator::compare(const std::shared_ptr<ngraph::Function>&
* + Check node attributes by Visitor API
*/
auto f_results = f->get_results();
auto f_ref_results = f_ref->get_results();
if (should_compare(CmpValues::NODES)) {
auto f_results = f->get_results();
auto f_ref_results = f_ref->get_results();
auto cmp = less_by_name;
// In case if Result source output has more than one name so the Result may have any of this names as a friendly
// name An in case of multiple names we sort Result operation using their parent node names
if (std::any_of(f_results.begin(),
f_results.end(),
[](const std::shared_ptr<ngraph::Node>& node) {
const auto& t = node->input_value(0).get_tensor_ptr();
return t->get_names().size() > 1;
}) ||
std::any_of(f_ref_results.begin(), f_ref_results.end(), [](const std::shared_ptr<ngraph::Node>& node) {
const auto& t = node->input_value(0).get_tensor_ptr();
return t->get_names().size() > 1;
})) {
cmp = less_by_parent_name;
}
auto cmp = less_by_name;
// In case if Result source output has more than one name so the Result may have any of this names as a friendly
// name An in case of multiple names we sort Result operation using their parent node names
if (std::any_of(f_results.begin(),
f_results.end(),
[](const std::shared_ptr<ngraph::Node> &node) {
const auto &t = node->input_value(0).get_tensor_ptr();
return t->get_names().size() > 1;
}) ||
std::any_of(f_ref_results.begin(), f_ref_results.end(), [](const std::shared_ptr<ngraph::Node> &node) {
const auto &t = node->input_value(0).get_tensor_ptr();
return t->get_names().size() > 1;
})) {
cmp = less_by_parent_name;
}
std::sort(f_results.begin(), f_results.end(), cmp);
std::sort(f_ref_results.begin(), f_ref_results.end(), cmp);
std::sort(f_results.begin(), f_results.end(), cmp);
std::sort(f_ref_results.begin(), f_ref_results.end(), cmp);
if (f_results.size() != f_ref_results.size()) {
return Result::error("Number of results is different: " + to_str(f_results.size()) + " and " +
to_str(f_ref_results.size()));
}
if (f_results.size() != f_ref_results.size()) {
return Result::error("Number of results is different: " + to_str(f_results.size()) + " and " +
to_str(f_ref_results.size()));
}
const auto& f_sinks = f->get_sinks();
const auto& f_ref_sinks = f_ref->get_sinks();
if (f_sinks.size() != f_ref_sinks.size()) {
return Result::error("Number of sinks is different: " + to_str(f_sinks.size()) + " and " +
to_str(f_ref_sinks.size()));
}
const auto &f_sinks = f->get_sinks();
const auto &f_ref_sinks = f_ref->get_sinks();
if (f_sinks.size() != f_ref_sinks.size()) {
return Result::error("Number of sinks is different: " + to_str(f_sinks.size()) + " and " +
to_str(f_ref_sinks.size()));
}
// Compare sinks
if (f_sinks.size() == 1) {
q.push({f_sinks[0].get(), f_ref_sinks[0].get()});
used.insert(f_sinks[0].get());
} else {
// Cast to Assign and find those that have same variable_id suffix
for (const auto& sink1 : f_sinks) {
auto assign1 = std::dynamic_pointer_cast<ov::op::util::VariableExtension>(sink1);
if (!assign1) {
return Result::error("Sink '" + name(sink1) +
"' is not a variable - graph comparison is not supported");
}
auto name1 = assign1->get_variable_id();
std::shared_ptr<ov::op::Sink> found_sink2;
for (const auto& sink2 : f_ref_sinks) {
auto assign2 = std::dynamic_pointer_cast<ov::op::util::VariableExtension>(sink2);
if (!assign2) {
return Result::error("Sink '" + name(sink2) +
// Compare sinks
if (f_sinks.size() == 1) {
q.push({f_sinks[0].get(), f_ref_sinks[0].get()});
used.insert(f_sinks[0].get());
} else {
// Cast to Assign and find those that have same variable_id suffix
for (const auto &sink1 : f_sinks) {
auto assign1 = std::dynamic_pointer_cast<ov::op::util::VariableExtension>(sink1);
if (!assign1) {
return Result::error("Sink '" + name(sink1) +
"' is not a variable - graph comparison is not supported");
}
auto name2 = assign2->get_variable_id();
if (name2.find(name1) != std::string::npos || name1.find(name2) != std::string::npos) {
found_sink2 = sink2;
break;
auto name1 = assign1->get_variable_id();
std::shared_ptr<ov::op::Sink> found_sink2;
for (const auto &sink2 : f_ref_sinks) {
auto assign2 = std::dynamic_pointer_cast<ov::op::util::VariableExtension>(sink2);
if (!assign2) {
return Result::error("Sink '" + name(sink2) +
"' is not a variable - graph comparison is not supported");
}
auto name2 = assign2->get_variable_id();
if (name2.find(name1) != std::string::npos || name1.find(name2) != std::string::npos) {
found_sink2 = sink2;
break;
}
}
if (!found_sink2) {
return Result::error("No suitable sink is found for: " + name(sink1) + ", var=" + name1);
}
q.push({sink1.get(), found_sink2.get()});
used.insert(sink1.get());
}
}
for (size_t i = 0; i < f_results.size(); ++i) {
if (should_compare(CmpValues::NAMES)) {
if (name(f_results[i]->get_input_node_shared_ptr(0)) !=
name(f_ref_results[i]->get_input_node_shared_ptr(0))) {
return Result::error(
"Different output node names: " + name(f_results[i]->get_input_node_shared_ptr(0)) +
" and " +
name(f_ref_results[i]->get_input_node_shared_ptr(0)));
}
}
if (!found_sink2) {
return Result::error("No suitable sink is found for: " + name(sink1) + ", var=" + name1);
q.push({f_results[i].get(), f_ref_results[i].get()});
used.insert(f_results[i].get());
}
std::stringstream errors;
while (!q.empty()) {
ngraph::Node *const node1 = q.front().first;
ngraph::Node *const node2 = q.front().second;
q.pop();
const auto result = compare(node1, node2, errors);
if (!result.valid) {
return result;
}
q.push({sink1.get(), found_sink2.get()});
used.insert(sink1.get());
add_nodes_inputs_to_queue(node1, node2);
}
const auto msg = errors.str();
return msg.empty() ? Result::ok() : Result::error(msg);
} else if (should_compare(CmpValues::ACCURACY)) {
auto status = accuracy_check(f_ref, f);
return status.status ? Result::ok() : Result::error(status.message);
} else {
return Result::error("CmpValues are invalid");
}
for (size_t i = 0; i < f_results.size(); ++i) {
if (should_compare(CmpValues::NAMES)) {
if (name(f_results[i]->get_input_node_shared_ptr(0)) !=
name(f_ref_results[i]->get_input_node_shared_ptr(0))) {
return Result::error(
"Different output node names: " + name(f_results[i]->get_input_node_shared_ptr(0)) + " and " +
name(f_ref_results[i]->get_input_node_shared_ptr(0)));
}
}
q.push({f_results[i].get(), f_ref_results[i].get()});
used.insert(f_results[i].get());
}
std::stringstream errors;
while (!q.empty()) {
ngraph::Node* const node1 = q.front().first;
ngraph::Node* const node2 = q.front().second;
q.pop();
const auto result = compare(node1, node2, errors);
if (!result.valid) {
return result;
}
add_nodes_inputs_to_queue(node1, node2);
}
const auto msg = errors.str();
return msg.empty() ? Result::ok() : Result::error(msg);
}
Comparator::Result Comparator::compare(ngraph::Node* node1, ngraph::Node* node2, std::ostream& err_log) {
@ -759,7 +776,7 @@ void Comparator::compare_outputs(ngraph::Node* node1, ngraph::Node* node2, std::
void Comparator::compare_nodes(ngraph::Node* node1, ngraph::Node* node2, std::ostream& err_log) {
if (should_compare(CmpValues::RUNTIME_KEYS) && !compare_rt_keys(*node1, *node2, err_log)) {
err_log << "Different runtime info detected\n" + name(node1) + " and " + name(node2) +
" not equal runtime info.\n";
" not equal runtime info.\n";
}
if (should_compare(CmpValues::ATTRIBUTES)) {
@ -805,123 +822,115 @@ void check_rt_info(const std::shared_ptr<ngraph::Function>& f) {
}
}
void set_tensor_name(ngraph::Output<ngraph::Node> output, const std::string& name) {
output.get_tensor_ptr()->set_names({name});
}
void set_tensor_names(ngraph::Output<ngraph::Node> output, const std::unordered_set<std::string>& names) {
output.get_tensor_ptr()->set_names(names);
}
namespace attributes {
namespace detail {
void ReadAndStoreAttributes::on_adapter(const std::string& name, ngraph::ValueAccessor<void>& adapter) {
if (auto inputs = ngraph::as_type<ngraph::AttributeAdapter<SubGraphOpInputDescription>>(&adapter)) {
insert(name, inputs->get());
} else if (auto outputs = ngraph::as_type<ngraph::AttributeAdapter<SubGraphOpOutputDescription>>(&adapter)) {
insert(name, outputs->get());
} else if (ngraph::is_type<ngraph::AttributeAdapter<SpecialBodyPorts>>(&adapter)) {
// drop comparison, no more info than port indexes which will be check in
// subgraph::compare_io
} else if (auto a = ngraph::as_type<ngraph::AttributeAdapter<std::shared_ptr<ngraph::runtime::AlignedBuffer>>>(
&adapter)) {
const auto beg = static_cast<unsigned char*>(a->get()->get_ptr());
const auto end = beg + a->get()->size();
insert(name, storage::MemoryChunk{storage::MemoryChunk::Data(beg, end)});
} else if (auto framework_node_attr =
ngraph::as_type<ngraph::AttributeAdapter<ov::op::util::FrameworkNodeAttrs>>(&adapter)) {
insert(name, framework_node_attr->get());
} else if (auto variable_ptr =
ngraph::as_type<ngraph::AttributeAdapter<std::shared_ptr<ngraph::Variable>>>(&adapter)) {
insert(name, variable_ptr->get());
} else if (auto shape_ptr = ngraph::as_type<ngraph::AttributeAdapter<ov::PartialShape>>(&adapter)) {
insert(name, shape_ptr->get());
} else if (auto dim_ptr = ngraph::as_type<ngraph::AttributeAdapter<ov::Dimension>>(&adapter)) {
insert(name, dim_ptr->get());
} else {
m_read_result += "store attr [ ERR ]: " + name + " [drop `void` comparison which is '" +
adapter.get_type_info().name + "']";
void ReadAndStoreAttributes::on_adapter(const std::string& name, ngraph::ValueAccessor<void>& adapter) {
if (auto inputs = ngraph::as_type<ngraph::AttributeAdapter<SubGraphOpInputDescription>>(&adapter)) {
insert(name, inputs->get());
} else if (auto outputs = ngraph::as_type<ngraph::AttributeAdapter<SubGraphOpOutputDescription>>(&adapter)) {
insert(name, outputs->get());
} else if (ngraph::is_type<ngraph::AttributeAdapter<SpecialBodyPorts>>(&adapter)) {
// drop comparison, no more info than port indexes which will be check in
// subgraph::compare_io
} else if (auto a = ngraph::as_type<ngraph::AttributeAdapter<std::shared_ptr<ngraph::runtime::AlignedBuffer>>>(
&adapter)) {
const auto beg = static_cast<unsigned char*>(a->get()->get_ptr());
const auto end = beg + a->get()->size();
insert(name, storage::MemoryChunk{storage::MemoryChunk::Data(beg, end)});
} else if (auto framework_node_attr =
ngraph::as_type<ngraph::AttributeAdapter<ov::op::util::FrameworkNodeAttrs>>(&adapter)) {
insert(name, framework_node_attr->get());
} else if (auto variable_ptr =
ngraph::as_type<ngraph::AttributeAdapter<std::shared_ptr<ngraph::Variable>>>(&adapter)) {
insert(name, variable_ptr->get());
} else if (auto shape_ptr = ngraph::as_type<ngraph::AttributeAdapter<ov::PartialShape>>(&adapter)) {
insert(name, shape_ptr->get());
} else if (auto dim_ptr = ngraph::as_type<ngraph::AttributeAdapter<ov::Dimension>>(&adapter)) {
insert(name, dim_ptr->get());
} else {
m_read_result += "store attr [ ERR ]: " + name + " [drop `void` comparison which is '" +
adapter.get_type_info().name + "']";
}
}
}
template <typename AttrValue>
void ReadAndCompareAttributes::verify(const std::string& name, const AttrValue& attr_value) {
if (should_return()) {
return;
}
m_visited_attributes.insert(name);
const auto ref_value = m_attr_ref.get<AttrValue>(name);
if (!ref_value) {
m_cmp_result += "missing attribute name: '" + name + "'";
return;
template <typename AttrValue>
void ReadAndCompareAttributes::verify(const std::string& name, const AttrValue& attr_value) {
if (should_return()) {
return;
}
m_visited_attributes.insert(name);
const auto ref_value = m_attr_ref.get<AttrValue>(name);
if (!ref_value) {
m_cmp_result += "missing attribute name: '" + name + "'";
return;
}
if (!equal::Equal<AttrValue>::equal_value(*ref_value, attr_value)) {
m_cmp_result += "mismatch in value: '" + name + "' : " + str::Get<AttrValue>::value(*ref_value) + " vs " +
str::Get<AttrValue>::value(attr_value);
}
}
if (!equal::Equal<AttrValue>::equal_value(*ref_value, attr_value)) {
m_cmp_result += "mismatch in value: '" + name + "' : " + str::Get<AttrValue>::value(*ref_value) + " vs " +
str::Get<AttrValue>::value(attr_value);
}
}
void ReadAndCompareAttributes::verify_mem_buf(const std::string& name,
const std::shared_ptr<ngraph::runtime::AlignedBuffer>& buffer) {
if (should_return()) {
return;
}
m_visited_attributes.insert(name);
const auto ref_value = m_attr_ref.get<storage::MemoryChunk>(name);
if (!ref_value) {
m_cmp_result += "missing attribute name: '" + name + "'";
return;
}
void ReadAndCompareAttributes::verify_mem_buf(const std::string& name,
const std::shared_ptr<ngraph::runtime::AlignedBuffer>& buffer) {
if (should_return()) {
return;
}
m_visited_attributes.insert(name);
const auto ref_value = m_attr_ref.get<storage::MemoryChunk>(name);
if (!ref_value) {
m_cmp_result += "missing attribute name: '" + name + "'";
return;
if (buffer->size() != ref_value->size() ||
std::memcmp(ref_value->data(), buffer->get_ptr(), ref_value->size()) != 0) {
m_cmp_result += "mismatch in value: '" + name + "' : look in to the mem buffer";
return;
}
}
if (buffer->size() != ref_value->size() ||
std::memcmp(ref_value->data(), buffer->get_ptr(), ref_value->size()) != 0) {
m_cmp_result += "mismatch in value: '" + name + "' : look in to the mem buffer";
return;
void ReadAndCompareAttributes::verify_function(const std::string& name, ModelAccessor& adapter) {
if (should_return()) {
return;
}
m_visited_attributes.insert(name);
const auto ref_value = m_attr_ref.get<std::shared_ptr<ngraph::Function>>(name);
if (!ref_value) {
m_cmp_result += "missing attribute name: '" + name + "'";
return;
}
Comparator c(m_check_flags);
const auto result = c.compare(*ref_value, adapter.get());
if (!result.valid) {
m_cmp_result += result.message;
}
}
}
void ReadAndCompareAttributes::verify_function(const std::string& name, ModelAccessor& adapter) {
if (should_return()) {
return;
void ReadAndCompareAttributes::verify_others(const std::string& name, ngraph::ValueAccessor<void>& adapter) {
if (auto inputs = ngraph::as_type<ngraph::AttributeAdapter<SubGraphOpInputDescription>>(&adapter)) {
verify(name, inputs->get());
} else if (auto outputs = ngraph::as_type<ngraph::AttributeAdapter<SubGraphOpOutputDescription>>(&adapter)) {
verify(name, outputs->get());
} else if (ngraph::is_type<ngraph::AttributeAdapter<SpecialBodyPorts>>(&adapter)) {
// drop comparison, no more info than port indexes which will be check in
// subgraph::compare_io
} else if (auto a = ngraph::as_type<ngraph::AttributeAdapter<std::shared_ptr<ngraph::runtime::AlignedBuffer>>>(
&adapter)) {
verify_mem_buf(name, a->get());
} else if (auto attrs = ngraph::as_type<ngraph::AttributeAdapter<ov::op::util::FrameworkNodeAttrs>>(&adapter)) {
verify(name, attrs->get());
} else if (auto variable_ptr =
ngraph::as_type<ngraph::AttributeAdapter<std::shared_ptr<ngraph::Variable>>>(&adapter)) {
verify(name, variable_ptr->get());
} else if (auto shape_ptr = ngraph::as_type<ngraph::AttributeAdapter<ov::PartialShape>>(&adapter)) {
verify(name, shape_ptr->get());
} else if (auto dim_ptr = ngraph::as_type<ngraph::AttributeAdapter<ov::Dimension>>(&adapter)) {
verify(name, dim_ptr->get());
} else {
m_cmp_result += "compare attr [ ERR ]: " + name + " [drop `void` comparison which is '" +
adapter.get_type_info().name + "']";
}
}
m_visited_attributes.insert(name);
const auto ref_value = m_attr_ref.get<std::shared_ptr<ngraph::Function>>(name);
if (!ref_value) {
m_cmp_result += "missing attribute name: '" + name + "'";
return;
}
Comparator c(m_check_flags);
const auto result = c.compare(*ref_value, adapter.get());
if (!result.valid) {
m_cmp_result += result.message;
}
}
void ReadAndCompareAttributes::verify_others(const std::string& name, ngraph::ValueAccessor<void>& adapter) {
if (auto inputs = ngraph::as_type<ngraph::AttributeAdapter<SubGraphOpInputDescription>>(&adapter)) {
verify(name, inputs->get());
} else if (auto outputs = ngraph::as_type<ngraph::AttributeAdapter<SubGraphOpOutputDescription>>(&adapter)) {
verify(name, outputs->get());
} else if (ngraph::is_type<ngraph::AttributeAdapter<SpecialBodyPorts>>(&adapter)) {
// drop comparison, no more info than port indexes which will be check in
// subgraph::compare_io
} else if (auto a = ngraph::as_type<ngraph::AttributeAdapter<std::shared_ptr<ngraph::runtime::AlignedBuffer>>>(
&adapter)) {
verify_mem_buf(name, a->get());
} else if (auto attrs = ngraph::as_type<ngraph::AttributeAdapter<ov::op::util::FrameworkNodeAttrs>>(&adapter)) {
verify(name, attrs->get());
} else if (auto variable_ptr =
ngraph::as_type<ngraph::AttributeAdapter<std::shared_ptr<ngraph::Variable>>>(&adapter)) {
verify(name, variable_ptr->get());
} else if (auto shape_ptr = ngraph::as_type<ngraph::AttributeAdapter<ov::PartialShape>>(&adapter)) {
verify(name, shape_ptr->get());
} else if (auto dim_ptr = ngraph::as_type<ngraph::AttributeAdapter<ov::Dimension>>(&adapter)) {
verify(name, dim_ptr->get());
} else {
m_cmp_result += "compare attr [ ERR ]: " + name + " [drop `void` comparison which is '" +
adapter.get_type_info().name + "']";
}
}
} // namespace detail
@ -937,3 +946,42 @@ Comparator::Result compare(ngraph::Node* node1, ngraph::Node* node2, Comparator:
}
} // namespace attributes
AccuracyCheckResult accuracy_check(const std::shared_ptr<ov::Model>& ref_function,
const std::shared_ptr<ov::Model>& cur_function) {
if (ref_function->is_dynamic() || cur_function->is_dynamic()) {
return AccuracyCheckResult{true, ""};
}
try {
IE_ASSERT(ref_function->get_parameters().size() == cur_function->get_parameters().size());
std::map<std::shared_ptr<ov::Node>, ov::Tensor> ref_input_data;
std::map<std::shared_ptr<ov::Node>, ov::Tensor> cur_input_data;
for (size_t i = 0; i < ref_function->get_parameters().size(); i++) {
const auto &tensor = ov::test::utils::create_and_fill_tensor(
ref_function->get_parameters()[i]->get_element_type(),
ref_function->get_parameters()[i]->get_shape());
ref_input_data[ref_function->get_parameters()[i]] = tensor;
cur_input_data[cur_function->get_parameters()[i]] = tensor;
}
auto ref_outputs = ngraph::helpers::interpretFunction(ref_function, ref_input_data);
auto outputs = ngraph::helpers::interpretFunction(cur_function, cur_input_data);
IE_ASSERT(ref_outputs.size() == outputs.size());
for (int i = 0; i < ref_outputs.size(); i++) {
ov::test::utils::compare(ref_outputs[i], outputs[i],
std::numeric_limits<double>::max(),
std::numeric_limits<double>::max());
}
}
catch (const std::runtime_error &re) {
return AccuracyCheckResult{false, re.what()};
} catch (const std::exception &ex) {
return AccuracyCheckResult{false, ex.what()};
} catch (...) {
return AccuracyCheckResult{false, ""};
}
return AccuracyCheckResult{true, ""};
}

View File

@ -20,12 +20,14 @@ class FunctionsComparator {
public:
enum CmpValues {
NONE = 0,
CONST_VALUES = 1 << 0,
NAMES = 1 << 1,
RUNTIME_KEYS = 1 << 2,
PRECISIONS = 1 << 3,
ATTRIBUTES = 1 << 4,
TENSOR_NAMES = 1 << 5,
NODES = 1 << 0,
CONST_VALUES = 1 << 1,
NAMES = 1 << 2,
RUNTIME_KEYS = 1 << 3,
PRECISIONS = 1 << 4,
ATTRIBUTES = 1 << 5,
TENSOR_NAMES = 1 << 6,
ACCURACY = 1 << 7
};
struct Result {
@ -45,6 +47,7 @@ public:
}
static FunctionsComparator with_default() noexcept {
auto fc = FunctionsComparator::no_default();
fc.enable(NODES);
fc.enable(PRECISIONS);
fc.enable(TENSOR_NAMES);
return fc;
@ -55,6 +58,11 @@ public:
return *this;
}
FunctionsComparator& disable(CmpValues f) noexcept {
m_comparison_flags = static_cast<CmpValues>(m_comparison_flags & ~f);
return *this;
}
bool should_compare(CmpValues f) const noexcept {
return m_comparison_flags & f;
}
@ -85,6 +93,7 @@ inline std::pair<bool, std::string> compare_functions(const std::shared_ptr<ngra
auto fc = FunctionsComparator::no_default();
using Cmp = FunctionsComparator::CmpValues;
fc.enable(Cmp::NODES);
if (compareConstValues)
fc.enable(Cmp::CONST_VALUES);
if (compareNames)
@ -104,10 +113,6 @@ inline std::pair<bool, std::string> compare_functions(const std::shared_ptr<ngra
void check_rt_info(const std::shared_ptr<ngraph::Function>& f);
void set_tensor_name(ngraph::Output<ngraph::Node> output, const std::string& name);
void set_tensor_names(ngraph::Output<ngraph::Node> output, const std::unordered_set<std::string>& names);
namespace ngraph {
namespace pass {
class InjectionPass : public ov::pass::ModelPass {
@ -188,8 +193,8 @@ public:
for (auto output : node->outputs()) {
const auto& tensor_names = output.get_names();
if (std::any_of(tensor_names.begin(), tensor_names.end(), [&](const std::string& name) {
return unique_tensor_names.count(name);
})) {
return unique_tensor_names.count(name);
})) {
std::stringstream ss;
ss << "Node: " << node->get_type_info() << " with name " << node->get_friendly_name() << " ";
ss << "has non unique tensor name.";
@ -478,9 +483,9 @@ public:
void on_adapter(const std::string& name, ngraph::ValueAccessor<void>& adapter) override;
#define ON_ADAPTER(TYPE) \
void on_adapter(const std::string& name, ngraph::ValueAccessor<TYPE>& adapter) override { \
insert(name, adapter.get()); \
}
void on_adapter(const std::string& name, ngraph::ValueAccessor<TYPE>& adapter) override { \
insert(name, adapter.get()); \
}
ON_ADAPTER(bool)
ON_ADAPTER(std::string)
@ -711,37 +716,37 @@ struct Equal<std::shared_ptr<Constant>> {
}
switch (lhs_t) {
case ngraph::element::Type_t::u1: {
const auto lhs_v = static_cast<const uint8_t*>(lhs->get_data_ptr());
const auto rhs_v = static_cast<const uint8_t*>(rhs->get_data_ptr());
const auto lhs_bit_size = shape_size(lhs->get_shape());
const auto rhs_bit_size = shape_size(rhs->get_shape());
return Equal<uint8_t*>::equal_value(lhs_v, rhs_v, lhs_bit_size, rhs_bit_size);
}
case ngraph::element::Type_t::bf16: {
auto lhs_v = lhs->cast_vector<ngraph::bfloat16>();
auto rhs_v = rhs->cast_vector<ngraph::bfloat16>();
return Equal<std::vector<ngraph::bfloat16>>::equal_value(lhs_v, rhs_v);
break;
}
case ngraph::element::Type_t::f16: {
const auto& lhs_v = lhs->cast_vector<ngraph::float16>();
const auto& rhs_v = rhs->cast_vector<ngraph::float16>();
return Equal<std::vector<ngraph::float16>>::equal_value(lhs_v, rhs_v);
break;
}
case ngraph::element::Type_t::f32: {
const auto& lhs_v = lhs->cast_vector<float>();
const auto& rhs_v = rhs->cast_vector<float>();
return Equal<std::vector<float>>::equal_value(lhs_v, rhs_v);
break;
}
default: {
const auto& lhs_v = lhs->cast_vector<double>();
const auto& rhs_v = rhs->cast_vector<double>();
return Equal<std::vector<double>>::equal_value(lhs_v, rhs_v);
break;
}
case ngraph::element::Type_t::u1: {
const auto lhs_v = static_cast<const uint8_t*>(lhs->get_data_ptr());
const auto rhs_v = static_cast<const uint8_t*>(rhs->get_data_ptr());
const auto lhs_bit_size = shape_size(lhs->get_shape());
const auto rhs_bit_size = shape_size(rhs->get_shape());
return Equal<uint8_t*>::equal_value(lhs_v, rhs_v, lhs_bit_size, rhs_bit_size);
}
case ngraph::element::Type_t::bf16: {
auto lhs_v = lhs->cast_vector<ngraph::bfloat16>();
auto rhs_v = rhs->cast_vector<ngraph::bfloat16>();
return Equal<std::vector<ngraph::bfloat16>>::equal_value(lhs_v, rhs_v);
break;
}
case ngraph::element::Type_t::f16: {
const auto& lhs_v = lhs->cast_vector<ngraph::float16>();
const auto& rhs_v = rhs->cast_vector<ngraph::float16>();
return Equal<std::vector<ngraph::float16>>::equal_value(lhs_v, rhs_v);
break;
}
case ngraph::element::Type_t::f32: {
const auto& lhs_v = lhs->cast_vector<float>();
const auto& rhs_v = rhs->cast_vector<float>();
return Equal<std::vector<float>>::equal_value(lhs_v, rhs_v);
break;
}
default: {
const auto& lhs_v = lhs->cast_vector<double>();
const auto& rhs_v = rhs->cast_vector<double>();
return Equal<std::vector<double>>::equal_value(lhs_v, rhs_v);
break;
}
}
return false;
}
@ -862,18 +867,18 @@ struct Get<std::shared_ptr<ngraph::Variable>, void> {
class ReadAndCompareAttributes : public ngraph::AttributeVisitor {
public:
ReadAndCompareAttributes(const ReadAndStoreAttributes& ref, Comparator::CmpValues check_flags)
: m_attr_ref(ref),
m_cmp_result{ref.read_result()},
m_check_flags(check_flags) {}
: m_attr_ref(ref),
m_cmp_result{ref.read_result()},
m_check_flags(check_flags) {}
void on_adapter(const std::string& name, ngraph::ValueAccessor<void>& adapter) override {
verify_others(name, adapter);
}
#define ON_ADAPTER(TYPE) \
void on_adapter(const std::string& name, ngraph::ValueAccessor<TYPE>& adapter) override { \
verify(name, adapter.get()); \
}
void on_adapter(const std::string& name, ngraph::ValueAccessor<TYPE>& adapter) override { \
verify(name, adapter.get()); \
}
ON_ADAPTER(bool)
ON_ADAPTER(std::string)
@ -981,3 +986,11 @@ private:
Comparator::Result compare(ngraph::Node* node1, ngraph::Node* node2, Comparator::CmpValues comparition_flags);
} // namespace attributes
struct AccuracyCheckResult {
bool status;
std::string message;
};
AccuracyCheckResult accuracy_check(const std::shared_ptr<ov::Model>& ref_function,
const std::shared_ptr<ov::Model>& cur_function);

View File

@ -2,248 +2,7 @@
// SPDX-License-Identifier: Apache-2.0
//
#include <functional_test_utils/include/functional_test_utils/blob_utils.hpp>
#include "ngraph_test_utils.hpp"
#include <ngraph_functions/utils/ngraph_helpers.hpp>
namespace {
template<class T_IE, class T_NGRAPH>
static void Compare(const T_NGRAPH *expected, const T_IE *actual, std::size_t size, float threshold, float abs_threshold = -1.f) {
for (std::size_t i = 0; i < size; ++i) {
const T_NGRAPH &ref = expected[i];
const auto &res = actual[i];
const auto absoluteDifference = CommonTestUtils::ie_abs(res - ref);
if (abs_threshold > 0.f && absoluteDifference > abs_threshold) {
IE_THROW() << "Absolute comparison of values expected: " << std::to_string(ref) << " and actual: " << std::to_string(res)
<< " at index " << i << " with absolute threshold " << abs_threshold
<< " failed";
}
if (absoluteDifference <= threshold) {
continue;
}
double max;
if (sizeof(T_IE) < sizeof(T_NGRAPH)) {
max = std::max(CommonTestUtils::ie_abs(T_NGRAPH(res)), CommonTestUtils::ie_abs(ref));
} else {
max = std::max(CommonTestUtils::ie_abs(res), CommonTestUtils::ie_abs(T_IE(ref)));
}
double diff = static_cast<float>(absoluteDifference) / max;
if (max == 0 || (diff > static_cast<float>(threshold)) ||
(std::isnan(static_cast<float>(res)) ^ std::isnan(static_cast<float>(ref)))) {
IE_THROW() << "Relative comparison of values expected: " << std::to_string(ref) << " and actual: " << std::to_string(res)
<< " at index " << i << " with threshold " << threshold
<< " failed";
}
}
}
template <typename T_IE>
void callCompare(const std::pair<ngraph::element::Type, std::vector<std::uint8_t>> &expected,
const T_IE* actualBuffer, size_t size, float threshold, float abs_threshold) {
auto expectedBuffer = expected.second.data();
switch (expected.first) {
case ngraph::element::Type_t::i64:
Compare<T_IE, int64_t>(reinterpret_cast<const int64_t *>(expectedBuffer),
actualBuffer, size, threshold, abs_threshold);
break;
case ngraph::element::Type_t::i32:
Compare<T_IE, int32_t>(reinterpret_cast<const int32_t *>(expectedBuffer),
actualBuffer, size, threshold, abs_threshold);
break;
case ngraph::element::Type_t::i16:
Compare<T_IE, int16_t>(reinterpret_cast<const int16_t *>(expectedBuffer),
actualBuffer, size, threshold, abs_threshold);
break;
case ngraph::element::Type_t::i8:
Compare<T_IE, int8_t>(reinterpret_cast<const int8_t *>(expectedBuffer),
actualBuffer, size, threshold, abs_threshold);
break;
case ngraph::element::Type_t::u64:
Compare<T_IE, uint64_t>(reinterpret_cast<const uint64_t *>(expectedBuffer),
actualBuffer, size, threshold, abs_threshold);
break;
case ngraph::element::Type_t::u32:
Compare<T_IE, uint32_t>(reinterpret_cast<const uint32_t *>(expectedBuffer),
actualBuffer, size, threshold, abs_threshold);
break;
case ngraph::element::Type_t::u16:
Compare<T_IE, uint16_t>(reinterpret_cast<const uint16_t *>(expectedBuffer),
actualBuffer, size, threshold, abs_threshold);
break;
case ngraph::element::Type_t::boolean:
case ngraph::element::Type_t::u8:
Compare<T_IE, uint8_t>(reinterpret_cast<const uint8_t *>(expectedBuffer),
actualBuffer, size, threshold, abs_threshold);
break;
case ngraph::element::Type_t::f64:
Compare<T_IE, double>(reinterpret_cast<const double *>(expectedBuffer),
actualBuffer, size, threshold, abs_threshold);
break;
case ngraph::element::Type_t::f32:
Compare<T_IE, float>(reinterpret_cast<const float *>(expectedBuffer),
actualBuffer, size, threshold, abs_threshold);
break;
case ngraph::element::Type_t::f16:
Compare<T_IE, ngraph::float16>(reinterpret_cast<const ngraph::float16 *>(expectedBuffer),
actualBuffer, size, threshold, abs_threshold);
break;
case ngraph::element::Type_t::bf16:
Compare<T_IE, ngraph::bfloat16>(reinterpret_cast<const ngraph::bfloat16 *>(expectedBuffer),
actualBuffer, size, threshold, abs_threshold);
break;
// case ngraph::element::Type_t::i4: {
// auto expectedOut = ngraph::helpers::convertOutputPrecision(
// expected.second,
// expected.first,
// ngraph::element::Type_t::i8,
// size);
// Compare<T_IE, int8_t>(reinterpret_cast<const int8_t *>(expectedOut.data()),
// actualBuffer, size, threshold, abs_threshold);
// break;
// }
// case ngraph::element::Type_t::u4: {
// auto expectedOut = ngraph::helpers::convertOutputPrecision(
// expected.second,
// expected.first,
// ngraph::element::Type_t::u8,
// size);
// Compare<T_IE, uint8_t>(reinterpret_cast<const uint8_t *>(expectedOut.data()),
// actualBuffer, size, threshold, abs_threshold);
// break;
// }
case ngraph::element::Type_t::dynamic:
case ngraph::element::Type_t::undefined:
Compare<T_IE, T_IE>(reinterpret_cast<const T_IE *>(expectedBuffer), actualBuffer, size, threshold, abs_threshold);
break;
default: FAIL() << "Comparator for " << expected.first << " precision isn't supported";
}
return;
}
void Compare(const std::pair<ngraph::element::Type, std::vector<std::uint8_t>> &expected,
const InferenceEngine::Blob::Ptr &actual,
float threshold,
float abs_threshold) {
const auto &precision = actual->getTensorDesc().getPrecision();
auto k = static_cast<float>(expected.first.size()) / precision.size();
// W/A for int4, uint4
if (expected.first == ngraph::element::Type_t::u4 || expected.first == ngraph::element::Type_t::i4) {
k /= 2;
} else if (expected.first == ngraph::element::Type_t::undefined || expected.first == ngraph::element::Type_t::dynamic) {
k = 1;
}
ASSERT_EQ(expected.second.size(), actual->byteSize() * k);
auto memory = InferenceEngine::as<InferenceEngine::MemoryBlob>(actual);
IE_ASSERT(memory);
const auto lockedMemory = memory->wmap();
const auto actualBuffer = lockedMemory.as<const std::uint8_t *>();
const auto &size = actual->size();
switch (precision) {
case InferenceEngine::Precision::FP32:
callCompare<float>(expected, reinterpret_cast<const float *>(actualBuffer), size, threshold, abs_threshold);
break;
case InferenceEngine::Precision::I32:
callCompare<int32_t>(expected, reinterpret_cast<const int32_t *>(actualBuffer), size, threshold, abs_threshold);
break;
case InferenceEngine::Precision::I64:
callCompare<int64_t>(expected, reinterpret_cast<const int64_t *>(actualBuffer), size, threshold, abs_threshold);
break;
case InferenceEngine::Precision::I8:
callCompare<int8_t>(expected, reinterpret_cast<const int8_t *>(actualBuffer), size, threshold, abs_threshold);
break;
case InferenceEngine::Precision::U16:
callCompare<uint16_t>(expected, reinterpret_cast<const uint16_t *>(actualBuffer), size, threshold, abs_threshold);
break;
case InferenceEngine::Precision::I16:
callCompare<int16_t>(expected, reinterpret_cast<const int16_t *>(actualBuffer), size, threshold, abs_threshold);
break;
case InferenceEngine::Precision::BOOL:
case InferenceEngine::Precision::U8:
callCompare<uint8_t>(expected, reinterpret_cast<const uint8_t *>(actualBuffer), size, threshold, abs_threshold);
break;
case InferenceEngine::Precision::U64:
callCompare<uint64_t>(expected, reinterpret_cast<const uint64_t *>(actualBuffer), size, threshold, abs_threshold);
break;
case InferenceEngine::Precision::BF16:
callCompare<ngraph::bfloat16>(expected, reinterpret_cast<const ngraph::bfloat16 *>(actualBuffer), size, threshold, abs_threshold);
break;
case InferenceEngine::Precision::FP16:
callCompare<ngraph::float16>(expected, reinterpret_cast<const ngraph::float16 *>(actualBuffer), size, threshold, abs_threshold);
break;
default:
FAIL() << "Comparator for " << precision << " precision isn't supported";
}
}
void Compare(const std::vector<std::pair<ngraph::element::Type, std::vector<std::uint8_t>>> &expectedOutputs,
const std::vector<InferenceEngine::Blob::Ptr> &actualOutputs,
float threshold, float abs_threshold) {
for (std::size_t outputIndex = 0; outputIndex < expectedOutputs.size(); ++outputIndex) {
const auto &expected = expectedOutputs[outputIndex];
const auto &actual = actualOutputs[outputIndex];
Compare(expected, actual, threshold, abs_threshold);
}
}
} // namespace
void TransformationTestsF::accuracy_check(std::shared_ptr<ov::Model> ref_function,
std::shared_ptr<ov::Model> cur_function) {
try {
if (ref_function->is_dynamic() || cur_function->is_dynamic()) {
return;
}
std::vector<std::vector<uint8_t>> input_data;
ngraph::element::TypeVector types;
for (auto param : ref_function->get_parameters()) {
types.push_back(param->get_element_type());
auto layout = InferenceEngine::Layout::ANY;
if (ov::is_scalar(param->get_shape())) {
layout = InferenceEngine::Layout::SCALAR;
}
InferenceEngine::TensorDesc td(InferenceEngine::Precision::FP32, param->get_shape(), layout);
const auto &input = FuncTestUtils::createAndFillBlob(td);
const auto &input_size = input->byteSize();
std::vector<uint8_t> data;
data.resize(input_size);
auto memory = InferenceEngine::as<InferenceEngine::MemoryBlob>(input);
IE_ASSERT(memory);
const auto lockedMemory = memory->wmap();
const auto buffer = lockedMemory.as<const std::uint8_t *>();
std::copy(buffer, buffer + input_size, data.data());
input_data.push_back(std::move(data));
}
auto ref_outputs = ngraph::helpers::interpreterFunction(ref_function, input_data, types);
auto outputs = ngraph::helpers::interpreterFunction(cur_function, input_data, types);
IE_ASSERT(ref_outputs.size() == outputs.size());
for (size_t i = 0; i < ref_outputs.size(); ++i) {
IE_ASSERT(ref_outputs[i].second.size() == outputs[i].second.size());
auto * ref = reinterpret_cast<float *>(ref_outputs[i].second.data());
auto * out = reinterpret_cast<float *>(outputs[i].second.data());
size_t size = ref_outputs[i].second.size() / sizeof(float);
IE_ASSERT(size > 0);
Compare<float, float>(ref, out, size, 1e-5);
}
}
catch (const std::runtime_error &re) {
GTEST_FATAL_FAILURE_(re.what());
} catch (const std::exception &ex) {
GTEST_FATAL_FAILURE_(ex.what());
} catch (...) {
GTEST_FATAL_FAILURE_("Unknown failure occurred.");
}
}
void init_unique_names(std::shared_ptr<ngraph::Function> f, const std::shared_ptr<ngraph::pass::UniqueNamesHolder>& unh) {
ngraph::pass::Manager manager;

View File

@ -32,6 +32,7 @@ class TransformationTestsF : public CommonTestUtils::TestsCommon {
public:
TransformationTestsF() : comparator(FunctionsComparator::no_default()) {
m_unh = std::make_shared<ngraph::pass::UniqueNamesHolder>();
comparator.enable(FunctionsComparator::CmpValues::NODES);
comparator.enable(FunctionsComparator::CmpValues::PRECISIONS);
comparator.enable(FunctionsComparator::CmpValues::RUNTIME_KEYS);
// TODO: enable attributes and constant values comparison by default XXX-68694
@ -56,12 +57,16 @@ public:
ASSERT_NO_THROW(check_rt_info(function));
}
if (comparator.should_compare(FunctionsComparator::ACCURACY)) {
auto acc_comparator = FunctionsComparator::no_default();
acc_comparator.enable(FunctionsComparator::CmpValues::ACCURACY);
auto res = acc_comparator.compare(function, cloned_function);
ASSERT_TRUE(res.valid) << res.message;
comparator.disable(FunctionsComparator::CmpValues::ACCURACY);
}
auto res = comparator.compare(function, function_ref);
ASSERT_TRUE(res.valid) << res.message;
if (m_enable_accuracy_check) {
accuracy_check(cloned_function, function);
}
}
// TODO: this is temporary solution to disable rt info checks that must be applied by default
@ -74,12 +79,6 @@ public:
m_soft_names_comparison = true;
}
void enable_accuracy_check() {
m_enable_accuracy_check = true;
}
void accuracy_check(std::shared_ptr<ov::Model> ref_function, std::shared_ptr<ov::Model> cur_function);
std::shared_ptr<ov::Model> function, function_ref;
ngraph::pass::Manager manager;
FunctionsComparator comparator;
@ -88,7 +87,6 @@ private:
std::shared_ptr<ngraph::pass::UniqueNamesHolder> m_unh;
bool m_disable_rt_info_check{false};
bool m_soft_names_comparison{true};
bool m_enable_accuracy_check{false};
};
void init_unique_names(std::shared_ptr<ngraph::Function> f, const std::shared_ptr<ngraph::pass::UniqueNamesHolder>& unh);