PySDM.backends.impl_numba.methods.displacement_methods

CPU implementation of backend methods for particle displacement (advection and sedimentation)

  1"""
  2CPU implementation of backend methods for particle displacement (advection and sedimentation)
  3"""
  4
  5from functools import cached_property
  6
  7import numba
  8
  9from PySDM.backends.impl_numba import conf
 10
 11from ...impl_common.backend_methods import BackendMethods
 12
 13
 14@numba.njit(**{**conf.JIT_FLAGS, **{"parallel": False}})
 15# pylint: disable=too-many-arguments
 16def calculate_displacement_body_common(
 17    dim, droplet, scheme, _l, _r, displacement, courant, position_in_cell, n_substeps
 18):
 19    displacement[dim, droplet] = scheme(
 20        position_in_cell[dim, droplet],
 21        courant[_l] / n_substeps,
 22        courant[_r] / n_substeps,
 23    )
 24
 25
 26class DisplacementMethods(BackendMethods):
 27    @staticmethod
 28    @numba.njit(**{**conf.JIT_FLAGS, **{"parallel": False, "cache": False}})
 29    # pylint: disable=too-many-arguments
 30    def calculate_displacement_body_1d(
 31        dim, scheme, displacement, courant, cell_origin, position_in_cell, n_substeps
 32    ):
 33        length = displacement.shape[1]
 34        for droplet in numba.prange(length):  # pylint: disable=not-an-iterable
 35            # Arakawa-C grid
 36            _l = cell_origin[0, droplet]
 37            _r = cell_origin[0, droplet] + 1
 38            calculate_displacement_body_common(
 39                dim,
 40                droplet,
 41                scheme,
 42                _l,
 43                _r,
 44                displacement,
 45                courant,
 46                position_in_cell,
 47                n_substeps,
 48            )
 49
 50    @staticmethod
 51    @numba.njit(**{**conf.JIT_FLAGS, **{"parallel": False, "cache": False}})
 52    # pylint: disable=too-many-arguments
 53    def calculate_displacement_body_2d(
 54        dim, scheme, displacement, courant, cell_origin, position_in_cell, n_substeps
 55    ):
 56        length = displacement.shape[1]
 57        for droplet in numba.prange(length):  # pylint: disable=not-an-iterable
 58            # Arakawa-C grid
 59            _l = (
 60                cell_origin[0, droplet],
 61                cell_origin[1, droplet],
 62            )
 63            _r = (
 64                cell_origin[0, droplet] + 1 * (dim == 0),
 65                cell_origin[1, droplet] + 1 * (dim == 1),
 66            )
 67            calculate_displacement_body_common(
 68                dim,
 69                droplet,
 70                scheme,
 71                _l,
 72                _r,
 73                displacement,
 74                courant,
 75                position_in_cell,
 76                n_substeps,
 77            )
 78
 79    @staticmethod
 80    @numba.njit(**{**conf.JIT_FLAGS, **{"parallel": False, "cache": False}})
 81    # pylint: disable=too-many-arguments
 82    def calculate_displacement_body_3d(
 83        dim, scheme, displacement, courant, cell_origin, position_in_cell, n_substeps
 84    ):
 85        n_sd = displacement.shape[1]
 86        for droplet in numba.prange(n_sd):  # pylint: disable=not-an-iterable
 87            # Arakawa-C grid
 88            _l = (
 89                cell_origin[0, droplet],
 90                cell_origin[1, droplet],
 91                cell_origin[2, droplet],
 92            )
 93            _r = (
 94                cell_origin[0, droplet] + 1 * (dim == 0),
 95                cell_origin[1, droplet] + 1 * (dim == 1),
 96                cell_origin[2, droplet] + 1 * (dim == 2),
 97            )
 98            calculate_displacement_body_common(
 99                dim,
100                droplet,
101                scheme,
102                _l,
103                _r,
104                displacement,
105                courant,
106                position_in_cell,
107                n_substeps,
108            )
109
110    def calculate_displacement(
111        self, *, dim, displacement, courant, cell_origin, position_in_cell, n_substeps
112    ):
113        n_dims = len(courant.shape)
114        scheme = self.formulae.particle_advection.displacement
115        if n_dims == 1:
116            DisplacementMethods.calculate_displacement_body_1d(
117                dim,
118                scheme,
119                displacement.data,
120                courant.data,
121                cell_origin.data,
122                position_in_cell.data,
123                n_substeps,
124            )
125        elif n_dims == 2:
126            DisplacementMethods.calculate_displacement_body_2d(
127                dim,
128                scheme,
129                displacement.data,
130                courant.data,
131                cell_origin.data,
132                position_in_cell.data,
133                n_substeps,
134            )
135        elif n_dims == 3:
136            DisplacementMethods.calculate_displacement_body_3d(
137                dim,
138                scheme,
139                displacement.data,
140                courant.data,
141                cell_origin.data,
142                position_in_cell.data,
143                n_substeps,
144            )
145        else:
146            raise NotImplementedError()
147
148    @cached_property
149    def _flag_precipitated_body(self):
150        @numba.njit(**{**self.default_jit_flags, "parallel": False})
151        # pylint: disable=too-many-arguments
152        def body(
153            cell_origin,
154            position_in_cell,
155            water_mass,
156            multiplicity,
157            idx,
158            length,
159            healthy,
160            precipitation_counting_level_index,
161            displacement,
162        ):
163            rainfall_mass = 0.0
164            flag = len(idx)
165            for i in range(length):
166                position_within_column = (
167                    cell_origin[-1, idx[i]] + position_in_cell[-1, idx[i]]
168                )
169                if (
170                    # falling
171                    displacement[-1, idx[i]] < 0
172                    and
173                    # and crossed precip-counting level
174                    position_within_column < precipitation_counting_level_index
175                ):
176                    rainfall_mass += abs(water_mass[idx[i]]) * multiplicity[idx[i]]
177                    idx[i] = flag
178                    healthy[0] = 0
179            return rainfall_mass
180
181        return body
182
183    @cached_property
184    def _flag_out_of_column_body(self):
185        @numba.njit(**{**self.default_jit_flags, "parallel": False})
186        # pylint: disable=too-many-arguments
187        def body(
188            cell_origin, position_in_cell, idx, length, healthy, domain_top_level_index
189        ):
190            flag = len(idx)
191            for i in range(length):
192                position_within_column = (
193                    cell_origin[-1, idx[i]] + position_in_cell[-1, idx[i]]
194                )
195                if (
196                    position_within_column < 0
197                    or position_within_column > domain_top_level_index
198                ):
199                    idx[i] = flag
200                    healthy[0] = 0
201
202        return body
203
204    # pylint: disable=too-many-arguments
205    def flag_precipitated(
206        self,
207        *,
208        cell_origin,
209        position_in_cell,
210        water_mass,
211        multiplicity,
212        idx,
213        length,
214        healthy,
215        precipitation_counting_level_index,
216        displacement,
217    ) -> float:
218        """return a scalar value corresponding to the mass of water (all phases) that crossed
219        the bottom boundary of the entire domain"""
220        return self._flag_precipitated_body(
221            cell_origin.data,
222            position_in_cell.data,
223            water_mass.data,
224            multiplicity.data,
225            idx.data,
226            length,
227            healthy.data,
228            precipitation_counting_level_index,
229            displacement.data,
230        )
231
232    # pylint: disable=too-many-arguments
233    def flag_out_of_column(
234        self,
235        cell_origin,
236        position_in_cell,
237        idx,
238        length,
239        healthy,
240        domain_top_level_index,
241    ):
242        self._flag_out_of_column_body(
243            cell_origin.data,
244            position_in_cell.data,
245            idx.data,
246            length,
247            healthy.data,
248            domain_top_level_index,
249        )
@numba.njit(**{**conf.JIT_FLAGS, **{'parallel': False}})
def calculate_displacement_body_common( dim, droplet, scheme, _l, _r, displacement, courant, position_in_cell, n_substeps):
15@numba.njit(**{**conf.JIT_FLAGS, **{"parallel": False}})
16# pylint: disable=too-many-arguments
17def calculate_displacement_body_common(
18    dim, droplet, scheme, _l, _r, displacement, courant, position_in_cell, n_substeps
19):
20    displacement[dim, droplet] = scheme(
21        position_in_cell[dim, droplet],
22        courant[_l] / n_substeps,
23        courant[_r] / n_substeps,
24    )
class DisplacementMethods(PySDM.backends.impl_common.backend_methods.BackendMethods):
 27class DisplacementMethods(BackendMethods):
 28    @staticmethod
 29    @numba.njit(**{**conf.JIT_FLAGS, **{"parallel": False, "cache": False}})
 30    # pylint: disable=too-many-arguments
 31    def calculate_displacement_body_1d(
 32        dim, scheme, displacement, courant, cell_origin, position_in_cell, n_substeps
 33    ):
 34        length = displacement.shape[1]
 35        for droplet in numba.prange(length):  # pylint: disable=not-an-iterable
 36            # Arakawa-C grid
 37            _l = cell_origin[0, droplet]
 38            _r = cell_origin[0, droplet] + 1
 39            calculate_displacement_body_common(
 40                dim,
 41                droplet,
 42                scheme,
 43                _l,
 44                _r,
 45                displacement,
 46                courant,
 47                position_in_cell,
 48                n_substeps,
 49            )
 50
 51    @staticmethod
 52    @numba.njit(**{**conf.JIT_FLAGS, **{"parallel": False, "cache": False}})
 53    # pylint: disable=too-many-arguments
 54    def calculate_displacement_body_2d(
 55        dim, scheme, displacement, courant, cell_origin, position_in_cell, n_substeps
 56    ):
 57        length = displacement.shape[1]
 58        for droplet in numba.prange(length):  # pylint: disable=not-an-iterable
 59            # Arakawa-C grid
 60            _l = (
 61                cell_origin[0, droplet],
 62                cell_origin[1, droplet],
 63            )
 64            _r = (
 65                cell_origin[0, droplet] + 1 * (dim == 0),
 66                cell_origin[1, droplet] + 1 * (dim == 1),
 67            )
 68            calculate_displacement_body_common(
 69                dim,
 70                droplet,
 71                scheme,
 72                _l,
 73                _r,
 74                displacement,
 75                courant,
 76                position_in_cell,
 77                n_substeps,
 78            )
 79
 80    @staticmethod
 81    @numba.njit(**{**conf.JIT_FLAGS, **{"parallel": False, "cache": False}})
 82    # pylint: disable=too-many-arguments
 83    def calculate_displacement_body_3d(
 84        dim, scheme, displacement, courant, cell_origin, position_in_cell, n_substeps
 85    ):
 86        n_sd = displacement.shape[1]
 87        for droplet in numba.prange(n_sd):  # pylint: disable=not-an-iterable
 88            # Arakawa-C grid
 89            _l = (
 90                cell_origin[0, droplet],
 91                cell_origin[1, droplet],
 92                cell_origin[2, droplet],
 93            )
 94            _r = (
 95                cell_origin[0, droplet] + 1 * (dim == 0),
 96                cell_origin[1, droplet] + 1 * (dim == 1),
 97                cell_origin[2, droplet] + 1 * (dim == 2),
 98            )
 99            calculate_displacement_body_common(
100                dim,
101                droplet,
102                scheme,
103                _l,
104                _r,
105                displacement,
106                courant,
107                position_in_cell,
108                n_substeps,
109            )
110
111    def calculate_displacement(
112        self, *, dim, displacement, courant, cell_origin, position_in_cell, n_substeps
113    ):
114        n_dims = len(courant.shape)
115        scheme = self.formulae.particle_advection.displacement
116        if n_dims == 1:
117            DisplacementMethods.calculate_displacement_body_1d(
118                dim,
119                scheme,
120                displacement.data,
121                courant.data,
122                cell_origin.data,
123                position_in_cell.data,
124                n_substeps,
125            )
126        elif n_dims == 2:
127            DisplacementMethods.calculate_displacement_body_2d(
128                dim,
129                scheme,
130                displacement.data,
131                courant.data,
132                cell_origin.data,
133                position_in_cell.data,
134                n_substeps,
135            )
136        elif n_dims == 3:
137            DisplacementMethods.calculate_displacement_body_3d(
138                dim,
139                scheme,
140                displacement.data,
141                courant.data,
142                cell_origin.data,
143                position_in_cell.data,
144                n_substeps,
145            )
146        else:
147            raise NotImplementedError()
148
149    @cached_property
150    def _flag_precipitated_body(self):
151        @numba.njit(**{**self.default_jit_flags, "parallel": False})
152        # pylint: disable=too-many-arguments
153        def body(
154            cell_origin,
155            position_in_cell,
156            water_mass,
157            multiplicity,
158            idx,
159            length,
160            healthy,
161            precipitation_counting_level_index,
162            displacement,
163        ):
164            rainfall_mass = 0.0
165            flag = len(idx)
166            for i in range(length):
167                position_within_column = (
168                    cell_origin[-1, idx[i]] + position_in_cell[-1, idx[i]]
169                )
170                if (
171                    # falling
172                    displacement[-1, idx[i]] < 0
173                    and
174                    # and crossed precip-counting level
175                    position_within_column < precipitation_counting_level_index
176                ):
177                    rainfall_mass += abs(water_mass[idx[i]]) * multiplicity[idx[i]]
178                    idx[i] = flag
179                    healthy[0] = 0
180            return rainfall_mass
181
182        return body
183
184    @cached_property
185    def _flag_out_of_column_body(self):
186        @numba.njit(**{**self.default_jit_flags, "parallel": False})
187        # pylint: disable=too-many-arguments
188        def body(
189            cell_origin, position_in_cell, idx, length, healthy, domain_top_level_index
190        ):
191            flag = len(idx)
192            for i in range(length):
193                position_within_column = (
194                    cell_origin[-1, idx[i]] + position_in_cell[-1, idx[i]]
195                )
196                if (
197                    position_within_column < 0
198                    or position_within_column > domain_top_level_index
199                ):
200                    idx[i] = flag
201                    healthy[0] = 0
202
203        return body
204
205    # pylint: disable=too-many-arguments
206    def flag_precipitated(
207        self,
208        *,
209        cell_origin,
210        position_in_cell,
211        water_mass,
212        multiplicity,
213        idx,
214        length,
215        healthy,
216        precipitation_counting_level_index,
217        displacement,
218    ) -> float:
219        """return a scalar value corresponding to the mass of water (all phases) that crossed
220        the bottom boundary of the entire domain"""
221        return self._flag_precipitated_body(
222            cell_origin.data,
223            position_in_cell.data,
224            water_mass.data,
225            multiplicity.data,
226            idx.data,
227            length,
228            healthy.data,
229            precipitation_counting_level_index,
230            displacement.data,
231        )
232
233    # pylint: disable=too-many-arguments
234    def flag_out_of_column(
235        self,
236        cell_origin,
237        position_in_cell,
238        idx,
239        length,
240        healthy,
241        domain_top_level_index,
242    ):
243        self._flag_out_of_column_body(
244            cell_origin.data,
245            position_in_cell.data,
246            idx.data,
247            length,
248            healthy.data,
249            domain_top_level_index,
250        )
@staticmethod
@numba.njit(**{**conf.JIT_FLAGS, **{'parallel': False, 'cache': False}})
def calculate_displacement_body_1d( dim, scheme, displacement, courant, cell_origin, position_in_cell, n_substeps):
28    @staticmethod
29    @numba.njit(**{**conf.JIT_FLAGS, **{"parallel": False, "cache": False}})
30    # pylint: disable=too-many-arguments
31    def calculate_displacement_body_1d(
32        dim, scheme, displacement, courant, cell_origin, position_in_cell, n_substeps
33    ):
34        length = displacement.shape[1]
35        for droplet in numba.prange(length):  # pylint: disable=not-an-iterable
36            # Arakawa-C grid
37            _l = cell_origin[0, droplet]
38            _r = cell_origin[0, droplet] + 1
39            calculate_displacement_body_common(
40                dim,
41                droplet,
42                scheme,
43                _l,
44                _r,
45                displacement,
46                courant,
47                position_in_cell,
48                n_substeps,
49            )
@staticmethod
@numba.njit(**{**conf.JIT_FLAGS, **{'parallel': False, 'cache': False}})
def calculate_displacement_body_2d( dim, scheme, displacement, courant, cell_origin, position_in_cell, n_substeps):
51    @staticmethod
52    @numba.njit(**{**conf.JIT_FLAGS, **{"parallel": False, "cache": False}})
53    # pylint: disable=too-many-arguments
54    def calculate_displacement_body_2d(
55        dim, scheme, displacement, courant, cell_origin, position_in_cell, n_substeps
56    ):
57        length = displacement.shape[1]
58        for droplet in numba.prange(length):  # pylint: disable=not-an-iterable
59            # Arakawa-C grid
60            _l = (
61                cell_origin[0, droplet],
62                cell_origin[1, droplet],
63            )
64            _r = (
65                cell_origin[0, droplet] + 1 * (dim == 0),
66                cell_origin[1, droplet] + 1 * (dim == 1),
67            )
68            calculate_displacement_body_common(
69                dim,
70                droplet,
71                scheme,
72                _l,
73                _r,
74                displacement,
75                courant,
76                position_in_cell,
77                n_substeps,
78            )
@staticmethod
@numba.njit(**{**conf.JIT_FLAGS, **{'parallel': False, 'cache': False}})
def calculate_displacement_body_3d( dim, scheme, displacement, courant, cell_origin, position_in_cell, n_substeps):
 80    @staticmethod
 81    @numba.njit(**{**conf.JIT_FLAGS, **{"parallel": False, "cache": False}})
 82    # pylint: disable=too-many-arguments
 83    def calculate_displacement_body_3d(
 84        dim, scheme, displacement, courant, cell_origin, position_in_cell, n_substeps
 85    ):
 86        n_sd = displacement.shape[1]
 87        for droplet in numba.prange(n_sd):  # pylint: disable=not-an-iterable
 88            # Arakawa-C grid
 89            _l = (
 90                cell_origin[0, droplet],
 91                cell_origin[1, droplet],
 92                cell_origin[2, droplet],
 93            )
 94            _r = (
 95                cell_origin[0, droplet] + 1 * (dim == 0),
 96                cell_origin[1, droplet] + 1 * (dim == 1),
 97                cell_origin[2, droplet] + 1 * (dim == 2),
 98            )
 99            calculate_displacement_body_common(
100                dim,
101                droplet,
102                scheme,
103                _l,
104                _r,
105                displacement,
106                courant,
107                position_in_cell,
108                n_substeps,
109            )
def calculate_displacement( self, *, dim, displacement, courant, cell_origin, position_in_cell, n_substeps):
111    def calculate_displacement(
112        self, *, dim, displacement, courant, cell_origin, position_in_cell, n_substeps
113    ):
114        n_dims = len(courant.shape)
115        scheme = self.formulae.particle_advection.displacement
116        if n_dims == 1:
117            DisplacementMethods.calculate_displacement_body_1d(
118                dim,
119                scheme,
120                displacement.data,
121                courant.data,
122                cell_origin.data,
123                position_in_cell.data,
124                n_substeps,
125            )
126        elif n_dims == 2:
127            DisplacementMethods.calculate_displacement_body_2d(
128                dim,
129                scheme,
130                displacement.data,
131                courant.data,
132                cell_origin.data,
133                position_in_cell.data,
134                n_substeps,
135            )
136        elif n_dims == 3:
137            DisplacementMethods.calculate_displacement_body_3d(
138                dim,
139                scheme,
140                displacement.data,
141                courant.data,
142                cell_origin.data,
143                position_in_cell.data,
144                n_substeps,
145            )
146        else:
147            raise NotImplementedError()
def flag_precipitated( self, *, cell_origin, position_in_cell, water_mass, multiplicity, idx, length, healthy, precipitation_counting_level_index, displacement) -> float:
206    def flag_precipitated(
207        self,
208        *,
209        cell_origin,
210        position_in_cell,
211        water_mass,
212        multiplicity,
213        idx,
214        length,
215        healthy,
216        precipitation_counting_level_index,
217        displacement,
218    ) -> float:
219        """return a scalar value corresponding to the mass of water (all phases) that crossed
220        the bottom boundary of the entire domain"""
221        return self._flag_precipitated_body(
222            cell_origin.data,
223            position_in_cell.data,
224            water_mass.data,
225            multiplicity.data,
226            idx.data,
227            length,
228            healthy.data,
229            precipitation_counting_level_index,
230            displacement.data,
231        )

return a scalar value corresponding to the mass of water (all phases) that crossed the bottom boundary of the entire domain

def flag_out_of_column( self, cell_origin, position_in_cell, idx, length, healthy, domain_top_level_index):
234    def flag_out_of_column(
235        self,
236        cell_origin,
237        position_in_cell,
238        idx,
239        length,
240        healthy,
241        domain_top_level_index,
242    ):
243        self._flag_out_of_column_body(
244            cell_origin.data,
245            position_in_cell.data,
246            idx.data,
247            length,
248            healthy.data,
249            domain_top_level_index,
250        )