PySDM.backends.impl_numba.methods.displacement_methods

CPU implementation of backend methods for particle displacement (advection and sedimentation)

  1# pylint: disable=too-many-positional-arguments
  2"""
  3CPU implementation of backend methods for particle displacement (advection and sedimentation)
  4"""
  5
  6from functools import cached_property
  7
  8import numba
  9
 10from PySDM.backends.impl_numba import conf
 11
 12from ...impl_common.backend_methods import BackendMethods
 13
 14
 15@numba.njit(**{**conf.JIT_FLAGS, **{"parallel": False}})
 16def calculate_displacement_body_common(
 17    dim, droplet, scheme, _l, _r, displacement, courant, position_in_cell, n_substeps
 18):
 19    displacement[dim, droplet] = scheme(
 20        position_in_cell[dim, droplet],
 21        courant[_l] / n_substeps,
 22        courant[_r] / n_substeps,
 23    )
 24
 25
 26class DisplacementMethods(BackendMethods):
 27    @staticmethod
 28    @numba.njit(**{**conf.JIT_FLAGS, **{"parallel": False, "cache": False}})
 29    def calculate_displacement_body_1d(
 30        dim, scheme, displacement, courant, cell_origin, position_in_cell, n_substeps
 31    ):
 32        length = displacement.shape[1]
 33        for droplet in numba.prange(length):  # pylint: disable=not-an-iterable
 34            # Arakawa-C grid
 35            _l = cell_origin[0, droplet]
 36            _r = cell_origin[0, droplet] + 1
 37            calculate_displacement_body_common(
 38                dim,
 39                droplet,
 40                scheme,
 41                _l,
 42                _r,
 43                displacement,
 44                courant,
 45                position_in_cell,
 46                n_substeps,
 47            )
 48
 49    @staticmethod
 50    @numba.njit(**{**conf.JIT_FLAGS, **{"parallel": False, "cache": False}})
 51    def calculate_displacement_body_2d(
 52        dim, scheme, displacement, courant, cell_origin, position_in_cell, n_substeps
 53    ):
 54        length = displacement.shape[1]
 55        for droplet in numba.prange(length):  # pylint: disable=not-an-iterable
 56            # Arakawa-C grid
 57            _l = (
 58                cell_origin[0, droplet],
 59                cell_origin[1, droplet],
 60            )
 61            _r = (
 62                cell_origin[0, droplet] + 1 * (dim == 0),
 63                cell_origin[1, droplet] + 1 * (dim == 1),
 64            )
 65            calculate_displacement_body_common(
 66                dim,
 67                droplet,
 68                scheme,
 69                _l,
 70                _r,
 71                displacement,
 72                courant,
 73                position_in_cell,
 74                n_substeps,
 75            )
 76
 77    @staticmethod
 78    @numba.njit(**{**conf.JIT_FLAGS, **{"parallel": False, "cache": False}})
 79    def calculate_displacement_body_3d(
 80        dim, scheme, displacement, courant, cell_origin, position_in_cell, n_substeps
 81    ):
 82        n_sd = displacement.shape[1]
 83        for droplet in numba.prange(n_sd):  # pylint: disable=not-an-iterable
 84            # Arakawa-C grid
 85            _l = (
 86                cell_origin[0, droplet],
 87                cell_origin[1, droplet],
 88                cell_origin[2, droplet],
 89            )
 90            _r = (
 91                cell_origin[0, droplet] + 1 * (dim == 0),
 92                cell_origin[1, droplet] + 1 * (dim == 1),
 93                cell_origin[2, droplet] + 1 * (dim == 2),
 94            )
 95            calculate_displacement_body_common(
 96                dim,
 97                droplet,
 98                scheme,
 99                _l,
100                _r,
101                displacement,
102                courant,
103                position_in_cell,
104                n_substeps,
105            )
106
107    def calculate_displacement(
108        self, *, dim, displacement, courant, cell_origin, position_in_cell, n_substeps
109    ):
110        n_dims = len(courant.shape)
111        scheme = self.formulae.particle_advection.displacement
112        if n_dims == 1:
113            DisplacementMethods.calculate_displacement_body_1d(
114                dim,
115                scheme,
116                displacement.data,
117                courant.data,
118                cell_origin.data,
119                position_in_cell.data,
120                n_substeps,
121            )
122        elif n_dims == 2:
123            DisplacementMethods.calculate_displacement_body_2d(
124                dim,
125                scheme,
126                displacement.data,
127                courant.data,
128                cell_origin.data,
129                position_in_cell.data,
130                n_substeps,
131            )
132        elif n_dims == 3:
133            DisplacementMethods.calculate_displacement_body_3d(
134                dim,
135                scheme,
136                displacement.data,
137                courant.data,
138                cell_origin.data,
139                position_in_cell.data,
140                n_substeps,
141            )
142        else:
143            raise NotImplementedError()
144
145    @cached_property
146    def _flag_precipitated_body(self):
147        @numba.njit(**{**self.default_jit_flags, "parallel": False})
148        def body(
149            cell_origin,
150            position_in_cell,
151            water_mass,
152            multiplicity,
153            idx,
154            length,
155            healthy,
156            precipitation_counting_level_index,
157            displacement,
158        ):
159            rainfall_mass = 0.0
160            flag = len(idx)
161            for i in range(length):
162                position_within_column = (
163                    cell_origin[-1, idx[i]] + position_in_cell[-1, idx[i]]
164                )
165                if (
166                    # falling
167                    displacement[-1, idx[i]] < 0
168                    and
169                    # and crossed precip-counting level
170                    position_within_column < precipitation_counting_level_index
171                ):
172                    rainfall_mass += abs(water_mass[idx[i]]) * multiplicity[idx[i]]
173                    idx[i] = flag
174                    healthy[0] = 0
175            return rainfall_mass
176
177        return body
178
179    @cached_property
180    def _flag_out_of_column_body(self):
181        @numba.njit(**{**self.default_jit_flags, "parallel": False})
182        def body(
183            cell_origin, position_in_cell, idx, length, healthy, domain_top_level_index
184        ):
185            flag = len(idx)
186            for i in range(length):
187                position_within_column = (
188                    cell_origin[-1, idx[i]] + position_in_cell[-1, idx[i]]
189                )
190                if (
191                    position_within_column < 0
192                    or position_within_column > domain_top_level_index
193                ):
194                    idx[i] = flag
195                    healthy[0] = 0
196
197        return body
198
199    def flag_precipitated(
200        self,
201        *,
202        cell_origin,
203        position_in_cell,
204        water_mass,
205        multiplicity,
206        idx,
207        length,
208        healthy,
209        precipitation_counting_level_index,
210        displacement,
211    ) -> float:
212        """return a scalar value corresponding to the mass of water (all phases) that crossed
213        the bottom boundary of the entire domain"""
214        return self._flag_precipitated_body(
215            cell_origin.data,
216            position_in_cell.data,
217            water_mass.data,
218            multiplicity.data,
219            idx.data,
220            length,
221            healthy.data,
222            precipitation_counting_level_index,
223            displacement.data,
224        )
225
226    def flag_out_of_column(
227        self,
228        cell_origin,
229        position_in_cell,
230        idx,
231        length,
232        healthy,
233        domain_top_level_index,
234    ):
235        self._flag_out_of_column_body(
236            cell_origin.data,
237            position_in_cell.data,
238            idx.data,
239            length,
240            healthy.data,
241            domain_top_level_index,
242        )
@numba.njit(**{**conf.JIT_FLAGS, **{'parallel': False}})
def calculate_displacement_body_common( dim, droplet, scheme, _l, _r, displacement, courant, position_in_cell, n_substeps):
16@numba.njit(**{**conf.JIT_FLAGS, **{"parallel": False}})
17def calculate_displacement_body_common(
18    dim, droplet, scheme, _l, _r, displacement, courant, position_in_cell, n_substeps
19):
20    displacement[dim, droplet] = scheme(
21        position_in_cell[dim, droplet],
22        courant[_l] / n_substeps,
23        courant[_r] / n_substeps,
24    )
class DisplacementMethods(PySDM.backends.impl_common.backend_methods.BackendMethods):
 27class DisplacementMethods(BackendMethods):
 28    @staticmethod
 29    @numba.njit(**{**conf.JIT_FLAGS, **{"parallel": False, "cache": False}})
 30    def calculate_displacement_body_1d(
 31        dim, scheme, displacement, courant, cell_origin, position_in_cell, n_substeps
 32    ):
 33        length = displacement.shape[1]
 34        for droplet in numba.prange(length):  # pylint: disable=not-an-iterable
 35            # Arakawa-C grid
 36            _l = cell_origin[0, droplet]
 37            _r = cell_origin[0, droplet] + 1
 38            calculate_displacement_body_common(
 39                dim,
 40                droplet,
 41                scheme,
 42                _l,
 43                _r,
 44                displacement,
 45                courant,
 46                position_in_cell,
 47                n_substeps,
 48            )
 49
 50    @staticmethod
 51    @numba.njit(**{**conf.JIT_FLAGS, **{"parallel": False, "cache": False}})
 52    def calculate_displacement_body_2d(
 53        dim, scheme, displacement, courant, cell_origin, position_in_cell, n_substeps
 54    ):
 55        length = displacement.shape[1]
 56        for droplet in numba.prange(length):  # pylint: disable=not-an-iterable
 57            # Arakawa-C grid
 58            _l = (
 59                cell_origin[0, droplet],
 60                cell_origin[1, droplet],
 61            )
 62            _r = (
 63                cell_origin[0, droplet] + 1 * (dim == 0),
 64                cell_origin[1, droplet] + 1 * (dim == 1),
 65            )
 66            calculate_displacement_body_common(
 67                dim,
 68                droplet,
 69                scheme,
 70                _l,
 71                _r,
 72                displacement,
 73                courant,
 74                position_in_cell,
 75                n_substeps,
 76            )
 77
 78    @staticmethod
 79    @numba.njit(**{**conf.JIT_FLAGS, **{"parallel": False, "cache": False}})
 80    def calculate_displacement_body_3d(
 81        dim, scheme, displacement, courant, cell_origin, position_in_cell, n_substeps
 82    ):
 83        n_sd = displacement.shape[1]
 84        for droplet in numba.prange(n_sd):  # pylint: disable=not-an-iterable
 85            # Arakawa-C grid
 86            _l = (
 87                cell_origin[0, droplet],
 88                cell_origin[1, droplet],
 89                cell_origin[2, droplet],
 90            )
 91            _r = (
 92                cell_origin[0, droplet] + 1 * (dim == 0),
 93                cell_origin[1, droplet] + 1 * (dim == 1),
 94                cell_origin[2, droplet] + 1 * (dim == 2),
 95            )
 96            calculate_displacement_body_common(
 97                dim,
 98                droplet,
 99                scheme,
100                _l,
101                _r,
102                displacement,
103                courant,
104                position_in_cell,
105                n_substeps,
106            )
107
108    def calculate_displacement(
109        self, *, dim, displacement, courant, cell_origin, position_in_cell, n_substeps
110    ):
111        n_dims = len(courant.shape)
112        scheme = self.formulae.particle_advection.displacement
113        if n_dims == 1:
114            DisplacementMethods.calculate_displacement_body_1d(
115                dim,
116                scheme,
117                displacement.data,
118                courant.data,
119                cell_origin.data,
120                position_in_cell.data,
121                n_substeps,
122            )
123        elif n_dims == 2:
124            DisplacementMethods.calculate_displacement_body_2d(
125                dim,
126                scheme,
127                displacement.data,
128                courant.data,
129                cell_origin.data,
130                position_in_cell.data,
131                n_substeps,
132            )
133        elif n_dims == 3:
134            DisplacementMethods.calculate_displacement_body_3d(
135                dim,
136                scheme,
137                displacement.data,
138                courant.data,
139                cell_origin.data,
140                position_in_cell.data,
141                n_substeps,
142            )
143        else:
144            raise NotImplementedError()
145
146    @cached_property
147    def _flag_precipitated_body(self):
148        @numba.njit(**{**self.default_jit_flags, "parallel": False})
149        def body(
150            cell_origin,
151            position_in_cell,
152            water_mass,
153            multiplicity,
154            idx,
155            length,
156            healthy,
157            precipitation_counting_level_index,
158            displacement,
159        ):
160            rainfall_mass = 0.0
161            flag = len(idx)
162            for i in range(length):
163                position_within_column = (
164                    cell_origin[-1, idx[i]] + position_in_cell[-1, idx[i]]
165                )
166                if (
167                    # falling
168                    displacement[-1, idx[i]] < 0
169                    and
170                    # and crossed precip-counting level
171                    position_within_column < precipitation_counting_level_index
172                ):
173                    rainfall_mass += abs(water_mass[idx[i]]) * multiplicity[idx[i]]
174                    idx[i] = flag
175                    healthy[0] = 0
176            return rainfall_mass
177
178        return body
179
180    @cached_property
181    def _flag_out_of_column_body(self):
182        @numba.njit(**{**self.default_jit_flags, "parallel": False})
183        def body(
184            cell_origin, position_in_cell, idx, length, healthy, domain_top_level_index
185        ):
186            flag = len(idx)
187            for i in range(length):
188                position_within_column = (
189                    cell_origin[-1, idx[i]] + position_in_cell[-1, idx[i]]
190                )
191                if (
192                    position_within_column < 0
193                    or position_within_column > domain_top_level_index
194                ):
195                    idx[i] = flag
196                    healthy[0] = 0
197
198        return body
199
200    def flag_precipitated(
201        self,
202        *,
203        cell_origin,
204        position_in_cell,
205        water_mass,
206        multiplicity,
207        idx,
208        length,
209        healthy,
210        precipitation_counting_level_index,
211        displacement,
212    ) -> float:
213        """return a scalar value corresponding to the mass of water (all phases) that crossed
214        the bottom boundary of the entire domain"""
215        return self._flag_precipitated_body(
216            cell_origin.data,
217            position_in_cell.data,
218            water_mass.data,
219            multiplicity.data,
220            idx.data,
221            length,
222            healthy.data,
223            precipitation_counting_level_index,
224            displacement.data,
225        )
226
227    def flag_out_of_column(
228        self,
229        cell_origin,
230        position_in_cell,
231        idx,
232        length,
233        healthy,
234        domain_top_level_index,
235    ):
236        self._flag_out_of_column_body(
237            cell_origin.data,
238            position_in_cell.data,
239            idx.data,
240            length,
241            healthy.data,
242            domain_top_level_index,
243        )
@staticmethod
@numba.njit(**{**conf.JIT_FLAGS, **{'parallel': False, 'cache': False}})
def calculate_displacement_body_1d( dim, scheme, displacement, courant, cell_origin, position_in_cell, n_substeps):
28    @staticmethod
29    @numba.njit(**{**conf.JIT_FLAGS, **{"parallel": False, "cache": False}})
30    def calculate_displacement_body_1d(
31        dim, scheme, displacement, courant, cell_origin, position_in_cell, n_substeps
32    ):
33        length = displacement.shape[1]
34        for droplet in numba.prange(length):  # pylint: disable=not-an-iterable
35            # Arakawa-C grid
36            _l = cell_origin[0, droplet]
37            _r = cell_origin[0, droplet] + 1
38            calculate_displacement_body_common(
39                dim,
40                droplet,
41                scheme,
42                _l,
43                _r,
44                displacement,
45                courant,
46                position_in_cell,
47                n_substeps,
48            )
@staticmethod
@numba.njit(**{**conf.JIT_FLAGS, **{'parallel': False, 'cache': False}})
def calculate_displacement_body_2d( dim, scheme, displacement, courant, cell_origin, position_in_cell, n_substeps):
50    @staticmethod
51    @numba.njit(**{**conf.JIT_FLAGS, **{"parallel": False, "cache": False}})
52    def calculate_displacement_body_2d(
53        dim, scheme, displacement, courant, cell_origin, position_in_cell, n_substeps
54    ):
55        length = displacement.shape[1]
56        for droplet in numba.prange(length):  # pylint: disable=not-an-iterable
57            # Arakawa-C grid
58            _l = (
59                cell_origin[0, droplet],
60                cell_origin[1, droplet],
61            )
62            _r = (
63                cell_origin[0, droplet] + 1 * (dim == 0),
64                cell_origin[1, droplet] + 1 * (dim == 1),
65            )
66            calculate_displacement_body_common(
67                dim,
68                droplet,
69                scheme,
70                _l,
71                _r,
72                displacement,
73                courant,
74                position_in_cell,
75                n_substeps,
76            )
@staticmethod
@numba.njit(**{**conf.JIT_FLAGS, **{'parallel': False, 'cache': False}})
def calculate_displacement_body_3d( dim, scheme, displacement, courant, cell_origin, position_in_cell, n_substeps):
 78    @staticmethod
 79    @numba.njit(**{**conf.JIT_FLAGS, **{"parallel": False, "cache": False}})
 80    def calculate_displacement_body_3d(
 81        dim, scheme, displacement, courant, cell_origin, position_in_cell, n_substeps
 82    ):
 83        n_sd = displacement.shape[1]
 84        for droplet in numba.prange(n_sd):  # pylint: disable=not-an-iterable
 85            # Arakawa-C grid
 86            _l = (
 87                cell_origin[0, droplet],
 88                cell_origin[1, droplet],
 89                cell_origin[2, droplet],
 90            )
 91            _r = (
 92                cell_origin[0, droplet] + 1 * (dim == 0),
 93                cell_origin[1, droplet] + 1 * (dim == 1),
 94                cell_origin[2, droplet] + 1 * (dim == 2),
 95            )
 96            calculate_displacement_body_common(
 97                dim,
 98                droplet,
 99                scheme,
100                _l,
101                _r,
102                displacement,
103                courant,
104                position_in_cell,
105                n_substeps,
106            )
def calculate_displacement( self, *, dim, displacement, courant, cell_origin, position_in_cell, n_substeps):
108    def calculate_displacement(
109        self, *, dim, displacement, courant, cell_origin, position_in_cell, n_substeps
110    ):
111        n_dims = len(courant.shape)
112        scheme = self.formulae.particle_advection.displacement
113        if n_dims == 1:
114            DisplacementMethods.calculate_displacement_body_1d(
115                dim,
116                scheme,
117                displacement.data,
118                courant.data,
119                cell_origin.data,
120                position_in_cell.data,
121                n_substeps,
122            )
123        elif n_dims == 2:
124            DisplacementMethods.calculate_displacement_body_2d(
125                dim,
126                scheme,
127                displacement.data,
128                courant.data,
129                cell_origin.data,
130                position_in_cell.data,
131                n_substeps,
132            )
133        elif n_dims == 3:
134            DisplacementMethods.calculate_displacement_body_3d(
135                dim,
136                scheme,
137                displacement.data,
138                courant.data,
139                cell_origin.data,
140                position_in_cell.data,
141                n_substeps,
142            )
143        else:
144            raise NotImplementedError()
def flag_precipitated( self, *, cell_origin, position_in_cell, water_mass, multiplicity, idx, length, healthy, precipitation_counting_level_index, displacement) -> float:
200    def flag_precipitated(
201        self,
202        *,
203        cell_origin,
204        position_in_cell,
205        water_mass,
206        multiplicity,
207        idx,
208        length,
209        healthy,
210        precipitation_counting_level_index,
211        displacement,
212    ) -> float:
213        """return a scalar value corresponding to the mass of water (all phases) that crossed
214        the bottom boundary of the entire domain"""
215        return self._flag_precipitated_body(
216            cell_origin.data,
217            position_in_cell.data,
218            water_mass.data,
219            multiplicity.data,
220            idx.data,
221            length,
222            healthy.data,
223            precipitation_counting_level_index,
224            displacement.data,
225        )

return a scalar value corresponding to the mass of water (all phases) that crossed the bottom boundary of the entire domain

def flag_out_of_column( self, cell_origin, position_in_cell, idx, length, healthy, domain_top_level_index):
227    def flag_out_of_column(
228        self,
229        cell_origin,
230        position_in_cell,
231        idx,
232        length,
233        healthy,
234        domain_top_level_index,
235    ):
236        self._flag_out_of_column_body(
237            cell_origin.data,
238            position_in_cell.data,
239            idx.data,
240            length,
241            healthy.data,
242            domain_top_level_index,
243        )