2 #include "../warp_mutex.h"
5 #include <cub/util_ptx.cuh>
20 template<
typename Allocator,
25 using Lock_t =
typename std::conditional<TYPE == OPERATION_TYPE::FIND, warp_shared_lock<T>,
warp_unique_lock<T>>::type;
27 template<
typename K,
typename V,
typename Fn>
28 LSLAB_DEVICE
void operator()(
warp_mutex* lock_table,
33 const unsigned bucket,
36 const unsigned laneId = threadIdx.x % 32;
39 unsigned work_queue = __ballot_sync(~0u, thread_mask);
41 unsigned last_work_queue = 0;
43 Lock_t<warp_mutex> lock;
45 while (work_queue != 0) {
47 next = (work_queue != last_work_queue) ?
nullptr : next;
48 unsigned src_lane = __ffs(work_queue) - 1;
50 K src_key = cub::ShuffleIndex<32>(key, src_lane, ~0);
52 unsigned src_bucket = __shfl_sync(~0u, bucket, src_lane);
54 bool found_empty =
false;
57 if(work_queue != last_work_queue) {
58 next = &buckets[src_bucket];
59 lock = Lock_t<warp_mutex>(lock_table[src_bucket]);
70 read_key = next->key[laneId];
71 unsigned valid = next->valid;
72 valid = (valid >> laneId) & 0x1;
74 found = read_key == src_key;
75 }
else if(TYPE == OPERATION_TYPE::INSERT && !found_empty) {
77 empty_location = next;
80 next_ptr = next->next;
83 auto masked_ballot = __ballot_sync(~0u, found);
85 if (masked_ballot != 0) {
86 unsigned found_lane = __ffs(masked_ballot) - 1;
87 if(laneId == src_lane) {
88 if(TYPE == OPERATION_TYPE::REMOVE) {
89 next->valid ^= (1 << found_lane);
91 fn(next->value[found_lane]);
95 next_ptr =
reinterpret_cast<slab_node<K, V>*
>(__shfl_sync(~0,
reinterpret_cast<unsigned long long>(next_ptr), 31));
96 if (next_ptr ==
nullptr) {
98 if(TYPE == OPERATION_TYPE::INSERT) {
102 masked_ballot = __ballot_sync(~0u, found_empty);
103 if(masked_ballot != 0) {
104 unsigned found_lane = __ffs(masked_ballot) - 1;
105 auto loc =
reinterpret_cast<unsigned long long>(empty_location);
106 next =
reinterpret_cast<slab_node<K, V>*
>(__shfl_sync(~0, loc, found_lane));
107 if(laneId == src_lane) {
108 next->key[found_lane] = key;
109 next->valid |= (1 << found_lane);
110 fn(next->value[found_lane]);
116 next_ptr = alloc.allocate(1);
118 next->next = next_ptr;
121 next =
reinterpret_cast<slab_node<K, V>*
>(__shfl_sync(~0,
reinterpret_cast<unsigned long long>(next), 31));
127 if ((TYPE == OPERATION_TYPE::FIND || TYPE == OPERATION_TYPE::UPDATE || TYPE == OPERATION_TYPE::REMOVE) && laneId == src_lane) {
136 last_work_queue = work_queue;
138 work_queue = __ballot_sync(~0u, thread_mask);
140 if(work_queue != last_work_queue){
142 lock = Lock_t<warp_mutex>();
148 LSLAB_DEVICE
void operator()(
warp_mutex* lock_table,
153 const unsigned bucket,
156 const unsigned laneId = threadIdx.x % 32;
159 unsigned work_queue = __ballot_sync(~0u, thread_mask);
161 unsigned last_work_queue = 0;
163 Lock_t<warp_mutex> lock;
165 while (work_queue != 0) {
167 next = (work_queue != last_work_queue) ?
nullptr : next;
168 unsigned src_lane = __ffs(work_queue) - 1;
170 K src_key = cub::ShuffleIndex<32>(key, src_lane, ~0);
172 unsigned src_bucket = __shfl_sync(~0u, bucket, src_lane);
174 bool found_empty =
false;
177 if(work_queue != last_work_queue) {
178 next = &buckets[src_bucket];
179 lock = Lock_t<warp_mutex>(lock_table[src_bucket]);
189 read_key = next->key[laneId];
190 unsigned valid = next->valid;
191 valid = (valid >> laneId) & 0x1;
193 found = read_key == src_key;
194 }
else if(TYPE == OPERATION_TYPE::INSERT && !found_empty) {
196 empty_location = next;
199 next_ptr = next->next;
202 auto masked_ballot = __ballot_sync(~0u, found);
204 if (masked_ballot != 0) {
205 unsigned found_lane = __ffs(masked_ballot) - 1;
206 if(laneId == src_lane) {
207 if(TYPE == OPERATION_TYPE::REMOVE) {
208 next->valid ^= (1 << found_lane);
214 next_ptr =
reinterpret_cast<set_node<K>*
>(__shfl_sync(~0,
reinterpret_cast<unsigned long long>(next_ptr), 31));
215 if (next_ptr ==
nullptr) {
217 if(TYPE == OPERATION_TYPE::INSERT) {
221 masked_ballot = __ballot_sync(~0u, found_empty);
222 if(masked_ballot != 0) {
223 unsigned found_lane = __ffs(masked_ballot) - 1;
224 auto loc =
reinterpret_cast<unsigned long long>(empty_location);
225 next =
reinterpret_cast<set_node<K>*
>(__shfl_sync(~0, loc, found_lane));
226 if(laneId == src_lane) {
227 next->key[found_lane] = key;
228 next->valid |= (1 << found_lane);
235 next_ptr = alloc.allocate(1);
236 next_ptr =
new (
static_cast<void*
>(next_ptr))
set_node<K>();
237 next->next = next_ptr;
240 next =
reinterpret_cast<set_node<K>*
>(__shfl_sync(~0,
reinterpret_cast<unsigned long long>(next), 31));
246 if ((TYPE == OPERATION_TYPE::FIND || TYPE == OPERATION_TYPE::REMOVE) && laneId == src_lane) {
256 last_work_queue = work_queue;
258 work_queue = __ballot_sync(~0u, thread_mask);
260 if(work_queue != last_work_queue){
261 lock = Lock_t<warp_mutex>();
Definition: traverse.h:22
Definition: warp_mutex.h:13
Definition: warp_mutex.h:62