diff --git a/src/tests/functional/inference_engine/transformations/common_optimizations/fq_mul_fusion_test.cpp b/src/tests/functional/inference_engine/transformations/common_optimizations/fq_mul_fusion_test.cpp index 5b5287c789d..02bb9718e83 100644 --- a/src/tests/functional/inference_engine/transformations/common_optimizations/fq_mul_fusion_test.cpp +++ b/src/tests/functional/inference_engine/transformations/common_optimizations/fq_mul_fusion_test.cpp @@ -92,7 +92,9 @@ TEST_P(FQMulFusion, ExpectFusion) { manager.run_passes(m_function); ASSERT_NO_THROW(check_rt_info(m_function)); - auto fc = FunctionsComparator::no_default().enable(FunctionsComparator::PRECISIONS); + auto fc = FunctionsComparator::no_default() + .enable(FunctionsComparator::PRECISIONS) + .enable(FunctionsComparator::NODES); auto res = fc.compare(m_function, m_expected_function); ASSERT_TRUE(res.valid) << res.message; }; diff --git a/src/tests/functional/inference_engine/transformations/common_optimizations/fq_reshape_fusion.cpp b/src/tests/functional/inference_engine/transformations/common_optimizations/fq_reshape_fusion.cpp index 3768f9cffda..a383e8913ff 100644 --- a/src/tests/functional/inference_engine/transformations/common_optimizations/fq_reshape_fusion.cpp +++ b/src/tests/functional/inference_engine/transformations/common_optimizations/fq_reshape_fusion.cpp @@ -113,7 +113,9 @@ TEST_P(nGraphFQReshapeFusionTests, ReshapeMatMul) { manager.run_passes(f); ASSERT_NO_THROW(check_rt_info(f)); - auto fc = FunctionsComparator::no_default().enable(FunctionsComparator::PRECISIONS); + auto fc = FunctionsComparator::no_default() + .enable(FunctionsComparator::PRECISIONS) + .enable(FunctionsComparator::NODES); auto res = fc.compare(f, ref_f); ASSERT_TRUE(res.valid) << res.message; } diff --git a/src/tests/functional/inference_engine/transformations/common_optimizations/matmul_const_transposes_extraction.cpp b/src/tests/functional/inference_engine/transformations/common_optimizations/matmul_const_transposes_extraction.cpp index 76c2ce90aa1..39ea894dac6 100644 --- a/src/tests/functional/inference_engine/transformations/common_optimizations/matmul_const_transposes_extraction.cpp +++ b/src/tests/functional/inference_engine/transformations/common_optimizations/matmul_const_transposes_extraction.cpp @@ -30,7 +30,7 @@ TEST_F(TransformationTestsF, MatMulConstTransposesExtractionConstantWeights) { function_ref = std::make_shared(matmul, ParameterVector{data}); } comparator.enable(FunctionsComparator::CmpValues::CONST_VALUES); - enable_accuracy_check(); + comparator.enable(FunctionsComparator::CmpValues::ACCURACY); } TEST_F(TransformationTestsF, MatMulConstTransposesExtractionFQOnWeights) { @@ -57,7 +57,7 @@ TEST_F(TransformationTestsF, MatMulConstTransposesExtractionFQOnWeights) { function_ref = std::make_shared(matmul, ParameterVector{data}); } comparator.enable(FunctionsComparator::CmpValues::CONST_VALUES); - enable_accuracy_check(); + comparator.enable(FunctionsComparator::CmpValues::ACCURACY); } TEST_F(TransformationTestsF, NegativeMatMulConstTransposesExtractionInvalidRank) { diff --git a/src/tests/functional/inference_engine/transformations/common_optimizations/nop_elimination.cpp b/src/tests/functional/inference_engine/transformations/common_optimizations/nop_elimination.cpp index b38ab845172..c5da73d6edd 100644 --- a/src/tests/functional/inference_engine/transformations/common_optimizations/nop_elimination.cpp +++ b/src/tests/functional/inference_engine/transformations/common_optimizations/nop_elimination.cpp @@ -998,7 +998,7 @@ TEST_P(EliminateEltwiseTests, eliminate_eltwise) { comparator.enable(FunctionsComparator::CmpValues::CONST_VALUES); if (type == element::f32) { - enable_accuracy_check(); + comparator.enable(FunctionsComparator::CmpValues::ACCURACY); } } @@ -1088,5 +1088,5 @@ TEST_F(TransformationTestsF, eliminate_eltwise_dequantization_subgraph) { manager.register_pass(); comparator.enable(FunctionsComparator::CmpValues::CONST_VALUES); - enable_accuracy_check(); + comparator.enable(FunctionsComparator::CmpValues::ACCURACY); } diff --git a/src/tests/functional/inference_engine/transformations/common_optimizations/pad_fusion.cpp b/src/tests/functional/inference_engine/transformations/common_optimizations/pad_fusion.cpp index 14efa6833a9..c9ab6da6805 100644 --- a/src/tests/functional/inference_engine/transformations/common_optimizations/pad_fusion.cpp +++ b/src/tests/functional/inference_engine/transformations/common_optimizations/pad_fusion.cpp @@ -120,11 +120,11 @@ TEST_F(TransformationTestsF, PadFusionConvolutionBackpropData) { auto pads_begin = opset5::Constant::create(element::i32, Shape{4}, {0, 0, 1, 1}); auto pads_end = opset5::Constant::create(element::i32, Shape{4}, {0, 0, 2, 2}); auto pad = std::make_shared(data, pads_begin, pads_end, op::PadMode::CONSTANT); - set_tensor_name(pad, "pad"); + auto filters = std::make_shared(element::f32, Shape{3, 2, 5, 5}); auto conv = std::make_shared(pad, filters, Strides{1, 1}, CoordinateDiff{4, 4}, CoordinateDiff{3, 3}, Shape{1, 1}); - set_tensor_name(conv, "conv"); + function = std::make_shared(NodeVector{conv}, ParameterVector{data, filters}); manager.register_pass(); @@ -134,7 +134,7 @@ TEST_F(TransformationTestsF, PadFusionConvolutionBackpropData) { auto filters = std::make_shared(element::f32, Shape{3, 2, 5, 5}); auto conv = std::make_shared(data, filters, Strides{1, 1}, CoordinateDiff{3, 3}, CoordinateDiff{1, 1}, Shape{1, 1}); - set_tensor_name(conv, "conv"); + function_ref = std::make_shared(NodeVector{conv}, ParameterVector{data, filters}); } @@ -244,11 +244,11 @@ TEST_F(TransformationTestsF, PadFusionConvolutionBackpropDataNonConstPadValue) { std::shared_ptr pad_value = opset5::Constant::create(element::f16, Shape{}, {0}); pad_value = std::make_shared(pad_value, element::f32); auto pad = std::make_shared(data, pads_begin, pads_end, pad_value, op::PadMode::CONSTANT); - set_tensor_name(pad, "pad"); + auto filters = std::make_shared(element::f32, Shape{3, 2, 5, 5}); auto conv = std::make_shared(pad, filters, Strides{1, 1}, CoordinateDiff{4, 4}, CoordinateDiff{3, 3}, Shape{1, 1}); - set_tensor_name(conv, "conv"); + function = std::make_shared(NodeVector{conv}, ParameterVector{data, filters}); manager.register_pass(); @@ -258,7 +258,7 @@ TEST_F(TransformationTestsF, PadFusionConvolutionBackpropDataNonConstPadValue) { auto filters = std::make_shared(element::f32, Shape{3, 2, 5, 5}); auto conv = std::make_shared(data, filters, Strides{1, 1}, CoordinateDiff{3, 3}, CoordinateDiff{1, 1}, Shape{1, 1}); - set_tensor_name(conv, "conv"); + function_ref = std::make_shared(NodeVector{conv}, ParameterVector{data, filters}); } @@ -422,30 +422,29 @@ TEST_F(TransformationTestsF, NegativePadFusionConvolutionBackpropDataTooSmallPad Shape data_shape{1, 3, 14, 14}; { auto data = std::make_shared(element::f32, data_shape); - set_tensor_name(data, "data"); + auto pads_begin = opset5::Constant::create(element::i32, Shape{4}, {0, 0, 2, 2}); auto pads_end = opset5::Constant::create(element::i32, Shape{4}, {0, 0, 2, 2}); auto pad = std::make_shared(data, pads_begin, pads_end, op::PadMode::CONSTANT); - set_tensor_name(pad, "pad"); + auto filters = std::make_shared(element::f32, Shape{3, 2, 5, 5}); auto conv = std::make_shared(pad, filters, Strides{1, 1}, CoordinateDiff{1, 1}, CoordinateDiff{1, 1}, Shape{1, 1}); - set_tensor_name(conv, "conv"); + function = std::make_shared(NodeVector{conv}, ParameterVector{data, filters}); manager.register_pass(); } { auto data = std::make_shared(element::f32, data_shape); - set_tensor_name(data, "data"); + auto pads_begin = opset5::Constant::create(element::i32, Shape{4}, {0, 0, 2, 2}); auto pads_end = opset5::Constant::create(element::i32, Shape{4}, {0, 0, 2, 2}); auto pad = std::make_shared(data, pads_begin, pads_end, op::PadMode::CONSTANT); - set_tensor_name(pad, "pad"); + auto filters = std::make_shared(element::f32, Shape{3, 2, 5, 5}); auto conv = std::make_shared(pad, filters, Strides{1, 1}, CoordinateDiff{1, 1}, CoordinateDiff{1, 1}, Shape{1, 1}); - set_tensor_name(conv, "conv"); function_ref = std::make_shared(NodeVector{conv}, ParameterVector{data, filters}); } diff --git a/src/tests/functional/inference_engine/transformations/common_optimizations/preprocessing_fusion_tests.cpp b/src/tests/functional/inference_engine/transformations/common_optimizations/preprocessing_fusion_tests.cpp index 700af83e1cf..4215ca2885a 100644 --- a/src/tests/functional/inference_engine/transformations/common_optimizations/preprocessing_fusion_tests.cpp +++ b/src/tests/functional/inference_engine/transformations/common_optimizations/preprocessing_fusion_tests.cpp @@ -141,7 +141,7 @@ TEST_F(TransformationTestsF, RICFusionSimple) { comparator.enable(FunctionsComparator::CmpValues::CONST_VALUES); comparator.enable(FunctionsComparator::CmpValues::ATTRIBUTES); disable_rt_info_check(); - enable_accuracy_check(); + comparator.enable(FunctionsComparator::CmpValues::ACCURACY); } TEST_F(TransformationTestsF, RICFusionHard) { @@ -198,7 +198,7 @@ TEST_F(TransformationTestsF, RICFusionHard) { comparator.enable(FunctionsComparator::CmpValues::CONST_VALUES); comparator.enable(FunctionsComparator::CmpValues::ATTRIBUTES); disable_rt_info_check(); - enable_accuracy_check(); + comparator.enable(FunctionsComparator::CmpValues::ACCURACY); } TEST_F(TransformationTestsF, RICFusionDynamic) { @@ -248,7 +248,7 @@ TEST_F(TransformationTestsF, RICFusionEltwise1) { comparator.enable(FunctionsComparator::CmpValues::CONST_VALUES); comparator.enable(FunctionsComparator::CmpValues::ATTRIBUTES); disable_rt_info_check(); - enable_accuracy_check(); + comparator.enable(FunctionsComparator::CmpValues::ACCURACY); } TEST_F(TransformationTestsF, RICFusionEltwise2) { @@ -273,7 +273,7 @@ TEST_F(TransformationTestsF, RICFusionEltwise2) { comparator.enable(FunctionsComparator::CmpValues::CONST_VALUES); comparator.enable(FunctionsComparator::CmpValues::ATTRIBUTES); disable_rt_info_check(); - enable_accuracy_check(); + comparator.enable(FunctionsComparator::CmpValues::ACCURACY); } TEST_F(TransformationTestsF, RICFusionEltwise3) { @@ -298,7 +298,7 @@ TEST_F(TransformationTestsF, RICFusionEltwise3) { comparator.enable(FunctionsComparator::CmpValues::CONST_VALUES); comparator.enable(FunctionsComparator::CmpValues::ATTRIBUTES); disable_rt_info_check(); - enable_accuracy_check(); + comparator.enable(FunctionsComparator::CmpValues::ACCURACY); } TEST_F(TransformationTestsF, RICFusionEltwise4) { @@ -324,7 +324,7 @@ TEST_F(TransformationTestsF, RICFusionEltwise4) { comparator.enable(FunctionsComparator::CmpValues::CONST_VALUES); comparator.enable(FunctionsComparator::CmpValues::ATTRIBUTES); disable_rt_info_check(); - enable_accuracy_check(); + comparator.enable(FunctionsComparator::CmpValues::ACCURACY); } TEST_F(TransformationTestsF, RICFusionEltwise5) { @@ -350,7 +350,7 @@ TEST_F(TransformationTestsF, RICFusionEltwise5) { comparator.enable(FunctionsComparator::CmpValues::CONST_VALUES); comparator.enable(FunctionsComparator::CmpValues::ATTRIBUTES); disable_rt_info_check(); - enable_accuracy_check(); + comparator.enable(FunctionsComparator::CmpValues::ACCURACY); } TEST_F(TransformationTestsF, RICFusionEltwiseNegative) { @@ -390,7 +390,7 @@ TEST_F(TransformationTestsF, RICFusionEltwiseTwoRIC) { comparator.enable(FunctionsComparator::CmpValues::CONST_VALUES); comparator.enable(FunctionsComparator::CmpValues::ATTRIBUTES); disable_rt_info_check(); - enable_accuracy_check(); + comparator.enable(FunctionsComparator::CmpValues::ACCURACY); } TEST_F(TransformationTestsF, RICFusionEltwiseNegative3) { @@ -431,7 +431,7 @@ TEST_F(TransformationTestsF, RICFusionGroupConv) { comparator.enable(FunctionsComparator::CmpValues::CONST_VALUES); comparator.enable(FunctionsComparator::CmpValues::ATTRIBUTES); disable_rt_info_check(); - enable_accuracy_check(); + comparator.enable(FunctionsComparator::CmpValues::ACCURACY); } TEST_F(TransformationTestsF, RICFusionGroupConvNegative) { @@ -473,7 +473,7 @@ TEST_F(TransformationTestsF, RICFusionTranspose) { comparator.enable(FunctionsComparator::CmpValues::CONST_VALUES); comparator.enable(FunctionsComparator::CmpValues::ATTRIBUTES); disable_rt_info_check(); - enable_accuracy_check(); + comparator.enable(FunctionsComparator::CmpValues::ACCURACY); } TEST_F(TransformationTestsF, RICFusionFQOnTheWay) { @@ -499,7 +499,7 @@ TEST_F(TransformationTestsF, RICFusionFQOnTheWay) { comparator.enable(FunctionsComparator::CmpValues::CONST_VALUES); comparator.enable(FunctionsComparator::CmpValues::ATTRIBUTES); disable_rt_info_check(); - enable_accuracy_check(); + comparator.enable(FunctionsComparator::CmpValues::ACCURACY); } TEST_F(TransformationTestsF, RICFusionFQOnTheWay2) { @@ -538,7 +538,7 @@ TEST_F(TransformationTestsF, RICFusionFQOnTheWay2) { comparator.enable(FunctionsComparator::CmpValues::CONST_VALUES); comparator.enable(FunctionsComparator::CmpValues::ATTRIBUTES); disable_rt_info_check(); - enable_accuracy_check(); + comparator.enable(FunctionsComparator::CmpValues::ACCURACY); } TEST_F(TransformationTestsF, RICFusionFQOnTheWay3) { @@ -578,7 +578,7 @@ TEST_F(TransformationTestsF, RICFusionFQOnTheWay3) { comparator.enable(FunctionsComparator::CmpValues::CONST_VALUES); comparator.enable(FunctionsComparator::CmpValues::ATTRIBUTES); disable_rt_info_check(); - enable_accuracy_check(); + comparator.enable(FunctionsComparator::CmpValues::ACCURACY); } TEST_F(TransformationTestsF, RICFusionShapeOf) { @@ -749,7 +749,7 @@ TEST_F(TransformationTestsF, FuseScaleValue) { } disable_rt_info_check(); - enable_accuracy_check(); + comparator.enable(FunctionsComparator::CmpValues::ACCURACY); } TEST_F(TransformationTestsF, FuseScaleValues) { @@ -779,5 +779,5 @@ TEST_F(TransformationTestsF, FuseScaleValues) { } disable_rt_info_check(); - enable_accuracy_check(); + comparator.enable(FunctionsComparator::CmpValues::ACCURACY); } diff --git a/src/tests/functional/inference_engine/transformations/common_optimizations/shuffle_channels_fusion_test.cpp b/src/tests/functional/inference_engine/transformations/common_optimizations/shuffle_channels_fusion_test.cpp index 54201fd589c..4ea58216808 100644 --- a/src/tests/functional/inference_engine/transformations/common_optimizations/shuffle_channels_fusion_test.cpp +++ b/src/tests/functional/inference_engine/transformations/common_optimizations/shuffle_channels_fusion_test.cpp @@ -90,7 +90,9 @@ protected: }; TEST_P(ShuffleChannelsFusion, CompareFunctions) { - auto fc = FunctionsComparator::no_default().enable(FunctionsComparator::PRECISIONS); + auto fc = FunctionsComparator::no_default() + .enable(FunctionsComparator::PRECISIONS) + .enable(FunctionsComparator::NODES); auto res = fc.compare(f, f_ref); ASSERT_TRUE(res.valid) << res.message; } diff --git a/src/tests/functional/inference_engine/transformations/common_optimizations/softmax_fusion.cpp b/src/tests/functional/inference_engine/transformations/common_optimizations/softmax_fusion.cpp index 9ce13d0cae2..5ab4ccdb778 100644 --- a/src/tests/functional/inference_engine/transformations/common_optimizations/softmax_fusion.cpp +++ b/src/tests/functional/inference_engine/transformations/common_optimizations/softmax_fusion.cpp @@ -55,7 +55,9 @@ TEST_P(SoftmaxFusionFixture, SoftmaxFusion) { f_ref = std::make_shared(NodeVector{softmax}, ParameterVector{data}); } - auto fc = FunctionsComparator::no_default().enable(FunctionsComparator::PRECISIONS); + auto fc = FunctionsComparator::no_default() + .enable(FunctionsComparator::PRECISIONS) + .enable(FunctionsComparator::NODES); auto res = fc.compare(f, f_ref); ASSERT_TRUE(res.valid) << res.message; } @@ -113,7 +115,9 @@ TEST_P(NegativeSoftmaxFusionFixture, NegativeSoftmaxFusion) { f_ref = std::make_shared(NodeVector{div}, ParameterVector{data}); } - auto fc = FunctionsComparator::no_default().enable(FunctionsComparator::PRECISIONS); + auto fc = FunctionsComparator::no_default() + .enable(FunctionsComparator::PRECISIONS) + .enable(FunctionsComparator::NODES); auto res = fc.compare(f, f_ref); ASSERT_TRUE(res.valid) << res.message; } diff --git a/src/tests/functional/inference_engine/transformations/common_optimizations/transpose_sinking_test.cpp b/src/tests/functional/inference_engine/transformations/common_optimizations/transpose_sinking_test.cpp index 948a53db3d6..515e4ec59a6 100644 --- a/src/tests/functional/inference_engine/transformations/common_optimizations/transpose_sinking_test.cpp +++ b/src/tests/functional/inference_engine/transformations/common_optimizations/transpose_sinking_test.cpp @@ -94,6 +94,7 @@ TEST_P(TransposeSinkingFQ, TransposeFQReduce) { ASSERT_NO_THROW(check_rt_info(f)); auto fc = FunctionsComparator::no_default() + .enable(FunctionsComparator::NODES) .enable(FunctionsComparator::PRECISIONS) .enable(FunctionsComparator::CONST_VALUES); auto res = fc.compare(f, f_ref); @@ -180,6 +181,7 @@ TEST_P(TransposeSinking, TransposeReduction) { ASSERT_NO_THROW(check_rt_info(f)); auto fc = FunctionsComparator::no_default() + .enable(FunctionsComparator::NODES) .enable(FunctionsComparator::PRECISIONS) .enable(FunctionsComparator::CONST_VALUES); diff --git a/src/tests/functional/inference_engine/transformations/common_optimizations/transpose_to_reshape_test.cpp b/src/tests/functional/inference_engine/transformations/common_optimizations/transpose_to_reshape_test.cpp index 647ea874894..292ba8e88c3 100644 --- a/src/tests/functional/inference_engine/transformations/common_optimizations/transpose_to_reshape_test.cpp +++ b/src/tests/functional/inference_engine/transformations/common_optimizations/transpose_to_reshape_test.cpp @@ -109,7 +109,9 @@ TEST_P(TransposeToReshapeTests, CompareFunctions) { f->validate_nodes_and_infer_types(); ASSERT_NO_THROW(check_rt_info(f)); - auto fc = FunctionsComparator::no_default().enable(FunctionsComparator::PRECISIONS); + auto fc = FunctionsComparator::no_default() + .enable(FunctionsComparator::NODES) + .enable(FunctionsComparator::PRECISIONS); auto res = fc.compare(f, f_ref); ASSERT_TRUE(res.valid) << res.message; } diff --git a/src/tests/functional/inference_engine/transformations/common_optimizations/weights_dequantize_to_fake_quantize.cpp b/src/tests/functional/inference_engine/transformations/common_optimizations/weights_dequantize_to_fake_quantize.cpp index 68f3883befc..a71c25dd933 100644 --- a/src/tests/functional/inference_engine/transformations/common_optimizations/weights_dequantize_to_fake_quantize.cpp +++ b/src/tests/functional/inference_engine/transformations/common_optimizations/weights_dequantize_to_fake_quantize.cpp @@ -84,6 +84,7 @@ TEST_P(TranslateNewWeightFormatToOldOne, ReshapeMatMul) { ASSERT_NO_THROW(check_rt_info(f)); auto fc = FunctionsComparator::no_default() + .enable(FunctionsComparator::NODES) .enable(FunctionsComparator::PRECISIONS) .enable(FunctionsComparator::CONST_VALUES); auto res = fc.compare(f, f_ref); diff --git a/src/tests/functional/inference_engine/transformations/legacy/conv_fusion_test.cpp b/src/tests/functional/inference_engine/transformations/legacy/conv_fusion_test.cpp index 80f3488619e..a7858073da2 100644 --- a/src/tests/functional/inference_engine/transformations/legacy/conv_fusion_test.cpp +++ b/src/tests/functional/inference_engine/transformations/legacy/conv_fusion_test.cpp @@ -127,7 +127,9 @@ TEST_P(ConvFusionTests, CompareFunctions) { manager.run_passes(f); ASSERT_NO_THROW(check_rt_info(f)); - auto fc = FunctionsComparator::no_default().enable(FunctionsComparator::PRECISIONS); + auto fc = FunctionsComparator::no_default() + .enable(FunctionsComparator::NODES) + .enable(FunctionsComparator::PRECISIONS); auto res = fc.compare(f, f_ref); ASSERT_TRUE(res.valid) << res.message; } diff --git a/src/tests/functional/inference_engine/transformations/legacy/mul_add_conversion_test.cpp b/src/tests/functional/inference_engine/transformations/legacy/mul_add_conversion_test.cpp index 1cd881caab3..1a511c7b4de 100644 --- a/src/tests/functional/inference_engine/transformations/legacy/mul_add_conversion_test.cpp +++ b/src/tests/functional/inference_engine/transformations/legacy/mul_add_conversion_test.cpp @@ -157,7 +157,9 @@ TEST_P(MulAddConversionTests, CompareFunctions) { ASSERT_NO_THROW(manager.run_passes(f)); f->validate_nodes_and_infer_types(); - auto fc = FunctionsComparator::no_default().enable(FunctionsComparator::PRECISIONS); + auto fc = FunctionsComparator::no_default() + .enable(FunctionsComparator::NODES) + .enable(FunctionsComparator::PRECISIONS); auto res = fc.compare(f, f_ref); ASSERT_TRUE(res.valid) << res.message; } @@ -178,7 +180,9 @@ TEST_P(MulOrAddConversionTests, CompareFunctions) { f->validate_nodes_and_infer_types(); - auto fc = FunctionsComparator::no_default().enable(FunctionsComparator::PRECISIONS); + auto fc = FunctionsComparator::no_default() + .enable(FunctionsComparator::NODES) + .enable(FunctionsComparator::PRECISIONS); auto res = fc.compare(f, f_ref); ASSERT_TRUE(res.valid) << res.message; } diff --git a/src/tests/functional/inference_engine/transformations/offline_transformations/pruning_test.cpp b/src/tests/functional/inference_engine/transformations/offline_transformations/pruning_test.cpp index f1ebb57e03f..e120936f679 100644 --- a/src/tests/functional/inference_engine/transformations/offline_transformations/pruning_test.cpp +++ b/src/tests/functional/inference_engine/transformations/offline_transformations/pruning_test.cpp @@ -246,7 +246,7 @@ TEST_F(TransformationTestsF, PropagateMasksBasic) { m.run_passes(function); } disable_rt_info_check(); - enable_accuracy_check(); + comparator.enable(FunctionsComparator::CmpValues::ACCURACY); } @@ -311,7 +311,7 @@ TEST_F(TransformationTestsF, PropagateMasksDynamicConvolution) { m.run_passes(function); } disable_rt_info_check(); - enable_accuracy_check(); + comparator.enable(FunctionsComparator::CmpValues::ACCURACY); } @@ -491,7 +491,7 @@ TEST_F(TransformationTestsF, PropagateMaskPassThrough) { m.run_passes(function); } disable_rt_info_check(); - enable_accuracy_check(); + comparator.enable(FunctionsComparator::CmpValues::ACCURACY); } @@ -643,7 +643,7 @@ TEST_F(TransformationTestsF, PropagateMasksHardDependencies) { m.run_passes(function); } disable_rt_info_check(); - enable_accuracy_check(); + comparator.enable(FunctionsComparator::CmpValues::ACCURACY); } @@ -761,7 +761,7 @@ TEST_F(TransformationTestsF, PropagateMasksQuantizedGroupConvolution) { m.run_passes(function); } disable_rt_info_check(); - enable_accuracy_check(); + comparator.enable(FunctionsComparator::CmpValues::ACCURACY); } @@ -899,7 +899,7 @@ TEST_F(TransformationTestsF, PropagateMasksQuantizedGroupConvolutionWithShapeOf) m.run_passes(function); } disable_rt_info_check(); - enable_accuracy_check(); + comparator.enable(FunctionsComparator::CmpValues::ACCURACY); } @@ -1017,7 +1017,7 @@ TEST_F(TransformationTestsF, PropagateMasksFakeQuantizePerTensor) { m.run_passes(function); } disable_rt_info_check(); - enable_accuracy_check(); + comparator.enable(FunctionsComparator::CmpValues::ACCURACY); } @@ -1202,7 +1202,7 @@ TEST_F(TransformationTestsF, PropagateMasksFakeQuantizePerChannel) { m.run_passes(function); } disable_rt_info_check(); - enable_accuracy_check(); + comparator.enable(FunctionsComparator::CmpValues::ACCURACY); } @@ -1299,7 +1299,7 @@ TEST_F(TransformationTestsF, TestConcatMaskPropagation) { m.run_passes(function); } disable_rt_info_check(); - enable_accuracy_check(); + comparator.enable(FunctionsComparator::CmpValues::ACCURACY); } @@ -1406,7 +1406,7 @@ TEST_F(TransformationTestsF, TestConcatMaskPropagationUp) { m.run_passes(function); } disable_rt_info_check(); - enable_accuracy_check(); + comparator.enable(FunctionsComparator::CmpValues::ACCURACY); } @@ -1554,7 +1554,7 @@ TEST_F(TransformationTestsF, PruneConvIsClosingAndInGroup) { m.run_passes(function); } disable_rt_info_check(); - enable_accuracy_check(); + comparator.enable(FunctionsComparator::CmpValues::ACCURACY); } @@ -1814,7 +1814,7 @@ TEST_F(TransformationTestsF, PruneReduceLayerDown) { m.run_passes(function); } disable_rt_info_check(); - enable_accuracy_check(); + comparator.enable(FunctionsComparator::CmpValues::ACCURACY); } @@ -1975,7 +1975,7 @@ TEST_F(TransformationTestsF, MaskPropagationReshapeUp) { m.run_passes(function); } disable_rt_info_check(); - enable_accuracy_check(); + comparator.enable(FunctionsComparator::CmpValues::ACCURACY); } @@ -2050,7 +2050,7 @@ TEST_F(TransformationTestsF, MaskPropagationReshapeUpWithShapeOf) { m.run_passes(function); } disable_rt_info_check(); - enable_accuracy_check(); + comparator.enable(FunctionsComparator::CmpValues::ACCURACY); } @@ -2149,7 +2149,7 @@ TEST_F(TransformationTestsF, MaskPropagationReshapeDown) { m.run_passes(function); } disable_rt_info_check(); - enable_accuracy_check(); + comparator.enable(FunctionsComparator::CmpValues::ACCURACY); } @@ -2403,7 +2403,7 @@ TEST_F(TransformationTestsF, PruneSEBlock) { m.run_passes(function); } disable_rt_info_check(); - enable_accuracy_check(); + comparator.enable(FunctionsComparator::CmpValues::ACCURACY); } @@ -2495,7 +2495,7 @@ TEST_F(TransformationTestsF, PropagateMasksLinear) { m.run_passes(function); } disable_rt_info_check(); - enable_accuracy_check(); + comparator.enable(FunctionsComparator::CmpValues::ACCURACY); } @@ -2679,7 +2679,7 @@ TEST_F(TransformationTestsF, PruneMasksMatMulColsStopRowsUp) { m.run_passes(function); } disable_rt_info_check(); - enable_accuracy_check(); + comparator.enable(FunctionsComparator::CmpValues::ACCURACY); } @@ -2760,7 +2760,7 @@ TEST_F(TransformationTestsF, PruneMasksMatMulRowsStopColsUp) { m.run_passes(function); } disable_rt_info_check(); - enable_accuracy_check(); + comparator.enable(FunctionsComparator::CmpValues::ACCURACY); } @@ -2855,5 +2855,5 @@ TEST_F(TransformationTestsF, PropagateFlattenUp) { m.run_passes(function); } disable_rt_info_check(); - enable_accuracy_check(); + comparator.enable(FunctionsComparator::CmpValues::ACCURACY); } diff --git a/src/tests/functional/inference_engine/transformations/op_conversions/convert_reduce_to_pooling_test.cpp b/src/tests/functional/inference_engine/transformations/op_conversions/convert_reduce_to_pooling_test.cpp index d27e7d4f74f..16a35fdfa81 100644 --- a/src/tests/functional/inference_engine/transformations/op_conversions/convert_reduce_to_pooling_test.cpp +++ b/src/tests/functional/inference_engine/transformations/op_conversions/convert_reduce_to_pooling_test.cpp @@ -126,7 +126,9 @@ TEST_P(ConvertReduceToPoolingTests, CompareFunctions) { m.run_passes(f); ASSERT_NO_THROW(check_rt_info(f)); - auto fc = FunctionsComparator::no_default().enable(FunctionsComparator::PRECISIONS); + auto fc = FunctionsComparator::no_default() + .enable(FunctionsComparator::NODES) + .enable(FunctionsComparator::PRECISIONS); auto res = fc.compare(f, f_ref); ASSERT_TRUE(res.valid) << res.message; } diff --git a/src/tests/functional/inference_engine/transformations/op_conversions/convert_scatter_elements_to_scatter_test.cpp b/src/tests/functional/inference_engine/transformations/op_conversions/convert_scatter_elements_to_scatter_test.cpp index eb684415335..c626ff597b0 100644 --- a/src/tests/functional/inference_engine/transformations/op_conversions/convert_scatter_elements_to_scatter_test.cpp +++ b/src/tests/functional/inference_engine/transformations/op_conversions/convert_scatter_elements_to_scatter_test.cpp @@ -90,7 +90,9 @@ void test(std::shared_ptr f, std::shared_ptr manager.register_pass(); ASSERT_NO_THROW(manager.run_passes(f)); - auto fc = FunctionsComparator::no_default().enable(FunctionsComparator::PRECISIONS); + auto fc = FunctionsComparator::no_default() + .enable(FunctionsComparator::NODES) + .enable(FunctionsComparator::PRECISIONS); auto res = fc.compare(f, f_ref); ASSERT_TRUE(res.valid) << res.message; } diff --git a/src/tests/functional/inference_engine/transformations/utils/compare_functions_test.cpp b/src/tests/functional/inference_engine/transformations/utils/compare_functions_test.cpp index e380f5045ac..8555ab1658b 100644 --- a/src/tests/functional/inference_engine/transformations/utils/compare_functions_test.cpp +++ b/src/tests/functional/inference_engine/transformations/utils/compare_functions_test.cpp @@ -650,7 +650,9 @@ TEST(TransformationTests, DifferentPrecisionVersusAttributes) { /// { // check precision only - const auto fc = FunctionsComparator::no_default().enable(FunctionsComparator::PRECISIONS); + const auto fc = FunctionsComparator::no_default() + .enable(FunctionsComparator::NODES) + .enable(FunctionsComparator::PRECISIONS); const auto res = fc.compare(f1, f2); EXPECT_FALSE(res.valid); EXPECT_THAT(res.message, HasSubstr("Different element type detected")); @@ -660,6 +662,7 @@ TEST(TransformationTests, DifferentPrecisionVersusAttributes) { { // check precision and attributes const auto fc = FunctionsComparator::no_default() + .enable(FunctionsComparator::NODES) .enable(FunctionsComparator::PRECISIONS) .enable(FunctionsComparator::ATTRIBUTES); const auto res = fc.compare(f1, f2); diff --git a/src/tests/functional/inference_engine/transformations/utils/compress_quantize_weights.cpp b/src/tests/functional/inference_engine/transformations/utils/compress_quantize_weights.cpp index 3c5310a666f..4b869521c53 100644 --- a/src/tests/functional/inference_engine/transformations/utils/compress_quantize_weights.cpp +++ b/src/tests/functional/inference_engine/transformations/utils/compress_quantize_weights.cpp @@ -62,7 +62,7 @@ class CompressQuantizeWeightsTests function_ref = std::make_shared(mul, ParameterVector{}); } comparator.enable(FunctionsComparator::CmpValues::CONST_VALUES); - enable_accuracy_check(); + comparator.enable(FunctionsComparator::CmpValues::ACCURACY); } }; @@ -108,7 +108,7 @@ TEST_F(TransformationTestsF, CompressQuantizeWeightsWithDequantizationSubgraph) function_ref = std::make_shared(NodeVector{mul}, ParameterVector{}); } comparator.enable(FunctionsComparator::CmpValues::CONST_VALUES); - enable_accuracy_check(); + comparator.enable(FunctionsComparator::CmpValues::ACCURACY); } TEST_F(TransformationTestsF, CompressQuantizeWeightsWithZeroPointOptimizer) { @@ -133,7 +133,7 @@ TEST_F(TransformationTestsF, CompressQuantizeWeightsWithZeroPointOptimizer) { function_ref = std::make_shared(NodeVector{mul}, ParameterVector{}); } comparator.enable(FunctionsComparator::CmpValues::CONST_VALUES); - enable_accuracy_check(); + comparator.enable(FunctionsComparator::CmpValues::ACCURACY); } TEST_F(TransformationTestsF, NegativeCompressQuantizeWeightsWithZeroPointOptimizer) { @@ -159,7 +159,7 @@ TEST_F(TransformationTestsF, NegativeCompressQuantizeWeightsWithZeroPointOptimiz function_ref = std::make_shared(NodeVector{mul}, ParameterVector{}); } comparator.enable(FunctionsComparator::CmpValues::CONST_VALUES); - enable_accuracy_check(); + comparator.enable(FunctionsComparator::CmpValues::ACCURACY); } TEST_F(TransformationTestsF, NegativeCompressQuantizeWeightsNonConstantInput) { @@ -175,5 +175,5 @@ TEST_F(TransformationTestsF, NegativeCompressQuantizeWeightsNonConstantInput) { manager.register_pass(); comparator.enable(FunctionsComparator::CmpValues::CONST_VALUES); - enable_accuracy_check(); + comparator.enable(FunctionsComparator::CmpValues::ACCURACY); } diff --git a/src/tests/ie_test_utils/common_test_utils/graph_comparator.cpp b/src/tests/ie_test_utils/common_test_utils/graph_comparator.cpp index 76a72195110..3469cbe37ee 100644 --- a/src/tests/ie_test_utils/common_test_utils/graph_comparator.cpp +++ b/src/tests/ie_test_utils/common_test_utils/graph_comparator.cpp @@ -2,7 +2,10 @@ // SPDX-License-Identifier: Apache-2.0 // -#include "graph_comparator.hpp" +#include "common_test_utils/graph_comparator.hpp" + +#include +#include "ie_common.h" #include #include @@ -20,6 +23,10 @@ #include #include +#include "common_test_utils/ov_tensor_utils.hpp" +#include "ngraph_functions/utils/ngraph_helpers.hpp" + + namespace { inline namespace tools { bool is_type_relaxed(const std::string& type) { @@ -31,12 +38,12 @@ bool compare_type_info(const ngraph::DiscreteTypeInfo& info1, const ngraph::Disc if (!is_type_relaxed(info1.name) && !is_type_relaxed(info2.name) && (info1.version != info2.version)) { return false; } - OPENVINO_SUPPRESS_DEPRECATED_END + OPENVINO_SUPPRESS_DEPRECATED_END const std::string info1Name = - is_type_relaxed(info1.name) && (info1.parent != nullptr) ? info1.parent->name : info1.name; + is_type_relaxed(info1.name) && (info1.parent != nullptr) ? info1.parent->name : info1.name; const std::string info2Name = - is_type_relaxed(info2.name) && (info2.parent != nullptr) ? info2.parent->name : info2.name; + is_type_relaxed(info2.name) && (info2.parent != nullptr) ? info2.parent->name : info2.name; return info1Name == info2Name; } @@ -121,9 +128,9 @@ public: explicit NodeAndInputDescription(const InputNode& input, const Parameter* parameter, const InputDescripton* description) - : m_input(input), - m_parameter(not_null(parameter)), - m_description(not_null(description)) {} + : m_input(input), + m_parameter(not_null(parameter)), + m_description(not_null(description)) {} static bool equal_descriptions(const InputDescripton* lhs, const InputDescripton* rhs) { if (!lhs || !rhs || lhs->get_type_info() != rhs->get_type_info()) { @@ -172,7 +179,7 @@ public: } for (size_t i = 0; i != param_shape.size(); ++i) { const auto expected_axis_size = - i == slice_description->m_axis ? slice_description->m_part_size * num_iterations : param_shape[i]; + i == slice_description->m_axis ? slice_description->m_part_size * num_iterations : param_shape[i]; if (input_shape[i] != expected_axis_size) { return false; } @@ -215,9 +222,9 @@ public: explicit NodeAndOutputDescription(const OutputNode& output, const Result* result, const OutputDescription* description) - : m_output(output), - m_result(not_null(result)), - m_description(not_null(description)) {} + : m_output(output), + m_result(not_null(result)), + m_description(not_null(description)) {} static bool equal_descriptions(const OutputDescription* lhs, const OutputDescription* rhs) { if (!lhs || !rhs || lhs->get_type_info() != rhs->get_type_info()) { @@ -302,8 +309,8 @@ public: using Id = uint64_t; explicit BackEdge(const Parameter* parameter, const Result* result) - : m_parameter(not_null(parameter)), - m_result(not_null(result)) {} + : m_parameter(not_null(parameter)), + m_result(not_null(result)) {} bool result_and_parameter_match() const { return equal_type_and_partial_shape(m_result->output(0), *m_parameter); @@ -539,103 +546,113 @@ Comparator::Result Comparator::compare(const std::shared_ptr& * + Check node attributes by Visitor API */ - auto f_results = f->get_results(); - auto f_ref_results = f_ref->get_results(); + if (should_compare(CmpValues::NODES)) { + auto f_results = f->get_results(); + auto f_ref_results = f_ref->get_results(); - auto cmp = less_by_name; - // In case if Result source output has more than one name so the Result may have any of this names as a friendly - // name An in case of multiple names we sort Result operation using their parent node names - if (std::any_of(f_results.begin(), - f_results.end(), - [](const std::shared_ptr& node) { - const auto& t = node->input_value(0).get_tensor_ptr(); - return t->get_names().size() > 1; - }) || - std::any_of(f_ref_results.begin(), f_ref_results.end(), [](const std::shared_ptr& node) { - const auto& t = node->input_value(0).get_tensor_ptr(); - return t->get_names().size() > 1; - })) { - cmp = less_by_parent_name; - } + auto cmp = less_by_name; + // In case if Result source output has more than one name so the Result may have any of this names as a friendly + // name An in case of multiple names we sort Result operation using their parent node names + if (std::any_of(f_results.begin(), + f_results.end(), + [](const std::shared_ptr &node) { + const auto &t = node->input_value(0).get_tensor_ptr(); + return t->get_names().size() > 1; + }) || + std::any_of(f_ref_results.begin(), f_ref_results.end(), [](const std::shared_ptr &node) { + const auto &t = node->input_value(0).get_tensor_ptr(); + return t->get_names().size() > 1; + })) { + cmp = less_by_parent_name; + } - std::sort(f_results.begin(), f_results.end(), cmp); - std::sort(f_ref_results.begin(), f_ref_results.end(), cmp); + std::sort(f_results.begin(), f_results.end(), cmp); + std::sort(f_ref_results.begin(), f_ref_results.end(), cmp); - if (f_results.size() != f_ref_results.size()) { - return Result::error("Number of results is different: " + to_str(f_results.size()) + " and " + - to_str(f_ref_results.size())); - } + if (f_results.size() != f_ref_results.size()) { + return Result::error("Number of results is different: " + to_str(f_results.size()) + " and " + + to_str(f_ref_results.size())); + } - const auto& f_sinks = f->get_sinks(); - const auto& f_ref_sinks = f_ref->get_sinks(); - if (f_sinks.size() != f_ref_sinks.size()) { - return Result::error("Number of sinks is different: " + to_str(f_sinks.size()) + " and " + - to_str(f_ref_sinks.size())); - } + const auto &f_sinks = f->get_sinks(); + const auto &f_ref_sinks = f_ref->get_sinks(); + if (f_sinks.size() != f_ref_sinks.size()) { + return Result::error("Number of sinks is different: " + to_str(f_sinks.size()) + " and " + + to_str(f_ref_sinks.size())); + } - // Compare sinks - if (f_sinks.size() == 1) { - q.push({f_sinks[0].get(), f_ref_sinks[0].get()}); - used.insert(f_sinks[0].get()); - } else { - // Cast to Assign and find those that have same variable_id suffix - for (const auto& sink1 : f_sinks) { - auto assign1 = std::dynamic_pointer_cast(sink1); - if (!assign1) { - return Result::error("Sink '" + name(sink1) + - "' is not a variable - graph comparison is not supported"); - } - auto name1 = assign1->get_variable_id(); - std::shared_ptr found_sink2; - for (const auto& sink2 : f_ref_sinks) { - auto assign2 = std::dynamic_pointer_cast(sink2); - if (!assign2) { - return Result::error("Sink '" + name(sink2) + + // Compare sinks + if (f_sinks.size() == 1) { + q.push({f_sinks[0].get(), f_ref_sinks[0].get()}); + used.insert(f_sinks[0].get()); + } else { + // Cast to Assign and find those that have same variable_id suffix + for (const auto &sink1 : f_sinks) { + auto assign1 = std::dynamic_pointer_cast(sink1); + if (!assign1) { + return Result::error("Sink '" + name(sink1) + "' is not a variable - graph comparison is not supported"); } - auto name2 = assign2->get_variable_id(); - if (name2.find(name1) != std::string::npos || name1.find(name2) != std::string::npos) { - found_sink2 = sink2; - break; + auto name1 = assign1->get_variable_id(); + std::shared_ptr found_sink2; + for (const auto &sink2 : f_ref_sinks) { + auto assign2 = std::dynamic_pointer_cast(sink2); + if (!assign2) { + return Result::error("Sink '" + name(sink2) + + "' is not a variable - graph comparison is not supported"); + } + auto name2 = assign2->get_variable_id(); + if (name2.find(name1) != std::string::npos || name1.find(name2) != std::string::npos) { + found_sink2 = sink2; + break; + } + } + if (!found_sink2) { + return Result::error("No suitable sink is found for: " + name(sink1) + ", var=" + name1); + } + q.push({sink1.get(), found_sink2.get()}); + used.insert(sink1.get()); + } + } + + for (size_t i = 0; i < f_results.size(); ++i) { + if (should_compare(CmpValues::NAMES)) { + if (name(f_results[i]->get_input_node_shared_ptr(0)) != + name(f_ref_results[i]->get_input_node_shared_ptr(0))) { + return Result::error( + "Different output node names: " + name(f_results[i]->get_input_node_shared_ptr(0)) + + " and " + + name(f_ref_results[i]->get_input_node_shared_ptr(0))); } } - if (!found_sink2) { - return Result::error("No suitable sink is found for: " + name(sink1) + ", var=" + name1); + q.push({f_results[i].get(), f_ref_results[i].get()}); + used.insert(f_results[i].get()); + } + + std::stringstream errors; + + while (!q.empty()) { + ngraph::Node *const node1 = q.front().first; + ngraph::Node *const node2 = q.front().second; + q.pop(); + + const auto result = compare(node1, node2, errors); + if (!result.valid) { + return result; } - q.push({sink1.get(), found_sink2.get()}); - used.insert(sink1.get()); + + add_nodes_inputs_to_queue(node1, node2); } + const auto msg = errors.str(); + return msg.empty() ? Result::ok() : Result::error(msg); + + } else if (should_compare(CmpValues::ACCURACY)) { + auto status = accuracy_check(f_ref, f); + return status.status ? Result::ok() : Result::error(status.message); + + } else { + return Result::error("CmpValues are invalid"); } - - for (size_t i = 0; i < f_results.size(); ++i) { - if (should_compare(CmpValues::NAMES)) { - if (name(f_results[i]->get_input_node_shared_ptr(0)) != - name(f_ref_results[i]->get_input_node_shared_ptr(0))) { - return Result::error( - "Different output node names: " + name(f_results[i]->get_input_node_shared_ptr(0)) + " and " + - name(f_ref_results[i]->get_input_node_shared_ptr(0))); - } - } - q.push({f_results[i].get(), f_ref_results[i].get()}); - used.insert(f_results[i].get()); - } - - std::stringstream errors; - - while (!q.empty()) { - ngraph::Node* const node1 = q.front().first; - ngraph::Node* const node2 = q.front().second; - q.pop(); - - const auto result = compare(node1, node2, errors); - if (!result.valid) { - return result; - } - - add_nodes_inputs_to_queue(node1, node2); - } - const auto msg = errors.str(); - return msg.empty() ? Result::ok() : Result::error(msg); } Comparator::Result Comparator::compare(ngraph::Node* node1, ngraph::Node* node2, std::ostream& err_log) { @@ -759,7 +776,7 @@ void Comparator::compare_outputs(ngraph::Node* node1, ngraph::Node* node2, std:: void Comparator::compare_nodes(ngraph::Node* node1, ngraph::Node* node2, std::ostream& err_log) { if (should_compare(CmpValues::RUNTIME_KEYS) && !compare_rt_keys(*node1, *node2, err_log)) { err_log << "Different runtime info detected\n" + name(node1) + " and " + name(node2) + - " not equal runtime info.\n"; + " not equal runtime info.\n"; } if (should_compare(CmpValues::ATTRIBUTES)) { @@ -805,123 +822,115 @@ void check_rt_info(const std::shared_ptr& f) { } } -void set_tensor_name(ngraph::Output output, const std::string& name) { - output.get_tensor_ptr()->set_names({name}); -} - -void set_tensor_names(ngraph::Output output, const std::unordered_set& names) { - output.get_tensor_ptr()->set_names(names); -} - namespace attributes { namespace detail { -void ReadAndStoreAttributes::on_adapter(const std::string& name, ngraph::ValueAccessor& adapter) { - if (auto inputs = ngraph::as_type>(&adapter)) { - insert(name, inputs->get()); - } else if (auto outputs = ngraph::as_type>(&adapter)) { - insert(name, outputs->get()); - } else if (ngraph::is_type>(&adapter)) { - // drop comparison, no more info than port indexes which will be check in - // subgraph::compare_io - } else if (auto a = ngraph::as_type>>( - &adapter)) { - const auto beg = static_cast(a->get()->get_ptr()); - const auto end = beg + a->get()->size(); - insert(name, storage::MemoryChunk{storage::MemoryChunk::Data(beg, end)}); - } else if (auto framework_node_attr = - ngraph::as_type>(&adapter)) { - insert(name, framework_node_attr->get()); - } else if (auto variable_ptr = - ngraph::as_type>>(&adapter)) { - insert(name, variable_ptr->get()); - } else if (auto shape_ptr = ngraph::as_type>(&adapter)) { - insert(name, shape_ptr->get()); - } else if (auto dim_ptr = ngraph::as_type>(&adapter)) { - insert(name, dim_ptr->get()); - } else { - m_read_result += "store attr [ ERR ]: " + name + " [drop `void` comparison which is '" + - adapter.get_type_info().name + "']"; + void ReadAndStoreAttributes::on_adapter(const std::string& name, ngraph::ValueAccessor& adapter) { + if (auto inputs = ngraph::as_type>(&adapter)) { + insert(name, inputs->get()); + } else if (auto outputs = ngraph::as_type>(&adapter)) { + insert(name, outputs->get()); + } else if (ngraph::is_type>(&adapter)) { + // drop comparison, no more info than port indexes which will be check in + // subgraph::compare_io + } else if (auto a = ngraph::as_type>>( + &adapter)) { + const auto beg = static_cast(a->get()->get_ptr()); + const auto end = beg + a->get()->size(); + insert(name, storage::MemoryChunk{storage::MemoryChunk::Data(beg, end)}); + } else if (auto framework_node_attr = + ngraph::as_type>(&adapter)) { + insert(name, framework_node_attr->get()); + } else if (auto variable_ptr = + ngraph::as_type>>(&adapter)) { + insert(name, variable_ptr->get()); + } else if (auto shape_ptr = ngraph::as_type>(&adapter)) { + insert(name, shape_ptr->get()); + } else if (auto dim_ptr = ngraph::as_type>(&adapter)) { + insert(name, dim_ptr->get()); + } else { + m_read_result += "store attr [ ERR ]: " + name + " [drop `void` comparison which is '" + + adapter.get_type_info().name + "']"; + } } -} -template -void ReadAndCompareAttributes::verify(const std::string& name, const AttrValue& attr_value) { - if (should_return()) { - return; - } - m_visited_attributes.insert(name); - const auto ref_value = m_attr_ref.get(name); - if (!ref_value) { - m_cmp_result += "missing attribute name: '" + name + "'"; - return; + template + void ReadAndCompareAttributes::verify(const std::string& name, const AttrValue& attr_value) { + if (should_return()) { + return; + } + m_visited_attributes.insert(name); + const auto ref_value = m_attr_ref.get(name); + if (!ref_value) { + m_cmp_result += "missing attribute name: '" + name + "'"; + return; + } + + if (!equal::Equal::equal_value(*ref_value, attr_value)) { + m_cmp_result += "mismatch in value: '" + name + "' : " + str::Get::value(*ref_value) + " vs " + + str::Get::value(attr_value); + } } - if (!equal::Equal::equal_value(*ref_value, attr_value)) { - m_cmp_result += "mismatch in value: '" + name + "' : " + str::Get::value(*ref_value) + " vs " + - str::Get::value(attr_value); - } -} + void ReadAndCompareAttributes::verify_mem_buf(const std::string& name, + const std::shared_ptr& buffer) { + if (should_return()) { + return; + } + m_visited_attributes.insert(name); + const auto ref_value = m_attr_ref.get(name); + if (!ref_value) { + m_cmp_result += "missing attribute name: '" + name + "'"; + return; + } -void ReadAndCompareAttributes::verify_mem_buf(const std::string& name, - const std::shared_ptr& buffer) { - if (should_return()) { - return; - } - m_visited_attributes.insert(name); - const auto ref_value = m_attr_ref.get(name); - if (!ref_value) { - m_cmp_result += "missing attribute name: '" + name + "'"; - return; + if (buffer->size() != ref_value->size() || + std::memcmp(ref_value->data(), buffer->get_ptr(), ref_value->size()) != 0) { + m_cmp_result += "mismatch in value: '" + name + "' : look in to the mem buffer"; + return; + } } - if (buffer->size() != ref_value->size() || - std::memcmp(ref_value->data(), buffer->get_ptr(), ref_value->size()) != 0) { - m_cmp_result += "mismatch in value: '" + name + "' : look in to the mem buffer"; - return; + void ReadAndCompareAttributes::verify_function(const std::string& name, ModelAccessor& adapter) { + if (should_return()) { + return; + } + m_visited_attributes.insert(name); + const auto ref_value = m_attr_ref.get>(name); + if (!ref_value) { + m_cmp_result += "missing attribute name: '" + name + "'"; + return; + } + Comparator c(m_check_flags); + const auto result = c.compare(*ref_value, adapter.get()); + if (!result.valid) { + m_cmp_result += result.message; + } } -} -void ReadAndCompareAttributes::verify_function(const std::string& name, ModelAccessor& adapter) { - if (should_return()) { - return; + void ReadAndCompareAttributes::verify_others(const std::string& name, ngraph::ValueAccessor& adapter) { + if (auto inputs = ngraph::as_type>(&adapter)) { + verify(name, inputs->get()); + } else if (auto outputs = ngraph::as_type>(&adapter)) { + verify(name, outputs->get()); + } else if (ngraph::is_type>(&adapter)) { + // drop comparison, no more info than port indexes which will be check in + // subgraph::compare_io + } else if (auto a = ngraph::as_type>>( + &adapter)) { + verify_mem_buf(name, a->get()); + } else if (auto attrs = ngraph::as_type>(&adapter)) { + verify(name, attrs->get()); + } else if (auto variable_ptr = + ngraph::as_type>>(&adapter)) { + verify(name, variable_ptr->get()); + } else if (auto shape_ptr = ngraph::as_type>(&adapter)) { + verify(name, shape_ptr->get()); + } else if (auto dim_ptr = ngraph::as_type>(&adapter)) { + verify(name, dim_ptr->get()); + } else { + m_cmp_result += "compare attr [ ERR ]: " + name + " [drop `void` comparison which is '" + + adapter.get_type_info().name + "']"; + } } - m_visited_attributes.insert(name); - const auto ref_value = m_attr_ref.get>(name); - if (!ref_value) { - m_cmp_result += "missing attribute name: '" + name + "'"; - return; - } - Comparator c(m_check_flags); - const auto result = c.compare(*ref_value, adapter.get()); - if (!result.valid) { - m_cmp_result += result.message; - } -} - -void ReadAndCompareAttributes::verify_others(const std::string& name, ngraph::ValueAccessor& adapter) { - if (auto inputs = ngraph::as_type>(&adapter)) { - verify(name, inputs->get()); - } else if (auto outputs = ngraph::as_type>(&adapter)) { - verify(name, outputs->get()); - } else if (ngraph::is_type>(&adapter)) { - // drop comparison, no more info than port indexes which will be check in - // subgraph::compare_io - } else if (auto a = ngraph::as_type>>( - &adapter)) { - verify_mem_buf(name, a->get()); - } else if (auto attrs = ngraph::as_type>(&adapter)) { - verify(name, attrs->get()); - } else if (auto variable_ptr = - ngraph::as_type>>(&adapter)) { - verify(name, variable_ptr->get()); - } else if (auto shape_ptr = ngraph::as_type>(&adapter)) { - verify(name, shape_ptr->get()); - } else if (auto dim_ptr = ngraph::as_type>(&adapter)) { - verify(name, dim_ptr->get()); - } else { - m_cmp_result += "compare attr [ ERR ]: " + name + " [drop `void` comparison which is '" + - adapter.get_type_info().name + "']"; - } -} } // namespace detail @@ -937,3 +946,42 @@ Comparator::Result compare(ngraph::Node* node1, ngraph::Node* node2, Comparator: } } // namespace attributes + +AccuracyCheckResult accuracy_check(const std::shared_ptr& ref_function, + const std::shared_ptr& cur_function) { + if (ref_function->is_dynamic() || cur_function->is_dynamic()) { + return AccuracyCheckResult{true, ""}; + } + try { + IE_ASSERT(ref_function->get_parameters().size() == cur_function->get_parameters().size()); + + std::map, ov::Tensor> ref_input_data; + std::map, ov::Tensor> cur_input_data; + for (size_t i = 0; i < ref_function->get_parameters().size(); i++) { + const auto &tensor = ov::test::utils::create_and_fill_tensor( + ref_function->get_parameters()[i]->get_element_type(), + ref_function->get_parameters()[i]->get_shape()); + ref_input_data[ref_function->get_parameters()[i]] = tensor; + cur_input_data[cur_function->get_parameters()[i]] = tensor; + } + + auto ref_outputs = ngraph::helpers::interpretFunction(ref_function, ref_input_data); + auto outputs = ngraph::helpers::interpretFunction(cur_function, cur_input_data); + + IE_ASSERT(ref_outputs.size() == outputs.size()); + + for (int i = 0; i < ref_outputs.size(); i++) { + ov::test::utils::compare(ref_outputs[i], outputs[i], + std::numeric_limits::max(), + std::numeric_limits::max()); + } + } + catch (const std::runtime_error &re) { + return AccuracyCheckResult{false, re.what()}; + } catch (const std::exception &ex) { + return AccuracyCheckResult{false, ex.what()}; + } catch (...) { + return AccuracyCheckResult{false, ""}; + } + return AccuracyCheckResult{true, ""}; +} diff --git a/src/tests/ie_test_utils/common_test_utils/graph_comparator.hpp b/src/tests/ie_test_utils/common_test_utils/graph_comparator.hpp index 9cdfe1bf3af..c96753ba65a 100644 --- a/src/tests/ie_test_utils/common_test_utils/graph_comparator.hpp +++ b/src/tests/ie_test_utils/common_test_utils/graph_comparator.hpp @@ -20,12 +20,14 @@ class FunctionsComparator { public: enum CmpValues { NONE = 0, - CONST_VALUES = 1 << 0, - NAMES = 1 << 1, - RUNTIME_KEYS = 1 << 2, - PRECISIONS = 1 << 3, - ATTRIBUTES = 1 << 4, - TENSOR_NAMES = 1 << 5, + NODES = 1 << 0, + CONST_VALUES = 1 << 1, + NAMES = 1 << 2, + RUNTIME_KEYS = 1 << 3, + PRECISIONS = 1 << 4, + ATTRIBUTES = 1 << 5, + TENSOR_NAMES = 1 << 6, + ACCURACY = 1 << 7 }; struct Result { @@ -45,6 +47,7 @@ public: } static FunctionsComparator with_default() noexcept { auto fc = FunctionsComparator::no_default(); + fc.enable(NODES); fc.enable(PRECISIONS); fc.enable(TENSOR_NAMES); return fc; @@ -55,6 +58,11 @@ public: return *this; } + FunctionsComparator& disable(CmpValues f) noexcept { + m_comparison_flags = static_cast(m_comparison_flags & ~f); + return *this; + } + bool should_compare(CmpValues f) const noexcept { return m_comparison_flags & f; } @@ -85,6 +93,7 @@ inline std::pair compare_functions(const std::shared_ptr compare_functions(const std::shared_ptr& f); -void set_tensor_name(ngraph::Output output, const std::string& name); - -void set_tensor_names(ngraph::Output output, const std::unordered_set& names); - namespace ngraph { namespace pass { class InjectionPass : public ov::pass::ModelPass { @@ -188,8 +193,8 @@ public: for (auto output : node->outputs()) { const auto& tensor_names = output.get_names(); if (std::any_of(tensor_names.begin(), tensor_names.end(), [&](const std::string& name) { - return unique_tensor_names.count(name); - })) { + return unique_tensor_names.count(name); + })) { std::stringstream ss; ss << "Node: " << node->get_type_info() << " with name " << node->get_friendly_name() << " "; ss << "has non unique tensor name."; @@ -478,9 +483,9 @@ public: void on_adapter(const std::string& name, ngraph::ValueAccessor& adapter) override; #define ON_ADAPTER(TYPE) \ - void on_adapter(const std::string& name, ngraph::ValueAccessor& adapter) override { \ - insert(name, adapter.get()); \ - } +void on_adapter(const std::string& name, ngraph::ValueAccessor& adapter) override { \ +insert(name, adapter.get()); \ +} ON_ADAPTER(bool) ON_ADAPTER(std::string) @@ -711,37 +716,37 @@ struct Equal> { } switch (lhs_t) { - case ngraph::element::Type_t::u1: { - const auto lhs_v = static_cast(lhs->get_data_ptr()); - const auto rhs_v = static_cast(rhs->get_data_ptr()); - const auto lhs_bit_size = shape_size(lhs->get_shape()); - const auto rhs_bit_size = shape_size(rhs->get_shape()); - return Equal::equal_value(lhs_v, rhs_v, lhs_bit_size, rhs_bit_size); - } - case ngraph::element::Type_t::bf16: { - auto lhs_v = lhs->cast_vector(); - auto rhs_v = rhs->cast_vector(); - return Equal>::equal_value(lhs_v, rhs_v); - break; - } - case ngraph::element::Type_t::f16: { - const auto& lhs_v = lhs->cast_vector(); - const auto& rhs_v = rhs->cast_vector(); - return Equal>::equal_value(lhs_v, rhs_v); - break; - } - case ngraph::element::Type_t::f32: { - const auto& lhs_v = lhs->cast_vector(); - const auto& rhs_v = rhs->cast_vector(); - return Equal>::equal_value(lhs_v, rhs_v); - break; - } - default: { - const auto& lhs_v = lhs->cast_vector(); - const auto& rhs_v = rhs->cast_vector(); - return Equal>::equal_value(lhs_v, rhs_v); - break; - } + case ngraph::element::Type_t::u1: { + const auto lhs_v = static_cast(lhs->get_data_ptr()); + const auto rhs_v = static_cast(rhs->get_data_ptr()); + const auto lhs_bit_size = shape_size(lhs->get_shape()); + const auto rhs_bit_size = shape_size(rhs->get_shape()); + return Equal::equal_value(lhs_v, rhs_v, lhs_bit_size, rhs_bit_size); + } + case ngraph::element::Type_t::bf16: { + auto lhs_v = lhs->cast_vector(); + auto rhs_v = rhs->cast_vector(); + return Equal>::equal_value(lhs_v, rhs_v); + break; + } + case ngraph::element::Type_t::f16: { + const auto& lhs_v = lhs->cast_vector(); + const auto& rhs_v = rhs->cast_vector(); + return Equal>::equal_value(lhs_v, rhs_v); + break; + } + case ngraph::element::Type_t::f32: { + const auto& lhs_v = lhs->cast_vector(); + const auto& rhs_v = rhs->cast_vector(); + return Equal>::equal_value(lhs_v, rhs_v); + break; + } + default: { + const auto& lhs_v = lhs->cast_vector(); + const auto& rhs_v = rhs->cast_vector(); + return Equal>::equal_value(lhs_v, rhs_v); + break; + } } return false; } @@ -862,18 +867,18 @@ struct Get, void> { class ReadAndCompareAttributes : public ngraph::AttributeVisitor { public: ReadAndCompareAttributes(const ReadAndStoreAttributes& ref, Comparator::CmpValues check_flags) - : m_attr_ref(ref), - m_cmp_result{ref.read_result()}, - m_check_flags(check_flags) {} + : m_attr_ref(ref), + m_cmp_result{ref.read_result()}, + m_check_flags(check_flags) {} void on_adapter(const std::string& name, ngraph::ValueAccessor& adapter) override { verify_others(name, adapter); } #define ON_ADAPTER(TYPE) \ - void on_adapter(const std::string& name, ngraph::ValueAccessor& adapter) override { \ - verify(name, adapter.get()); \ - } +void on_adapter(const std::string& name, ngraph::ValueAccessor& adapter) override { \ +verify(name, adapter.get()); \ +} ON_ADAPTER(bool) ON_ADAPTER(std::string) @@ -981,3 +986,11 @@ private: Comparator::Result compare(ngraph::Node* node1, ngraph::Node* node2, Comparator::CmpValues comparition_flags); } // namespace attributes + +struct AccuracyCheckResult { + bool status; + std::string message; +}; + +AccuracyCheckResult accuracy_check(const std::shared_ptr& ref_function, + const std::shared_ptr& cur_function); diff --git a/src/tests/ie_test_utils/common_test_utils/ngraph_test_utils.cpp b/src/tests/ie_test_utils/common_test_utils/ngraph_test_utils.cpp index 28e7959f65d..b2b2fea8f08 100644 --- a/src/tests/ie_test_utils/common_test_utils/ngraph_test_utils.cpp +++ b/src/tests/ie_test_utils/common_test_utils/ngraph_test_utils.cpp @@ -2,248 +2,7 @@ // SPDX-License-Identifier: Apache-2.0 // -#include #include "ngraph_test_utils.hpp" -#include - -namespace { - -template -static void Compare(const T_NGRAPH *expected, const T_IE *actual, std::size_t size, float threshold, float abs_threshold = -1.f) { - for (std::size_t i = 0; i < size; ++i) { - const T_NGRAPH &ref = expected[i]; - const auto &res = actual[i]; - const auto absoluteDifference = CommonTestUtils::ie_abs(res - ref); - if (abs_threshold > 0.f && absoluteDifference > abs_threshold) { - IE_THROW() << "Absolute comparison of values expected: " << std::to_string(ref) << " and actual: " << std::to_string(res) - << " at index " << i << " with absolute threshold " << abs_threshold - << " failed"; - } - if (absoluteDifference <= threshold) { - continue; - } - double max; - if (sizeof(T_IE) < sizeof(T_NGRAPH)) { - max = std::max(CommonTestUtils::ie_abs(T_NGRAPH(res)), CommonTestUtils::ie_abs(ref)); - } else { - max = std::max(CommonTestUtils::ie_abs(res), CommonTestUtils::ie_abs(T_IE(ref))); - } - double diff = static_cast(absoluteDifference) / max; - if (max == 0 || (diff > static_cast(threshold)) || - (std::isnan(static_cast(res)) ^ std::isnan(static_cast(ref)))) { - IE_THROW() << "Relative comparison of values expected: " << std::to_string(ref) << " and actual: " << std::to_string(res) - << " at index " << i << " with threshold " << threshold - << " failed"; - } - } -} - -template -void callCompare(const std::pair> &expected, - const T_IE* actualBuffer, size_t size, float threshold, float abs_threshold) { - auto expectedBuffer = expected.second.data(); - switch (expected.first) { - case ngraph::element::Type_t::i64: - Compare(reinterpret_cast(expectedBuffer), - actualBuffer, size, threshold, abs_threshold); - break; - case ngraph::element::Type_t::i32: - Compare(reinterpret_cast(expectedBuffer), - actualBuffer, size, threshold, abs_threshold); - break; - case ngraph::element::Type_t::i16: - Compare(reinterpret_cast(expectedBuffer), - actualBuffer, size, threshold, abs_threshold); - break; - case ngraph::element::Type_t::i8: - Compare(reinterpret_cast(expectedBuffer), - actualBuffer, size, threshold, abs_threshold); - break; - case ngraph::element::Type_t::u64: - Compare(reinterpret_cast(expectedBuffer), - actualBuffer, size, threshold, abs_threshold); - break; - case ngraph::element::Type_t::u32: - Compare(reinterpret_cast(expectedBuffer), - actualBuffer, size, threshold, abs_threshold); - break; - case ngraph::element::Type_t::u16: - Compare(reinterpret_cast(expectedBuffer), - actualBuffer, size, threshold, abs_threshold); - break; - case ngraph::element::Type_t::boolean: - case ngraph::element::Type_t::u8: - Compare(reinterpret_cast(expectedBuffer), - actualBuffer, size, threshold, abs_threshold); - break; - case ngraph::element::Type_t::f64: - Compare(reinterpret_cast(expectedBuffer), - actualBuffer, size, threshold, abs_threshold); - break; - case ngraph::element::Type_t::f32: - Compare(reinterpret_cast(expectedBuffer), - actualBuffer, size, threshold, abs_threshold); - break; - case ngraph::element::Type_t::f16: - Compare(reinterpret_cast(expectedBuffer), - actualBuffer, size, threshold, abs_threshold); - break; - case ngraph::element::Type_t::bf16: - Compare(reinterpret_cast(expectedBuffer), - actualBuffer, size, threshold, abs_threshold); - break; -// case ngraph::element::Type_t::i4: { -// auto expectedOut = ngraph::helpers::convertOutputPrecision( -// expected.second, -// expected.first, -// ngraph::element::Type_t::i8, -// size); -// Compare(reinterpret_cast(expectedOut.data()), -// actualBuffer, size, threshold, abs_threshold); -// break; -// } -// case ngraph::element::Type_t::u4: { -// auto expectedOut = ngraph::helpers::convertOutputPrecision( -// expected.second, -// expected.first, -// ngraph::element::Type_t::u8, -// size); -// Compare(reinterpret_cast(expectedOut.data()), -// actualBuffer, size, threshold, abs_threshold); -// break; -// } - case ngraph::element::Type_t::dynamic: - case ngraph::element::Type_t::undefined: - Compare(reinterpret_cast(expectedBuffer), actualBuffer, size, threshold, abs_threshold); - break; - default: FAIL() << "Comparator for " << expected.first << " precision isn't supported"; - } - return; -} - - -void Compare(const std::pair> &expected, - const InferenceEngine::Blob::Ptr &actual, - float threshold, - float abs_threshold) { - const auto &precision = actual->getTensorDesc().getPrecision(); - auto k = static_cast(expected.first.size()) / precision.size(); - // W/A for int4, uint4 - if (expected.first == ngraph::element::Type_t::u4 || expected.first == ngraph::element::Type_t::i4) { - k /= 2; - } else if (expected.first == ngraph::element::Type_t::undefined || expected.first == ngraph::element::Type_t::dynamic) { - k = 1; - } - ASSERT_EQ(expected.second.size(), actual->byteSize() * k); - - auto memory = InferenceEngine::as(actual); - IE_ASSERT(memory); - const auto lockedMemory = memory->wmap(); - const auto actualBuffer = lockedMemory.as(); - - const auto &size = actual->size(); - switch (precision) { - case InferenceEngine::Precision::FP32: - callCompare(expected, reinterpret_cast(actualBuffer), size, threshold, abs_threshold); - break; - case InferenceEngine::Precision::I32: - callCompare(expected, reinterpret_cast(actualBuffer), size, threshold, abs_threshold); - break; - case InferenceEngine::Precision::I64: - callCompare(expected, reinterpret_cast(actualBuffer), size, threshold, abs_threshold); - break; - case InferenceEngine::Precision::I8: - callCompare(expected, reinterpret_cast(actualBuffer), size, threshold, abs_threshold); - break; - case InferenceEngine::Precision::U16: - callCompare(expected, reinterpret_cast(actualBuffer), size, threshold, abs_threshold); - break; - case InferenceEngine::Precision::I16: - callCompare(expected, reinterpret_cast(actualBuffer), size, threshold, abs_threshold); - break; - case InferenceEngine::Precision::BOOL: - case InferenceEngine::Precision::U8: - callCompare(expected, reinterpret_cast(actualBuffer), size, threshold, abs_threshold); - break; - case InferenceEngine::Precision::U64: - callCompare(expected, reinterpret_cast(actualBuffer), size, threshold, abs_threshold); - break; - case InferenceEngine::Precision::BF16: - callCompare(expected, reinterpret_cast(actualBuffer), size, threshold, abs_threshold); - break; - case InferenceEngine::Precision::FP16: - callCompare(expected, reinterpret_cast(actualBuffer), size, threshold, abs_threshold); - break; - default: - FAIL() << "Comparator for " << precision << " precision isn't supported"; - } -} - -void Compare(const std::vector>> &expectedOutputs, - const std::vector &actualOutputs, - float threshold, float abs_threshold) { - for (std::size_t outputIndex = 0; outputIndex < expectedOutputs.size(); ++outputIndex) { - const auto &expected = expectedOutputs[outputIndex]; - const auto &actual = actualOutputs[outputIndex]; - Compare(expected, actual, threshold, abs_threshold); - } -} -} // namespace - -void TransformationTestsF::accuracy_check(std::shared_ptr ref_function, - std::shared_ptr cur_function) { - try { - if (ref_function->is_dynamic() || cur_function->is_dynamic()) { - return; - } - std::vector> input_data; - ngraph::element::TypeVector types; - for (auto param : ref_function->get_parameters()) { - types.push_back(param->get_element_type()); - - auto layout = InferenceEngine::Layout::ANY; - if (ov::is_scalar(param->get_shape())) { - layout = InferenceEngine::Layout::SCALAR; - } - InferenceEngine::TensorDesc td(InferenceEngine::Precision::FP32, param->get_shape(), layout); - const auto &input = FuncTestUtils::createAndFillBlob(td); - const auto &input_size = input->byteSize(); - - std::vector data; - data.resize(input_size); - - auto memory = InferenceEngine::as(input); - IE_ASSERT(memory); - - const auto lockedMemory = memory->wmap(); - const auto buffer = lockedMemory.as(); - std::copy(buffer, buffer + input_size, data.data()); - - input_data.push_back(std::move(data)); - } - - auto ref_outputs = ngraph::helpers::interpreterFunction(ref_function, input_data, types); - auto outputs = ngraph::helpers::interpreterFunction(cur_function, input_data, types); - - IE_ASSERT(ref_outputs.size() == outputs.size()); - - for (size_t i = 0; i < ref_outputs.size(); ++i) { - IE_ASSERT(ref_outputs[i].second.size() == outputs[i].second.size()); - auto * ref = reinterpret_cast(ref_outputs[i].second.data()); - auto * out = reinterpret_cast(outputs[i].second.data()); - size_t size = ref_outputs[i].second.size() / sizeof(float); - IE_ASSERT(size > 0); - Compare(ref, out, size, 1e-5); - } - } - catch (const std::runtime_error &re) { - GTEST_FATAL_FAILURE_(re.what()); - } catch (const std::exception &ex) { - GTEST_FATAL_FAILURE_(ex.what()); - } catch (...) { - GTEST_FATAL_FAILURE_("Unknown failure occurred."); - } -} void init_unique_names(std::shared_ptr f, const std::shared_ptr& unh) { ngraph::pass::Manager manager; diff --git a/src/tests/ie_test_utils/common_test_utils/ngraph_test_utils.hpp b/src/tests/ie_test_utils/common_test_utils/ngraph_test_utils.hpp index 549ada20f7f..c79898e1930 100644 --- a/src/tests/ie_test_utils/common_test_utils/ngraph_test_utils.hpp +++ b/src/tests/ie_test_utils/common_test_utils/ngraph_test_utils.hpp @@ -32,6 +32,7 @@ class TransformationTestsF : public CommonTestUtils::TestsCommon { public: TransformationTestsF() : comparator(FunctionsComparator::no_default()) { m_unh = std::make_shared(); + comparator.enable(FunctionsComparator::CmpValues::NODES); comparator.enable(FunctionsComparator::CmpValues::PRECISIONS); comparator.enable(FunctionsComparator::CmpValues::RUNTIME_KEYS); // TODO: enable attributes and constant values comparison by default XXX-68694 @@ -56,12 +57,16 @@ public: ASSERT_NO_THROW(check_rt_info(function)); } + if (comparator.should_compare(FunctionsComparator::ACCURACY)) { + auto acc_comparator = FunctionsComparator::no_default(); + acc_comparator.enable(FunctionsComparator::CmpValues::ACCURACY); + auto res = acc_comparator.compare(function, cloned_function); + ASSERT_TRUE(res.valid) << res.message; + comparator.disable(FunctionsComparator::CmpValues::ACCURACY); + } + auto res = comparator.compare(function, function_ref); ASSERT_TRUE(res.valid) << res.message; - - if (m_enable_accuracy_check) { - accuracy_check(cloned_function, function); - } } // TODO: this is temporary solution to disable rt info checks that must be applied by default @@ -74,12 +79,6 @@ public: m_soft_names_comparison = true; } - void enable_accuracy_check() { - m_enable_accuracy_check = true; - } - - void accuracy_check(std::shared_ptr ref_function, std::shared_ptr cur_function); - std::shared_ptr function, function_ref; ngraph::pass::Manager manager; FunctionsComparator comparator; @@ -88,7 +87,6 @@ private: std::shared_ptr m_unh; bool m_disable_rt_info_check{false}; bool m_soft_names_comparison{true}; - bool m_enable_accuracy_check{false}; }; void init_unique_names(std::shared_ptr f, const std::shared_ptr& unh);