[JAVA] Use JUnit 4 features to log Java tests log (#1953)

* [JAVA] Use JUnit 4 features to log Java tests log

* [JAVA] Add device parameter for Java tests
This commit is contained in:
Dmitry Kurtaev
2020-08-27 17:10:24 +03:00
committed by GitHub
parent d3682417bb
commit 8b2c12967d
9 changed files with 99 additions and 60 deletions

View File

@@ -29,7 +29,8 @@ if(ENABLE_TESTS)
${CMAKE_LIBRARY_OUTPUT_DIRECTORY}/hamcrest-core-1.3.jar)
file(GLOB_RECURSE java_tests_src tests/*.java)
add_jar(ie_java_api_tests_jar ${java_tests_src}
add_jar(ie_java_api_tests_jar
SOURCES ${java_tests_src} samples/ArgumentParser.java
OUTPUT_NAME ie_java_api_tests
OUTPUT_DIR ${CMAKE_LIBRARY_OUTPUT_DIRECTORY}
INCLUDE_JARS ${CMAKE_LIBRARY_OUTPUT_DIRECTORY}/*)

View File

@@ -1,10 +1,12 @@
import org.junit.Assert;
import org.junit.Test;
import org.intel.openvino.*;
public class BlobTests extends IETest {
public void testGetBlob(){
@Test
public void testGetBlob() {
int[] dimsArr = {1, 3, 200, 200};
TensorDesc tDesc = new TensorDesc(Precision.U8, dimsArr, Layout.NHWC);
@@ -13,7 +15,8 @@ public class BlobTests extends IETest {
Assert.assertArrayEquals(blob.getTensorDesc().getDims(), dimsArr);
}
public void testGetBlobFromFloat(){
@Test
public void testGetBlobFromFloat() {
int[] dimsArr = {1, 1, 2, 2};
TensorDesc tDesc = new TensorDesc(Precision.FP32, dimsArr, Layout.NHWC);

View File

@@ -1,4 +1,5 @@
import org.junit.Assert;
import static org.junit.Assert.*;
import org.junit.Test;
import java.util.ArrayList;
import java.util.HashMap;
@@ -9,6 +10,7 @@ import org.intel.openvino.*;
public class CNNNetworkTests extends IETest {
IECore core = new IECore();
@Test
public void testInputName() {
CNNNetwork net = core.ReadNetwork(modelXml);
Map<String, InputInfo> inputsInfo = net.getInputsInfo();
@@ -17,6 +19,7 @@ public class CNNNetworkTests extends IETest {
assertEquals("Input name", "data", inputName);
}
@Test
public void testReshape() {
CNNNetwork net = core.ReadNetwork(modelXml);
@@ -27,9 +30,10 @@ public class CNNNetworkTests extends IETest {
net.reshape(input);
Map<String, int[]> res = net.getInputShapes();
Assert.assertArrayEquals(input.get("data"), res.get("data"));
assertArrayEquals(input.get("data"), res.get("data"));
}
@Test
public void testAddOutput() {
CNNNetwork net = core.ReadNetwork(modelXml);
Map<String, Data> output = net.getOutputsInfo();

View File

@@ -1,58 +1,58 @@
import static org.junit.Assert.*;
import org.junit.Test;
import org.intel.openvino.*;
import java.util.Map;
import java.util.HashMap;
public class IECoreTests extends IETest {
IECore core;
String exceptionMessage;
IECore core = new IECore();
@Override
protected void setUp() {
core = new IECore();
exceptionMessage = "";
}
public void testInitIECore(){
assertTrue(core instanceof IECore);
}
public void testReadNetwork(){
@Test
public void testReadNetwork() {
CNNNetwork net = core.ReadNetwork(modelXml, modelBin);
assertEquals("Network name", "test_model", net.getName());
}
public void testReadNetworkXmlOnly(){
@Test
public void testReadNetworkXmlOnly() {
CNNNetwork net = core.ReadNetwork(modelXml);
assertEquals("Batch size", 1, net.getBatchSize());
}
public void testReadNetworkIncorrectXmlPath(){
try{
@Test
public void testReadNetworkIncorrectXmlPath() {
String exceptionMessage = "";
try {
CNNNetwork net = core.ReadNetwork("model.xml", modelBin);
} catch (Exception e){
} catch (Exception e) {
exceptionMessage = e.getMessage();
}
assertTrue(exceptionMessage.contains("Model file model.xml cannot be opened!"));
}
public void testReadNetworkIncorrectBinPath(){
try{
@Test
public void testReadNetworkIncorrectBinPath() {
String exceptionMessage = "";
try {
CNNNetwork net = core.ReadNetwork(modelXml, "model.bin");
} catch (Exception e){
} catch (Exception e) {
exceptionMessage = e.getMessage();
}
assertTrue(exceptionMessage.contains("Weights file model.bin cannot be opened!"));
}
public void testLoadNetwork(){
@Test
public void testLoadNetwork() {
CNNNetwork net = core.ReadNetwork(modelXml, modelBin);
ExecutableNetwork executableNetwork = core.LoadNetwork(net, device);
assertTrue(executableNetwork instanceof ExecutableNetwork);
}
public void testLoadNetworDeviceConfig(){
@Test
public void testLoadNetworDeviceConfig() {
CNNNetwork net = core.ReadNetwork(modelXml, modelBin);
Map<String, String> testMap = new HashMap<String, String>();
@@ -66,11 +66,13 @@ public class IECoreTests extends IETest {
assertTrue(executableNetwork instanceof ExecutableNetwork);
}
public void testLoadNetworkWrongDevice(){
@Test
public void testLoadNetworkWrongDevice() {
String exceptionMessage = "";
CNNNetwork net = core.ReadNetwork(modelXml, modelBin);
try{
try {
core.LoadNetwork(net, "DEVISE");
} catch (Exception e){
} catch (Exception e) {
exceptionMessage = e.getMessage();
}
assertTrue(exceptionMessage.contains("Device with \"DEVISE\" name is not registered in the InferenceEngine"));

View File

@@ -1,26 +1,42 @@
import junit.framework.TestCase;
import org.junit.Assert;
import org.junit.Before;
import org.junit.Rule;
import org.junit.Ignore;
import org.junit.runner.Description;
import org.junit.Rule;
import org.junit.rules.TestWatcher;
import java.nio.file.Paths;
import java.lang.Class;
import java.util.List;
import org.intel.openvino.*;
public class IETest extends TestCase {
@Ignore
public class IETest {
String modelXml;
String modelBin;
String device;
static String device;
public IETest(){
public IETest() {
try {
System.loadLibrary(IECore.NATIVE_LIBRARY_NAME);
} catch (UnsatisfiedLinkError e) {
System.err.println("Failed to load Inference Engine library\n" + e);
System.exit(1);
}
modelXml = Paths.get(System.getenv("MODELS_PATH"), "models", "test_model", "test_model_fp32.xml").toString();
modelBin = Paths.get(System.getenv("MODELS_PATH"), "models", "test_model", "test_model_fp32.bin").toString();
device = "CPU";
}
@Rule
public TestWatcher watchman = new TestWatcher() {
@Override
protected void succeeded(Description description) {
System.out.println(description + " - OK");
}
@Override
protected void failed(Throwable e, Description description) {
System.out.println(description + " - FAILED");
}
};
}

View File

@@ -1,3 +1,7 @@
import static org.junit.Assert.*;
import org.junit.Test;
import org.junit.Before;
import java.util.Map;
import java.util.Vector;
import java.util.ArrayList;
@@ -12,16 +16,17 @@ public class InferRequestTests extends IETest {
InferRequest inferRequest;
boolean completionCallback;
@Override
protected void setUp() {
@Before
public void setUp() {
core = new IECore();
net = core.ReadNetwork(modelXml);
executableNetwork = core.LoadNetwork(net, "CPU");
executableNetwork = core.LoadNetwork(net, device);
inferRequest = executableNetwork.CreateInferRequest();
completionCallback = false;
}
public void testGetPerformanceCounts(){
@Test
public void testGetPerformanceCounts() {
inferRequest.Infer();
Vector<String> layer_name = new Vector<>();
@@ -53,7 +58,7 @@ public class InferRequestTests extends IETest {
assertEquals("Map size", layer_name.size(), res.size());
ArrayList<String> resKeySet = new ArrayList<String>(res.keySet());
for (int i = 0; i < res.size(); i++){
for (int i = 0; i < res.size(); i++) {
String key = resKeySet.get(i);
InferenceEngineProfileInfo resVal = res.get(key);
@@ -64,6 +69,7 @@ public class InferRequestTests extends IETest {
}
}
@Test
public void testStartAsync() {
inferRequest.StartAsync();
StatusCode statusCode = inferRequest.Wait(WaitMode.RESULT_READY);
@@ -71,8 +77,9 @@ public class InferRequestTests extends IETest {
assertEquals("StartAsync", StatusCode.OK, statusCode);
}
@Test
public void testSetCompletionCallback() {
inferRequest.SetCompletionCallback(new Runnable(){
inferRequest.SetCompletionCallback(new Runnable() {
@Override
public void run() {

View File

@@ -1,19 +1,17 @@
import static org.junit.Assert.*;
import org.junit.Test;
import java.util.ArrayList;
import java.util.Map;
import org.intel.openvino.*;
public class InputInfoTests extends IETest {
IECore core;
CNNNetwork net;
@Override
protected void setUp() {
core = new IECore();
}
IECore core = new IECore();
@Test
public void testSetLayout() {
net = core.ReadNetwork(modelXml);
CNNNetwork net = core.ReadNetwork(modelXml);
Map<String, InputInfo> inputsInfo = net.getInputsInfo();
String inputName = new ArrayList<String>(inputsInfo.keySet()).get(0);
@@ -24,8 +22,9 @@ public class InputInfoTests extends IETest {
assertEquals("setLayout", Layout.NHWC, inputInfo.getLayout());
}
@Test
public void testSetPrecision() {
net = core.ReadNetwork(modelXml);
CNNNetwork net = core.ReadNetwork(modelXml);
Map<String, InputInfo> inputsInfo = net.getInputsInfo();
String inputName = new ArrayList<String>(inputsInfo.keySet()).get(0);

View File

@@ -3,13 +3,17 @@ import org.junit.runner.Result;
import org.junit.runner.notification.Failure;
public class OpenVinoTestRunner {
public static void main(String[] args) {
ArgumentParser parser = new ArgumentParser("");
parser.addArgument("-d", "device to test");
parser.parseArgs(args);
IETest.device = parser.get("-d", "CPU");
Result result = JUnitCore.runClasses(TestsSuite.class);
for (Failure failure : result.getFailures()) {
System.out.println(failure.toString());
}
System.out.println(result.wasSuccessful());
}
}

View File

@@ -32,11 +32,14 @@ public class TestsSuite extends IETest{
String dir = new File(TestsSuite.class.getProtectionDomain().getCodeSource().getLocation().toURI()).getPath().toString();
List<Class<?>> results = findClasses(dir);
results.forEach(result->suite.addTest(new junit.framework.JUnit4TestAdapter(result)));
for (Class<?> cl : results) {
if (cl.getName() == "ArgumentParser")
continue;
suite.addTest(new junit.framework.JUnit4TestAdapter(cl));
}
} catch (ClassNotFoundException e) {
System.out.println("ClassNotFoundException: " + e.getMessage());
} catch (URISyntaxException e){
} catch (URISyntaxException e) {
System.out.println("URISyntaxException: " + e.getMessage());
}
return suite;
@@ -53,9 +56,9 @@ public class TestsSuite extends IETest{
classes.add(Class.forName(name.substring(0, name.length() - ".class".length())));
}
}
} catch(FileNotFoundException e){
} catch(FileNotFoundException e) {
System.out.println("FileNotFoundException: " + e.getMessage());
} catch(IOException e){
} catch(IOException e) {
System.out.println("IOException: " + e.getMessage());
}
return classes;