[IE Tests] Aplly some comments for dynamic shapes (#7829)
* [IE Tests] Aplly some comments for dynamic shapes * fix tss * fix postfix * remove cout * Fix crash * fix ci
This commit is contained in:
parent
f762751968
commit
bd2b346c62
@ -40,9 +40,9 @@ public:
|
||||
std::vector<std::pair<ngraph::element::Type, std::vector<std::uint8_t>>> CalculateRefs() override {
|
||||
// Convert the second input constant precision to i64 to run the reference function
|
||||
if (ngraph::element::Type_t::i8 == secondConstantType) {
|
||||
ngraph::pass::ConvertPrecision<ngraph::element::Type_t::i8, ngraph::element::Type_t::i64>().run_on_function(function);
|
||||
ngraph::pass::ConvertPrecision<ngraph::element::Type_t::i8, ngraph::element::Type_t::i64>().run_on_function(functionRefs);
|
||||
} else if (ngraph::element::Type_t::bf16 == secondConstantType) {
|
||||
ngraph::pass::ConvertPrecision<ngraph::element::Type_t::bf16, ngraph::element::Type_t::i64>().run_on_function(function);
|
||||
ngraph::pass::ConvertPrecision<ngraph::element::Type_t::bf16, ngraph::element::Type_t::i64>().run_on_function(functionRefs);
|
||||
}
|
||||
return LayerTestsUtils::LayerTestsCommon::CalculateRefs();
|
||||
}
|
||||
|
@ -69,6 +69,8 @@ std::vector<std::string> disabledTestPatterns() {
|
||||
R"(.*Behavior.*(Multi|Auto).*InferRequestSetBlobByType.*Batched.*)",
|
||||
R"(.*(Multi|Auto).*Behavior.*InferRequestIOBBlobTest.*canProcessDeallocatedOutputBlobAfterGetAndSetBlob.*)",
|
||||
// TODO: until issue is xxx-59670 is resolved
|
||||
R"(.*Gather8LayerTest.*)"
|
||||
R"(.*Gather8LayerTest.*)",
|
||||
// TODO: Issue 66516
|
||||
R"(.*smoke_PrePostProcess_GPU.*convert_element_type_and_mean.*)"
|
||||
};
|
||||
}
|
@ -54,6 +54,7 @@ void DetectNetworkBatch::SetUp() {
|
||||
void DetectNetworkBatch::LoadNetwork() {
|
||||
cnnNetwork = InferenceEngine::CNNNetwork{function};
|
||||
cnnNetwork.setBatchSize(m_batchSize);
|
||||
functionRefs = ngraph::clone_function(*cnnNetwork.getFunction());
|
||||
ConfigureNetwork();
|
||||
executableNetwork = core->LoadNetwork(cnnNetwork, targetDevice, configuration);
|
||||
}
|
||||
@ -61,21 +62,18 @@ void DetectNetworkBatch::LoadNetwork() {
|
||||
TEST_P(DetectNetworkBatch, InferWithOneInput) {
|
||||
SKIP_IF_CURRENT_TEST_IS_DISABLED()
|
||||
function = ngraph::builder::subgraph::makeSplitConvConcat();
|
||||
functionRefs = ngraph::clone_function(*function);
|
||||
Run();
|
||||
};
|
||||
|
||||
TEST_P(DetectNetworkBatch, InferWithMultipleInputs_DiffDims) {
|
||||
SKIP_IF_CURRENT_TEST_IS_DISABLED()
|
||||
function = makeNNWithMultipleInputsDiffDims();
|
||||
functionRefs = ngraph::clone_function(*function);
|
||||
Run();
|
||||
};
|
||||
|
||||
TEST_P(DetectNetworkBatch, InferWithMultipleInputs_SameDims) {
|
||||
SKIP_IF_CURRENT_TEST_IS_DISABLED()
|
||||
function = makeNNWithMultipleInputsSameDims();
|
||||
functionRefs = ngraph::clone_function(*function);
|
||||
Run();
|
||||
};
|
||||
|
||||
|
@ -155,9 +155,11 @@ protected:
|
||||
float abs_threshold;
|
||||
InferenceEngine::CNNNetwork cnnNetwork;
|
||||
std::shared_ptr<InferenceEngine::Core> core;
|
||||
// dynamic input shapes
|
||||
std::vector<ngraph::PartialShape> inputDynamicShapes;
|
||||
// index for targetStaticShape
|
||||
size_t index = 0;
|
||||
// target static input shapes which is used for reshape ngraph function & generate input blobs
|
||||
std::vector<std::vector<ngraph::Shape>> targetStaticShapes;
|
||||
|
||||
virtual void Validate();
|
||||
|
@ -26,16 +26,16 @@ LayerTestsCommon::LayerTestsCommon() : threshold(1e-2f), abs_threshold(-1.f) {
|
||||
}
|
||||
void LayerTestsCommon::ResizeNgraphFunction() {
|
||||
auto params = function->get_parameters();
|
||||
std::map<std::string, ngraph::PartialShape> shapes;
|
||||
ASSERT_LE(params.size(), targetStaticShapes[index].size());
|
||||
for (size_t i = 0; i < params.size(); i++) {
|
||||
params[i]->set_partial_shape(targetStaticShapes[index][i]);
|
||||
shapes.insert({*params[i]->get_output_tensor(0).get_names().begin(), targetStaticShapes[index][i]});
|
||||
}
|
||||
functionRefs = ngraph::clone_function(*function);
|
||||
functionRefs->set_friendly_name("FunctionRefs");
|
||||
function->reshape(shapes);
|
||||
functionRefs->reshape(shapes);
|
||||
}
|
||||
|
||||
void LayerTestsCommon::Run() {
|
||||
//TODO: w/a: to identify gaps with functionRefs and init it
|
||||
if (functionRefs == nullptr) {
|
||||
functionRefs = ngraph::clone_function(*function);
|
||||
functionRefs->set_friendly_name("refFunction");
|
||||
@ -60,7 +60,8 @@ void LayerTestsCommon::Run() {
|
||||
|
||||
try {
|
||||
LoadNetwork();
|
||||
for (size_t i = 0; i < targetStaticShapes.size(); i++) {
|
||||
size_t i = 0;
|
||||
do {
|
||||
index = i;
|
||||
try {
|
||||
if (!inputDynamicShapes.empty()) {
|
||||
@ -72,9 +73,14 @@ void LayerTestsCommon::Run() {
|
||||
Validate();
|
||||
s.updateOPsStats(functionRefs, PassRate::Statuses::PASSED);
|
||||
} catch (const std::exception &ex) {
|
||||
THROW_IE_EXCEPTION << "Incorrect target static shape: " << CommonTestUtils::vec2str(targetStaticShapes[i]) << std::endl << ex.what();
|
||||
std::string errorMessage;
|
||||
if (!targetStaticShapes.empty()) {
|
||||
errorMessage = "Incorrect target static shape: " + CommonTestUtils::vec2str(targetStaticShapes[i]) + "\n";
|
||||
}
|
||||
errorMessage += ex.what();
|
||||
THROW_IE_EXCEPTION << ex.what();
|
||||
}
|
||||
}
|
||||
} while (++i < targetStaticShapes.size());
|
||||
}
|
||||
catch (const std::runtime_error &re) {
|
||||
s.updateOPsStats(functionRefs, PassRate::Statuses::FAILED);
|
||||
@ -362,7 +368,7 @@ void LayerTestsCommon::ConfigureNetwork() {
|
||||
ASSERT_EQ(params.size(), inputDynamicShapes.size());
|
||||
for (size_t i = 0; i < inputDynamicShapes.size(); i++) {
|
||||
ngraph::PartialShape dynamicShape = inputDynamicShapes[i];
|
||||
if (dynamicShape.rank() == 0) {
|
||||
if (dynamicShape.rank() == 0 && dynamicShape.is_static()) {
|
||||
continue;
|
||||
}
|
||||
std::string inputName = params[i]->get_friendly_name();
|
||||
|
@ -30,7 +30,6 @@ void PrePostProcessTest::SetUp() {
|
||||
std::tie(func, targetDevice) = GetParam();
|
||||
function = (std::get<0>(func))();
|
||||
threshold = std::get<2>(func);
|
||||
functionRefs = ngraph::clone_function(*function);
|
||||
}
|
||||
|
||||
TEST_P(PrePostProcessTest, CompareWithRefs) {
|
||||
|
Loading…
Reference in New Issue
Block a user