PySDM.backends.impl_thrust_rtc.methods.pair_methods
GPU implementation of pairwise operations backend methods
1""" 2GPU implementation of pairwise operations backend methods 3""" 4 5from functools import cached_property 6 7from PySDM.backends.impl_thrust_rtc.conf import NICE_THRUST_FLAGS 8from PySDM.backends.impl_thrust_rtc.nice_thrust import nice_thrust 9 10from ..conf import trtc 11from ..methods.thrust_rtc_backend_methods import ThrustRTCBackendMethods 12 13 14class PairMethods(ThrustRTCBackendMethods): 15 @cached_property 16 def __distance_pair_body(self): 17 return trtc.For( 18 param_names=("data_out", "data_in", "is_first_in_pair"), 19 name_iter="i", 20 body=""" 21 if (is_first_in_pair[i]) { 22 data_out[(int64_t)(i/2)] = abs(data_in[i] - data_in[i + 1]); 23 } 24 """, 25 ) 26 27 @nice_thrust(**NICE_THRUST_FLAGS) 28 def distance_pair(self, data_out, data_in, is_first_in_pair, idx): 29 perm_in = trtc.DVPermutation(data_in.data, idx.data) 30 trtc.Fill(data_out.data, trtc.DVDouble(0)) 31 self.__distance_pair_body.launch_n( 32 len(idx), [data_out.data, perm_in, is_first_in_pair.indicator.data] 33 ) 34 35 @cached_property 36 def __find_pairs_body(self): 37 return trtc.For( 38 param_names=("cell_start", "perm_cell_id", "is_first_in_pair", "length"), 39 name_iter="i", 40 body=""" 41 if (i < length -1) { 42 auto is_in_same_cell = perm_cell_id[i] == perm_cell_id[i + 1]; 43 auto is_even_index = (i - cell_start[perm_cell_id[i]]) % 2 == 0; 44 45 is_first_in_pair[i] = is_in_same_cell && is_even_index; 46 } else { 47 is_first_in_pair[i] = false; 48 } 49 """, 50 ) 51 52 @nice_thrust(**NICE_THRUST_FLAGS) 53 # TODO #330 handle cell_idx (_ below) 54 def find_pairs(self, cell_start, is_first_in_pair, cell_id, _, idx): 55 perm_cell_id = trtc.DVPermutation(cell_id.data, idx.data) 56 d_length = trtc.DVInt64(len(idx)) 57 self.__find_pairs_body.launch_n( 58 n=len(idx), 59 args=( 60 cell_start.data, 61 perm_cell_id, 62 is_first_in_pair.indicator.data, 63 d_length, 64 ), 65 ) 66 67 @cached_property 68 def __max_pair_body(self): 69 return trtc.For( 70 param_names=("data_out", "perm_in", "is_first_in_pair"), 71 name_iter="i", 72 body=""" 73 if (is_first_in_pair[i]) { 74 data_out[(int64_t)(i/2)] = max(perm_in[i], perm_in[i + 1]); 75 } 76 """, 77 ) 78 79 @nice_thrust(**NICE_THRUST_FLAGS) 80 def max_pair(self, data_out, data_in, is_first_in_pair, idx): 81 perm_in = trtc.DVPermutation(data_in.data, idx.data) 82 trtc.Fill(data_out.data, trtc.DVDouble(0)) 83 self.__max_pair_body.launch_n( 84 len(idx), [data_out.data, perm_in, is_first_in_pair.indicator.data] 85 ) 86 87 @cached_property 88 def __sort_pair_body(self): 89 return trtc.For( 90 param_names=("data_out", "data_in", "is_first_in_pair"), 91 name_iter="i", 92 body=""" 93 if (is_first_in_pair[i]) { 94 if (data_in[i] < data_in[i + 1]) { 95 data_out[i] = data_in[i + 1]; 96 data_out[i + 1] = data_in[i]; 97 } 98 else { 99 data_out[i] = data_in[i]; 100 data_out[i + 1] = data_in[i + 1]; 101 } 102 } 103 """, 104 ) 105 106 @nice_thrust(**NICE_THRUST_FLAGS) 107 def sort_pair(self, data_out, data_in, is_first_in_pair, idx): 108 perm_in = trtc.DVPermutation(data_in.data, idx.data) 109 trtc.Fill(data_out.data, trtc.DVDouble(0)) 110 if len(idx) > 1: 111 self.__sort_pair_body.launch_n( 112 len(idx) - 1, [data_out.data, perm_in, is_first_in_pair.indicator.data] 113 ) 114 115 @cached_property 116 def __sort_within_pair_by_attr_body(self): 117 return trtc.For( 118 param_names=("idx", "is_first_in_pair", "attr"), 119 name_iter="i", 120 body=""" 121 if (is_first_in_pair[i]) { 122 if (attr[idx[i]] < attr[idx[i + 1]]) { 123 auto tmp = idx[i]; 124 idx[i] = idx[i + 1]; 125 idx[i + 1] = tmp; 126 } 127 } 128 """, 129 ) 130 131 def sort_within_pair_by_attr(self, idx, is_first_in_pair, attr): 132 if len(idx) < 2: 133 return 134 self.__sort_within_pair_by_attr_body.launch_n( 135 len(idx) - 1, [idx.data, is_first_in_pair.indicator.data, attr.data] 136 ) 137 138 @cached_property 139 def __sum_pair_body(self): 140 return trtc.For( 141 param_names=("data_out", "perm_in", "is_first_in_pair"), 142 name_iter="i", 143 body=""" 144 if (is_first_in_pair[i]) { 145 data_out[(int64_t)(i/2)] = perm_in[i] + perm_in[i + 1]; 146 } 147 """, 148 ) 149 150 @nice_thrust(**NICE_THRUST_FLAGS) 151 def sum_pair(self, data_out, data_in, is_first_in_pair, idx): 152 perm_in = trtc.DVPermutation(data_in.data, idx.data) 153 trtc.Fill(data_out.data, trtc.DVDouble(0)) 154 self.__sum_pair_body.launch_n( 155 n=len(idx), 156 args=(data_out.data, perm_in, is_first_in_pair.indicator.data), 157 ) 158 159 @cached_property 160 def __min_pair_body(self): 161 return trtc.For( 162 param_names=( 163 "data_out", 164 "data_in", 165 "is_first_in_pair", 166 "idx", 167 ), 168 name_iter="i", 169 body=""" 170 if (is_first_in_pair[i]) { 171 data_out[(int64_t)(i/2)] = min(data_in[idx[i]], data_in[idx[i + 1]]); 172 } 173 """, 174 ) 175 176 @nice_thrust(**NICE_THRUST_FLAGS) 177 def min_pair(self, data_out, data_in, is_first_in_pair, idx): 178 trtc.Fill(data_out.data, trtc.DVDouble(0)) 179 self.__min_pair_body.launch_n( 180 n=len(idx), 181 args=( 182 data_out.data, 183 data_in.data, 184 is_first_in_pair.indicator.data, 185 idx.data, 186 ), 187 ) 188 189 @cached_property 190 def __multiply_pair_body(self): 191 return trtc.For( 192 param_names=( 193 "data_out", 194 "data_in", 195 "is_first_in_pair", 196 "idx", 197 ), 198 name_iter="i", 199 body=""" 200 if (is_first_in_pair[i]) { 201 data_out[(int64_t)(i/2)] = data_in[idx[i]] * data_in[idx[i + 1]]; 202 } 203 """, 204 ) 205 206 @nice_thrust(**NICE_THRUST_FLAGS) 207 def multiply_pair(self, data_out, data_in, is_first_in_pair, idx): 208 trtc.Fill(data_out.data, trtc.DVDouble(0)) 209 self.__multiply_pair_body.launch_n( 210 n=len(idx), 211 args=( 212 data_out.data, 213 data_in.data, 214 is_first_in_pair.indicator.data, 215 idx.data, 216 ), 217 )
class
PairMethods(PySDM.backends.impl_thrust_rtc.methods.thrust_rtc_backend_methods.ThrustRTCBackendMethods):
15class PairMethods(ThrustRTCBackendMethods): 16 @cached_property 17 def __distance_pair_body(self): 18 return trtc.For( 19 param_names=("data_out", "data_in", "is_first_in_pair"), 20 name_iter="i", 21 body=""" 22 if (is_first_in_pair[i]) { 23 data_out[(int64_t)(i/2)] = abs(data_in[i] - data_in[i + 1]); 24 } 25 """, 26 ) 27 28 @nice_thrust(**NICE_THRUST_FLAGS) 29 def distance_pair(self, data_out, data_in, is_first_in_pair, idx): 30 perm_in = trtc.DVPermutation(data_in.data, idx.data) 31 trtc.Fill(data_out.data, trtc.DVDouble(0)) 32 self.__distance_pair_body.launch_n( 33 len(idx), [data_out.data, perm_in, is_first_in_pair.indicator.data] 34 ) 35 36 @cached_property 37 def __find_pairs_body(self): 38 return trtc.For( 39 param_names=("cell_start", "perm_cell_id", "is_first_in_pair", "length"), 40 name_iter="i", 41 body=""" 42 if (i < length -1) { 43 auto is_in_same_cell = perm_cell_id[i] == perm_cell_id[i + 1]; 44 auto is_even_index = (i - cell_start[perm_cell_id[i]]) % 2 == 0; 45 46 is_first_in_pair[i] = is_in_same_cell && is_even_index; 47 } else { 48 is_first_in_pair[i] = false; 49 } 50 """, 51 ) 52 53 @nice_thrust(**NICE_THRUST_FLAGS) 54 # TODO #330 handle cell_idx (_ below) 55 def find_pairs(self, cell_start, is_first_in_pair, cell_id, _, idx): 56 perm_cell_id = trtc.DVPermutation(cell_id.data, idx.data) 57 d_length = trtc.DVInt64(len(idx)) 58 self.__find_pairs_body.launch_n( 59 n=len(idx), 60 args=( 61 cell_start.data, 62 perm_cell_id, 63 is_first_in_pair.indicator.data, 64 d_length, 65 ), 66 ) 67 68 @cached_property 69 def __max_pair_body(self): 70 return trtc.For( 71 param_names=("data_out", "perm_in", "is_first_in_pair"), 72 name_iter="i", 73 body=""" 74 if (is_first_in_pair[i]) { 75 data_out[(int64_t)(i/2)] = max(perm_in[i], perm_in[i + 1]); 76 } 77 """, 78 ) 79 80 @nice_thrust(**NICE_THRUST_FLAGS) 81 def max_pair(self, data_out, data_in, is_first_in_pair, idx): 82 perm_in = trtc.DVPermutation(data_in.data, idx.data) 83 trtc.Fill(data_out.data, trtc.DVDouble(0)) 84 self.__max_pair_body.launch_n( 85 len(idx), [data_out.data, perm_in, is_first_in_pair.indicator.data] 86 ) 87 88 @cached_property 89 def __sort_pair_body(self): 90 return trtc.For( 91 param_names=("data_out", "data_in", "is_first_in_pair"), 92 name_iter="i", 93 body=""" 94 if (is_first_in_pair[i]) { 95 if (data_in[i] < data_in[i + 1]) { 96 data_out[i] = data_in[i + 1]; 97 data_out[i + 1] = data_in[i]; 98 } 99 else { 100 data_out[i] = data_in[i]; 101 data_out[i + 1] = data_in[i + 1]; 102 } 103 } 104 """, 105 ) 106 107 @nice_thrust(**NICE_THRUST_FLAGS) 108 def sort_pair(self, data_out, data_in, is_first_in_pair, idx): 109 perm_in = trtc.DVPermutation(data_in.data, idx.data) 110 trtc.Fill(data_out.data, trtc.DVDouble(0)) 111 if len(idx) > 1: 112 self.__sort_pair_body.launch_n( 113 len(idx) - 1, [data_out.data, perm_in, is_first_in_pair.indicator.data] 114 ) 115 116 @cached_property 117 def __sort_within_pair_by_attr_body(self): 118 return trtc.For( 119 param_names=("idx", "is_first_in_pair", "attr"), 120 name_iter="i", 121 body=""" 122 if (is_first_in_pair[i]) { 123 if (attr[idx[i]] < attr[idx[i + 1]]) { 124 auto tmp = idx[i]; 125 idx[i] = idx[i + 1]; 126 idx[i + 1] = tmp; 127 } 128 } 129 """, 130 ) 131 132 def sort_within_pair_by_attr(self, idx, is_first_in_pair, attr): 133 if len(idx) < 2: 134 return 135 self.__sort_within_pair_by_attr_body.launch_n( 136 len(idx) - 1, [idx.data, is_first_in_pair.indicator.data, attr.data] 137 ) 138 139 @cached_property 140 def __sum_pair_body(self): 141 return trtc.For( 142 param_names=("data_out", "perm_in", "is_first_in_pair"), 143 name_iter="i", 144 body=""" 145 if (is_first_in_pair[i]) { 146 data_out[(int64_t)(i/2)] = perm_in[i] + perm_in[i + 1]; 147 } 148 """, 149 ) 150 151 @nice_thrust(**NICE_THRUST_FLAGS) 152 def sum_pair(self, data_out, data_in, is_first_in_pair, idx): 153 perm_in = trtc.DVPermutation(data_in.data, idx.data) 154 trtc.Fill(data_out.data, trtc.DVDouble(0)) 155 self.__sum_pair_body.launch_n( 156 n=len(idx), 157 args=(data_out.data, perm_in, is_first_in_pair.indicator.data), 158 ) 159 160 @cached_property 161 def __min_pair_body(self): 162 return trtc.For( 163 param_names=( 164 "data_out", 165 "data_in", 166 "is_first_in_pair", 167 "idx", 168 ), 169 name_iter="i", 170 body=""" 171 if (is_first_in_pair[i]) { 172 data_out[(int64_t)(i/2)] = min(data_in[idx[i]], data_in[idx[i + 1]]); 173 } 174 """, 175 ) 176 177 @nice_thrust(**NICE_THRUST_FLAGS) 178 def min_pair(self, data_out, data_in, is_first_in_pair, idx): 179 trtc.Fill(data_out.data, trtc.DVDouble(0)) 180 self.__min_pair_body.launch_n( 181 n=len(idx), 182 args=( 183 data_out.data, 184 data_in.data, 185 is_first_in_pair.indicator.data, 186 idx.data, 187 ), 188 ) 189 190 @cached_property 191 def __multiply_pair_body(self): 192 return trtc.For( 193 param_names=( 194 "data_out", 195 "data_in", 196 "is_first_in_pair", 197 "idx", 198 ), 199 name_iter="i", 200 body=""" 201 if (is_first_in_pair[i]) { 202 data_out[(int64_t)(i/2)] = data_in[idx[i]] * data_in[idx[i + 1]]; 203 } 204 """, 205 ) 206 207 @nice_thrust(**NICE_THRUST_FLAGS) 208 def multiply_pair(self, data_out, data_in, is_first_in_pair, idx): 209 trtc.Fill(data_out.data, trtc.DVDouble(0)) 210 self.__multiply_pair_body.launch_n( 211 n=len(idx), 212 args=( 213 data_out.data, 214 data_in.data, 215 is_first_in_pair.indicator.data, 216 idx.data, 217 ), 218 )