PySDM.backends.impl_numba.methods.displacement_methods
CPU implementation of backend methods for particle displacement (advection and sedimentation)
1# pylint: disable=too-many-positional-arguments 2""" 3CPU implementation of backend methods for particle displacement (advection and sedimentation) 4""" 5 6from functools import cached_property 7 8import numba 9 10from PySDM.backends.impl_numba import conf 11 12from ...impl_common.backend_methods import BackendMethods 13 14 15@numba.njit(**{**conf.JIT_FLAGS, **{"parallel": False}}) 16def calculate_displacement_body_common( 17 dim, droplet, scheme, _l, _r, displacement, courant, position_in_cell, n_substeps 18): 19 displacement[dim, droplet] = scheme( 20 position_in_cell[dim, droplet], 21 courant[_l] / n_substeps, 22 courant[_r] / n_substeps, 23 ) 24 25 26class DisplacementMethods(BackendMethods): 27 @staticmethod 28 @numba.njit(**{**conf.JIT_FLAGS, **{"parallel": False, "cache": False}}) 29 def calculate_displacement_body_1d( 30 dim, scheme, displacement, courant, cell_origin, position_in_cell, n_substeps 31 ): 32 length = displacement.shape[1] 33 for droplet in numba.prange(length): # pylint: disable=not-an-iterable 34 # Arakawa-C grid 35 _l = cell_origin[0, droplet] 36 _r = cell_origin[0, droplet] + 1 37 calculate_displacement_body_common( 38 dim, 39 droplet, 40 scheme, 41 _l, 42 _r, 43 displacement, 44 courant, 45 position_in_cell, 46 n_substeps, 47 ) 48 49 @staticmethod 50 @numba.njit(**{**conf.JIT_FLAGS, **{"parallel": False, "cache": False}}) 51 def calculate_displacement_body_2d( 52 dim, scheme, displacement, courant, cell_origin, position_in_cell, n_substeps 53 ): 54 length = displacement.shape[1] 55 for droplet in numba.prange(length): # pylint: disable=not-an-iterable 56 # Arakawa-C grid 57 _l = ( 58 cell_origin[0, droplet], 59 cell_origin[1, droplet], 60 ) 61 _r = ( 62 cell_origin[0, droplet] + 1 * (dim == 0), 63 cell_origin[1, droplet] + 1 * (dim == 1), 64 ) 65 calculate_displacement_body_common( 66 dim, 67 droplet, 68 scheme, 69 _l, 70 _r, 71 displacement, 72 courant, 73 position_in_cell, 74 n_substeps, 75 ) 76 77 @staticmethod 78 @numba.njit(**{**conf.JIT_FLAGS, **{"parallel": False, "cache": False}}) 79 def calculate_displacement_body_3d( 80 dim, scheme, displacement, courant, cell_origin, position_in_cell, n_substeps 81 ): 82 n_sd = displacement.shape[1] 83 for droplet in numba.prange(n_sd): # pylint: disable=not-an-iterable 84 # Arakawa-C grid 85 _l = ( 86 cell_origin[0, droplet], 87 cell_origin[1, droplet], 88 cell_origin[2, droplet], 89 ) 90 _r = ( 91 cell_origin[0, droplet] + 1 * (dim == 0), 92 cell_origin[1, droplet] + 1 * (dim == 1), 93 cell_origin[2, droplet] + 1 * (dim == 2), 94 ) 95 calculate_displacement_body_common( 96 dim, 97 droplet, 98 scheme, 99 _l, 100 _r, 101 displacement, 102 courant, 103 position_in_cell, 104 n_substeps, 105 ) 106 107 def calculate_displacement( 108 self, *, dim, displacement, courant, cell_origin, position_in_cell, n_substeps 109 ): 110 n_dims = len(courant.shape) 111 scheme = self.formulae.particle_advection.displacement 112 if n_dims == 1: 113 DisplacementMethods.calculate_displacement_body_1d( 114 dim, 115 scheme, 116 displacement.data, 117 courant.data, 118 cell_origin.data, 119 position_in_cell.data, 120 n_substeps, 121 ) 122 elif n_dims == 2: 123 DisplacementMethods.calculate_displacement_body_2d( 124 dim, 125 scheme, 126 displacement.data, 127 courant.data, 128 cell_origin.data, 129 position_in_cell.data, 130 n_substeps, 131 ) 132 elif n_dims == 3: 133 DisplacementMethods.calculate_displacement_body_3d( 134 dim, 135 scheme, 136 displacement.data, 137 courant.data, 138 cell_origin.data, 139 position_in_cell.data, 140 n_substeps, 141 ) 142 else: 143 raise NotImplementedError() 144 145 @cached_property 146 def _flag_precipitated_body(self): 147 @numba.njit(**{**self.default_jit_flags, "parallel": False}) 148 def body( 149 cell_origin, 150 position_in_cell, 151 water_mass, 152 multiplicity, 153 idx, 154 length, 155 healthy, 156 precipitation_counting_level_index, 157 displacement, 158 ): 159 rainfall_mass = 0.0 160 flag = len(idx) 161 for i in range(length): 162 position_within_column = ( 163 cell_origin[-1, idx[i]] + position_in_cell[-1, idx[i]] 164 ) 165 if ( 166 # falling 167 displacement[-1, idx[i]] < 0 168 and 169 # and crossed precip-counting level 170 position_within_column < precipitation_counting_level_index 171 ): 172 rainfall_mass += abs(water_mass[idx[i]]) * multiplicity[idx[i]] 173 idx[i] = flag 174 healthy[0] = 0 175 return rainfall_mass 176 177 return body 178 179 @cached_property 180 def _flag_out_of_column_body(self): 181 @numba.njit(**{**self.default_jit_flags, "parallel": False}) 182 def body( 183 cell_origin, position_in_cell, idx, length, healthy, domain_top_level_index 184 ): 185 flag = len(idx) 186 for i in range(length): 187 position_within_column = ( 188 cell_origin[-1, idx[i]] + position_in_cell[-1, idx[i]] 189 ) 190 if ( 191 position_within_column < 0 192 or position_within_column > domain_top_level_index 193 ): 194 idx[i] = flag 195 healthy[0] = 0 196 197 return body 198 199 def flag_precipitated( 200 self, 201 *, 202 cell_origin, 203 position_in_cell, 204 water_mass, 205 multiplicity, 206 idx, 207 length, 208 healthy, 209 precipitation_counting_level_index, 210 displacement, 211 ) -> float: 212 """return a scalar value corresponding to the mass of water (all phases) that crossed 213 the bottom boundary of the entire domain""" 214 return self._flag_precipitated_body( 215 cell_origin.data, 216 position_in_cell.data, 217 water_mass.data, 218 multiplicity.data, 219 idx.data, 220 length, 221 healthy.data, 222 precipitation_counting_level_index, 223 displacement.data, 224 ) 225 226 def flag_out_of_column( 227 self, 228 cell_origin, 229 position_in_cell, 230 idx, 231 length, 232 healthy, 233 domain_top_level_index, 234 ): 235 self._flag_out_of_column_body( 236 cell_origin.data, 237 position_in_cell.data, 238 idx.data, 239 length, 240 healthy.data, 241 domain_top_level_index, 242 )
@numba.njit(**{**conf.JIT_FLAGS, **{'parallel': False}})
def
calculate_displacement_body_common( dim, droplet, scheme, _l, _r, displacement, courant, position_in_cell, n_substeps):
16@numba.njit(**{**conf.JIT_FLAGS, **{"parallel": False}}) 17def calculate_displacement_body_common( 18 dim, droplet, scheme, _l, _r, displacement, courant, position_in_cell, n_substeps 19): 20 displacement[dim, droplet] = scheme( 21 position_in_cell[dim, droplet], 22 courant[_l] / n_substeps, 23 courant[_r] / n_substeps, 24 )
27class DisplacementMethods(BackendMethods): 28 @staticmethod 29 @numba.njit(**{**conf.JIT_FLAGS, **{"parallel": False, "cache": False}}) 30 def calculate_displacement_body_1d( 31 dim, scheme, displacement, courant, cell_origin, position_in_cell, n_substeps 32 ): 33 length = displacement.shape[1] 34 for droplet in numba.prange(length): # pylint: disable=not-an-iterable 35 # Arakawa-C grid 36 _l = cell_origin[0, droplet] 37 _r = cell_origin[0, droplet] + 1 38 calculate_displacement_body_common( 39 dim, 40 droplet, 41 scheme, 42 _l, 43 _r, 44 displacement, 45 courant, 46 position_in_cell, 47 n_substeps, 48 ) 49 50 @staticmethod 51 @numba.njit(**{**conf.JIT_FLAGS, **{"parallel": False, "cache": False}}) 52 def calculate_displacement_body_2d( 53 dim, scheme, displacement, courant, cell_origin, position_in_cell, n_substeps 54 ): 55 length = displacement.shape[1] 56 for droplet in numba.prange(length): # pylint: disable=not-an-iterable 57 # Arakawa-C grid 58 _l = ( 59 cell_origin[0, droplet], 60 cell_origin[1, droplet], 61 ) 62 _r = ( 63 cell_origin[0, droplet] + 1 * (dim == 0), 64 cell_origin[1, droplet] + 1 * (dim == 1), 65 ) 66 calculate_displacement_body_common( 67 dim, 68 droplet, 69 scheme, 70 _l, 71 _r, 72 displacement, 73 courant, 74 position_in_cell, 75 n_substeps, 76 ) 77 78 @staticmethod 79 @numba.njit(**{**conf.JIT_FLAGS, **{"parallel": False, "cache": False}}) 80 def calculate_displacement_body_3d( 81 dim, scheme, displacement, courant, cell_origin, position_in_cell, n_substeps 82 ): 83 n_sd = displacement.shape[1] 84 for droplet in numba.prange(n_sd): # pylint: disable=not-an-iterable 85 # Arakawa-C grid 86 _l = ( 87 cell_origin[0, droplet], 88 cell_origin[1, droplet], 89 cell_origin[2, droplet], 90 ) 91 _r = ( 92 cell_origin[0, droplet] + 1 * (dim == 0), 93 cell_origin[1, droplet] + 1 * (dim == 1), 94 cell_origin[2, droplet] + 1 * (dim == 2), 95 ) 96 calculate_displacement_body_common( 97 dim, 98 droplet, 99 scheme, 100 _l, 101 _r, 102 displacement, 103 courant, 104 position_in_cell, 105 n_substeps, 106 ) 107 108 def calculate_displacement( 109 self, *, dim, displacement, courant, cell_origin, position_in_cell, n_substeps 110 ): 111 n_dims = len(courant.shape) 112 scheme = self.formulae.particle_advection.displacement 113 if n_dims == 1: 114 DisplacementMethods.calculate_displacement_body_1d( 115 dim, 116 scheme, 117 displacement.data, 118 courant.data, 119 cell_origin.data, 120 position_in_cell.data, 121 n_substeps, 122 ) 123 elif n_dims == 2: 124 DisplacementMethods.calculate_displacement_body_2d( 125 dim, 126 scheme, 127 displacement.data, 128 courant.data, 129 cell_origin.data, 130 position_in_cell.data, 131 n_substeps, 132 ) 133 elif n_dims == 3: 134 DisplacementMethods.calculate_displacement_body_3d( 135 dim, 136 scheme, 137 displacement.data, 138 courant.data, 139 cell_origin.data, 140 position_in_cell.data, 141 n_substeps, 142 ) 143 else: 144 raise NotImplementedError() 145 146 @cached_property 147 def _flag_precipitated_body(self): 148 @numba.njit(**{**self.default_jit_flags, "parallel": False}) 149 def body( 150 cell_origin, 151 position_in_cell, 152 water_mass, 153 multiplicity, 154 idx, 155 length, 156 healthy, 157 precipitation_counting_level_index, 158 displacement, 159 ): 160 rainfall_mass = 0.0 161 flag = len(idx) 162 for i in range(length): 163 position_within_column = ( 164 cell_origin[-1, idx[i]] + position_in_cell[-1, idx[i]] 165 ) 166 if ( 167 # falling 168 displacement[-1, idx[i]] < 0 169 and 170 # and crossed precip-counting level 171 position_within_column < precipitation_counting_level_index 172 ): 173 rainfall_mass += abs(water_mass[idx[i]]) * multiplicity[idx[i]] 174 idx[i] = flag 175 healthy[0] = 0 176 return rainfall_mass 177 178 return body 179 180 @cached_property 181 def _flag_out_of_column_body(self): 182 @numba.njit(**{**self.default_jit_flags, "parallel": False}) 183 def body( 184 cell_origin, position_in_cell, idx, length, healthy, domain_top_level_index 185 ): 186 flag = len(idx) 187 for i in range(length): 188 position_within_column = ( 189 cell_origin[-1, idx[i]] + position_in_cell[-1, idx[i]] 190 ) 191 if ( 192 position_within_column < 0 193 or position_within_column > domain_top_level_index 194 ): 195 idx[i] = flag 196 healthy[0] = 0 197 198 return body 199 200 def flag_precipitated( 201 self, 202 *, 203 cell_origin, 204 position_in_cell, 205 water_mass, 206 multiplicity, 207 idx, 208 length, 209 healthy, 210 precipitation_counting_level_index, 211 displacement, 212 ) -> float: 213 """return a scalar value corresponding to the mass of water (all phases) that crossed 214 the bottom boundary of the entire domain""" 215 return self._flag_precipitated_body( 216 cell_origin.data, 217 position_in_cell.data, 218 water_mass.data, 219 multiplicity.data, 220 idx.data, 221 length, 222 healthy.data, 223 precipitation_counting_level_index, 224 displacement.data, 225 ) 226 227 def flag_out_of_column( 228 self, 229 cell_origin, 230 position_in_cell, 231 idx, 232 length, 233 healthy, 234 domain_top_level_index, 235 ): 236 self._flag_out_of_column_body( 237 cell_origin.data, 238 position_in_cell.data, 239 idx.data, 240 length, 241 healthy.data, 242 domain_top_level_index, 243 )
@staticmethod
@numba.njit(**{**conf.JIT_FLAGS, **{'parallel': False, 'cache': False}})
def
calculate_displacement_body_1d( dim, scheme, displacement, courant, cell_origin, position_in_cell, n_substeps):
28 @staticmethod 29 @numba.njit(**{**conf.JIT_FLAGS, **{"parallel": False, "cache": False}}) 30 def calculate_displacement_body_1d( 31 dim, scheme, displacement, courant, cell_origin, position_in_cell, n_substeps 32 ): 33 length = displacement.shape[1] 34 for droplet in numba.prange(length): # pylint: disable=not-an-iterable 35 # Arakawa-C grid 36 _l = cell_origin[0, droplet] 37 _r = cell_origin[0, droplet] + 1 38 calculate_displacement_body_common( 39 dim, 40 droplet, 41 scheme, 42 _l, 43 _r, 44 displacement, 45 courant, 46 position_in_cell, 47 n_substeps, 48 )
@staticmethod
@numba.njit(**{**conf.JIT_FLAGS, **{'parallel': False, 'cache': False}})
def
calculate_displacement_body_2d( dim, scheme, displacement, courant, cell_origin, position_in_cell, n_substeps):
50 @staticmethod 51 @numba.njit(**{**conf.JIT_FLAGS, **{"parallel": False, "cache": False}}) 52 def calculate_displacement_body_2d( 53 dim, scheme, displacement, courant, cell_origin, position_in_cell, n_substeps 54 ): 55 length = displacement.shape[1] 56 for droplet in numba.prange(length): # pylint: disable=not-an-iterable 57 # Arakawa-C grid 58 _l = ( 59 cell_origin[0, droplet], 60 cell_origin[1, droplet], 61 ) 62 _r = ( 63 cell_origin[0, droplet] + 1 * (dim == 0), 64 cell_origin[1, droplet] + 1 * (dim == 1), 65 ) 66 calculate_displacement_body_common( 67 dim, 68 droplet, 69 scheme, 70 _l, 71 _r, 72 displacement, 73 courant, 74 position_in_cell, 75 n_substeps, 76 )
@staticmethod
@numba.njit(**{**conf.JIT_FLAGS, **{'parallel': False, 'cache': False}})
def
calculate_displacement_body_3d( dim, scheme, displacement, courant, cell_origin, position_in_cell, n_substeps):
78 @staticmethod 79 @numba.njit(**{**conf.JIT_FLAGS, **{"parallel": False, "cache": False}}) 80 def calculate_displacement_body_3d( 81 dim, scheme, displacement, courant, cell_origin, position_in_cell, n_substeps 82 ): 83 n_sd = displacement.shape[1] 84 for droplet in numba.prange(n_sd): # pylint: disable=not-an-iterable 85 # Arakawa-C grid 86 _l = ( 87 cell_origin[0, droplet], 88 cell_origin[1, droplet], 89 cell_origin[2, droplet], 90 ) 91 _r = ( 92 cell_origin[0, droplet] + 1 * (dim == 0), 93 cell_origin[1, droplet] + 1 * (dim == 1), 94 cell_origin[2, droplet] + 1 * (dim == 2), 95 ) 96 calculate_displacement_body_common( 97 dim, 98 droplet, 99 scheme, 100 _l, 101 _r, 102 displacement, 103 courant, 104 position_in_cell, 105 n_substeps, 106 )
def
calculate_displacement( self, *, dim, displacement, courant, cell_origin, position_in_cell, n_substeps):
108 def calculate_displacement( 109 self, *, dim, displacement, courant, cell_origin, position_in_cell, n_substeps 110 ): 111 n_dims = len(courant.shape) 112 scheme = self.formulae.particle_advection.displacement 113 if n_dims == 1: 114 DisplacementMethods.calculate_displacement_body_1d( 115 dim, 116 scheme, 117 displacement.data, 118 courant.data, 119 cell_origin.data, 120 position_in_cell.data, 121 n_substeps, 122 ) 123 elif n_dims == 2: 124 DisplacementMethods.calculate_displacement_body_2d( 125 dim, 126 scheme, 127 displacement.data, 128 courant.data, 129 cell_origin.data, 130 position_in_cell.data, 131 n_substeps, 132 ) 133 elif n_dims == 3: 134 DisplacementMethods.calculate_displacement_body_3d( 135 dim, 136 scheme, 137 displacement.data, 138 courant.data, 139 cell_origin.data, 140 position_in_cell.data, 141 n_substeps, 142 ) 143 else: 144 raise NotImplementedError()
def
flag_precipitated( self, *, cell_origin, position_in_cell, water_mass, multiplicity, idx, length, healthy, precipitation_counting_level_index, displacement) -> float:
200 def flag_precipitated( 201 self, 202 *, 203 cell_origin, 204 position_in_cell, 205 water_mass, 206 multiplicity, 207 idx, 208 length, 209 healthy, 210 precipitation_counting_level_index, 211 displacement, 212 ) -> float: 213 """return a scalar value corresponding to the mass of water (all phases) that crossed 214 the bottom boundary of the entire domain""" 215 return self._flag_precipitated_body( 216 cell_origin.data, 217 position_in_cell.data, 218 water_mass.data, 219 multiplicity.data, 220 idx.data, 221 length, 222 healthy.data, 223 precipitation_counting_level_index, 224 displacement.data, 225 )
return a scalar value corresponding to the mass of water (all phases) that crossed the bottom boundary of the entire domain
def
flag_out_of_column( self, cell_origin, position_in_cell, idx, length, healthy, domain_top_level_index):
227 def flag_out_of_column( 228 self, 229 cell_origin, 230 position_in_cell, 231 idx, 232 length, 233 healthy, 234 domain_top_level_index, 235 ): 236 self._flag_out_of_column_body( 237 cell_origin.data, 238 position_in_cell.data, 239 idx.data, 240 length, 241 healthy.data, 242 domain_top_level_index, 243 )