[CPU] New operations: IsFinite, IsInf, IsNaN. (#14314)

This commit is contained in:
Nikolay Shchegolev 2022-12-01 16:34:39 +04:00 committed by GitHub
parent 64391cdb3f
commit add3b11880
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
21 changed files with 307 additions and 41 deletions

View File

@ -34,8 +34,6 @@ xfail_issue_33488 = xfail_test(reason="RuntimeError: OV does not support the fol
"MaxUnpool")
skip_issue_38084 = pytest.mark.skip(reason="Aborted (core dumped) Assertion "
"`(layer->get_output_partial_shape(i).is_static())' failed.")
xfail_issue_33589 = xfail_test(reason="OV does not support the following ONNX operations: "
"IsNaN and isInf")
xfail_issue_33595 = xfail_test(reason="RuntimeError: OV does not support the following ONNX operations: "
"Unique")
xfail_issue_33596 = xfail_test(reason="RuntimeError: OV does not support different sequence operations: "

View File

@ -10,7 +10,6 @@ from tests import (
skip_rng_tests,
xfail_issue_33488,
xfail_issue_33581,
xfail_issue_33589,
xfail_issue_33595,
xfail_issue_33596,
xfail_issue_33606,
@ -223,13 +222,6 @@ tests_expected_to_fail = [
"OnnxBackendNodeModelTest.test_maxunpool_export_with_output_shape_cpu",
"OnnxBackendNodeModelTest.test_maxunpool_export_without_output_shape_cpu",
),
(
xfail_issue_33589,
"OnnxBackendNodeModelTest.test_isnan_cpu",
"OnnxBackendNodeModelTest.test_isinf_positive_cpu",
"OnnxBackendNodeModelTest.test_isinf_negative_cpu",
"OnnxBackendNodeModelTest.test_isinf_cpu",
),
(xfail_issue_38724, "OnnxBackendNodeModelTest.test_resize_tf_crop_and_resize_cpu"),
(
xfail_issue_33606,

View File

@ -31,8 +31,6 @@ xfail_issue_33488 = xfail_test(reason="RuntimeError: nGraph does not support the
"MaxUnpool")
skip_issue_38084 = pytest.mark.skip(reason="Aborted (core dumped) Assertion "
"`(layer->get_output_partial_shape(i).is_static())' failed.")
xfail_issue_33589 = xfail_test(reason="nGraph does not support the following ONNX operations: "
"IsNaN and isInf")
xfail_issue_33595 = xfail_test(reason="RuntimeError: nGraph does not support the following ONNX operations: "
"Unique")
xfail_issue_33596 = xfail_test(reason="RuntimeError: nGraph does not support different sequence operations: "

View File

@ -10,7 +10,6 @@ from tests_compatibility import (
xfail_unsupported_by_legacy_api,
xfail_issue_33488,
xfail_issue_33581,
xfail_issue_33589,
xfail_issue_33595,
xfail_issue_33596,
xfail_issue_33606,
@ -224,13 +223,6 @@ tests_expected_to_fail = [
"OnnxBackendNodeModelTest.test_maxunpool_export_with_output_shape_cpu",
"OnnxBackendNodeModelTest.test_maxunpool_export_without_output_shape_cpu",
),
(
xfail_issue_33589,
"OnnxBackendNodeModelTest.test_isnan_cpu",
"OnnxBackendNodeModelTest.test_isinf_positive_cpu",
"OnnxBackendNodeModelTest.test_isinf_negative_cpu",
"OnnxBackendNodeModelTest.test_isinf_cpu",
),
(xfail_issue_38724, "OnnxBackendNodeModelTest.test_resize_tf_crop_and_resize_cpu"),
(
xfail_issue_33606,

View File

@ -425,10 +425,3 @@ IE_CPU.onnx_softmax_crossentropy_loss_mean
# Cannot find blob with name: Y
IE_CPU.onnx_bool_init_and
IE_CPU.onnx_is_finite
IE_CPU.onnx_is_inf_default
IE_CPU.onnx_is_inf_negative_only
IE_CPU.onnx_is_inf_positive_only
IE_CPU.onnx_is_inf_detect_none
IE_CPU.onnx_is_nan

View File

@ -26,6 +26,9 @@ const InferenceEngine::details::caseless_unordered_map<std::string, Type> type_t
{ "AdaptiveMaxPool", Type::AdaptivePooling},
{ "AdaptiveAvgPool", Type::AdaptivePooling},
{ "Add", Type::Eltwise },
{ "IsFinite", Type::Eltwise },
{ "IsInf", Type::Eltwise },
{ "IsNaN", Type::Eltwise },
{ "Subtract", Type::Eltwise },
{ "Multiply", Type::Eltwise },
{ "Divide", Type::Eltwise },
@ -416,6 +419,9 @@ std::string algToString(const Algorithm alg) {
CASE(DeconvolutionCommon);
CASE(DeconvolutionGrouped);
CASE(EltwiseAdd);
CASE(EltwiseIsFinite);
CASE(EltwiseIsInf);
CASE(EltwiseIsNaN);
CASE(EltwiseMultiply);
CASE(EltwiseSubtract);
CASE(EltwiseDivide);

View File

@ -134,6 +134,9 @@ enum class Algorithm {
// Elementwise algorithms
EltwiseAdd,
EltwiseIsFinite,
EltwiseIsInf,
EltwiseIsNaN,
EltwiseMultiply,
EltwiseSubtract,
EltwiseDivide,

View File

@ -940,6 +940,18 @@ const std::map<const ngraph::DiscreteTypeInfo, Eltwise::Initializer> Eltwise::in
{ngraph::op::v1::NotEqual::get_type_info_static(), [](const std::shared_ptr<ngraph::Node>& op, Eltwise& node) {
node.algorithm = Algorithm::EltwiseNotEqual;
}},
{ov::op::v10::IsFinite::get_type_info_static(), [](const std::shared_ptr<ov::Node>& op, Eltwise& node) {
node.algorithm = Algorithm::EltwiseIsFinite;
}},
{ov::op::v10::IsInf::get_type_info_static(), [](const std::shared_ptr<ov::Node>& op, Eltwise& node) {
node.algorithm = Algorithm::EltwiseIsInf;
const auto& attributes = ov::as_type_ptr<ov::op::v10::IsInf>(op)->get_attributes();
node.alpha = attributes.detect_negative;
node.beta = attributes.detect_positive;
}},
{ov::op::v10::IsNaN::get_type_info_static(), [](const std::shared_ptr<ov::Node>& op, Eltwise& node) {
node.algorithm = Algorithm::EltwiseIsNaN;
}},
{ngraph::op::v1::Greater::get_type_info_static(), [](const std::shared_ptr<ngraph::Node>& op, Eltwise& node) {
node.algorithm = Algorithm::EltwiseGreater;
}},
@ -1562,6 +1574,12 @@ public:
case Algorithm::EltwisePrelu: *dst_ptr_f = src_f[0] > 0 ? src_f[0] : src_f[0] * src_f[1]; break;
case Algorithm::EltwiseErf: *dst_ptr_f = std::erf(src_f[0]); break;
case Algorithm::EltwiseSoftSign: *dst_ptr_f = src_f[0] / (1 + std::fabs(src_f[0])); break;
case Algorithm::EltwiseIsFinite: *dst_ptr_f = std::isfinite(src_f[0]); break;
case Algorithm::EltwiseIsInf:
*dst_ptr_f = _opData.alpha && (src_f[0] == -std::numeric_limits<float>::infinity()) ||
_opData.beta && (src_f[0] == std::numeric_limits<float>::infinity());
break;
case Algorithm::EltwiseIsNaN: *dst_ptr_f = std::isnan(src_f[0]); break;
default: IE_THROW() << "Unsupported operation type for Eltwise executor";
}
}
@ -1646,6 +1664,9 @@ Eltwise::Eltwise(const std::shared_ptr<ngraph::Node>& op, const dnnl::engine& en
size_t Eltwise::getOpInputsNum() const {
switch (getAlgorithm()) {
case Algorithm::EltwiseIsFinite:
case Algorithm::EltwiseIsInf:
case Algorithm::EltwiseIsNaN:
case Algorithm::EltwiseRelu:
case Algorithm::EltwiseGelu:
case Algorithm::EltwiseElu:
@ -1729,6 +1750,8 @@ void Eltwise::initSupportedPrimitiveDescriptors() {
// if dim rank is greater than the maximum possible, we should use the reference execution
canUseOptimizedImpl = mayiuse(x64::sse41) && getInputShapeAtPort(0).getRank() <= MAX_ELTWISE_DIM_RANK;
// 98206 to add JIT implementation.
canUseOptimizedImpl &= !one_of(getAlgorithm(), Algorithm::EltwiseIsFinite, Algorithm::EltwiseIsInf, Algorithm::EltwiseIsNaN);
if (!canUseOptimizedImpl && !fusedWith.empty()) {
IE_THROW(Unexpected) << "Eltwise node with name '" << getName() << "' uses reference impl, but unexpectedly fused with other ops";

View File

@ -55,4 +55,22 @@ const auto ComparisonTestParams = ::testing::Combine(
INSTANTIATE_TEST_SUITE_P(smoke_CompareWithRefs, ComparisonLayerTest, ComparisonTestParams, ComparisonLayerTest::getTestCaseName);
std::vector<ngraph::helpers::ComparisonTypes> comparisonOpTypesIs = {
ngraph::helpers::ComparisonTypes::IS_FINITE,
ngraph::helpers::ComparisonTypes::IS_INF,
ngraph::helpers::ComparisonTypes::IS_NAN
};
const auto ComparisonTestParamsIs = ::testing::Combine(
::testing::ValuesIn(CommonTestUtils::combineParams(inputShapes)),
::testing::Values(InferenceEngine::Precision::FP32),
::testing::ValuesIn(comparisonOpTypesIs),
::testing::Values(ngraph::helpers::InputLayerType::CONSTANT),
::testing::Values(InferenceEngine::Precision::UNSPECIFIED),
::testing::Values(InferenceEngine::Precision::UNSPECIFIED),
::testing::Values(CommonTestUtils::DEVICE_CPU),
::testing::Values(additional_config));
INSTANTIATE_TEST_SUITE_P(smoke_IsOp, ComparisonLayerTest, ComparisonTestParamsIs, ComparisonLayerTest::getTestCaseName);
} // namespace

View File

@ -0,0 +1,73 @@
// Copyright (C) 2018-2022 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include <vector>
#include "single_layer_tests/is_inf.hpp"
using namespace ov::test;
using namespace ov::test::subgraph;
namespace {
std::vector<std::vector<InputShape>> inShapesStatic = {
{ {{}, {{2}}} },
{ {{}, {{2, 200}}} },
{ {{}, {{10, 200}}} },
{ {{}, {{1, 10, 100}}} },
{ {{}, {{4, 4, 16}}} },
{ {{}, {{1, 1, 1, 3}}} },
{ {{}, {{2, 17, 5, 4}}} },
{ {{}, {{2, 17, 5, 1}}} },
{ {{}, {{1, 2, 4}}} },
{ {{}, {{1, 4, 4}}} },
{ {{}, {{1, 4, 4, 1}}} },
{ {{}, {{16, 16, 16, 16, 16}}} },
{ {{}, {{16, 16, 16, 16, 1}}} },
{ {{}, {{16, 16, 16, 1, 16}}} },
{ {{}, {{16, 32, 1, 1, 1}}} },
{ {{}, {{1, 1, 1, 1, 1, 1, 3}}} },
{ {{}, {{1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1}}} }
};
std::vector<std::vector<InputShape>> inShapesDynamic = {
{{{ngraph::Dimension(1, 10), 200}, {{2, 200}, {1, 200}}}}
};
std::vector<ElementType> netPrecisions = {
ov::element::f32
};
std::vector<bool> detectNegative = {
true, false
};
std::vector<bool> detectPositive = {
true, false
};
std::map<std::string, std::string> additional_config = {};
const auto isInfParams = ::testing::Combine(
::testing::ValuesIn(inShapesStatic),
::testing::ValuesIn(detectNegative),
::testing::ValuesIn(detectPositive),
::testing::ValuesIn(netPrecisions),
::testing::Values(CommonTestUtils::DEVICE_CPU),
::testing::Values(additional_config));
const auto isInfParamsDyn = ::testing::Combine(
::testing::ValuesIn(inShapesDynamic),
::testing::ValuesIn(detectNegative),
::testing::ValuesIn(detectPositive),
::testing::ValuesIn(netPrecisions),
::testing::Values(CommonTestUtils::DEVICE_CPU),
::testing::Values(additional_config));
TEST_P(IsInfLayerTest, CompareWithRefs) {
run();
}
INSTANTIATE_TEST_SUITE_P(smoke_static, IsInfLayerTest, isInfParams, IsInfLayerTest::getTestCaseName);
INSTANTIATE_TEST_SUITE_P(smoke_dynamic, IsInfLayerTest, isInfParamsDyn, IsInfLayerTest::getTestCaseName);
} // namespace

View File

@ -0,0 +1,37 @@
// Copyright (C) 2022 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#pragma once
#include <string>
#include <tuple>
#include "shared_test_classes/base/ov_subgraph.hpp"
namespace ov {
namespace test {
namespace subgraph {
using IsInfParams = std::tuple<
std::vector<InputShape>, // Data shape
bool, // Detect negative
bool, // Detect positive
ElementType, // Data precision
std::string, // Device name
std::map<std::string, std::string> // Additional config
>;
class IsInfLayerTest : public testing::WithParamInterface<IsInfParams>,
virtual public SubgraphBaseTest {
public:
static std::string getTestCaseName(const testing::TestParamInfo<IsInfParams>& obj);
protected:
void SetUp() override;
void generate_inputs(const std::vector<ov::Shape>& targetInputStaticShapes) override;
};
} // namespace subgraph
} // namespace test
} // namespace ov

View File

@ -0,0 +1,91 @@
// Copyright (C) 2018-2022 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include "single_layer_tests/is_inf.hpp"
#include "ngraph_functions/builders.hpp"
#include "ie_test_utils/common_test_utils/ov_tensor_utils.hpp"
#include "ie_plugin_config.hpp"
using namespace ov::test::subgraph;
std::string IsInfLayerTest::getTestCaseName(const testing::TestParamInfo<IsInfParams>& obj) {
std::vector<InputShape> inputShapes;
ElementType dataPrc;
bool detectNegative, detectPositive;
std::string targetName;
std::map<std::string, std::string> additionalConfig;
std::tie(inputShapes, detectNegative, detectPositive, dataPrc, targetName, additionalConfig) = obj.param;
std::ostringstream result;
result << "IS=(";
for (size_t i = 0lu; i < inputShapes.size(); i++) {
result << CommonTestUtils::partialShape2str({inputShapes[i].first}) << (i < inputShapes.size() - 1lu ? "_" : "");
}
result << ")_TS=";
for (size_t i = 0lu; i < inputShapes.front().second.size(); i++) {
result << "{";
for (size_t j = 0lu; j < inputShapes.size(); j++) {
result << CommonTestUtils::vec2str(inputShapes[j].second[i]) << (j < inputShapes.size() - 1lu ? "_" : "");
}
result << "}_";
}
result << ")_detectNegative=" << (detectNegative ? "True" : "False") << "_";
result << "detectPositive=" << (detectPositive ? "True" : "False") << "_";
result << "dataPrc=" << dataPrc << "_";
result << "trgDev=" << targetName;
if (!additionalConfig.empty()) {
result << "_PluginConf";
for (auto &item : additionalConfig) {
if (item.second == InferenceEngine::PluginConfigParams::YES)
result << "_" << item.first << "=" << item.second;
}
}
return result.str();
}
void IsInfLayerTest::SetUp() {
std::vector<InputShape> shapes;
ElementType dataPrc;
bool detectNegative, detectPositive;
std::string targetName;
std::map<std::string, std::string> additionalConfig;
std::tie(shapes, detectNegative, detectPositive, dataPrc, targetDevice, additionalConfig) = this->GetParam();
init_input_shapes(shapes);
configuration.insert(additionalConfig.begin(), additionalConfig.end());
auto parameters = ngraph::builder::makeDynamicParams(dataPrc, inputDynamicShapes);
parameters[0]->set_friendly_name("Data");
auto paramOuts = ngraph::helpers::convert2OutputVector(ngraph::helpers::castOps2Nodes<ov::op::v0::Parameter>(parameters));
ov::op::v10::IsInf::Attributes attributes {detectNegative, detectPositive};
auto isInf = std::make_shared<ov::op::v10::IsInf>(paramOuts[0], attributes);
ov::ResultVector results;
for (int i = 0; i < isInf->get_output_size(); i++) {
results.push_back(std::make_shared<ov::op::v0::Result>(isInf->output(i)));
}
function = std::make_shared<ov::Model>(results, parameters, "IsInf");
}
void IsInfLayerTest::generate_inputs(const std::vector<ov::Shape>& targetInputStaticShapes) {
inputs.clear();
const auto& funcInputs = function->inputs();
const auto& input = funcInputs[0];
int32_t range = std::accumulate(targetInputStaticShapes[0].begin(), targetInputStaticShapes[0].end(), 1u, std::multiplies<uint32_t>());
auto tensor = utils::create_and_fill_tensor(
input.get_element_type(), targetInputStaticShapes[0], range, -range / 2, 1);
auto pointer = tensor.data<element_type_traits<ov::element::Type_t::f32>::value_type>();
testing::internal::Random random(1);
for (size_t i = 0; i < range / 2; i++) {
pointer[random.Generate(range)] = i % 2 == 0 ? std::numeric_limits<float>::infinity() : -std::numeric_limits<float>::infinity();
}
inputs.insert({input.get_node_shared_ptr(), tensor});
}

View File

@ -32,10 +32,12 @@ typedef std::tuple<
class ComparisonLayerTest : public testing::WithParamInterface<ComparisonTestParams>,
virtual public LayerTestsUtils::LayerTestsCommon {
ngraph::helpers::ComparisonTypes comparisonOpType;
protected:
void SetUp() override;
public:
static std::string getTestCaseName(const testing::TestParamInfo<ComparisonTestParams> &obj);
InferenceEngine::Blob::Ptr GenerateInput(const InferenceEngine::InputInfo &inputInfo) const override;
};
} // namespace LayerTestsDefinitions

View File

@ -6,13 +6,14 @@
#include "shared_test_classes/single_layer/comparison.hpp"
using namespace LayerTestsDefinitions::ComparisonParams;
using namespace ngraph::helpers;
namespace LayerTestsDefinitions {
std::string ComparisonLayerTest::getTestCaseName(const testing::TestParamInfo<ComparisonTestParams> &obj) {
InputShapesTuple inputShapes;
InferenceEngine::Precision ngInputsPrecision;
ngraph::helpers::ComparisonTypes comparisonOpType;
ngraph::helpers::InputLayerType secondInputType;
ComparisonTypes comparisonOpType;
InputLayerType secondInputType;
InferenceEngine::Precision ieInPrecision;
InferenceEngine::Precision ieOutPrecision;
std::string targetName;
@ -45,8 +46,7 @@ std::string ComparisonLayerTest::getTestCaseName(const testing::TestParamInfo<Co
void ComparisonLayerTest::SetUp() {
InputShapesTuple inputShapes;
InferenceEngine::Precision ngInputsPrecision;
ngraph::helpers::ComparisonTypes comparisonOpType;
ngraph::helpers::InputLayerType secondInputType;
InputLayerType secondInputType;
InferenceEngine::Precision ieInPrecision;
InferenceEngine::Precision ieOutPrecision;
std::string targetName;
@ -69,11 +69,36 @@ void ComparisonLayerTest::SetUp() {
auto inputs = ngraph::builder::makeParams(ngInputsPrc, {inputShapes.first});
auto secondInput = ngraph::builder::makeInputLayer(ngInputsPrc, secondInputType, inputShapes.second);
if (secondInputType == ngraph::helpers::InputLayerType::PARAMETER) {
inputs.push_back(std::dynamic_pointer_cast<ngraph::opset3::Parameter>(secondInput));
if (secondInputType == InputLayerType::PARAMETER) {
inputs.push_back(std::dynamic_pointer_cast<ov::op::v0::Parameter>(secondInput));
}
auto comparisonNode = ngraph::builder::makeComparison(inputs[0], secondInput, comparisonOpType);
function = std::make_shared<ngraph::Function>(comparisonNode, inputs, "Comparison");
function = std::make_shared<ov::Model>(comparisonNode, inputs, "Comparison");
}
InferenceEngine::Blob::Ptr ComparisonLayerTest::GenerateInput(const InferenceEngine::InputInfo &inputInfo) const {
auto blob = LayerTestsUtils::LayerTestsCommon::GenerateInput(inputInfo);
if (comparisonOpType == ComparisonTypes::IS_FINITE || comparisonOpType == ComparisonTypes::IS_NAN) {
auto *dataPtr = blob->buffer().as<float*>();
auto range = blob->size();
testing::internal::Random random(1);
if (comparisonOpType == ComparisonTypes::IS_FINITE) {
for (size_t i = 0; i < range / 2; i++) {
dataPtr[random.Generate(range)] =
i % 3 == 0 ? std::numeric_limits<float>::infinity() : i % 3 == 1 ? -std::numeric_limits<float>::infinity() :
std::numeric_limits<double>::quiet_NaN();
}
} else {
for (size_t i = 0; i < range / 2; i++) {
dataPtr[random.Generate(range)] = std::numeric_limits<double>::quiet_NaN();
}
}
}
return blob;
}
} // namespace LayerTestsDefinitions

View File

@ -142,6 +142,9 @@ enum EltwiseTypes {
enum ComparisonTypes {
EQUAL,
NOT_EQUAL,
IS_FINITE,
IS_INF,
IS_NAN,
LESS,
LESS_EQUAL,
GREATER,

View File

@ -9,9 +9,9 @@
namespace ngraph {
namespace builder {
std::shared_ptr<ngraph::Node> makeComparison(const ngraph::Output<Node> &in0,
const ngraph::Output<Node> &in1,
ngraph::helpers::ComparisonTypes comparisonType) {
std::shared_ptr<ov::Node> makeComparison(const ov::Output<Node> &in0,
const ov::Output<Node> &in1,
ngraph::helpers::ComparisonTypes comparisonType) {
switch (comparisonType) {
case ngraph::helpers::ComparisonTypes::EQUAL:
return std::make_shared<ngraph::opset3::Equal>(in0, in1);
@ -21,6 +21,12 @@ std::shared_ptr<ngraph::Node> makeComparison(const ngraph::Output<Node> &in0,
return std::make_shared<ngraph::opset3::Greater>(in0, in1);
case ngraph::helpers::ComparisonTypes::GREATER_EQUAL:
return std::make_shared<ngraph::opset3::GreaterEqual>(in0, in1);
case ngraph::helpers::ComparisonTypes::IS_FINITE:
return std::make_shared<ov::op::v10::IsFinite>(in0);
case ngraph::helpers::ComparisonTypes::IS_INF:
return std::make_shared<ov::op::v10::IsInf>(in0);
case ngraph::helpers::ComparisonTypes::IS_NAN:
return std::make_shared<ov::op::v10::IsNaN>(in0);
case ngraph::helpers::ComparisonTypes::LESS:
return std::make_shared<ngraph::opset3::Less>(in0, in1);
case ngraph::helpers::ComparisonTypes::LESS_EQUAL:

View File

@ -722,6 +722,15 @@ std::ostream& operator<<(std::ostream & os, ngraph::helpers::ComparisonTypes typ
case ngraph::helpers::ComparisonTypes::GREATER_EQUAL:
os << "GreaterEqual";
break;
case ngraph::helpers::ComparisonTypes::IS_FINITE:
os << "IsFinite";
break;
case ngraph::helpers::ComparisonTypes::IS_INF:
os << "IsInf";
break;
case ngraph::helpers::ComparisonTypes::IS_NAN:
os << "IsNaN";
break;
case ngraph::helpers::ComparisonTypes::LESS:
os << "Less";
break;

View File

@ -13,7 +13,7 @@ os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
def mix_array_with_value(input_array, value):
input_shape = input_array
input_shape = input_array.shape
mask = np.random.randint(0, 2, input_shape).astype(bool)
return np.where(mask, input_array, value)

View File

@ -37,7 +37,6 @@ class TestIsFinite(CommonTFLayerTest):
]
@pytest.mark.parametrize("params", test_data_basic)
@pytest.mark.xfail(reason="94741")
@pytest.mark.precommit_tf_fe
def test_is_finite_basic(self, params, ie_device, precision, ir_version, temp_dir,
use_new_frontend, use_old_api):

View File

@ -35,7 +35,6 @@ class TestIsInf(CommonTFLayerTest):
]
@pytest.mark.parametrize("params", test_data_basic)
@pytest.mark.xfail(reason="94753")
@pytest.mark.precommit_tf_fe
def test_is_inf_basic(self, params, ie_device, precision, ir_version, temp_dir,
use_new_frontend, use_old_api):

View File

@ -37,7 +37,6 @@ class TestIsNan(CommonTFLayerTest):
]
@pytest.mark.parametrize("params", test_data_basic)
@pytest.mark.xfail(reason="94721")
@pytest.mark.precommit_tf_fe
def test_is_nan_basic(self, params, ie_device, precision, ir_version, temp_dir,
use_new_frontend, use_old_api):