added support for doubles

This commit is contained in:
Mark Wolters 2023-08-15 12:08:07 -04:00
parent 2502a970b6
commit f53870d58b
4 changed files with 77 additions and 9 deletions

View File

@ -0,0 +1,63 @@
/*
* Copyright (c) 2023 nosqlbench
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/
package io.nosqlbench.loader.hdf.embedding;
public class DoubleEmbeddingGenerator implements EmbeddingGenerator {
@Override
public float[][] generateEmbeddingFrom(Object o, int[] dims) {
return switch (dims.length) {
case 1 -> new float[][]{convertToFloat((double[]) o)};
case 2 -> convertToFloats((double[][]) o);
case 3 -> flatten(o, dims);
default -> throw new RuntimeException("unsupported embedding dimensionality: " + dims.length);
};
}
private float[][] convertToFloats(double[][] o) {
float[][] floats = new float[o.length][];
for (int i = 0; i < o.length; i++) {
floats[i] = convertToFloat(o[i]);
}
return floats;
}
public float[] convertToFloat(double[] doubleArray) {
if (doubleArray == null) {
return null;
}
float[] floatArray = new float[doubleArray.length];
for (int i = 0; i < doubleArray.length; i++) {
floatArray[i] = (float) doubleArray[i];
}
return floatArray;
}
private float[][] flatten(Object o, int[] dims) {
double[][][] arr = (double[][][]) o;
float[][] flat = new float[dims[0]][dims[1] * dims[2]];
for (int i = 0; i < dims[0]; i++) {
for (int j = 0; j < dims[1]; j++) {
for (int k = 0; k < dims[2]; k++) {
flat[i][j * dims[2] + k] = (float)arr[i][j][k];
}
}
}
return flat;
}
}

View File

@ -25,6 +25,7 @@ public class EmbeddingGeneratorFactory {
public static EmbeddingGenerator getGenerator(String type) {
String typeLower = type.equalsIgnoreCase("short") ? "int" : type.toLowerCase();
if (typeLower.equals("integer")) typeLower = "int";
switch (typeLower) {
case "string" -> {
if (!generators.containsKey(type)) {
@ -38,6 +39,12 @@ public class EmbeddingGeneratorFactory {
}
return generators.get(type);
}
case "double" -> {
if (!generators.containsKey(type)) {
generators.put(type, new DoubleEmbeddingGenerator());
}
return generators.get(type);
}
case "int" -> {
if (!generators.containsKey(type)) {
generators.put(type, new IntEmbeddingGenerator());

View File

@ -21,14 +21,12 @@ public class FloatEmbeddingGenerator implements EmbeddingGenerator {
@Override
public float[][] generateEmbeddingFrom(Object o, int[] dims) {
switch (dims.length) {
case 1:
return new float[][]{new float[]{(float) o}};
case 2: return (float[][]) o;
case 3: return flatten(o, dims);
default:
throw new RuntimeException("unsupported embedding dimensionality: " + dims.length);
}
return switch (dims.length) {
case 1 -> new float[][]{(float[]) o};
case 2 -> (float[][]) o;
case 3 -> flatten(o, dims);
default -> throw new RuntimeException("unsupported embedding dimensionality: " + dims.length);
};
}
private float[][] flatten(Object o, int[] dims) {

View File

@ -1,5 +1,5 @@
format: HDF5
sourceFile: /home/mwolters138/Documents/hdf5/datasets/pass/glove-25-angular.hdf5
sourceFile: /home/mwolters138/Downloads/h5ex_t_float.h5 #/home/mwolters138/Documents/hdf5/datasets/pass/glove-25-angular.hdf5
datasets:
- all
embedding: word2vec