[CONFORMANCE] Fix run of Conformance tests (#11225)

This commit is contained in:
Irina Efode 2022-03-28 12:29:27 +03:00 committed by GitHub
parent 10698abc29
commit 76e2f2697f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -849,7 +849,7 @@ std::shared_ptr<ov::Model> generate(const std::shared_ptr<ov::op::v8::RandomUnif
const auto params = ngraph::builder::makeDynamicParams(ov::element::i32, {{3}});
const auto min_value = ngraph::builder::makeConstant<float>(ov::element::f16, {}, {0.f});
const auto max_value = ngraph::builder::makeConstant<float>(ov::element::f16, {}, {1.f});
auto Node = std::make_shared<ov::op::v8::RandomUniform>(params.at(0), min_value, max_value, ov::element::f32, 10, 10);
auto Node = std::make_shared<ov::op::v8::RandomUniform>(params.at(0), min_value, max_value, ov::element::f16, 10, 10);
ov::ResultVector results{std::make_shared<ov::op::v0::Result>(Node)};
return std::make_shared<ov::Model>(results, params, "RandomUniformGraph");
}
@ -1603,7 +1603,7 @@ std::shared_ptr<ov::Model> generateRNNCellBase(const std::shared_ptr<ov::op::Op>
RNNCellBaseNode = std::make_shared<ov::op::v3::GRUCell>(params.at(0), params.at(1),
W, R, B, 3);
ov::ResultVector results{std::make_shared<ov::op::v0::Result>(RNNCellBaseNode)};
return std::make_shared<ov::Model>(results, params, "RNNCellBaseGraph");
return std::make_shared<ov::Model>(results, params, "GRUCell3BaseGraph");
} else if (ov::is_type<ov::op::v0::LSTMCell>(node)) {
const auto params = ngraph::builder::makeDynamicParams(ov::element::f16, {{2, 3}, {2, 3}, {2, 3}});
const auto W = ngraph::builder::makeConstant<float16>(ov::element::f16, {12, 3}, {}, true);
@ -1614,7 +1614,7 @@ std::shared_ptr<ov::Model> generateRNNCellBase(const std::shared_ptr<ov::op::Op>
W, R, B, P, 3);
ov::ResultVector results{std::make_shared<ov::op::v0::Result>(RNNCellBaseNode->output(0)),
std::make_shared<ov::op::v0::Result>(RNNCellBaseNode->output(1))};
return std::make_shared<ov::Model>(results, params, "RNNCellBaseGraph");
return std::make_shared<ov::Model>(results, params, "LSTMCell1BaseGraph");
} else if (ov::is_type<ov::op::v4::LSTMCell>(node)) {
const auto params = ngraph::builder::makeDynamicParams(ov::element::f16, {{2, 3}, {2, 3}, {2, 3}});
const auto W = ngraph::builder::makeConstant<float16>(ov::element::f16, {12, 3}, {}, true);
@ -1624,7 +1624,7 @@ std::shared_ptr<ov::Model> generateRNNCellBase(const std::shared_ptr<ov::op::Op>
W, R, B, 3);
ov::ResultVector results{std::make_shared<ov::op::v0::Result>(RNNCellBaseNode->output(0)),
std::make_shared<ov::op::v0::Result>(RNNCellBaseNode->output(1))};;
return std::make_shared<ov::Model>(results, params, "RNNCellBaseGraph");
return std::make_shared<ov::Model>(results, params, "LSTMCell4BaseGraph");
} else if (ov::is_type<ov::op::v5::LSTMSequence>(node)) {
const auto params = ngraph::builder::makeDynamicParams({ov::element::f16, ov::element::f16, ov::element::f16, ov::element::i64},
{{5, 10, 10}, {5, 1, 10}, {5, 1, 10}, {5}});
@ -1636,7 +1636,7 @@ std::shared_ptr<ov::Model> generateRNNCellBase(const std::shared_ptr<ov::op::Op>
ov::ResultVector results{std::make_shared<ov::op::v0::Result>(RNNCellBaseNode->output(0)),
std::make_shared<ov::op::v0::Result>(RNNCellBaseNode->output(1)),
std::make_shared<ov::op::v0::Result>(RNNCellBaseNode->output(2))};
return std::make_shared<ov::Model>(results, params, "RNNCellBaseGraph");
return std::make_shared<ov::Model>(results, params, "LSTMSeqBaseGraph");
} else if (ov::is_type<ov::op::v0::RNNCell>(node)) {
const auto params = ngraph::builder::makeDynamicParams(ov::element::f16, {{2, 3}, {2, 3}});
const auto W = ngraph::builder::makeConstant<float16>(ov::element::f16, {3, 3}, {}, true);
@ -1656,7 +1656,7 @@ std::shared_ptr<ov::Model> generateRNNCellBase(const std::shared_ptr<ov::op::Op>
W, R, B, 3, ov::op::RecurrentSequenceDirection::FORWARD);
ov::ResultVector results{std::make_shared<ov::op::v0::Result>(RNNCellBaseNode->output(0)),
std::make_shared<ov::op::v0::Result>(RNNCellBaseNode->output(1))};
return std::make_shared<ov::Model>(results, params, "RNNCellBaseGraph");
return std::make_shared<ov::Model>(results, params, "RNNSeqBaseGraph");
} else {
return nullptr;
}