additional work on DNN angualr1 methods

This commit is contained in:
Jonathan Shook 2023-12-20 17:16:42 -06:00
parent 274b5b119b
commit 53b3c780bc
2 changed files with 138 additions and 0 deletions

View File

@ -0,0 +1,55 @@
/*
* Copyright (c) 2023 nosqlbench
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.nosqlbench.virtdata.library.basics.shared.vectors.dnn;
import io.nosqlbench.virtdata.api.annotations.Categories;
import io.nosqlbench.virtdata.api.annotations.Category;
import io.nosqlbench.virtdata.api.annotations.ThreadSafeMapper;
import java.util.Arrays;
import java.util.function.LongFunction;
@ThreadSafeMapper
@Categories(Category.experimental)
public class DNN_angular1_v implements LongFunction<float[]> {
private final int d;
private final long n;
private final long m;
/**
* @param D
* Dimensions in each vector
* @param N
* The number of vectors in the training set
* @param M
* The modulo which is used to construct equivalence classes
*/
public DNN_angular1_v(int D, long N, long M) {
d = D;
n = N;
m = M;
}
@Override
public float[] apply(long i) {
float[] vector = new float[d];
Arrays.fill(vector, i + 1);
vector[vector.length - 1] = (i + 1) * (i % m);
return vector;
}
}

View File

@ -0,0 +1,83 @@
/*
* Copyright (c) 2023 nosqlbench
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.nosqlbench.virtdata.library.basics.shared.vectors.dnn;
import org.jetbrains.annotations.TestOnly;
import org.junit.jupiter.api.Test;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import static org.assertj.core.api.Assertions.assertThat;
import static org.junit.jupiter.api.Assertions.*;
class DNN_angular1_vTest {
@Test
public void testCosineSimilarity() {
assertThat(cosine_similarity(new float[]{1,2,3,4,5,6,7},new float[]{7,6,5,4,3,2,1})).isEqualTo(0.6);
assertThat(cosine_similarity(new float[]{1,2,3,4,5,6,7},new float[]{1,2,3,4,5,6,7})).isEqualTo(1.0);
}
@Test
public void testSimpleGeneration() {
DNN_angular1_v vs = new DNN_angular1_v(2,100,3);
assertThat(vs.apply(0)).isEqualTo(new float[]{1,0});
assertThat(vs.apply(1)).isEqualTo(new float[]{2,2});
assertThat(vs.apply(2)).isEqualTo(new float[]{3,6});
assertThat(vs.apply(3)).isEqualTo(new float[]{4,0});
assertThat(vs.apply(4)).isEqualTo(new float[]{5,5});
assertThat(vs.apply(5)).isEqualTo(new float[]{6,12});
assertThat(vs.apply(6)).isEqualTo(new float[]{7,0});
}
@Test
public void testBasicAngularVectors() {
DNN_angular1_v vf = new DNN_angular1_v(10, 100, 7);
float[][] vectors = new float[100][];
for (int i = 0; i < 100; i++) {
vectors[i] = vf.apply(i);
}
int[] same = new int[100];
Arrays.fill(same,-1);
for (int vidx = 0; vidx < same.length; vidx++) {
for (int compare_to = 0; compare_to < vidx; compare_to++) {
if (cosine_similarity(vectors[vidx],vectors[compare_to])==1.0) {
same[vidx]=compare_to;
}
}
}
for (int sameas = 0; sameas < same.length; sameas++) {
assertThat(same[sameas]==sameas%7);
}
}
private double cosine_similarity(float[] a, float[] b) {
double dp = 0.0d;
double as = 0.0d;
double bs = 0.0d;
for (int i = 0; i < a.length; i++) {
dp += (a[i] * b[i]);
as += (a[i] * a[i]);
bs += (b[i] * b[i]);
}
double similarity = dp / (Math.sqrt(as) * Math.sqrt(bs));
return similarity;
}
}