This commit is contained in:
Jonathan Shook 2024-02-26 14:11:00 -06:00
parent 74135bf05e
commit b649a5ba81
2 changed files with 14 additions and 8 deletions

View File

@ -36,8 +36,9 @@ public class LoadCqlVectorFromArray implements LongFunction<CqlVector> {
private final Function<Object, Object> nameFunc; private final Function<Object, Object> nameFunc;
private final CqlVector[] defaultValue; private final CqlVector[] defaultValue;
private final int len; private final int len;
private final int batchsize;
public LoadCqlVectorFromArray(String name, int len) { public LoadCqlVectorFromArray(String name, int len, int batchsize) {
this.name = name; this.name = name;
this.nameFunc = null; this.nameFunc = null;
Float[] ary = new Float[len]; Float[] ary = new Float[len];
@ -46,23 +47,28 @@ public class LoadCqlVectorFromArray implements LongFunction<CqlVector> {
} }
this.defaultValue = new CqlVector[]{CqlVector.newInstance(ary)}; this.defaultValue = new CqlVector[]{CqlVector.newInstance(ary)};
this.len = len; this.len = len;
this.batchsize = batchsize;
} }
@Override @Override
public CqlVector apply(long cycle) { public CqlVector apply(long cycle) {
int offset = (int) (cycle % len); int offset = (int) (cycle % batchsize);
HashMap<String, Object> map = SharedState.tl_ObjectMap.get(); HashMap<String, Object> map = SharedState.tl_ObjectMap.get();
String varname = (nameFunc != null) ? String.valueOf(nameFunc.apply(cycle)) : name; String varname = (nameFunc != null) ? String.valueOf(nameFunc.apply(cycle)) : name;
Object object = map.getOrDefault(varname, defaultValue); Object object = map.getOrDefault(varname, defaultValue);
if (object.getClass().isArray()) { if (object.getClass().isArray()) {
object = Array.get(object,offset); object = Array.get(object,offset);
} else if (object instanceof double[][] dary) {
object = dary[offset];
} else if (object instanceof float[][] fary) {
object = fary[offset];
} else if (object instanceof Double[][] dary) { } else if (object instanceof Double[][] dary) {
object = dary[offset]; object = dary[offset];
} else if (object instanceof Float[][] fary) { } else if (object instanceof Float[][] fary) {
object = fary[offset]; object = fary[offset];
} else if (object instanceof CqlVector[] cary) { } else if (object instanceof CqlVector<?>[] cary) {
object = cary[offset]; object = cary[offset];
} else if (object instanceof List list) { } else if (object instanceof List<?> list) {
object = list.get(offset); object = list.get(offset);
} else { } else {
throw new RuntimeException("Unrecognized type for ary of ary:" + object.getClass().getCanonicalName()); throw new RuntimeException("Unrecognized type for ary of ary:" + object.getClass().getCanonicalName());

View File

@ -103,10 +103,10 @@ public class JsonElementUtils {
JsonElement element0 = dary.get(vector_idx); JsonElement element0 = dary.get(vector_idx);
JsonObject eobj1 = element0.getAsJsonObject(); JsonObject eobj1 = element0.getAsJsonObject();
JsonElement embedding = eobj1.get("embedding"); JsonElement embedding = eobj1.get("embedding");
JsonArray ary = embedding.getAsJsonArray(); JsonArray vectorAry = embedding.getAsJsonArray();
float[] newV = new float[ary.size()]; float[] newV = new float[vectorAry.size()];
for (int component_idx = 0; component_idx < floats2dary.length; component_idx++) { for (int component_idx = 0; component_idx < vectorAry.size(); component_idx++) {
newV[component_idx]=ary.get(component_idx).getAsFloat(); newV[component_idx]=vectorAry.get(component_idx).getAsFloat();
} }
floats2dary[vector_idx]=newV; floats2dary[vector_idx]=newV;
} }