Merge pull request #1825 from nosqlbench/mwolters/osknnsearch_enhancements

Mwolters/osknnsearch enhancements
This commit is contained in:
Madhavan 2024-02-22 19:29:26 -05:00 committed by GitHub
commit 7d620ed136
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 240 additions and 6 deletions

View File

@ -20,18 +20,29 @@ import io.nosqlbench.adapter.opensearch.OpenSearchAdapter;
import io.nosqlbench.adapter.opensearch.ops.KnnSearchOp;
import io.nosqlbench.adapter.opensearch.pojos.Doc;
import io.nosqlbench.adapters.api.templating.ParsedOp;
import org.opensearch.client.json.JsonData;
import org.opensearch.client.opensearch.OpenSearchClient;
import org.opensearch.client.opensearch._types.FieldValue;
import org.opensearch.client.opensearch._types.query_dsl.KnnQuery;
import org.opensearch.client.opensearch._types.query_dsl.Query;
import org.opensearch.client.opensearch.core.SearchRequest;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.function.LongFunction;
public class KnnSearchOpDispenser extends BaseOpenSearchOpDispenser {
private Class<?> schemaClass;
public KnnSearchOpDispenser(OpenSearchAdapter adapter, ParsedOp op, LongFunction<String> targetF) {
super(adapter, op, targetF);
String schemaClassStr = op.getStaticConfigOr("schema", "io.nosqlbench.adapter.opensearch.pojos.Doc");
try {
schemaClass = Class.forName(schemaClassStr);
} catch (Exception e) {
throw new RuntimeException("Unable to load schema class: " + schemaClassStr, e);
}
}
@Override
@ -41,16 +52,49 @@ public class KnnSearchOpDispenser extends BaseOpenSearchOpDispenser {
knnfunc = op.enhanceFuncOptionally(knnfunc, "vector", List.class, this::convertVector);
knnfunc = op.enhanceFuncOptionally(knnfunc, "field",String.class, KnnQuery.Builder::field);
//TODO: Implement the filter query builder here
//knnfunc = op.enhanceFuncOptionally(knnfunc, "filter",Query.class, KnnQuery.Builder::filter);
Optional<LongFunction<Map>> filterFunction = op.getAsOptionalFunction("filter", Map.class);
if (filterFunction.isPresent()) {
LongFunction<KnnQuery.Builder> finalFunc = knnfunc;
LongFunction<Query> builtFilter = buildFilterQuery(filterFunction.get());
knnfunc = l -> finalFunc.apply(l).filter(builtFilter.apply(l));
}
LongFunction<KnnQuery.Builder> finalKnnfunc = knnfunc;
LongFunction<SearchRequest.Builder> bfunc =
l -> new SearchRequest.Builder().size(100)
l -> new SearchRequest.Builder().size(op.getStaticValueOr("size", 100))
.index(targetF.apply(l))
.query(new Query.Builder().knn(finalKnnfunc.apply(l).build()).build());
return (long l) -> new KnnSearchOp(clientF.apply(l), bfunc.apply(l).build(), Doc.class);
return (long l) -> new KnnSearchOp(clientF.apply(l), bfunc.apply(l).build(), schemaClass);
}
private LongFunction<Query> buildFilterQuery(LongFunction<Map> mapLongFunction) {
return l -> {
Map<String,String> filterFields = mapLongFunction.apply(l);
String field = filterFields.get("field");
String comparator = filterFields.get("comparator");
String value = filterFields.get("value");
return switch (comparator) {
case "gte" -> Query.of(f -> f
.bool(b -> b
.must(m -> m
.range(r -> r
.field(field)
.gte(JsonData.of(Integer.valueOf(value)))))));
case "lte" -> Query.of(f -> f
.bool(b -> b
.must(m -> m
.range(r -> r
.field(field)
.lte(JsonData.of(Integer.valueOf(value)))))));
case "eq" -> Query.of(f -> f
.bool(b -> b
.must(m -> m
.term(t -> t
.field(field)
.value(FieldValue.of(value))))));
default -> throw new RuntimeException("Invalid comparator specified");
};
};
}
private KnnQuery.Builder convertVector(KnnQuery.Builder builder, List list) {
@ -60,5 +104,4 @@ public class KnnSearchOpDispenser extends BaseOpenSearchOpDispenser {
}
return builder.vector(vector);
}
}

View File

@ -0,0 +1,61 @@
/*
* Copyright (c) 2024 nosqlbench
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.nosqlbench.adapter.opensearch.pojos;
public class UserDefinedSchema {
private float[] vectorValues;
private String recordKey;
private String type;
public String getType() {
return type;
}
public void setType(String type) {
this.type = type;
}
public UserDefinedSchema() {
}
public UserDefinedSchema(float[] value, String key, String type) {
this.vectorValues = value;
this.recordKey = key;
this.type = type;
}
public float[] getVectorValues() {
return vectorValues;
}
public void setVectorValues(float[] vectorValues) {
this.vectorValues = vectorValues;
}
@Override
public String toString() {
return "{" + "values=" + vectorValues + "}";
}
public String getRecordKey() {
return recordKey;
}
public void setRecordKey(String recordKey) {
this.recordKey = recordKey;
}
}

View File

@ -0,0 +1,130 @@
description: |
advanced options for knn search in open search
https://www.elastic.co/guide/en/elasticsearch/reference/current/rest-apis.html
template vars:
TEMPLATE(indexname,vectors_index)
TEMPLATE(dimensions,25)
TEMPLATE(search_cycles,1M)
TEMPLATE(rampup_cycles,TEMPLATE(trainsize))
TEMPLATE(size,100)
params:
driver: opensearch
instrument: true
scenarios:
vectors_brief:
bulkrampup: >-
run tags='block:bulkrampup' labels='target:opensearch'
threads=TEMPLATE(rampup_threads,10) cycles=TEMPLATE(trainsize)
errors=count,warn
# search: run tags='block:search' labels='target:opensearch' threads=TEMPLATE(search_threads,10) cycles=TEMPLATE(testsize)
#rampup: >-
# run tags='block:rampup' labels='target:opensearch'
# threads=TEMPLATE(rampup_threads,10) cycles=TEMPLATE(trainsize)
# errors=count,warn
## search: run tags='block:search' labels='target:opensearch' threads=TEMPLATE(search_threads,10) cycles=TEMPLATE(testsize)
search_and_verify: >-
run tags='block:search_and_verify' labels='target:opensearch'
threads=TEMPLATE(search_threads,10) cycles=TEMPLATE(testsize)
errors=count,warn
search_specify_schema: >-
run tags='block:search_specify_schema' labels='target:opensearch'
threads=TEMPLATE(search_threads,10) cycles=TEMPLATE(testsize)
errors=count,warn
vectors:
# drop: run tags='block:drop' labels='target:opensearch' threads===1 cycles===UNDEF
schema: run tags='block:schema' labels='target:opensearch' threads===1 cycles===UNDEF
rampup: >-
run tags='block:rampup' labels='target:opensearch'
threads=TEMPLATE(rampup_threads,10) cycles=TEMPLATE(trainsize)
errors=count,warn
# search: run tags='block:search' labels='target:opensearch' threads=TEMPLATE(search_threads,10) cycles=TEMPLATE(testsize)
search_and_verify: >-
run tags='block:search_and_verify' labels='target:opensearch'
threads=TEMPLATE(search_threads,10) cycles=TEMPLATE(testsize)
errors=count,warn
# errors=counter,warn,log
bindings:
id: ToString()
test_floatlist: HdfFileToFloatList("testdata/TEMPLATE(dataset).hdf5", "/test");
relevant_indices: HdfFileToIntArray("testdata/TEMPLATE(dataset).hdf5", "/neighbors")
distance_floatlist: HdfFileToFloatList("testdata/TEMPLATE(dataset).hdf5", "/distance")
train_floatlist: HdfFileToFloatList("testdata/TEMPLATE(dataset).hdf5", "/train");
blocks:
drop:
ops:
drop_index:
delete_index: TEMPLATE(indexname,vectors_index)
schema:
ops:
create_index:
create_index: TEMPLATE(indexname)
mappings:
m1: v1
search:
ops:
search:
knn_search: TEMPLATE(indexname,vectors_index)
k: 100
vector: "{test_floatlist}"
field: value
search_specify_schema:
ops:
search:
knn_search: TEMPLATE(indexname,vectors_index)
k: 100
vector: "{test_floatlist}"
field: value
schema: io.nosqlbench.adapter.opensearch.pojos.UserDefinedSchema
size: 100
filter:
field: "type"
comparator: "eq"
value: "experimental"
search_and_verify:
ops:
select_ann_limit_TEMPLATE(k,100):
knn_search: TEMPLATE(indexname,vectors_index)
k: 100
vector: "{test_floatlist}"
field: value
verifier-init: |
relevancy=new io.nosqlbench.nb.api.engine.metrics.wrappers.RelevancyMeasures(_parsed_op)
for (int k in List.of(100)) {
relevancy.addFunction(io.nosqlbench.engine.extensions.computefunctions.RelevancyFunctions.recall("recall",k));
relevancy.addFunction(io.nosqlbench.engine.extensions.computefunctions.RelevancyFunctions.precision("precision",k));
relevancy.addFunction(io.nosqlbench.engine.extensions.computefunctions.RelevancyFunctions.F1("F1",k));
relevancy.addFunction(io.nosqlbench.engine.extensions.computefunctions.RelevancyFunctions.reciprocal_rank("RR",k));
relevancy.addFunction(io.nosqlbench.engine.extensions.computefunctions.RelevancyFunctions.average_precision("AP",k));
}
verifier: |
// driver-specific function
actual_indices=io.nosqlbench.adapter.opensearch.Utils.DocHitsToIntIndicesArray(result)
// driver-agnostic function
relevancy.accept({relevant_indices},actual_indices);
return true;
bulkrampup:
ops:
bulk_index:
bulk: TEMPLATE(indexname)
op_template:
repeat: TEMPLATE(bulk_repeat,100)
index: TEMPLATE(indexname)
document:
key: "{id}"
value: "{train_floatlist}"
rampup:
ops:
index:
index: TEMPLATE(indexname)
document:
key: "{id}"
value: "{train_floatlist}"