mirror of
https://github.com/nosqlbench/nosqlbench.git
synced 2024-12-24 16:00:09 -06:00
added support for doubles
This commit is contained in:
parent
2502a970b6
commit
f53870d58b
@ -0,0 +1,63 @@
|
||||
/*
|
||||
* Copyright (c) 2023 nosqlbench
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*
|
||||
*/
|
||||
|
||||
package io.nosqlbench.loader.hdf.embedding;
|
||||
|
||||
public class DoubleEmbeddingGenerator implements EmbeddingGenerator {
|
||||
|
||||
@Override
|
||||
public float[][] generateEmbeddingFrom(Object o, int[] dims) {
|
||||
return switch (dims.length) {
|
||||
case 1 -> new float[][]{convertToFloat((double[]) o)};
|
||||
case 2 -> convertToFloats((double[][]) o);
|
||||
case 3 -> flatten(o, dims);
|
||||
default -> throw new RuntimeException("unsupported embedding dimensionality: " + dims.length);
|
||||
};
|
||||
}
|
||||
|
||||
private float[][] convertToFloats(double[][] o) {
|
||||
float[][] floats = new float[o.length][];
|
||||
for (int i = 0; i < o.length; i++) {
|
||||
floats[i] = convertToFloat(o[i]);
|
||||
}
|
||||
return floats;
|
||||
}
|
||||
|
||||
public float[] convertToFloat(double[] doubleArray) {
|
||||
if (doubleArray == null) {
|
||||
return null;
|
||||
}
|
||||
float[] floatArray = new float[doubleArray.length];
|
||||
for (int i = 0; i < doubleArray.length; i++) {
|
||||
floatArray[i] = (float) doubleArray[i];
|
||||
}
|
||||
return floatArray;
|
||||
}
|
||||
|
||||
private float[][] flatten(Object o, int[] dims) {
|
||||
double[][][] arr = (double[][][]) o;
|
||||
float[][] flat = new float[dims[0]][dims[1] * dims[2]];
|
||||
for (int i = 0; i < dims[0]; i++) {
|
||||
for (int j = 0; j < dims[1]; j++) {
|
||||
for (int k = 0; k < dims[2]; k++) {
|
||||
flat[i][j * dims[2] + k] = (float)arr[i][j][k];
|
||||
}
|
||||
}
|
||||
}
|
||||
return flat;
|
||||
}
|
||||
}
|
@ -25,6 +25,7 @@ public class EmbeddingGeneratorFactory {
|
||||
|
||||
public static EmbeddingGenerator getGenerator(String type) {
|
||||
String typeLower = type.equalsIgnoreCase("short") ? "int" : type.toLowerCase();
|
||||
if (typeLower.equals("integer")) typeLower = "int";
|
||||
switch (typeLower) {
|
||||
case "string" -> {
|
||||
if (!generators.containsKey(type)) {
|
||||
@ -38,6 +39,12 @@ public class EmbeddingGeneratorFactory {
|
||||
}
|
||||
return generators.get(type);
|
||||
}
|
||||
case "double" -> {
|
||||
if (!generators.containsKey(type)) {
|
||||
generators.put(type, new DoubleEmbeddingGenerator());
|
||||
}
|
||||
return generators.get(type);
|
||||
}
|
||||
case "int" -> {
|
||||
if (!generators.containsKey(type)) {
|
||||
generators.put(type, new IntEmbeddingGenerator());
|
||||
|
@ -21,14 +21,12 @@ public class FloatEmbeddingGenerator implements EmbeddingGenerator {
|
||||
|
||||
@Override
|
||||
public float[][] generateEmbeddingFrom(Object o, int[] dims) {
|
||||
switch (dims.length) {
|
||||
case 1:
|
||||
return new float[][]{new float[]{(float) o}};
|
||||
case 2: return (float[][]) o;
|
||||
case 3: return flatten(o, dims);
|
||||
default:
|
||||
throw new RuntimeException("unsupported embedding dimensionality: " + dims.length);
|
||||
}
|
||||
return switch (dims.length) {
|
||||
case 1 -> new float[][]{(float[]) o};
|
||||
case 2 -> (float[][]) o;
|
||||
case 3 -> flatten(o, dims);
|
||||
default -> throw new RuntimeException("unsupported embedding dimensionality: " + dims.length);
|
||||
};
|
||||
}
|
||||
|
||||
private float[][] flatten(Object o, int[] dims) {
|
||||
|
@ -1,5 +1,5 @@
|
||||
format: HDF5
|
||||
sourceFile: /home/mwolters138/Documents/hdf5/datasets/pass/glove-25-angular.hdf5
|
||||
sourceFile: /home/mwolters138/Downloads/h5ex_t_float.h5 #/home/mwolters138/Documents/hdf5/datasets/pass/glove-25-angular.hdf5
|
||||
datasets:
|
||||
- all
|
||||
embedding: word2vec
|
||||
|
Loading…
Reference in New Issue
Block a user