From b649a5ba8177e0f75fac5a5fdc618e3aa3b94c02 Mon Sep 17 00:00:00 2001 From: Jonathan Shook Date: Mon, 26 Feb 2024 14:11:00 -0600 Subject: [PATCH] fixes --- .../to_cqlvector/LoadCqlVectorFromArray.java | 14 ++++++++++---- .../nosqlbench/adapter/http/JsonElementUtils.java | 8 ++++---- 2 files changed, 14 insertions(+), 8 deletions(-) diff --git a/adapter-cqld4/src/main/java/io/nosqlbench/datamappers/functions/to_cqlvector/LoadCqlVectorFromArray.java b/adapter-cqld4/src/main/java/io/nosqlbench/datamappers/functions/to_cqlvector/LoadCqlVectorFromArray.java index cd8076c11..00a095cda 100644 --- a/adapter-cqld4/src/main/java/io/nosqlbench/datamappers/functions/to_cqlvector/LoadCqlVectorFromArray.java +++ b/adapter-cqld4/src/main/java/io/nosqlbench/datamappers/functions/to_cqlvector/LoadCqlVectorFromArray.java @@ -36,8 +36,9 @@ public class LoadCqlVectorFromArray implements LongFunction { private final Function nameFunc; private final CqlVector[] defaultValue; private final int len; + private final int batchsize; - public LoadCqlVectorFromArray(String name, int len) { + public LoadCqlVectorFromArray(String name, int len, int batchsize) { this.name = name; this.nameFunc = null; Float[] ary = new Float[len]; @@ -46,23 +47,28 @@ public class LoadCqlVectorFromArray implements LongFunction { } this.defaultValue = new CqlVector[]{CqlVector.newInstance(ary)}; this.len = len; + this.batchsize = batchsize; } @Override public CqlVector apply(long cycle) { - int offset = (int) (cycle % len); + int offset = (int) (cycle % batchsize); HashMap map = SharedState.tl_ObjectMap.get(); String varname = (nameFunc != null) ? String.valueOf(nameFunc.apply(cycle)) : name; Object object = map.getOrDefault(varname, defaultValue); if (object.getClass().isArray()) { object = Array.get(object,offset); + } else if (object instanceof double[][] dary) { + object = dary[offset]; + } else if (object instanceof float[][] fary) { + object = fary[offset]; } else if (object instanceof Double[][] dary) { object = dary[offset]; } else if (object instanceof Float[][] fary) { object = fary[offset]; - } else if (object instanceof CqlVector[] cary) { + } else if (object instanceof CqlVector[] cary) { object = cary[offset]; - } else if (object instanceof List list) { + } else if (object instanceof List list) { object = list.get(offset); } else { throw new RuntimeException("Unrecognized type for ary of ary:" + object.getClass().getCanonicalName()); diff --git a/adapter-http/src/main/java/io/nosqlbench/adapter/http/JsonElementUtils.java b/adapter-http/src/main/java/io/nosqlbench/adapter/http/JsonElementUtils.java index ae36af224..530812dd3 100644 --- a/adapter-http/src/main/java/io/nosqlbench/adapter/http/JsonElementUtils.java +++ b/adapter-http/src/main/java/io/nosqlbench/adapter/http/JsonElementUtils.java @@ -103,10 +103,10 @@ public class JsonElementUtils { JsonElement element0 = dary.get(vector_idx); JsonObject eobj1 = element0.getAsJsonObject(); JsonElement embedding = eobj1.get("embedding"); - JsonArray ary = embedding.getAsJsonArray(); - float[] newV = new float[ary.size()]; - for (int component_idx = 0; component_idx < floats2dary.length; component_idx++) { - newV[component_idx]=ary.get(component_idx).getAsFloat(); + JsonArray vectorAry = embedding.getAsJsonArray(); + float[] newV = new float[vectorAry.size()]; + for (int component_idx = 0; component_idx < vectorAry.size(); component_idx++) { + newV[component_idx]=vectorAry.get(component_idx).getAsFloat(); } floats2dary[vector_idx]=newV; }