diff --git a/adapter-aws-opensearch/src/main/java/io/nosqlbench/adapter/opensearch/dispensers/KnnSearchOpDispenser.java b/adapter-aws-opensearch/src/main/java/io/nosqlbench/adapter/opensearch/dispensers/KnnSearchOpDispenser.java index 532539863..6e77ef557 100644 --- a/adapter-aws-opensearch/src/main/java/io/nosqlbench/adapter/opensearch/dispensers/KnnSearchOpDispenser.java +++ b/adapter-aws-opensearch/src/main/java/io/nosqlbench/adapter/opensearch/dispensers/KnnSearchOpDispenser.java @@ -20,12 +20,16 @@ import io.nosqlbench.adapter.opensearch.OpenSearchAdapter; import io.nosqlbench.adapter.opensearch.ops.KnnSearchOp; import io.nosqlbench.adapter.opensearch.pojos.Doc; import io.nosqlbench.adapters.api.templating.ParsedOp; +import org.opensearch.client.json.JsonData; import org.opensearch.client.opensearch.OpenSearchClient; +import org.opensearch.client.opensearch._types.FieldValue; import org.opensearch.client.opensearch._types.query_dsl.KnnQuery; import org.opensearch.client.opensearch._types.query_dsl.Query; import org.opensearch.client.opensearch.core.SearchRequest; import java.util.List; +import java.util.Map; +import java.util.Optional; import java.util.function.LongFunction; public class KnnSearchOpDispenser extends BaseOpenSearchOpDispenser { @@ -48,9 +52,12 @@ public class KnnSearchOpDispenser extends BaseOpenSearchOpDispenser { knnfunc = op.enhanceFuncOptionally(knnfunc, "vector", List.class, this::convertVector); knnfunc = op.enhanceFuncOptionally(knnfunc, "field",String.class, KnnQuery.Builder::field); - //TODO: Implement the filter query builder here - //knnfunc = op.enhanceFuncOptionally(knnfunc, "filter",Query.class, KnnQuery.Builder::filter); - + Optional> filterFunction = op.getAsOptionalFunction("filter", Map.class); + if (filterFunction.isPresent()) { + LongFunction finalFunc = knnfunc; + LongFunction builtFilter = buildFilterQuery(filterFunction.get()); + knnfunc = l -> finalFunc.apply(l).filter(builtFilter.apply(l)); + } LongFunction finalKnnfunc = knnfunc; LongFunction bfunc = l -> new SearchRequest.Builder().size(op.getStaticValueOr("size", 100)) @@ -60,6 +67,36 @@ public class KnnSearchOpDispenser extends BaseOpenSearchOpDispenser { return (long l) -> new KnnSearchOp(clientF.apply(l), bfunc.apply(l).build(), schemaClass); } + private LongFunction buildFilterQuery(LongFunction mapLongFunction) { + return l -> { + Map filterFields = mapLongFunction.apply(l); + String field = filterFields.get("field"); + String comparator = filterFields.get("comparator"); + String value = filterFields.get("value"); + return switch (comparator) { + case "gte" -> Query.of(f -> f + .bool(b -> b + .must(m -> m + .range(r -> r + .field(field) + .gte(JsonData.of(Integer.valueOf(value))))))); + case "lte" -> Query.of(f -> f + .bool(b -> b + .must(m -> m + .range(r -> r + .field(field) + .lte(JsonData.of(Integer.valueOf(value))))))); + case "eq" -> Query.of(f -> f + .bool(b -> b + .must(m -> m + .term(t -> t + .field(field) + .value(FieldValue.of(value)))))); + default -> throw new RuntimeException("Invalid comparator specified"); + }; + }; + } + private KnnQuery.Builder convertVector(KnnQuery.Builder builder, List list) { float[] vector = new float[list.size()]; for (int i = 0; i < list.size(); i++) { diff --git a/adapter-aws-opensearch/src/main/java/io/nosqlbench/adapter/opensearch/pojos/UserDefinedSchema.java b/adapter-aws-opensearch/src/main/java/io/nosqlbench/adapter/opensearch/pojos/UserDefinedSchema.java index 55d09e8bb..d97b1c4be 100644 --- a/adapter-aws-opensearch/src/main/java/io/nosqlbench/adapter/opensearch/pojos/UserDefinedSchema.java +++ b/adapter-aws-opensearch/src/main/java/io/nosqlbench/adapter/opensearch/pojos/UserDefinedSchema.java @@ -19,6 +19,15 @@ package io.nosqlbench.adapter.opensearch.pojos; public class UserDefinedSchema { private float[] vectorValues; private String recordKey; + private String type; + + public String getType() { + return type; + } + + public void setType(String type) { + this.type = type; + } public UserDefinedSchema() { } diff --git a/adapter-aws-opensearch/src/main/resources/activities/osvectors_advancedsearch.yaml b/adapter-aws-opensearch/src/main/resources/activities/osvectors_advancedsearch.yaml index f2be003a8..bca79b306 100644 --- a/adapter-aws-opensearch/src/main/resources/activities/osvectors_advancedsearch.yaml +++ b/adapter-aws-opensearch/src/main/resources/activities/osvectors_advancedsearch.yaml @@ -84,6 +84,10 @@ blocks: field: value schema: io.nosqlbench.adapter.opensearch.pojos.UserDefinedSchema size: 100 + filter: + field: "age" + comparator: "gt" + value: "30" search_and_verify: ops: select_ann_limit_TEMPLATE(k,100):