Merge pull request #1703 from nosqlbench/nosqlbench-1691-dnn

Nosqlbench 1691 dnn
This commit is contained in:
Jonathan Shook 2023-12-08 14:07:09 -06:00 committed by GitHub
commit 22c50b02cd
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 415 additions and 1 deletions

View File

@ -39,7 +39,7 @@ public class ToDouble implements LongToDoubleFunction {
private final LongToDoubleFunction func; private final LongToDoubleFunction func;
ToDouble(Object func) { public ToDouble(Object func) {
if (func instanceof Number number) { if (func instanceof Number number) {
final double aDouble = number.doubleValue(); final double aDouble = number.doubleValue();
this.func = l -> aDouble; this.func = l -> aDouble;

View File

@ -0,0 +1,104 @@
/*
* Copyright (c) 2023 nosqlbench
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.nosqlbench.virtdata.library.basics.shared.vectors.dnn;
import io.nosqlbench.virtdata.api.annotations.Categories;
import io.nosqlbench.virtdata.api.annotations.Category;
import io.nosqlbench.virtdata.api.annotations.ThreadSafeMapper;
import java.util.function.IntFunction;
/**
* Compute the indices of the neighbors of a given v using DNN mapping.
* To avoid ambiguity on equidistant neighbors, odd neighborhood sizes are preferred.
*/
@ThreadSafeMapper
@Categories(Category.experimental)
public class DNN_euclidean_neighbors implements IntFunction<int[]> {
private final int D;
private final int N;
private final int k;
/**
* @param k
* The size of neighborhood
* @param N
* The number of total vectors, necessary for boundary conditions of defined vector
* @param D
* Number of dimensions in each vector
*/
public DNN_euclidean_neighbors(int k, int N, int D) {
this.D = D;
this.N = N;
this.k = k;
}
/**
* <P>Compute neighbor indices with a (hopefully) fast implementation. There are surely some simplifications to be
* made in the functions below, but even in the current form it avoids a significant number of branches.</P>
*
* <P>This code is not as simple as it could be. It was built more for speed than simplicity since it will be a hot
* spot for testing. The unit tests for this are essential.</P>
*
* <P>The method is thus:
* <OL>
* <LI>Determine the sections of the neighborhood which aren't subject to boundary conditions,
* starting at the central vector (the index of the query vector).</LI>
* <LI>Layer these in rank order using closed-form index functions.</LI>
* <LI>Layer in any zero-boundary values which were deferred from above.</LI>
* <LI>Layer in an N-boundary values which were deferred from above.</LI>
* </OL>
* </P>
*
* <P>The boundary conditions for zero and N are mutually exclusive. Even though there is some amount of
* ranging and book keeping in this approach, it should make the general case more stable, especially
* when there are many dimensions and many neighbors.
* </P>
*
* @param value
* the function argument, or the index of the query vector for the DNN addressing scheme
* @return A ranked neighborhood of vector indices, using the DNN addressing scheme
*/
@Override
public int[] apply(int value) {
value = Math.min(Math.max(0,value),N-1);
int[] indices = new int[k];
int leftBoundary = (value << 1) + 1;
int rightBoundary = ((N - (value + 1)) << 1) + 1;
int insideNeighbors = Math.min(k, Math.min(leftBoundary, rightBoundary));
for (int i = 0; i < insideNeighbors; i++) {
// Leave this here as an explainer, please
// int sign = ((((i + 1) & 1) << 1) - 1); // this gives us -1 or +1 depending on odd or even, and is inverted
// int offset = ((i + 1)>>1); // half rounded down, shifted by 1
// offset *= sign;
// int v = value + (((((i + 1) & 1) << 1) - 1) * ((i + 1) >> 1));
indices[i] = value + (((((i + 1) & 1) << 1) - 1) * ((i + 1) >> 1));
}
int leftFill = Math.max(0, k - leftBoundary);
// TODO: Evaluate optimization from Dave2Wave for reducing additions
for (int i = 0; i < leftFill; i++) {
indices[insideNeighbors + i] = insideNeighbors + i;
}
int rightFill = Math.max(0, k - rightBoundary);
for (int i = 0; i < rightFill; i++) {
indices[insideNeighbors + i] = (N - 1) - (insideNeighbors + i);
}
return indices;
}
}

View File

@ -0,0 +1,54 @@
/*
* Copyright (c) 2023 nosqlbench
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.nosqlbench.virtdata.library.basics.shared.vectors.dnn;
import io.nosqlbench.virtdata.api.annotations.Categories;
import io.nosqlbench.virtdata.api.annotations.Category;
import io.nosqlbench.virtdata.api.annotations.ThreadSafeMapper;
import java.util.function.LongFunction;
@ThreadSafeMapper
@Categories(Category.experimental)
public class DNN_euclidean_v implements LongFunction<float[]> {
private final int D;
private final long N;
private final double scale;
public DNN_euclidean_v(int D, long N) {
this(D,N,1.0d);
}
public DNN_euclidean_v(int D, long N, double scale) {
this.D = D;
this.N = N;
this.scale = scale;
}
@Override
public float[] apply(long value) {
if (value>= N) {
throw new RuntimeException("You can't generate a vector for ordinal " + value + " when your population is " + this.N);
}
float[] vector = new float[D];
for (int idx = 0; idx < vector.length; idx++) {
vector[idx]= (float)(value+(idx*scale));
}
return vector;
}
}

View File

@ -0,0 +1,57 @@
/*
* Copyright (c) 2023 nosqlbench
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.nosqlbench.virtdata.library.basics.shared.vectors.dnn;
import io.nosqlbench.virtdata.api.annotations.Categories;
import io.nosqlbench.virtdata.api.annotations.Category;
import io.nosqlbench.virtdata.api.annotations.ThreadSafeMapper;
import java.util.Arrays;
import java.util.function.LongFunction;
@ThreadSafeMapper
@Categories(Category.experimental)
public class DNN_euclidean_v_series implements LongFunction<float[][]> {
private final int dimensions;
private final long population;
private final int k;
public DNN_euclidean_v_series(int dimensions, long population, int k) {
this.dimensions = dimensions;
this.population = population;
this.k = k;
}
@Override
public float[][] apply(long value) {
long nextInterval = value + k;
if (nextInterval > population) {
throw new RuntimeException("You can't generate a vector for ordinal " + value + " when your population is " + this.population);
}
int capacity = dimensions + k;
float[] image = new float[capacity];
for (int imgidx = 0; imgidx < capacity; imgidx++) {
image[imgidx]=imgidx+value;
}
float[][] vectorSeq = new float[k][dimensions];
for (int i = 0; i < vectorSeq.length; i++) {
vectorSeq[i]=Arrays.copyOfRange(image,i,i+dimensions);
}
return vectorSeq;
}
}

View File

@ -0,0 +1,52 @@
/*
* Copyright (c) 2023 nosqlbench
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.nosqlbench.virtdata.library.basics.shared.vectors.dnn;
import io.nosqlbench.virtdata.api.annotations.Categories;
import io.nosqlbench.virtdata.api.annotations.Category;
import io.nosqlbench.virtdata.api.annotations.ThreadSafeMapper;
import java.util.function.LongFunction;
@ThreadSafeMapper
@Categories(Category.experimental)
public class DNN_euclidean_v_wrap implements LongFunction<float[]> {
private final int D;
private final long N;
private final double scale;
public DNN_euclidean_v_wrap(int D, long N, double scale) {
this.D = D;
this.N = N;
this.scale = scale;
}
public DNN_euclidean_v_wrap(int D, long N) {
this(D,N,1.0d);
}
@Override
public float[] apply(long value) {
value = value % N;
float[] vector = new float[D];
for (int idx = 0; idx < vector.length; idx++) {
vector[idx]= (float)(value+(idx*scale));
}
return vector;
}
}

View File

@ -0,0 +1,20 @@
/*
* Copyright (c) 2023 nosqlbench
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/**
* This is an experimental package based on the DNN or "Das/Direct Nearest Neighbor" method.
*/
package io.nosqlbench.virtdata.library.basics.shared.vectors.dnn;

View File

@ -0,0 +1,64 @@
/*
* Copyright (c) 2023 nosqlbench
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.nosqlbench.virtdata.library.basics.shared.vectors.dnn;
import org.junit.jupiter.api.Test;
import static org.assertj.core.api.Assertions.assertThat;
class DNNEuclideanNeighborsTest {
@Test
public void test_DNN_K3_N7_D5() {
DNN_euclidean_neighbors idxF = new DNN_euclidean_neighbors(3, 7, 5);
assertThat(idxF.apply(0)).isEqualTo(new int[]{0,1,2});
assertThat(idxF.apply(1)).isEqualTo(new int[]{1,0,2});
assertThat(idxF.apply(2)).isEqualTo(new int[]{2,1,3});
assertThat(idxF.apply(3)).isEqualTo(new int[]{3,2,4});
assertThat(idxF.apply(4)).isEqualTo(new int[]{4,3,5});
assertThat(idxF.apply(5)).isEqualTo(new int[]{5,4,6});
assertThat(idxF.apply(6)).isEqualTo(new int[]{6,5,4});
}
@Test
public void test_DNN_k4_n7_d5() {
DNN_euclidean_neighbors idxF = new DNN_euclidean_neighbors(4, 7, 5);
assertThat(idxF.apply(0)).isEqualTo(new int[]{0,1,2,3});
assertThat(idxF.apply(1)).isEqualTo(new int[]{1,0,2,3});
assertThat(idxF.apply(2)).isEqualTo(new int[]{2,1,3,0});
assertThat(idxF.apply(3)).isEqualTo(new int[]{3,2,4,1});
assertThat(idxF.apply(4)).isEqualTo(new int[]{4,3,5,2});
assertThat(idxF.apply(5)).isEqualTo(new int[]{5,4,6,3});
assertThat(idxF.apply(6)).isEqualTo(new int[]{6,5,4,3});
}
@Test
public void test_DNN_k6_n100_d10() {
DNN_euclidean_neighbors idxF = new DNN_euclidean_neighbors(6, 100, 10);
assertThat(idxF.apply(99)).isEqualTo(new int[]{99,98,97,96,95,94});
}
@Test
public void test_DNN_K6_N101_D10() {
DNN_euclidean_neighbors idxF = new DNN_euclidean_neighbors(6, 101, 10);
assertThat(idxF.apply(101)).isEqualTo(new int[]{100,99,98,97,96,95});
assertThat(idxF.apply(100)).isEqualTo(new int[]{100,99,98,97,96,95});
assertThat(idxF.apply(99)).isEqualTo(new int[]{99,98,100,97,96,95});
assertThat(idxF.apply(98)).isEqualTo(new int[]{98,97,99,96,100,95});
}
}

View File

@ -0,0 +1,63 @@
/*
* Copyright (c) 2023 nosqlbench
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.nosqlbench.virtdata.library.basics.shared.vectors.dnn;
import org.junit.jupiter.api.Test;
import static org.assertj.core.api.Assertions.assertThat;
import static org.junit.jupiter.api.Assertions.*;
class DNNEuclideanVTest {
@Test
public void testBasicVectors() {
DNN_euclidean_v vf = new DNN_euclidean_v(5, 7);
assertThat(vf.apply(3L)).isEqualTo(new float[]{3f,4f,5f,6f,7f});
assertThrows(RuntimeException.class, () -> vf.apply(7));
}
@Test
public void testBasicVectorsScaled() {
DNN_euclidean_v vf = new DNN_euclidean_v(5, 7, 3.0);
assertThat(vf.apply(3L)).isEqualTo(new float[]{3f,6f,9f,12f,15f});
assertThrows(RuntimeException.class, () -> vf.apply(7));
}
@Test
public void testWrappingVectors() {
DNN_euclidean_v_wrap vf = new DNN_euclidean_v_wrap(5, 7);
assertThat(vf.apply(3L)).isEqualTo(new float[]{3f,4f,5f,6f,7f});
assertThat(vf.apply(0L)).isEqualTo(new float[]{0f,1f,2f,3f,4f});
assertThat(vf.apply(7L)).isEqualTo(new float[]{0f,1f,2f,3f,4f});
}
@Test
public void testContiguousVectors() {
DNN_euclidean_v_series vf = new DNN_euclidean_v_series(4,10,2);
assertThat(vf.apply(7L)).isEqualTo(
new float[][] {
{7f,8f,9f,10f},
{8f,9f,10f,11f}
}
);
assertThrows(RuntimeException.class, () -> vf.apply(10));
}
}