Floats comparison implementation in IE_Engines (#4342)
This commit is contained in:
@@ -1583,15 +1583,10 @@ IE_GPU.onnx_model_gather_elements_int32_axis_0
|
||||
IE_GPU.onnx_model_gather_elements_int8_axis_1
|
||||
IE_GPU.onnx_model_gather_elements_float_3D_axis_2
|
||||
|
||||
IE_CPU.evaluate_ctc_greedy_decoder_seq_len
|
||||
IE_GPU.evaluate_ctc_greedy_decoder_seq_len
|
||||
IE_CPU.evaluate_ctc_greedy_decoder_seq_len_f16
|
||||
IE_GPU.evaluate_ctc_greedy_decoder_seq_len_f16
|
||||
IE_CPU.evaluate_ctc_greedy_decoder_seq_len_merge
|
||||
IE_GPU.evaluate_ctc_greedy_decoder_seq_len_merge
|
||||
IE_CPU.evaluate_ctc_greedy_decoder_seq_len_multiple_batches
|
||||
IE_GPU.evaluate_ctc_greedy_decoder_seq_len_multiple_batches
|
||||
IE_CPU.evaluate_ctc_greedy_decoder_seq_len_multiple_batches2
|
||||
IE_GPU.evaluate_ctc_greedy_decoder_seq_len_multiple_batches2
|
||||
|
||||
# incorrect result for Minimum if u16 type is used
|
||||
|
||||
@@ -22,6 +22,7 @@ set (SRC
|
||||
all_close_f.cpp
|
||||
engine/ie_engines.cpp
|
||||
engine/interpreter_engine.cpp
|
||||
engine/shared_utils.cpp
|
||||
float_util.cpp
|
||||
test_tools.cpp
|
||||
test_control.cpp
|
||||
|
||||
@@ -19,6 +19,7 @@
|
||||
#include "ngraph/opsets/opset.hpp"
|
||||
#include "ngraph/pass/manager.hpp"
|
||||
#include "pass/opset1_upgrade.hpp"
|
||||
#include "shared_utils.hpp"
|
||||
|
||||
using namespace ngraph;
|
||||
|
||||
@@ -121,7 +122,7 @@ namespace
|
||||
default: THROW_IE_EXCEPTION << "Not implemented yet";
|
||||
}
|
||||
}
|
||||
};
|
||||
}; // namespace
|
||||
|
||||
namespace
|
||||
{
|
||||
@@ -156,7 +157,7 @@ namespace
|
||||
#endif
|
||||
throw std::runtime_error("unsupported type");
|
||||
}
|
||||
}
|
||||
} // namespace
|
||||
|
||||
test::IE_Engine::IE_Engine(const std::shared_ptr<Function> function, const char* device)
|
||||
: m_function{function}
|
||||
@@ -244,6 +245,36 @@ testing::AssertionResult
|
||||
test::IE_Engine::compare_results_with_tolerance_as_fp(const float tolerance)
|
||||
{
|
||||
auto comparison_result = testing::AssertionSuccess();
|
||||
|
||||
for (const auto& output : m_network_outputs)
|
||||
{
|
||||
if (comparison_result == testing::AssertionFailure())
|
||||
{
|
||||
break;
|
||||
}
|
||||
|
||||
InferenceEngine::MemoryBlob::CPtr computed_output_blob =
|
||||
InferenceEngine::as<InferenceEngine::MemoryBlob>(m_inference_req.GetBlob(output.first));
|
||||
|
||||
const auto& expected_output_blob = m_expected_outputs[output.first];
|
||||
|
||||
switch (expected_output_blob->getTensorDesc().getPrecision())
|
||||
{
|
||||
case InferenceEngine::Precision::FP32:
|
||||
{
|
||||
const auto test_results =
|
||||
extract_test_results<float>(computed_output_blob, expected_output_blob);
|
||||
comparison_result =
|
||||
test::compare_with_tolerance(test_results.first, test_results.second, tolerance);
|
||||
break;
|
||||
}
|
||||
default:
|
||||
comparison_result = testing::AssertionFailure()
|
||||
<< "Unsupported data type encountered in "
|
||||
"'compare_results_with_tolerance_as_fp' method";
|
||||
}
|
||||
}
|
||||
|
||||
return comparison_result;
|
||||
}
|
||||
|
||||
@@ -308,4 +339,4 @@ namespace InferenceEngine
|
||||
template class TBlob<ngraph::float16>;
|
||||
template class TBlob<char>;
|
||||
#endif
|
||||
}
|
||||
} // namespace InferenceEngine
|
||||
|
||||
@@ -14,11 +14,13 @@
|
||||
// limitations under the License.
|
||||
//*****************************************************************************
|
||||
|
||||
#include "interpreter_engine.hpp"
|
||||
#include <cmath>
|
||||
#include <iomanip>
|
||||
#include <sstream>
|
||||
|
||||
#include "interpreter_engine.hpp"
|
||||
#include "shared_utils.hpp"
|
||||
|
||||
using namespace ngraph;
|
||||
|
||||
namespace
|
||||
@@ -45,32 +47,7 @@ namespace
|
||||
const auto expected = expected_results->get_vector<float>();
|
||||
const auto result = read_vector<float>(results);
|
||||
|
||||
Shape out_shape = expected_results->get_shape();
|
||||
|
||||
size_t num_of_elems = shape_size(out_shape);
|
||||
std::stringstream msg;
|
||||
|
||||
msg << std::setprecision(std::numeric_limits<long double>::digits10 + 1);
|
||||
|
||||
bool rc = true;
|
||||
|
||||
for (std::size_t j = 0; j < num_of_elems; ++j)
|
||||
{
|
||||
float diff = std::abs(result[j] - expected[j]);
|
||||
if (diff > tolerance)
|
||||
{
|
||||
msg << expected[j] << " is not close to " << result[j] << " at index " << j << "\n";
|
||||
rc = false;
|
||||
}
|
||||
}
|
||||
|
||||
if (!rc)
|
||||
{
|
||||
comparison_result = testing::AssertionFailure();
|
||||
}
|
||||
|
||||
comparison_result << msg.str();
|
||||
return comparison_result;
|
||||
return ngraph::test::compare_with_tolerance(expected, result, tolerance);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
|
||||
50
ngraph/test/util/engine/shared_utils.cpp
Normal file
50
ngraph/test/util/engine/shared_utils.cpp
Normal file
@@ -0,0 +1,50 @@
|
||||
|
||||
// Copyright 2017-2021 Intel Corporation
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
//*****************************************************************************
|
||||
|
||||
#include <cmath>
|
||||
#include <sstream>
|
||||
|
||||
#include "shared_utils.hpp"
|
||||
|
||||
testing::AssertionResult ngraph::test::compare_with_tolerance(const std::vector<float>& expected,
|
||||
const std::vector<float>& results,
|
||||
const float tolerance)
|
||||
{
|
||||
auto comparison_result = testing::AssertionSuccess();
|
||||
|
||||
std::stringstream msg;
|
||||
msg << std::setprecision(std::numeric_limits<long double>::digits10 + 1);
|
||||
|
||||
bool rc = true;
|
||||
|
||||
for (std::size_t j = 0; j < expected.size(); ++j)
|
||||
{
|
||||
float diff = std::fabs(results[j] - expected[j]);
|
||||
if (diff > tolerance)
|
||||
{
|
||||
msg << expected[j] << " is not close to " << results[j] << " at index " << j << "\n";
|
||||
rc = false;
|
||||
}
|
||||
}
|
||||
|
||||
if (!rc)
|
||||
{
|
||||
comparison_result = testing::AssertionFailure();
|
||||
comparison_result << msg.str();
|
||||
}
|
||||
|
||||
return comparison_result;
|
||||
}
|
||||
30
ngraph/test/util/engine/shared_utils.hpp
Normal file
30
ngraph/test/util/engine/shared_utils.hpp
Normal file
@@ -0,0 +1,30 @@
|
||||
//*****************************************************************************
|
||||
// Copyright 2017-2021 Intel Corporation
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
//*****************************************************************************
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <gtest/gtest.h>
|
||||
#include <vector>
|
||||
|
||||
namespace ngraph
|
||||
{
|
||||
namespace test
|
||||
{
|
||||
testing::AssertionResult compare_with_tolerance(const std::vector<float>& expected_results,
|
||||
const std::vector<float>& results,
|
||||
const float tolerance);
|
||||
}
|
||||
} // namespace ngraph
|
||||
Reference in New Issue
Block a user