211 lines
7.2 KiB
C++
211 lines
7.2 KiB
C++
#include "threadpool/atomic_list.h"
|
|
#include "common/UnitTest.h"
|
|
#include "common/Utilities.h"
|
|
#include <iostream>
|
|
#include <stdio.h>
|
|
#include <stdlib.h>
|
|
#include <string>
|
|
#include <vector>
|
|
#include <thread>
|
|
#include <chrono>
|
|
#include <functional>
|
|
#include <atomic>
|
|
#include <algorithm>
|
|
|
|
|
|
|
|
static void modify_list( AtomicList<int,1024>& list )
|
|
{
|
|
const int N_count = 50000;
|
|
for (int i=0; i<N_count; i++) {
|
|
auto v1 = list.remove_first( );
|
|
auto v2 = list.remove( [](int) { return true; } );
|
|
auto v3 = list.remove( [](int v) { return v>=(rand()/8); } );
|
|
auto v4 = list.remove( [](int v) { return v>=(rand()/4); } );
|
|
auto v5 = list.remove( [](int v) { return v>=(rand()/2); } );
|
|
if ( v1 !=-1 ) { list.insert( v1 ); }
|
|
if ( v2 !=-1 ) { list.insert( v2 ); }
|
|
if ( v3 !=-1 ) { list.insert( v3 ); }
|
|
if ( v4 !=-1 ) { list.insert( v4 ); }
|
|
if ( v5 !=-1 ) { list.insert( v5 ); }
|
|
}
|
|
}
|
|
|
|
|
|
static bool check_list( const std::vector<int>& x, AtomicList<int,1024>& list )
|
|
{
|
|
bool pass = list.check();
|
|
pass = pass && (int) x.size() == list.size();
|
|
if ( pass ) {
|
|
for (size_t i=0; i<x.size(); i++)
|
|
pass = pass && x[i] == list.remove( [](int) { return true; } );
|
|
}
|
|
// Restore the list
|
|
for (int i=0; i<list.size(); i++)
|
|
list.remove_first();
|
|
for (size_t i=0; i<x.size(); i++)
|
|
list.insert( x[i] );
|
|
return pass;
|
|
}
|
|
|
|
|
|
static inline void clear_list(AtomicList<int,1024>& list )
|
|
{
|
|
for (int i=0; i<list.size(); i++)
|
|
list.remove_first();
|
|
}
|
|
|
|
|
|
|
|
/******************************************************************
|
|
* The main program *
|
|
******************************************************************/
|
|
int main( int, char *[] )
|
|
{
|
|
UnitTest ut;
|
|
|
|
int N_threads = 8; // Number of threads
|
|
|
|
// Create the list
|
|
AtomicList<int,1024> list(-1);
|
|
if ( list.size()==0 && list.check() )
|
|
ut.passes( "Initialize" );
|
|
else
|
|
ut.failure( "Initialize" );
|
|
|
|
// Initialize the list with some empty values
|
|
for (int i=0; i<80; i++)
|
|
list.insert( rand() );
|
|
list.insert( 2 );
|
|
list.insert( 1 );
|
|
list.insert( rand() );
|
|
|
|
// Try to pull off a couple of values
|
|
int v1 = list.remove( [](int a) { return a==1; } ); // Find the entry with 1
|
|
int v2 = list.remove( [](int) { return true; } ); // Get the first entry
|
|
int v3 = list.remove( [](int) { return false; } ); // Fail to get an entry
|
|
if ( v1==1 && v2==2 && v3==-1 && list.size()==81 && list.check() )
|
|
ut.passes( "Basic sanity test" );
|
|
else
|
|
ut.failure( "Basic sanity test" );
|
|
|
|
// Clear the list
|
|
while ( list.remove( [](int) { return true; } ) != -1 ) {}
|
|
|
|
// Create a list of known values
|
|
//std::vector<int> data0(512);
|
|
std::vector<int> data0(5*N_threads);
|
|
for (size_t i=0; i<data0.size(); i++)
|
|
data0[i] = rand();
|
|
auto data = data0;
|
|
std::sort( data.begin(), data.end() );
|
|
|
|
// Test the cost to insert
|
|
int N_it = 20;
|
|
for (int i=0; i<list.size(); i++)
|
|
list.remove( [](int) { return true; } );
|
|
std::chrono::duration<double> time;
|
|
std::chrono::time_point<std::chrono::high_resolution_clock> start, stop;
|
|
time = time.zero();
|
|
for (int it=0; it<N_it; it++ ) {
|
|
clear_list( list );
|
|
start = std::chrono::high_resolution_clock::now();
|
|
for (size_t i=0; i<data0.size(); i++)
|
|
list.insert( data0[i] );
|
|
stop = std::chrono::high_resolution_clock::now();
|
|
time += ( stop - start );
|
|
}
|
|
printf("insert time/item = %0.0f ns\n",1e9*time.count()/(N_it*data0.size()));
|
|
|
|
// Test the cost to remove (first)
|
|
time = time.zero();
|
|
for (int it=0; it<N_it; it++ ) {
|
|
check_list( data, list );
|
|
start = std::chrono::high_resolution_clock::now();
|
|
for (size_t i=0; i<data0.size(); i++)
|
|
list.remove_first( );
|
|
stop = std::chrono::high_resolution_clock::now();
|
|
time += ( stop - start );
|
|
}
|
|
printf("remove (first) time/item = %0.0f ns\n",1e9*time.count()/(N_it*data0.size()));
|
|
|
|
// Test the cost to remove (in order)
|
|
time = time.zero();
|
|
for (int it=0; it<N_it; it++ ) {
|
|
check_list( data, list );
|
|
start = std::chrono::high_resolution_clock::now();
|
|
for (size_t i=0; i<data0.size(); i++)
|
|
list.remove( [](int) { return true; } );
|
|
stop = std::chrono::high_resolution_clock::now();
|
|
time += ( stop - start );
|
|
}
|
|
printf("remove (ordered) time/item = %0.0f ns\n",1e9*time.count()/(N_it*data0.size()));
|
|
|
|
// Test the cost to remove (out order)
|
|
time = time.zero();
|
|
for (int it=0; it<N_it; it++ ) {
|
|
check_list( data, list );
|
|
start = std::chrono::high_resolution_clock::now();
|
|
for (size_t i=0; i<data0.size(); i++) {
|
|
int tmp = data0[i];
|
|
list.remove( [tmp](int v) { return v==tmp; } );
|
|
}
|
|
stop = std::chrono::high_resolution_clock::now();
|
|
time += ( stop - start );
|
|
}
|
|
printf("remove (unordered) time/item = %0.0f ns\n",1e9*time.count()/(N_it*data0.size()));
|
|
|
|
// Read/write to the list and check the results
|
|
int64_t N0 = list.N_remove();
|
|
check_list( data, list );
|
|
start = std::chrono::high_resolution_clock::now();
|
|
modify_list( list );
|
|
stop = std::chrono::high_resolution_clock::now();
|
|
double time_serial = std::chrono::duration<double>(stop-start).count();
|
|
int64_t N1 = list.N_remove();
|
|
bool pass = check_list( data, list );
|
|
if ( pass )
|
|
ut.passes( "Serial get/insert" );
|
|
else
|
|
ut.failure( "Serial get/insert" );
|
|
printf("serial time = %0.5f s\n",time_serial);
|
|
printf("serial time/item = %0.0f ns\n",1e9*time_serial/(N1-N0));
|
|
|
|
// Have multiple threads reading/writing to the list simultaneously
|
|
std::vector<std::thread> threads( N_threads );
|
|
start = std::chrono::high_resolution_clock::now();
|
|
for ( int i = 0; i < N_threads; i++ )
|
|
threads[i] = std::thread( modify_list, std::ref(list) );
|
|
for ( int i = 0; i < N_threads; i++ )
|
|
threads[i].join();
|
|
stop = std::chrono::high_resolution_clock::now();
|
|
double time_parallel = std::chrono::duration<double>(stop-start).count();
|
|
int64_t N2 = list.N_remove();
|
|
pass = check_list( data, list );
|
|
if ( pass )
|
|
ut.passes( "Parallel get/insert" );
|
|
else
|
|
ut.failure( "Parallel get/insert" );
|
|
printf("parallel time = %0.5f s\n",time_parallel);
|
|
printf("parallel time/item = %0.0f ns\n",1e9*time_parallel/(N2-N1));
|
|
|
|
// Try to over-fill the list
|
|
while ( !list.empty() )
|
|
list.remove_first();
|
|
for (int i=1; i<=list.capacity(); i++)
|
|
list.insert( i );
|
|
try {
|
|
list.insert( list.capacity()+1 );
|
|
ut.failure( "List overflow" );
|
|
} catch (const std::exception& e) {
|
|
ut.passes( "List overflow" );
|
|
} catch(...) {
|
|
ut.failure( "List overflow (unknown exception)" );
|
|
}
|
|
|
|
// Finished
|
|
ut.report();
|
|
int N_errors = static_cast<int>( ut.NumFailGlobal() );
|
|
return N_errors;
|
|
}
|