mirror of
https://github.com/nosqlbench/nosqlbench.git
synced 2025-02-25 18:55:28 -06:00
cleanup and added support for shorts as ints
This commit is contained in:
@@ -78,6 +78,12 @@
|
||||
<version>1.0.0-M2.1</version>
|
||||
</dependency>
|
||||
|
||||
<dependency>
|
||||
<groupId>org.deeplearning4j</groupId>
|
||||
<artifactId>deeplearning4j-nlp</artifactId>
|
||||
<version>1.0.0-M2.1</version>
|
||||
</dependency>
|
||||
|
||||
<dependency>
|
||||
<groupId>io.jhdf</groupId>
|
||||
<artifactId>jhdf</artifactId>
|
||||
|
||||
@@ -71,6 +71,5 @@ public class HdfLoader {
|
||||
logger.error(e);
|
||||
System.exit(1);
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
@@ -24,7 +24,8 @@ public class EmbeddingGeneratorFactory {
|
||||
private static final Map<String,EmbeddingGenerator> generators = new HashMap<>();
|
||||
|
||||
public static EmbeddingGenerator getGenerator(String type) {
|
||||
switch (type.toLowerCase()) {
|
||||
String typeLower = type.equalsIgnoreCase("short") ? "int" : type.toLowerCase();
|
||||
switch (typeLower) {
|
||||
case "string" -> {
|
||||
if (!generators.containsKey(type)) {
|
||||
generators.put(type, new StringEmbeddingGenerator());
|
||||
@@ -37,9 +38,9 @@ public class EmbeddingGeneratorFactory {
|
||||
}
|
||||
return generators.get(type);
|
||||
}
|
||||
case "short" -> {
|
||||
case "int" -> {
|
||||
if (!generators.containsKey(type)) {
|
||||
generators.put(type, new ShortEmbeddingGenerator());
|
||||
generators.put(type, new IntEmbeddingGenerator());
|
||||
}
|
||||
return generators.get(type);
|
||||
}
|
||||
|
||||
@@ -17,14 +17,14 @@
|
||||
|
||||
package io.nosqlbench.loader.hdf.embedding;
|
||||
|
||||
public class ShortEmbeddingGenerator implements EmbeddingGenerator {
|
||||
public class IntEmbeddingGenerator implements EmbeddingGenerator {
|
||||
@Override
|
||||
public float[][] generateEmbeddingFrom(Object o, int[] dims) {
|
||||
switch (dims.length) {
|
||||
case 1 -> {
|
||||
float[] arr = new float[dims[0]];
|
||||
for (int i = 0; i < dims[0]; i++) {
|
||||
arr[i] = ((short[]) o)[i];
|
||||
arr[i] = ((int[]) o)[i];
|
||||
}
|
||||
return new float[][]{arr};
|
||||
}
|
||||
@@ -32,7 +32,7 @@ public class ShortEmbeddingGenerator implements EmbeddingGenerator {
|
||||
float[][] arr = new float[dims[0]][dims[1]];
|
||||
for (int i = 0; i < dims[0]; i++) {
|
||||
for (int j = 0; j < dims[1]; j++) {
|
||||
arr[i][j] = ((short[][]) o)[i][j];
|
||||
arr[i][j] = ((int[][]) o)[i][j];
|
||||
}
|
||||
}
|
||||
return arr;
|
||||
@@ -46,7 +46,7 @@ public class ShortEmbeddingGenerator implements EmbeddingGenerator {
|
||||
}
|
||||
|
||||
private float[][] flatten(Object o, int[] dims) {
|
||||
short[][][] arr = (short[][][]) o;
|
||||
int[][][] arr = (int[][][]) o;
|
||||
float[][] flat = new float[dims[0]][dims[1] * dims[2]];
|
||||
for (int i = 0; i < dims[0]; i++) {
|
||||
for (int j = 0; j < dims[1]; j++) {
|
||||
@@ -17,13 +17,41 @@
|
||||
|
||||
package io.nosqlbench.loader.hdf.embedding;
|
||||
|
||||
//import org.deeplearning4j.models.word2vec.Word2Vec;
|
||||
//import org.deeplearning4j.text.tokenization.tokenizerfactory.DefaultTokenizerFactory;
|
||||
import org.deeplearning4j.models.word2vec.Word2Vec;
|
||||
import org.deeplearning4j.text.sentenceiterator.BasicLineIterator;
|
||||
import org.deeplearning4j.text.sentenceiterator.CollectionSentenceIterator;
|
||||
import org.deeplearning4j.text.sentenceiterator.SentenceIterator;
|
||||
import org.deeplearning4j.text.tokenization.tokenizerfactory.DefaultTokenizerFactory;
|
||||
import org.deeplearning4j.text.tokenization.tokenizerfactory.TokenizerFactory;
|
||||
|
||||
import java.util.Arrays;
|
||||
|
||||
public class StringEmbeddingGenerator implements EmbeddingGenerator {
|
||||
private TokenizerFactory tokenizerFactory= new DefaultTokenizerFactory();
|
||||
|
||||
@Override
|
||||
public float[][] generateEmbeddingFrom(Object o, int[] dims) {
|
||||
return new float[][]{{0.0f, 1.0f},{1.0f, 0.0f}}; //TODO
|
||||
switch (dims.length) {
|
||||
case 1 -> {
|
||||
return generateWordEmbeddings((String[]) o);
|
||||
}
|
||||
default -> throw new RuntimeException("unsupported embedding dimensionality: " + dims.length);
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
private float[][] generateWordEmbeddings(String[] text) {
|
||||
SentenceIterator iter = new CollectionSentenceIterator(Arrays.asList(text));
|
||||
/*Word2Vec vec = new Word2Vec.Builder()
|
||||
.minWordFrequency(1)
|
||||
.iterations(1)
|
||||
.layerSize(targetDims)
|
||||
.seed(42)
|
||||
.windowSize(5)
|
||||
.iterate(iter)
|
||||
.tokenizerFactory(tokenizerFactory)
|
||||
.build();
|
||||
*/
|
||||
return null;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -88,14 +88,9 @@ public class Hdf5Reader implements HdfReader {
|
||||
Future<?> future = executorService.submit(() -> {
|
||||
logger.info("Processing dataset: " + ds);
|
||||
Dataset dataset = hdfFile.getDatasetByPath(ds);
|
||||
DataType dataType = dataset.getDataType();
|
||||
|
||||
int[] dims = dataset.getDimensions();
|
||||
Object data = dataset.getData();
|
||||
|
||||
String type = dataset.getJavaType().getSimpleName();
|
||||
EmbeddingGenerator generator = getGenerator(dataset.getJavaType().getSimpleName());
|
||||
float[][] vectors = generator.generateEmbeddingFrom(data, dims);
|
||||
float[][] vectors = generator.generateEmbeddingFrom(dataset.getData(), dims);
|
||||
for (int i = 0; i < dims[0]; i++) {
|
||||
try {
|
||||
queue.put(vectors[i]);
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
format: HDF5
|
||||
sourceFile: /home/mwolters138/Downloads/NEONDSImagingSpectrometerData.h5 #h5ex_t_arrayatt.h5
|
||||
sourceFile: /home/mwolters138/Documents/hdf5/datasets/deep-image-96-angular.hdf5
|
||||
datasets:
|
||||
- all
|
||||
embedding: word2vec
|
||||
|
||||
Reference in New Issue
Block a user