adding filtering to os queries

This commit is contained in:
Mark Wolters 2024-02-13 11:30:07 -04:00
parent b93d85eae1
commit 2c24681c4d
3 changed files with 53 additions and 3 deletions

View File

@ -20,12 +20,16 @@ import io.nosqlbench.adapter.opensearch.OpenSearchAdapter;
import io.nosqlbench.adapter.opensearch.ops.KnnSearchOp;
import io.nosqlbench.adapter.opensearch.pojos.Doc;
import io.nosqlbench.adapters.api.templating.ParsedOp;
import org.opensearch.client.json.JsonData;
import org.opensearch.client.opensearch.OpenSearchClient;
import org.opensearch.client.opensearch._types.FieldValue;
import org.opensearch.client.opensearch._types.query_dsl.KnnQuery;
import org.opensearch.client.opensearch._types.query_dsl.Query;
import org.opensearch.client.opensearch.core.SearchRequest;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.function.LongFunction;
public class KnnSearchOpDispenser extends BaseOpenSearchOpDispenser {
@ -48,9 +52,12 @@ public class KnnSearchOpDispenser extends BaseOpenSearchOpDispenser {
knnfunc = op.enhanceFuncOptionally(knnfunc, "vector", List.class, this::convertVector);
knnfunc = op.enhanceFuncOptionally(knnfunc, "field",String.class, KnnQuery.Builder::field);
//TODO: Implement the filter query builder here
//knnfunc = op.enhanceFuncOptionally(knnfunc, "filter",Query.class, KnnQuery.Builder::filter);
Optional<LongFunction<Map>> filterFunction = op.getAsOptionalFunction("filter", Map.class);
if (filterFunction.isPresent()) {
LongFunction<KnnQuery.Builder> finalFunc = knnfunc;
LongFunction<Query> builtFilter = buildFilterQuery(filterFunction.get());
knnfunc = l -> finalFunc.apply(l).filter(builtFilter.apply(l));
}
LongFunction<KnnQuery.Builder> finalKnnfunc = knnfunc;
LongFunction<SearchRequest.Builder> bfunc =
l -> new SearchRequest.Builder().size(op.getStaticValueOr("size", 100))
@ -60,6 +67,36 @@ public class KnnSearchOpDispenser extends BaseOpenSearchOpDispenser {
return (long l) -> new KnnSearchOp(clientF.apply(l), bfunc.apply(l).build(), schemaClass);
}
private LongFunction<Query> buildFilterQuery(LongFunction<Map> mapLongFunction) {
return l -> {
Map<String,String> filterFields = mapLongFunction.apply(l);
String field = filterFields.get("field");
String comparator = filterFields.get("comparator");
String value = filterFields.get("value");
return switch (comparator) {
case "gte" -> Query.of(f -> f
.bool(b -> b
.must(m -> m
.range(r -> r
.field(field)
.gte(JsonData.of(Integer.valueOf(value)))))));
case "lte" -> Query.of(f -> f
.bool(b -> b
.must(m -> m
.range(r -> r
.field(field)
.lte(JsonData.of(Integer.valueOf(value)))))));
case "eq" -> Query.of(f -> f
.bool(b -> b
.must(m -> m
.term(t -> t
.field(field)
.value(FieldValue.of(value))))));
default -> throw new RuntimeException("Invalid comparator specified");
};
};
}
private KnnQuery.Builder convertVector(KnnQuery.Builder builder, List list) {
float[] vector = new float[list.size()];
for (int i = 0; i < list.size(); i++) {

View File

@ -19,6 +19,15 @@ package io.nosqlbench.adapter.opensearch.pojos;
public class UserDefinedSchema {
private float[] vectorValues;
private String recordKey;
private String type;
public String getType() {
return type;
}
public void setType(String type) {
this.type = type;
}
public UserDefinedSchema() {
}

View File

@ -84,6 +84,10 @@ blocks:
field: value
schema: io.nosqlbench.adapter.opensearch.pojos.UserDefinedSchema
size: 100
filter:
field: "age"
comparator: "gt"
value: "30"
search_and_verify:
ops:
select_ann_limit_TEMPLATE(k,100):