Files
LBPM/threadpool/test/test_atomic_list.cpp
2017-07-05 12:08:21 -04:00

211 lines
7.2 KiB
C++

#include "threadpool/atomic_list.h"
#include "common/UnitTest.h"
#include "common/Utilities.h"
#include <iostream>
#include <stdio.h>
#include <stdlib.h>
#include <string>
#include <vector>
#include <thread>
#include <chrono>
#include <functional>
#include <atomic>
#include <algorithm>
static void modify_list( AtomicList<int,1024>& list )
{
const int N_count = 50000;
for (int i=0; i<N_count; i++) {
auto v1 = list.remove_first( );
auto v2 = list.remove( [](int) { return true; } );
auto v3 = list.remove( [](int v) { return v>=(rand()/8); } );
auto v4 = list.remove( [](int v) { return v>=(rand()/4); } );
auto v5 = list.remove( [](int v) { return v>=(rand()/2); } );
if ( v1 !=-1 ) { list.insert( v1 ); }
if ( v2 !=-1 ) { list.insert( v2 ); }
if ( v3 !=-1 ) { list.insert( v3 ); }
if ( v4 !=-1 ) { list.insert( v4 ); }
if ( v5 !=-1 ) { list.insert( v5 ); }
}
}
static bool check_list( const std::vector<int>& x, AtomicList<int,1024>& list )
{
bool pass = list.check();
pass = pass && (int) x.size() == list.size();
if ( pass ) {
for (size_t i=0; i<x.size(); i++)
pass = pass && x[i] == list.remove( [](int) { return true; } );
}
// Restore the list
for (int i=0; i<list.size(); i++)
list.remove_first();
for (size_t i=0; i<x.size(); i++)
list.insert( x[i] );
return pass;
}
static inline void clear_list(AtomicList<int,1024>& list )
{
for (int i=0; i<list.size(); i++)
list.remove_first();
}
/******************************************************************
* The main program *
******************************************************************/
int main( int, char *[] )
{
UnitTest ut;
int N_threads = 8; // Number of threads
// Create the list
AtomicList<int,1024> list(-1);
if ( list.size()==0 && list.check() )
ut.passes( "Initialize" );
else
ut.failure( "Initialize" );
// Initialize the list with some empty values
for (int i=0; i<80; i++)
list.insert( rand() );
list.insert( 2 );
list.insert( 1 );
list.insert( rand() );
// Try to pull off a couple of values
int v1 = list.remove( [](int a) { return a==1; } ); // Find the entry with 1
int v2 = list.remove( [](int) { return true; } ); // Get the first entry
int v3 = list.remove( [](int) { return false; } ); // Fail to get an entry
if ( v1==1 && v2==2 && v3==-1 && list.size()==81 && list.check() )
ut.passes( "Basic sanity test" );
else
ut.failure( "Basic sanity test" );
// Clear the list
while ( list.remove( [](int) { return true; } ) != -1 ) {}
// Create a list of known values
//std::vector<int> data0(512);
std::vector<int> data0(5*N_threads);
for (size_t i=0; i<data0.size(); i++)
data0[i] = rand();
auto data = data0;
std::sort( data.begin(), data.end() );
// Test the cost to insert
int N_it = 20;
for (int i=0; i<list.size(); i++)
list.remove( [](int) { return true; } );
std::chrono::duration<double> time;
std::chrono::time_point<std::chrono::high_resolution_clock> start, stop;
time = time.zero();
for (int it=0; it<N_it; it++ ) {
clear_list( list );
start = std::chrono::high_resolution_clock::now();
for (size_t i=0; i<data0.size(); i++)
list.insert( data0[i] );
stop = std::chrono::high_resolution_clock::now();
time += ( stop - start );
}
printf("insert time/item = %0.0f ns\n",1e9*time.count()/(N_it*data0.size()));
// Test the cost to remove (first)
time = time.zero();
for (int it=0; it<N_it; it++ ) {
check_list( data, list );
start = std::chrono::high_resolution_clock::now();
for (size_t i=0; i<data0.size(); i++)
list.remove_first( );
stop = std::chrono::high_resolution_clock::now();
time += ( stop - start );
}
printf("remove (first) time/item = %0.0f ns\n",1e9*time.count()/(N_it*data0.size()));
// Test the cost to remove (in order)
time = time.zero();
for (int it=0; it<N_it; it++ ) {
check_list( data, list );
start = std::chrono::high_resolution_clock::now();
for (size_t i=0; i<data0.size(); i++)
list.remove( [](int) { return true; } );
stop = std::chrono::high_resolution_clock::now();
time += ( stop - start );
}
printf("remove (ordered) time/item = %0.0f ns\n",1e9*time.count()/(N_it*data0.size()));
// Test the cost to remove (out order)
time = time.zero();
for (int it=0; it<N_it; it++ ) {
check_list( data, list );
start = std::chrono::high_resolution_clock::now();
for (size_t i=0; i<data0.size(); i++) {
int tmp = data0[i];
list.remove( [tmp](int v) { return v==tmp; } );
}
stop = std::chrono::high_resolution_clock::now();
time += ( stop - start );
}
printf("remove (unordered) time/item = %0.0f ns\n",1e9*time.count()/(N_it*data0.size()));
// Read/write to the list and check the results
int64_t N0 = list.N_remove();
check_list( data, list );
start = std::chrono::high_resolution_clock::now();
modify_list( list );
stop = std::chrono::high_resolution_clock::now();
double time_serial = std::chrono::duration<double>(stop-start).count();
int64_t N1 = list.N_remove();
bool pass = check_list( data, list );
if ( pass )
ut.passes( "Serial get/insert" );
else
ut.failure( "Serial get/insert" );
printf("serial time = %0.5f s\n",time_serial);
printf("serial time/item = %0.0f ns\n",1e9*time_serial/(N1-N0));
// Have multiple threads reading/writing to the list simultaneously
std::vector<std::thread> threads( N_threads );
start = std::chrono::high_resolution_clock::now();
for ( int i = 0; i < N_threads; i++ )
threads[i] = std::thread( modify_list, std::ref(list) );
for ( int i = 0; i < N_threads; i++ )
threads[i].join();
stop = std::chrono::high_resolution_clock::now();
double time_parallel = std::chrono::duration<double>(stop-start).count();
int64_t N2 = list.N_remove();
pass = check_list( data, list );
if ( pass )
ut.passes( "Parallel get/insert" );
else
ut.failure( "Parallel get/insert" );
printf("parallel time = %0.5f s\n",time_parallel);
printf("parallel time/item = %0.0f ns\n",1e9*time_parallel/(N2-N1));
// Try to over-fill the list
while ( !list.empty() )
list.remove_first();
for (int i=1; i<=list.capacity(); i++)
list.insert( i );
try {
list.insert( list.capacity()+1 );
ut.failure( "List overflow" );
} catch (const std::exception& e) {
ut.passes( "List overflow" );
} catch(...) {
ut.failure( "List overflow (unknown exception)" );
}
// Finished
ut.report();
int N_errors = static_cast<int>( ut.NumFailGlobal() );
return N_errors;
}