LSLab
traverse.h
1 #include "../lslab.h"
2 #include "../warp_mutex.h"
3 #include "slab_node.h"
4 #include <cuda.h>
5 #include <cub/util_ptx.cuh>
6 
7 #pragma once
8 
9 namespace lslab {
10 
11 namespace detail {
12 
13 enum OPERATION_TYPE {
14  INSERT,
15  UPDATE,
16  FIND,
17  REMOVE
18 };
19 
20 template<typename Allocator,
21  int TYPE>
22 struct traverse {
23 
24  template<typename T>
25  using Lock_t = typename std::conditional<TYPE == OPERATION_TYPE::FIND, warp_shared_lock<T>, warp_unique_lock<T>>::type;
26 
27  template<typename K, typename V, typename Fn>
28  LSLAB_DEVICE void operator()(warp_mutex* lock_table,
29  slab_node<K, V>* buckets,
30  const K& key,
31  Fn&& fn,
32  Allocator& alloc,
33  const unsigned bucket,
34  bool thread_mask) {
35 
36  const unsigned laneId = threadIdx.x % 32;
37  slab_node<K, V>* next = nullptr;
38 
39  unsigned work_queue = __ballot_sync(~0u, thread_mask);
40 
41  unsigned last_work_queue = 0;
42 
43  Lock_t<warp_mutex> lock;
44 
45  while (work_queue != 0) {
46 
47  next = (work_queue != last_work_queue) ? nullptr : next;
48  unsigned src_lane = __ffs(work_queue) - 1;
49 
50  K src_key = cub::ShuffleIndex<32>(key, src_lane, ~0);
51 
52  unsigned src_bucket = __shfl_sync(~0u, bucket, src_lane);
53 
54  bool found_empty = false;
55  slab_node<K, V>* empty_location = nullptr;
56 
57  if(work_queue != last_work_queue) {
58  next = &buckets[src_bucket];
59  lock = Lock_t<warp_mutex>(lock_table[src_bucket]);
60  //__threadfence_system();
61  }
62 
63  slab_node<K, V>* next_ptr = nullptr;
64  K read_key;
65 
66  bool found = false;
67 
68 
69  if(laneId < 31) {
70  read_key = next->key[laneId];
71  unsigned valid = next->valid;
72  valid = (valid >> laneId) & 0x1;
73  if(valid) {
74  found = read_key == src_key;
75  } else if(TYPE == OPERATION_TYPE::INSERT && !found_empty) {
76  found_empty = true;
77  empty_location = next;
78  }
79  } else {
80  next_ptr = next->next;
81  }
82 
83  auto masked_ballot = __ballot_sync(~0u, found); //& 0x7fffffffu;
84 
85  if (masked_ballot != 0) {
86  unsigned found_lane = __ffs(masked_ballot) - 1;
87  if(laneId == src_lane) {
88  if(TYPE == OPERATION_TYPE::REMOVE) {
89  next->valid ^= (1 << found_lane);
90  }
91  fn(next->value[found_lane]);
92  thread_mask = false;
93  }
94  } else {
95  next_ptr = reinterpret_cast<slab_node<K, V>*>(__shfl_sync(~0, reinterpret_cast<unsigned long long>(next_ptr), 31));
96  if (next_ptr == nullptr) {
97 
98  if(TYPE == OPERATION_TYPE::INSERT) {
99  // if we are doing an insert here
100  // check if found empty and then insert there
101  // otherwise allocate
102  masked_ballot = __ballot_sync(~0u, found_empty);
103  if(masked_ballot != 0) {
104  unsigned found_lane = __ffs(masked_ballot) - 1;
105  auto loc = reinterpret_cast<unsigned long long>(empty_location);
106  next = reinterpret_cast<slab_node<K, V>*>(__shfl_sync(~0, loc, found_lane));
107  if(laneId == src_lane) {
108  next->key[found_lane] = key;
109  next->valid |= (1 << found_lane);
110  fn(next->value[found_lane]);
111  // mark it valid
112  thread_mask = false;
113  }
114  } else {
115  if(laneId == 31) {
116  next_ptr = alloc.allocate(1);
117  next_ptr = new (static_cast<void*>(next_ptr)) slab_node<K, V>();
118  next->next = next_ptr;
119  next = next_ptr;
120  }
121  next = reinterpret_cast<slab_node<K, V>*>(__shfl_sync(~0, reinterpret_cast<unsigned long long>(next), 31));
122  if(next == nullptr)
123  __trap();
124  }
125  }
126 
127  if ((TYPE == OPERATION_TYPE::FIND || TYPE == OPERATION_TYPE::UPDATE || TYPE == OPERATION_TYPE::REMOVE) && laneId == src_lane) {
128  // on read only did not find
129  thread_mask = false;
130  }
131  } else {
132  next = next_ptr;
133  }
134  }
135 
136  last_work_queue = work_queue;
137 
138  work_queue = __ballot_sync(~0u, thread_mask);
139 
140  if(work_queue != last_work_queue){
141  //__threadfence_system();
142  lock = Lock_t<warp_mutex>();
143  }
144  }
145  }
146 
147  template<typename K>
148  LSLAB_DEVICE void operator()(warp_mutex* lock_table,
149  set_node<K>* buckets,
150  const K& key,
151  bool& result,
152  Allocator& alloc,
153  const unsigned bucket,
154  bool thread_mask) {
155 
156  const unsigned laneId = threadIdx.x % 32;
157  set_node<K>* next = nullptr;
158 
159  unsigned work_queue = __ballot_sync(~0u, thread_mask);
160 
161  unsigned last_work_queue = 0;
162 
163  Lock_t<warp_mutex> lock;
164 
165  while (work_queue != 0) {
166 
167  next = (work_queue != last_work_queue) ? nullptr : next;
168  unsigned src_lane = __ffs(work_queue) - 1;
169 
170  K src_key = cub::ShuffleIndex<32>(key, src_lane, ~0);
171 
172  unsigned src_bucket = __shfl_sync(~0u, bucket, src_lane);
173 
174  bool found_empty = false;
175  set_node<K>* empty_location = nullptr;
176 
177  if(work_queue != last_work_queue) {
178  next = &buckets[src_bucket];
179  lock = Lock_t<warp_mutex>(lock_table[src_bucket]);
180  }
181 
182  set_node<K>* next_ptr = nullptr;
183  K read_key;
184 
185  bool found = false;
186 
187 
188  if(laneId < 31) {
189  read_key = next->key[laneId];
190  unsigned valid = next->valid;
191  valid = (valid >> laneId) & 0x1;
192  if(valid) {
193  found = read_key == src_key;
194  } else if(TYPE == OPERATION_TYPE::INSERT && !found_empty) {
195  found_empty = true;
196  empty_location = next;
197  }
198  } else {
199  next_ptr = next->next;
200  }
201 
202  auto masked_ballot = __ballot_sync(~0u, found); //& 0x7fffffffu;
203 
204  if (masked_ballot != 0) {
205  unsigned found_lane = __ffs(masked_ballot) - 1;
206  if(laneId == src_lane) {
207  if(TYPE == OPERATION_TYPE::REMOVE) {
208  next->valid ^= (1 << found_lane);
209  }
210  result = true;
211  thread_mask = false;
212  }
213  } else {
214  next_ptr = reinterpret_cast<set_node<K>*>(__shfl_sync(~0, reinterpret_cast<unsigned long long>(next_ptr), 31));
215  if (next_ptr == nullptr) {
216 
217  if(TYPE == OPERATION_TYPE::INSERT) {
218  // if we are doing an insert here
219  // check if found empty and then insert there
220  // otherwise allocate
221  masked_ballot = __ballot_sync(~0u, found_empty);
222  if(masked_ballot != 0) {
223  unsigned found_lane = __ffs(masked_ballot) - 1;
224  auto loc = reinterpret_cast<unsigned long long>(empty_location);
225  next = reinterpret_cast<set_node<K>*>(__shfl_sync(~0, loc, found_lane));
226  if(laneId == src_lane) {
227  next->key[found_lane] = key;
228  next->valid |= (1 << found_lane);
229  // mark it valid
230  thread_mask = false;
231  result = true;
232  }
233  } else {
234  if(laneId == 31) {
235  next_ptr = alloc.allocate(1);
236  next_ptr = new (static_cast<void*>(next_ptr)) set_node<K>();
237  next->next = next_ptr;
238  next = next_ptr;
239  }
240  next = reinterpret_cast<set_node<K>*>(__shfl_sync(~0, reinterpret_cast<unsigned long long>(next), 31));
241  if(next == nullptr)
242  __trap();
243  }
244  }
245 
246  if ((TYPE == OPERATION_TYPE::FIND || TYPE == OPERATION_TYPE::REMOVE) && laneId == src_lane) {
247  // on read only did not find
248  thread_mask = false;
249  result = false;
250  }
251  } else {
252  next = next_ptr;
253  }
254  }
255 
256  last_work_queue = work_queue;
257 
258  work_queue = __ballot_sync(~0u, thread_mask);
259 
260  if(work_queue != last_work_queue){
261  lock = Lock_t<warp_mutex>();
262  }
263  }
264  }
265 
266 };
267 }
268 }
Definition: traverse.h:22
Definition: warp_mutex.h:13
Definition: warp_mutex.h:62