PySDM.backends.impl_thrust_rtc.methods.pair_methods

GPU implementation of pairwise operations backend methods

  1"""
  2GPU implementation of pairwise operations backend methods
  3"""
  4
  5from functools import cached_property
  6
  7from PySDM.backends.impl_thrust_rtc.conf import NICE_THRUST_FLAGS
  8from PySDM.backends.impl_thrust_rtc.nice_thrust import nice_thrust
  9
 10from ..conf import trtc
 11from ..methods.thrust_rtc_backend_methods import ThrustRTCBackendMethods
 12
 13
 14class PairMethods(ThrustRTCBackendMethods):
 15    @cached_property
 16    def __distance_pair_body(self):
 17        return trtc.For(
 18            param_names=("data_out", "data_in", "is_first_in_pair"),
 19            name_iter="i",
 20            body="""
 21        if (is_first_in_pair[i]) {
 22            data_out[(int64_t)(i/2)] = abs(data_in[i] - data_in[i + 1]);
 23        }
 24        """,
 25        )
 26
 27    @nice_thrust(**NICE_THRUST_FLAGS)
 28    def distance_pair(self, data_out, data_in, is_first_in_pair, idx):
 29        perm_in = trtc.DVPermutation(data_in.data, idx.data)
 30        trtc.Fill(data_out.data, trtc.DVDouble(0))
 31        self.__distance_pair_body.launch_n(
 32            len(idx), [data_out.data, perm_in, is_first_in_pair.indicator.data]
 33        )
 34
 35    @cached_property
 36    def __find_pairs_body(self):
 37        return trtc.For(
 38            param_names=("cell_start", "perm_cell_id", "is_first_in_pair", "length"),
 39            name_iter="i",
 40            body="""
 41            if (i < length -1) {
 42                auto is_in_same_cell = perm_cell_id[i] == perm_cell_id[i + 1];
 43                auto is_even_index = (i - cell_start[perm_cell_id[i]]) % 2 == 0;
 44
 45                is_first_in_pair[i] = is_in_same_cell && is_even_index;
 46            } else {
 47                is_first_in_pair[i] = false;
 48            }
 49            """,
 50        )
 51
 52    @nice_thrust(**NICE_THRUST_FLAGS)
 53    # TODO #330 handle cell_idx (_ below)
 54    def find_pairs(self, cell_start, is_first_in_pair, cell_id, _, idx):
 55        perm_cell_id = trtc.DVPermutation(cell_id.data, idx.data)
 56        d_length = trtc.DVInt64(len(idx))
 57        self.__find_pairs_body.launch_n(
 58            n=len(idx),
 59            args=(
 60                cell_start.data,
 61                perm_cell_id,
 62                is_first_in_pair.indicator.data,
 63                d_length,
 64            ),
 65        )
 66
 67    @cached_property
 68    def __max_pair_body(self):
 69        return trtc.For(
 70            param_names=("data_out", "perm_in", "is_first_in_pair"),
 71            name_iter="i",
 72            body="""
 73            if (is_first_in_pair[i]) {
 74                data_out[(int64_t)(i/2)] = max(perm_in[i], perm_in[i + 1]);
 75            }
 76            """,
 77        )
 78
 79    @nice_thrust(**NICE_THRUST_FLAGS)
 80    def max_pair(self, data_out, data_in, is_first_in_pair, idx):
 81        perm_in = trtc.DVPermutation(data_in.data, idx.data)
 82        trtc.Fill(data_out.data, trtc.DVDouble(0))
 83        self.__max_pair_body.launch_n(
 84            len(idx), [data_out.data, perm_in, is_first_in_pair.indicator.data]
 85        )
 86
 87    @cached_property
 88    def __sort_pair_body(self):
 89        return trtc.For(
 90            param_names=("data_out", "data_in", "is_first_in_pair"),
 91            name_iter="i",
 92            body="""
 93            if (is_first_in_pair[i]) {
 94                if (data_in[i] < data_in[i + 1]) {
 95                    data_out[i] = data_in[i + 1];
 96                    data_out[i + 1] = data_in[i];
 97                }
 98                else {
 99                    data_out[i] = data_in[i];
100                    data_out[i + 1] = data_in[i + 1];
101                }
102            }
103            """,
104        )
105
106    @nice_thrust(**NICE_THRUST_FLAGS)
107    def sort_pair(self, data_out, data_in, is_first_in_pair, idx):
108        perm_in = trtc.DVPermutation(data_in.data, idx.data)
109        trtc.Fill(data_out.data, trtc.DVDouble(0))
110        if len(idx) > 1:
111            self.__sort_pair_body.launch_n(
112                len(idx) - 1, [data_out.data, perm_in, is_first_in_pair.indicator.data]
113            )
114
115    @cached_property
116    def __sort_within_pair_by_attr_body(self):
117        return trtc.For(
118            param_names=("idx", "is_first_in_pair", "attr"),
119            name_iter="i",
120            body="""
121            if (is_first_in_pair[i]) {
122                if (attr[idx[i]] < attr[idx[i + 1]]) {
123                    auto tmp = idx[i];
124                    idx[i] = idx[i + 1];
125                    idx[i + 1] = tmp;
126                }
127            }
128            """,
129        )
130
131    def sort_within_pair_by_attr(self, idx, is_first_in_pair, attr):
132        if len(idx) < 2:
133            return
134        self.__sort_within_pair_by_attr_body.launch_n(
135            len(idx) - 1, [idx.data, is_first_in_pair.indicator.data, attr.data]
136        )
137
138    @cached_property
139    def __sum_pair_body(self):
140        return trtc.For(
141            param_names=("data_out", "perm_in", "is_first_in_pair"),
142            name_iter="i",
143            body="""
144            if (is_first_in_pair[i]) {
145                data_out[(int64_t)(i/2)] = perm_in[i] + perm_in[i + 1];
146            }
147            """,
148        )
149
150    @nice_thrust(**NICE_THRUST_FLAGS)
151    def sum_pair(self, data_out, data_in, is_first_in_pair, idx):
152        perm_in = trtc.DVPermutation(data_in.data, idx.data)
153        trtc.Fill(data_out.data, trtc.DVDouble(0))
154        self.__sum_pair_body.launch_n(
155            n=len(idx),
156            args=(data_out.data, perm_in, is_first_in_pair.indicator.data),
157        )
158
159    @cached_property
160    def __min_pair_body(self):
161        return trtc.For(
162            param_names=(
163                "data_out",
164                "data_in",
165                "is_first_in_pair",
166                "idx",
167            ),
168            name_iter="i",
169            body="""
170            if (is_first_in_pair[i]) {
171                data_out[(int64_t)(i/2)] = min(data_in[idx[i]], data_in[idx[i + 1]]);
172            }
173            """,
174        )
175
176    @nice_thrust(**NICE_THRUST_FLAGS)
177    def min_pair(self, data_out, data_in, is_first_in_pair, idx):
178        trtc.Fill(data_out.data, trtc.DVDouble(0))
179        self.__min_pair_body.launch_n(
180            n=len(idx),
181            args=(
182                data_out.data,
183                data_in.data,
184                is_first_in_pair.indicator.data,
185                idx.data,
186            ),
187        )
188
189    @cached_property
190    def __multiply_pair_body(self):
191        return trtc.For(
192            param_names=(
193                "data_out",
194                "data_in",
195                "is_first_in_pair",
196                "idx",
197            ),
198            name_iter="i",
199            body="""
200            if (is_first_in_pair[i]) {
201                data_out[(int64_t)(i/2)] = data_in[idx[i]] * data_in[idx[i + 1]];
202            }
203            """,
204        )
205
206    @nice_thrust(**NICE_THRUST_FLAGS)
207    def multiply_pair(self, data_out, data_in, is_first_in_pair, idx):
208        trtc.Fill(data_out.data, trtc.DVDouble(0))
209        self.__multiply_pair_body.launch_n(
210            n=len(idx),
211            args=(
212                data_out.data,
213                data_in.data,
214                is_first_in_pair.indicator.data,
215                idx.data,
216            ),
217        )
 15class PairMethods(ThrustRTCBackendMethods):
 16    @cached_property
 17    def __distance_pair_body(self):
 18        return trtc.For(
 19            param_names=("data_out", "data_in", "is_first_in_pair"),
 20            name_iter="i",
 21            body="""
 22        if (is_first_in_pair[i]) {
 23            data_out[(int64_t)(i/2)] = abs(data_in[i] - data_in[i + 1]);
 24        }
 25        """,
 26        )
 27
 28    @nice_thrust(**NICE_THRUST_FLAGS)
 29    def distance_pair(self, data_out, data_in, is_first_in_pair, idx):
 30        perm_in = trtc.DVPermutation(data_in.data, idx.data)
 31        trtc.Fill(data_out.data, trtc.DVDouble(0))
 32        self.__distance_pair_body.launch_n(
 33            len(idx), [data_out.data, perm_in, is_first_in_pair.indicator.data]
 34        )
 35
 36    @cached_property
 37    def __find_pairs_body(self):
 38        return trtc.For(
 39            param_names=("cell_start", "perm_cell_id", "is_first_in_pair", "length"),
 40            name_iter="i",
 41            body="""
 42            if (i < length -1) {
 43                auto is_in_same_cell = perm_cell_id[i] == perm_cell_id[i + 1];
 44                auto is_even_index = (i - cell_start[perm_cell_id[i]]) % 2 == 0;
 45
 46                is_first_in_pair[i] = is_in_same_cell && is_even_index;
 47            } else {
 48                is_first_in_pair[i] = false;
 49            }
 50            """,
 51        )
 52
 53    @nice_thrust(**NICE_THRUST_FLAGS)
 54    # TODO #330 handle cell_idx (_ below)
 55    def find_pairs(self, cell_start, is_first_in_pair, cell_id, _, idx):
 56        perm_cell_id = trtc.DVPermutation(cell_id.data, idx.data)
 57        d_length = trtc.DVInt64(len(idx))
 58        self.__find_pairs_body.launch_n(
 59            n=len(idx),
 60            args=(
 61                cell_start.data,
 62                perm_cell_id,
 63                is_first_in_pair.indicator.data,
 64                d_length,
 65            ),
 66        )
 67
 68    @cached_property
 69    def __max_pair_body(self):
 70        return trtc.For(
 71            param_names=("data_out", "perm_in", "is_first_in_pair"),
 72            name_iter="i",
 73            body="""
 74            if (is_first_in_pair[i]) {
 75                data_out[(int64_t)(i/2)] = max(perm_in[i], perm_in[i + 1]);
 76            }
 77            """,
 78        )
 79
 80    @nice_thrust(**NICE_THRUST_FLAGS)
 81    def max_pair(self, data_out, data_in, is_first_in_pair, idx):
 82        perm_in = trtc.DVPermutation(data_in.data, idx.data)
 83        trtc.Fill(data_out.data, trtc.DVDouble(0))
 84        self.__max_pair_body.launch_n(
 85            len(idx), [data_out.data, perm_in, is_first_in_pair.indicator.data]
 86        )
 87
 88    @cached_property
 89    def __sort_pair_body(self):
 90        return trtc.For(
 91            param_names=("data_out", "data_in", "is_first_in_pair"),
 92            name_iter="i",
 93            body="""
 94            if (is_first_in_pair[i]) {
 95                if (data_in[i] < data_in[i + 1]) {
 96                    data_out[i] = data_in[i + 1];
 97                    data_out[i + 1] = data_in[i];
 98                }
 99                else {
100                    data_out[i] = data_in[i];
101                    data_out[i + 1] = data_in[i + 1];
102                }
103            }
104            """,
105        )
106
107    @nice_thrust(**NICE_THRUST_FLAGS)
108    def sort_pair(self, data_out, data_in, is_first_in_pair, idx):
109        perm_in = trtc.DVPermutation(data_in.data, idx.data)
110        trtc.Fill(data_out.data, trtc.DVDouble(0))
111        if len(idx) > 1:
112            self.__sort_pair_body.launch_n(
113                len(idx) - 1, [data_out.data, perm_in, is_first_in_pair.indicator.data]
114            )
115
116    @cached_property
117    def __sort_within_pair_by_attr_body(self):
118        return trtc.For(
119            param_names=("idx", "is_first_in_pair", "attr"),
120            name_iter="i",
121            body="""
122            if (is_first_in_pair[i]) {
123                if (attr[idx[i]] < attr[idx[i + 1]]) {
124                    auto tmp = idx[i];
125                    idx[i] = idx[i + 1];
126                    idx[i + 1] = tmp;
127                }
128            }
129            """,
130        )
131
132    def sort_within_pair_by_attr(self, idx, is_first_in_pair, attr):
133        if len(idx) < 2:
134            return
135        self.__sort_within_pair_by_attr_body.launch_n(
136            len(idx) - 1, [idx.data, is_first_in_pair.indicator.data, attr.data]
137        )
138
139    @cached_property
140    def __sum_pair_body(self):
141        return trtc.For(
142            param_names=("data_out", "perm_in", "is_first_in_pair"),
143            name_iter="i",
144            body="""
145            if (is_first_in_pair[i]) {
146                data_out[(int64_t)(i/2)] = perm_in[i] + perm_in[i + 1];
147            }
148            """,
149        )
150
151    @nice_thrust(**NICE_THRUST_FLAGS)
152    def sum_pair(self, data_out, data_in, is_first_in_pair, idx):
153        perm_in = trtc.DVPermutation(data_in.data, idx.data)
154        trtc.Fill(data_out.data, trtc.DVDouble(0))
155        self.__sum_pair_body.launch_n(
156            n=len(idx),
157            args=(data_out.data, perm_in, is_first_in_pair.indicator.data),
158        )
159
160    @cached_property
161    def __min_pair_body(self):
162        return trtc.For(
163            param_names=(
164                "data_out",
165                "data_in",
166                "is_first_in_pair",
167                "idx",
168            ),
169            name_iter="i",
170            body="""
171            if (is_first_in_pair[i]) {
172                data_out[(int64_t)(i/2)] = min(data_in[idx[i]], data_in[idx[i + 1]]);
173            }
174            """,
175        )
176
177    @nice_thrust(**NICE_THRUST_FLAGS)
178    def min_pair(self, data_out, data_in, is_first_in_pair, idx):
179        trtc.Fill(data_out.data, trtc.DVDouble(0))
180        self.__min_pair_body.launch_n(
181            n=len(idx),
182            args=(
183                data_out.data,
184                data_in.data,
185                is_first_in_pair.indicator.data,
186                idx.data,
187            ),
188        )
189
190    @cached_property
191    def __multiply_pair_body(self):
192        return trtc.For(
193            param_names=(
194                "data_out",
195                "data_in",
196                "is_first_in_pair",
197                "idx",
198            ),
199            name_iter="i",
200            body="""
201            if (is_first_in_pair[i]) {
202                data_out[(int64_t)(i/2)] = data_in[idx[i]] * data_in[idx[i + 1]];
203            }
204            """,
205        )
206
207    @nice_thrust(**NICE_THRUST_FLAGS)
208    def multiply_pair(self, data_out, data_in, is_first_in_pair, idx):
209        trtc.Fill(data_out.data, trtc.DVDouble(0))
210        self.__multiply_pair_body.launch_n(
211            n=len(idx),
212            args=(
213                data_out.data,
214                data_in.data,
215                is_first_in_pair.indicator.data,
216                idx.data,
217            ),
218        )
def distance_pair(*args, **kwargs):
11        def wrapper(*args, **kwargs):
12            if debug_print:
13                print(func.__name__)
14            result = func(*args, **kwargs)
15            if wait:
16                trtc.Wait()
17            return result
def find_pairs(*args, **kwargs):
11        def wrapper(*args, **kwargs):
12            if debug_print:
13                print(func.__name__)
14            result = func(*args, **kwargs)
15            if wait:
16                trtc.Wait()
17            return result
def max_pair(*args, **kwargs):
11        def wrapper(*args, **kwargs):
12            if debug_print:
13                print(func.__name__)
14            result = func(*args, **kwargs)
15            if wait:
16                trtc.Wait()
17            return result
def sort_pair(*args, **kwargs):
11        def wrapper(*args, **kwargs):
12            if debug_print:
13                print(func.__name__)
14            result = func(*args, **kwargs)
15            if wait:
16                trtc.Wait()
17            return result
def sort_within_pair_by_attr(self, idx, is_first_in_pair, attr):
132    def sort_within_pair_by_attr(self, idx, is_first_in_pair, attr):
133        if len(idx) < 2:
134            return
135        self.__sort_within_pair_by_attr_body.launch_n(
136            len(idx) - 1, [idx.data, is_first_in_pair.indicator.data, attr.data]
137        )
def sum_pair(*args, **kwargs):
11        def wrapper(*args, **kwargs):
12            if debug_print:
13                print(func.__name__)
14            result = func(*args, **kwargs)
15            if wait:
16                trtc.Wait()
17            return result
def min_pair(*args, **kwargs):
11        def wrapper(*args, **kwargs):
12            if debug_print:
13                print(func.__name__)
14            result = func(*args, **kwargs)
15            if wait:
16                trtc.Wait()
17            return result
def multiply_pair(*args, **kwargs):
11        def wrapper(*args, **kwargs):
12            if debug_print:
13                print(func.__name__)
14            result = func(*args, **kwargs)
15            if wait:
16                trtc.Wait()
17            return result