cleanup and added support for shorts as ints

This commit is contained in:
Mark Wolters
2023-08-09 10:29:58 -04:00
parent 0f3ec7f0f0
commit 1f53f03a50
7 changed files with 47 additions and 18 deletions

View File

@@ -78,6 +78,12 @@
<version>1.0.0-M2.1</version>
</dependency>
<dependency>
<groupId>org.deeplearning4j</groupId>
<artifactId>deeplearning4j-nlp</artifactId>
<version>1.0.0-M2.1</version>
</dependency>
<dependency>
<groupId>io.jhdf</groupId>
<artifactId>jhdf</artifactId>

View File

@@ -71,6 +71,5 @@ public class HdfLoader {
logger.error(e);
System.exit(1);
}
}
}

View File

@@ -24,7 +24,8 @@ public class EmbeddingGeneratorFactory {
private static final Map<String,EmbeddingGenerator> generators = new HashMap<>();
public static EmbeddingGenerator getGenerator(String type) {
switch (type.toLowerCase()) {
String typeLower = type.equalsIgnoreCase("short") ? "int" : type.toLowerCase();
switch (typeLower) {
case "string" -> {
if (!generators.containsKey(type)) {
generators.put(type, new StringEmbeddingGenerator());
@@ -37,9 +38,9 @@ public class EmbeddingGeneratorFactory {
}
return generators.get(type);
}
case "short" -> {
case "int" -> {
if (!generators.containsKey(type)) {
generators.put(type, new ShortEmbeddingGenerator());
generators.put(type, new IntEmbeddingGenerator());
}
return generators.get(type);
}

View File

@@ -17,14 +17,14 @@
package io.nosqlbench.loader.hdf.embedding;
public class ShortEmbeddingGenerator implements EmbeddingGenerator {
public class IntEmbeddingGenerator implements EmbeddingGenerator {
@Override
public float[][] generateEmbeddingFrom(Object o, int[] dims) {
switch (dims.length) {
case 1 -> {
float[] arr = new float[dims[0]];
for (int i = 0; i < dims[0]; i++) {
arr[i] = ((short[]) o)[i];
arr[i] = ((int[]) o)[i];
}
return new float[][]{arr};
}
@@ -32,7 +32,7 @@ public class ShortEmbeddingGenerator implements EmbeddingGenerator {
float[][] arr = new float[dims[0]][dims[1]];
for (int i = 0; i < dims[0]; i++) {
for (int j = 0; j < dims[1]; j++) {
arr[i][j] = ((short[][]) o)[i][j];
arr[i][j] = ((int[][]) o)[i][j];
}
}
return arr;
@@ -46,7 +46,7 @@ public class ShortEmbeddingGenerator implements EmbeddingGenerator {
}
private float[][] flatten(Object o, int[] dims) {
short[][][] arr = (short[][][]) o;
int[][][] arr = (int[][][]) o;
float[][] flat = new float[dims[0]][dims[1] * dims[2]];
for (int i = 0; i < dims[0]; i++) {
for (int j = 0; j < dims[1]; j++) {

View File

@@ -17,13 +17,41 @@
package io.nosqlbench.loader.hdf.embedding;
//import org.deeplearning4j.models.word2vec.Word2Vec;
//import org.deeplearning4j.text.tokenization.tokenizerfactory.DefaultTokenizerFactory;
import org.deeplearning4j.models.word2vec.Word2Vec;
import org.deeplearning4j.text.sentenceiterator.BasicLineIterator;
import org.deeplearning4j.text.sentenceiterator.CollectionSentenceIterator;
import org.deeplearning4j.text.sentenceiterator.SentenceIterator;
import org.deeplearning4j.text.tokenization.tokenizerfactory.DefaultTokenizerFactory;
import org.deeplearning4j.text.tokenization.tokenizerfactory.TokenizerFactory;
import java.util.Arrays;
public class StringEmbeddingGenerator implements EmbeddingGenerator {
private TokenizerFactory tokenizerFactory= new DefaultTokenizerFactory();
@Override
public float[][] generateEmbeddingFrom(Object o, int[] dims) {
return new float[][]{{0.0f, 1.0f},{1.0f, 0.0f}}; //TODO
switch (dims.length) {
case 1 -> {
return generateWordEmbeddings((String[]) o);
}
default -> throw new RuntimeException("unsupported embedding dimensionality: " + dims.length);
}
}
private float[][] generateWordEmbeddings(String[] text) {
SentenceIterator iter = new CollectionSentenceIterator(Arrays.asList(text));
/*Word2Vec vec = new Word2Vec.Builder()
.minWordFrequency(1)
.iterations(1)
.layerSize(targetDims)
.seed(42)
.windowSize(5)
.iterate(iter)
.tokenizerFactory(tokenizerFactory)
.build();
*/
return null;
}
}

View File

@@ -88,14 +88,9 @@ public class Hdf5Reader implements HdfReader {
Future<?> future = executorService.submit(() -> {
logger.info("Processing dataset: " + ds);
Dataset dataset = hdfFile.getDatasetByPath(ds);
DataType dataType = dataset.getDataType();
int[] dims = dataset.getDimensions();
Object data = dataset.getData();
String type = dataset.getJavaType().getSimpleName();
EmbeddingGenerator generator = getGenerator(dataset.getJavaType().getSimpleName());
float[][] vectors = generator.generateEmbeddingFrom(data, dims);
float[][] vectors = generator.generateEmbeddingFrom(dataset.getData(), dims);
for (int i = 0; i < dims[0]; i++) {
try {
queue.put(vectors[i]);

View File

@@ -1,5 +1,5 @@
format: HDF5
sourceFile: /home/mwolters138/Downloads/NEONDSImagingSpectrometerData.h5 #h5ex_t_arrayatt.h5
sourceFile: /home/mwolters138/Documents/hdf5/datasets/deep-image-96-angular.hdf5
datasets:
- all
embedding: word2vec