Remove old dynamism
This commit is contained in:
parent
fadf43ecb7
commit
1393c9c805
@ -155,12 +155,6 @@ protected:
|
||||
float abs_threshold;
|
||||
InferenceEngine::CNNNetwork cnnNetwork;
|
||||
std::shared_ptr<InferenceEngine::Core> core;
|
||||
// dynamic input shapes
|
||||
std::vector<ngraph::PartialShape> inputDynamicShapes;
|
||||
// index for targetStaticShape
|
||||
size_t index = 0;
|
||||
// target static input shapes which is used for reshape ngraph function & generate input blobs
|
||||
std::vector<std::vector<ngraph::Shape>> targetStaticShapes;
|
||||
|
||||
virtual void Validate();
|
||||
|
||||
@ -171,7 +165,6 @@ protected:
|
||||
InferenceEngine::InferRequest inferRequest;
|
||||
|
||||
private:
|
||||
void ResizeNgraphFunction();
|
||||
RefMode refMode = RefMode::INTERPRETER;
|
||||
};
|
||||
|
||||
|
@ -24,16 +24,6 @@ namespace LayerTestsUtils {
|
||||
LayerTestsCommon::LayerTestsCommon() : threshold(1e-2f), abs_threshold(-1.f) {
|
||||
core = PluginCache::get().ie(targetDevice);
|
||||
}
|
||||
void LayerTestsCommon::ResizeNgraphFunction() {
|
||||
auto params = function->get_parameters();
|
||||
std::map<std::string, ngraph::PartialShape> shapes;
|
||||
ASSERT_LE(params.size(), targetStaticShapes[index].size());
|
||||
for (size_t i = 0; i < params.size(); i++) {
|
||||
shapes.insert({*params[i]->get_output_tensor(0).get_names().begin(), targetStaticShapes[index][i]});
|
||||
}
|
||||
function->reshape(shapes);
|
||||
functionRefs->reshape(shapes);
|
||||
}
|
||||
|
||||
void LayerTestsCommon::Run() {
|
||||
if (functionRefs == nullptr) {
|
||||
@ -60,27 +50,10 @@ void LayerTestsCommon::Run() {
|
||||
|
||||
try {
|
||||
LoadNetwork();
|
||||
size_t i = 0;
|
||||
do {
|
||||
index = i;
|
||||
try {
|
||||
if (!inputDynamicShapes.empty()) {
|
||||
// resize ngraph function according new target shape
|
||||
ResizeNgraphFunction();
|
||||
}
|
||||
GenerateInputs();
|
||||
Infer();
|
||||
Validate();
|
||||
s.updateOPsStats(functionRefs, PassRate::Statuses::PASSED);
|
||||
} catch (const std::exception &ex) {
|
||||
std::string errorMessage;
|
||||
if (!targetStaticShapes.empty()) {
|
||||
errorMessage = "Incorrect target static shape: " + CommonTestUtils::vec2str(targetStaticShapes[i]) + "\n";
|
||||
}
|
||||
errorMessage += ex.what();
|
||||
THROW_IE_EXCEPTION << ex.what();
|
||||
}
|
||||
} while (++i < targetStaticShapes.size());
|
||||
}
|
||||
catch (const std::runtime_error &re) {
|
||||
s.updateOPsStats(functionRefs, PassRate::Statuses::FAILED);
|
||||
@ -364,22 +337,6 @@ void LayerTestsCommon::ConfigureNetwork() {
|
||||
out.second->setPrecision(outPrc);
|
||||
}
|
||||
}
|
||||
|
||||
// Reshape CNNNetwork before load to the plugin in dynamic scenario
|
||||
if (!inputDynamicShapes.empty()) {
|
||||
auto params = function->get_parameters();
|
||||
std::map<std::string, ngraph::PartialShape> inputDataMap;
|
||||
ASSERT_EQ(params.size(), inputDynamicShapes.size());
|
||||
for (size_t i = 0; i < inputDynamicShapes.size(); i++) {
|
||||
ngraph::PartialShape dynamicShape = inputDynamicShapes[i];
|
||||
if (dynamicShape.rank() == 0 && dynamicShape.is_static()) {
|
||||
continue;
|
||||
}
|
||||
std::string inputName = params[i]->get_friendly_name();
|
||||
inputDataMap.insert({inputName, dynamicShape});
|
||||
}
|
||||
cnnNetwork.reshape(inputDataMap);
|
||||
}
|
||||
}
|
||||
|
||||
void LayerTestsCommon::LoadNetwork() {
|
||||
@ -398,21 +355,7 @@ void LayerTestsCommon::GenerateInputs() {
|
||||
const auto infoIt = inputsInfo.find(param->get_friendly_name());
|
||||
GTEST_ASSERT_NE(infoIt, inputsInfo.cend());
|
||||
InferenceEngine::InputInfo::CPtr info = infoIt->second;
|
||||
InferenceEngine::Blob::Ptr blob = nullptr;
|
||||
if (!inputDynamicShapes.empty()) {
|
||||
if (inputDynamicShapes[i].rank() != 0) {
|
||||
InferenceEngine::DataPtr dataNew(
|
||||
new InferenceEngine::Data(infoIt->first, info->getTensorDesc().getPrecision(),
|
||||
targetStaticShapes[index][i],
|
||||
info->getTensorDesc().getLayout()));
|
||||
InferenceEngine::InputInfo infoNew;
|
||||
infoNew.setInputData(dataNew);
|
||||
blob = GenerateInput(infoNew);
|
||||
}
|
||||
}
|
||||
if (blob == nullptr) {
|
||||
blob = GenerateInput(*info);
|
||||
}
|
||||
InferenceEngine::Blob::Ptr blob = GenerateInput(*info);
|
||||
inputs.push_back(blob);
|
||||
}
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user