This commit is contained in:
Jonathan Shook 2024-02-26 14:11:00 -06:00
parent 6116cb8d2e
commit b2e1905793
2 changed files with 14 additions and 8 deletions

View File

@ -36,8 +36,9 @@ public class LoadCqlVectorFromArray implements LongFunction<CqlVector> {
private final Function<Object, Object> nameFunc;
private final CqlVector[] defaultValue;
private final int len;
private final int batchsize;
public LoadCqlVectorFromArray(String name, int len) {
public LoadCqlVectorFromArray(String name, int len, int batchsize) {
this.name = name;
this.nameFunc = null;
Float[] ary = new Float[len];
@ -46,23 +47,28 @@ public class LoadCqlVectorFromArray implements LongFunction<CqlVector> {
}
this.defaultValue = new CqlVector[]{CqlVector.newInstance(ary)};
this.len = len;
this.batchsize = batchsize;
}
@Override
public CqlVector apply(long cycle) {
int offset = (int) (cycle % len);
int offset = (int) (cycle % batchsize);
HashMap<String, Object> map = SharedState.tl_ObjectMap.get();
String varname = (nameFunc != null) ? String.valueOf(nameFunc.apply(cycle)) : name;
Object object = map.getOrDefault(varname, defaultValue);
if (object.getClass().isArray()) {
object = Array.get(object,offset);
} else if (object instanceof double[][] dary) {
object = dary[offset];
} else if (object instanceof float[][] fary) {
object = fary[offset];
} else if (object instanceof Double[][] dary) {
object = dary[offset];
} else if (object instanceof Float[][] fary) {
object = fary[offset];
} else if (object instanceof CqlVector[] cary) {
} else if (object instanceof CqlVector<?>[] cary) {
object = cary[offset];
} else if (object instanceof List list) {
} else if (object instanceof List<?> list) {
object = list.get(offset);
} else {
throw new RuntimeException("Unrecognized type for ary of ary:" + object.getClass().getCanonicalName());

View File

@ -103,10 +103,10 @@ public class JsonElementUtils {
JsonElement element0 = dary.get(vector_idx);
JsonObject eobj1 = element0.getAsJsonObject();
JsonElement embedding = eobj1.get("embedding");
JsonArray ary = embedding.getAsJsonArray();
float[] newV = new float[ary.size()];
for (int component_idx = 0; component_idx < floats2dary.length; component_idx++) {
newV[component_idx]=ary.get(component_idx).getAsFloat();
JsonArray vectorAry = embedding.getAsJsonArray();
float[] newV = new float[vectorAry.size()];
for (int component_idx = 0; component_idx < vectorAry.size(); component_idx++) {
newV[component_idx]=vectorAry.get(component_idx).getAsFloat();
}
floats2dary[vector_idx]=newV;
}