working on writers

This commit is contained in:
Mark Wolters
2023-08-07 14:15:43 -04:00
parent a78ef4b2a3
commit 9af0abbd04
11 changed files with 188 additions and 11 deletions

View File

@@ -63,6 +63,13 @@
<artifactId>jhdf</artifactId>
<version>0.6.10</version>
</dependency>
<dependency>
<groupId>io.nosqlbench</groupId>
<artifactId>nb-api</artifactId>
<version>5.17.3-SNAPSHOT</version>
<scope>compile</scope>
</dependency>
</dependencies>
</project>

View File

@@ -0,0 +1,22 @@
/*
* Copyright (c) 2023 nosqlbench
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/
package io.nosqlbench.loader.hdf.embedding;
public interface EmbeddingGenerator {
public float[][] generateEmbeddingFrom(Object o);
}

View File

@@ -0,0 +1,43 @@
/*
* Copyright (c) 2023 nosqlbench
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/
package io.nosqlbench.loader.hdf.embedding;
import java.util.HashMap;
import java.util.Map;
public class EmbeddingGeneratorFactory {
private static final Map<String,EmbeddingGenerator> generators = new HashMap<>();
public static EmbeddingGenerator getGenerator(String type) {
switch (type) {
case "string" -> {
if (!generators.containsKey(type)) {
generators.put(type, new StringEmbeddingGenerator());
}
return generators.get(type);
}
case "float" -> {
if (!generators.containsKey(type)) {
generators.put(type, new FloatEmbeddingGenerator());
}
return generators.get(type);
}
default -> throw new RuntimeException("Unknown embedding type: " + type);
}
}
}

View File

@@ -0,0 +1,26 @@
/*
* Copyright (c) 2023 nosqlbench
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/
package io.nosqlbench.loader.hdf.embedding;
public class FloatEmbeddingGenerator implements EmbeddingGenerator {
@Override
public float[][] generateEmbeddingFrom(Object o) {
return (float[][]) o;
}
}

View File

@@ -0,0 +1,26 @@
/*
* Copyright (c) 2023 nosqlbench
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/
package io.nosqlbench.loader.hdf.embedding;
public class StringEmbeddingGenerator implements EmbeddingGenerator {
@Override
public float[][] generateEmbeddingFrom(Object o) {
return null;
}
}

View File

@@ -23,6 +23,8 @@ import io.jhdf.api.Group;
import io.jhdf.api.Node;
import io.jhdf.object.datatype.DataType;
import io.nosqlbench.loader.hdf.config.LoaderConfig;
import io.nosqlbench.loader.hdf.embedding.EmbeddingGenerator;
import io.nosqlbench.loader.hdf.embedding.EmbeddingGeneratorFactory;
import io.nosqlbench.loader.hdf.writers.VectorWriter;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
@@ -34,6 +36,8 @@ import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.LinkedBlockingQueue;
import static io.nosqlbench.loader.hdf.embedding.EmbeddingGeneratorFactory.*;
public class Hdf5Reader implements HdfReader {
private static final Logger logger = LogManager.getLogger(Hdf5Reader.class);
public static final String ALL = "all";
@@ -83,12 +87,24 @@ public class Hdf5Reader implements HdfReader {
logger.info("Processing dataset: " + ds);
Dataset dataset = hdfFile.getDatasetByPath(ds);
DataType dataType = dataset.getDataType();
long l = dataset.getSize();
int[] dims = dataset.getDimensions();
//queue.put(vector);
int[] dims = dataset.getDimensions();
Object data = dataset.getData();
String type = dataset.getJavaType().getSimpleName();
EmbeddingGenerator generator = getGenerator(dataset.getJavaType().getSimpleName());
float[][] vectors = generator.generateEmbeddingFrom(data);
for (int i = 0; i < dims[0]; i++) {
try {
queue.put(vectors[i]);
} catch (InterruptedException e) {
logger.error(e.getMessage(), e);
}
}
// });
}
hdfFile.close();
writer.shutdown();
}
}

View File

@@ -21,6 +21,7 @@ import java.util.concurrent.LinkedBlockingQueue;
public abstract class AbstractVectorWriter implements VectorWriter {
protected LinkedBlockingQueue<float[]> queue;
protected boolean shutdown = false;
public void setQueue(LinkedBlockingQueue<float[]> queue) {
this.queue = queue;
@@ -28,7 +29,7 @@ public abstract class AbstractVectorWriter implements VectorWriter {
@Override
public void run() {
while (true) {
while (!shutdown || !queue.isEmpty()) {
try {
float[] vector = queue.take();
if (vector.length==0) {

View File

@@ -17,14 +17,42 @@
package io.nosqlbench.loader.hdf.writers;
import com.datastax.oss.driver.api.core.CqlSession;
import com.datastax.oss.driver.api.core.cql.PreparedStatement;
import io.nosqlbench.loader.hdf.config.LoaderConfig;
import java.nio.file.Paths;
import java.util.Map;
public class AstraVectorWriter extends AbstractVectorWriter {
private CqlSession session;
PreparedStatement insert_vector;
public AstraVectorWriter(LoaderConfig config) {
Map<String,String> astraParams = config.getAstra();
session = CqlSession.builder()
.withCloudSecureConnectBundle(Paths.get(astraParams.get("scb")))
.withAuthCredentials(astraParams.get("clientId"), astraParams.get("clientSecret"))
.withKeyspace(astraParams.get("keyspace"))
.build();
insert_vector = session.prepare(astraParams.get("query"));
}
@Override
protected void writeVector(float[] vector) {
session.execute(insert_vector.bind(getPartitionValue(vector), vector));
}
private String getPartitionValue(float[] vector) {
float sum = 0;
for (float f : vector) {
sum += f;
}
return String.valueOf(sum);
}
@Override
public void shutdown() {
shutdown = true;
}
}

View File

@@ -47,4 +47,9 @@ public class FileVectorWriter extends AbstractVectorWriter {
logger.error(e.getMessage(), e);
}
}
@Override
public void shutdown() {
shutdown = true;
}
}

View File

@@ -21,4 +21,6 @@ import java.util.concurrent.LinkedBlockingQueue;
public interface VectorWriter extends Runnable {
void setQueue(LinkedBlockingQueue<float[]> queue);
void shutdown();
}

View File

@@ -1,12 +1,13 @@
format: HDF5
sourceFile: /home/mwolters138/Downloads/NEONDSImagingSpectrometerData.h5 #h5ex_t_arrayatt.h5
sourceFile: /home/mwolters138/Downloads/embeddings.h5 #NEONDSImagingSpectrometerData.h5 #h5ex_t_arrayatt.h5
datasets:
- all
embedding: word2vec
writer: filewriter
writer: astra #filewriter
astra:
database: test
keyspace: test
table: test
scb: /home/username/scb
file: /home/username/data
scb: /home/mwolters138/Dev/testing/secure-connect-vector-correctness.zip
clientId: IvpdaZejwNuvWeupsIkWTHeL
clientSecret: .bxut2-OQL,dWunZeQbjZC0vMHd88UWXKS.xT,nl95zQC0B0xU9FzSWK3HSUGO11o_7pr7wG7+EMaZqegkKlr4fZ54__furPMtWPGiPp,2cZ1q15vrWwc9_-AcgeCbuf
keyspace: baselines768dot
query: INSERT INTO vectors(key, value) VALUES (?,?)
targetFile: /home/mwolters138/vectors.txt