PySDM.backends.impl_numba.methods.displacement_methods
CPU implementation of backend methods for particle displacement (advection and sedimentation)
1""" 2CPU implementation of backend methods for particle displacement (advection and sedimentation) 3""" 4 5from functools import cached_property 6 7import numba 8 9from PySDM.backends.impl_numba import conf 10 11from ...impl_common.backend_methods import BackendMethods 12 13 14@numba.njit(**{**conf.JIT_FLAGS, **{"parallel": False}}) 15# pylint: disable=too-many-arguments 16def calculate_displacement_body_common( 17 dim, droplet, scheme, _l, _r, displacement, courant, position_in_cell, n_substeps 18): 19 displacement[dim, droplet] = scheme( 20 position_in_cell[dim, droplet], 21 courant[_l] / n_substeps, 22 courant[_r] / n_substeps, 23 ) 24 25 26class DisplacementMethods(BackendMethods): 27 @staticmethod 28 @numba.njit(**{**conf.JIT_FLAGS, **{"parallel": False, "cache": False}}) 29 # pylint: disable=too-many-arguments 30 def calculate_displacement_body_1d( 31 dim, scheme, displacement, courant, cell_origin, position_in_cell, n_substeps 32 ): 33 length = displacement.shape[1] 34 for droplet in numba.prange(length): # pylint: disable=not-an-iterable 35 # Arakawa-C grid 36 _l = cell_origin[0, droplet] 37 _r = cell_origin[0, droplet] + 1 38 calculate_displacement_body_common( 39 dim, 40 droplet, 41 scheme, 42 _l, 43 _r, 44 displacement, 45 courant, 46 position_in_cell, 47 n_substeps, 48 ) 49 50 @staticmethod 51 @numba.njit(**{**conf.JIT_FLAGS, **{"parallel": False, "cache": False}}) 52 # pylint: disable=too-many-arguments 53 def calculate_displacement_body_2d( 54 dim, scheme, displacement, courant, cell_origin, position_in_cell, n_substeps 55 ): 56 length = displacement.shape[1] 57 for droplet in numba.prange(length): # pylint: disable=not-an-iterable 58 # Arakawa-C grid 59 _l = ( 60 cell_origin[0, droplet], 61 cell_origin[1, droplet], 62 ) 63 _r = ( 64 cell_origin[0, droplet] + 1 * (dim == 0), 65 cell_origin[1, droplet] + 1 * (dim == 1), 66 ) 67 calculate_displacement_body_common( 68 dim, 69 droplet, 70 scheme, 71 _l, 72 _r, 73 displacement, 74 courant, 75 position_in_cell, 76 n_substeps, 77 ) 78 79 @staticmethod 80 @numba.njit(**{**conf.JIT_FLAGS, **{"parallel": False, "cache": False}}) 81 # pylint: disable=too-many-arguments 82 def calculate_displacement_body_3d( 83 dim, scheme, displacement, courant, cell_origin, position_in_cell, n_substeps 84 ): 85 n_sd = displacement.shape[1] 86 for droplet in numba.prange(n_sd): # pylint: disable=not-an-iterable 87 # Arakawa-C grid 88 _l = ( 89 cell_origin[0, droplet], 90 cell_origin[1, droplet], 91 cell_origin[2, droplet], 92 ) 93 _r = ( 94 cell_origin[0, droplet] + 1 * (dim == 0), 95 cell_origin[1, droplet] + 1 * (dim == 1), 96 cell_origin[2, droplet] + 1 * (dim == 2), 97 ) 98 calculate_displacement_body_common( 99 dim, 100 droplet, 101 scheme, 102 _l, 103 _r, 104 displacement, 105 courant, 106 position_in_cell, 107 n_substeps, 108 ) 109 110 def calculate_displacement( 111 self, *, dim, displacement, courant, cell_origin, position_in_cell, n_substeps 112 ): 113 n_dims = len(courant.shape) 114 scheme = self.formulae.particle_advection.displacement 115 if n_dims == 1: 116 DisplacementMethods.calculate_displacement_body_1d( 117 dim, 118 scheme, 119 displacement.data, 120 courant.data, 121 cell_origin.data, 122 position_in_cell.data, 123 n_substeps, 124 ) 125 elif n_dims == 2: 126 DisplacementMethods.calculate_displacement_body_2d( 127 dim, 128 scheme, 129 displacement.data, 130 courant.data, 131 cell_origin.data, 132 position_in_cell.data, 133 n_substeps, 134 ) 135 elif n_dims == 3: 136 DisplacementMethods.calculate_displacement_body_3d( 137 dim, 138 scheme, 139 displacement.data, 140 courant.data, 141 cell_origin.data, 142 position_in_cell.data, 143 n_substeps, 144 ) 145 else: 146 raise NotImplementedError() 147 148 @cached_property 149 def _flag_precipitated_body(self): 150 @numba.njit(**{**self.default_jit_flags, "parallel": False}) 151 # pylint: disable=too-many-arguments 152 def body( 153 cell_origin, 154 position_in_cell, 155 water_mass, 156 multiplicity, 157 idx, 158 length, 159 healthy, 160 precipitation_counting_level_index, 161 displacement, 162 ): 163 rainfall_mass = 0.0 164 flag = len(idx) 165 for i in range(length): 166 position_within_column = ( 167 cell_origin[-1, idx[i]] + position_in_cell[-1, idx[i]] 168 ) 169 if ( 170 # falling 171 displacement[-1, idx[i]] < 0 172 and 173 # and crossed precip-counting level 174 position_within_column < precipitation_counting_level_index 175 ): 176 rainfall_mass += abs(water_mass[idx[i]]) * multiplicity[idx[i]] 177 idx[i] = flag 178 healthy[0] = 0 179 return rainfall_mass 180 181 return body 182 183 @cached_property 184 def _flag_out_of_column_body(self): 185 @numba.njit(**{**self.default_jit_flags, "parallel": False}) 186 # pylint: disable=too-many-arguments 187 def body( 188 cell_origin, position_in_cell, idx, length, healthy, domain_top_level_index 189 ): 190 flag = len(idx) 191 for i in range(length): 192 position_within_column = ( 193 cell_origin[-1, idx[i]] + position_in_cell[-1, idx[i]] 194 ) 195 if ( 196 position_within_column < 0 197 or position_within_column > domain_top_level_index 198 ): 199 idx[i] = flag 200 healthy[0] = 0 201 202 return body 203 204 # pylint: disable=too-many-arguments 205 def flag_precipitated( 206 self, 207 *, 208 cell_origin, 209 position_in_cell, 210 water_mass, 211 multiplicity, 212 idx, 213 length, 214 healthy, 215 precipitation_counting_level_index, 216 displacement, 217 ) -> float: 218 """return a scalar value corresponding to the mass of water (all phases) that crossed 219 the bottom boundary of the entire domain""" 220 return self._flag_precipitated_body( 221 cell_origin.data, 222 position_in_cell.data, 223 water_mass.data, 224 multiplicity.data, 225 idx.data, 226 length, 227 healthy.data, 228 precipitation_counting_level_index, 229 displacement.data, 230 ) 231 232 # pylint: disable=too-many-arguments 233 def flag_out_of_column( 234 self, 235 cell_origin, 236 position_in_cell, 237 idx, 238 length, 239 healthy, 240 domain_top_level_index, 241 ): 242 self._flag_out_of_column_body( 243 cell_origin.data, 244 position_in_cell.data, 245 idx.data, 246 length, 247 healthy.data, 248 domain_top_level_index, 249 )
@numba.njit(**{**conf.JIT_FLAGS, **{'parallel': False}})
def
calculate_displacement_body_common( dim, droplet, scheme, _l, _r, displacement, courant, position_in_cell, n_substeps):
15@numba.njit(**{**conf.JIT_FLAGS, **{"parallel": False}}) 16# pylint: disable=too-many-arguments 17def calculate_displacement_body_common( 18 dim, droplet, scheme, _l, _r, displacement, courant, position_in_cell, n_substeps 19): 20 displacement[dim, droplet] = scheme( 21 position_in_cell[dim, droplet], 22 courant[_l] / n_substeps, 23 courant[_r] / n_substeps, 24 )
27class DisplacementMethods(BackendMethods): 28 @staticmethod 29 @numba.njit(**{**conf.JIT_FLAGS, **{"parallel": False, "cache": False}}) 30 # pylint: disable=too-many-arguments 31 def calculate_displacement_body_1d( 32 dim, scheme, displacement, courant, cell_origin, position_in_cell, n_substeps 33 ): 34 length = displacement.shape[1] 35 for droplet in numba.prange(length): # pylint: disable=not-an-iterable 36 # Arakawa-C grid 37 _l = cell_origin[0, droplet] 38 _r = cell_origin[0, droplet] + 1 39 calculate_displacement_body_common( 40 dim, 41 droplet, 42 scheme, 43 _l, 44 _r, 45 displacement, 46 courant, 47 position_in_cell, 48 n_substeps, 49 ) 50 51 @staticmethod 52 @numba.njit(**{**conf.JIT_FLAGS, **{"parallel": False, "cache": False}}) 53 # pylint: disable=too-many-arguments 54 def calculate_displacement_body_2d( 55 dim, scheme, displacement, courant, cell_origin, position_in_cell, n_substeps 56 ): 57 length = displacement.shape[1] 58 for droplet in numba.prange(length): # pylint: disable=not-an-iterable 59 # Arakawa-C grid 60 _l = ( 61 cell_origin[0, droplet], 62 cell_origin[1, droplet], 63 ) 64 _r = ( 65 cell_origin[0, droplet] + 1 * (dim == 0), 66 cell_origin[1, droplet] + 1 * (dim == 1), 67 ) 68 calculate_displacement_body_common( 69 dim, 70 droplet, 71 scheme, 72 _l, 73 _r, 74 displacement, 75 courant, 76 position_in_cell, 77 n_substeps, 78 ) 79 80 @staticmethod 81 @numba.njit(**{**conf.JIT_FLAGS, **{"parallel": False, "cache": False}}) 82 # pylint: disable=too-many-arguments 83 def calculate_displacement_body_3d( 84 dim, scheme, displacement, courant, cell_origin, position_in_cell, n_substeps 85 ): 86 n_sd = displacement.shape[1] 87 for droplet in numba.prange(n_sd): # pylint: disable=not-an-iterable 88 # Arakawa-C grid 89 _l = ( 90 cell_origin[0, droplet], 91 cell_origin[1, droplet], 92 cell_origin[2, droplet], 93 ) 94 _r = ( 95 cell_origin[0, droplet] + 1 * (dim == 0), 96 cell_origin[1, droplet] + 1 * (dim == 1), 97 cell_origin[2, droplet] + 1 * (dim == 2), 98 ) 99 calculate_displacement_body_common( 100 dim, 101 droplet, 102 scheme, 103 _l, 104 _r, 105 displacement, 106 courant, 107 position_in_cell, 108 n_substeps, 109 ) 110 111 def calculate_displacement( 112 self, *, dim, displacement, courant, cell_origin, position_in_cell, n_substeps 113 ): 114 n_dims = len(courant.shape) 115 scheme = self.formulae.particle_advection.displacement 116 if n_dims == 1: 117 DisplacementMethods.calculate_displacement_body_1d( 118 dim, 119 scheme, 120 displacement.data, 121 courant.data, 122 cell_origin.data, 123 position_in_cell.data, 124 n_substeps, 125 ) 126 elif n_dims == 2: 127 DisplacementMethods.calculate_displacement_body_2d( 128 dim, 129 scheme, 130 displacement.data, 131 courant.data, 132 cell_origin.data, 133 position_in_cell.data, 134 n_substeps, 135 ) 136 elif n_dims == 3: 137 DisplacementMethods.calculate_displacement_body_3d( 138 dim, 139 scheme, 140 displacement.data, 141 courant.data, 142 cell_origin.data, 143 position_in_cell.data, 144 n_substeps, 145 ) 146 else: 147 raise NotImplementedError() 148 149 @cached_property 150 def _flag_precipitated_body(self): 151 @numba.njit(**{**self.default_jit_flags, "parallel": False}) 152 # pylint: disable=too-many-arguments 153 def body( 154 cell_origin, 155 position_in_cell, 156 water_mass, 157 multiplicity, 158 idx, 159 length, 160 healthy, 161 precipitation_counting_level_index, 162 displacement, 163 ): 164 rainfall_mass = 0.0 165 flag = len(idx) 166 for i in range(length): 167 position_within_column = ( 168 cell_origin[-1, idx[i]] + position_in_cell[-1, idx[i]] 169 ) 170 if ( 171 # falling 172 displacement[-1, idx[i]] < 0 173 and 174 # and crossed precip-counting level 175 position_within_column < precipitation_counting_level_index 176 ): 177 rainfall_mass += abs(water_mass[idx[i]]) * multiplicity[idx[i]] 178 idx[i] = flag 179 healthy[0] = 0 180 return rainfall_mass 181 182 return body 183 184 @cached_property 185 def _flag_out_of_column_body(self): 186 @numba.njit(**{**self.default_jit_flags, "parallel": False}) 187 # pylint: disable=too-many-arguments 188 def body( 189 cell_origin, position_in_cell, idx, length, healthy, domain_top_level_index 190 ): 191 flag = len(idx) 192 for i in range(length): 193 position_within_column = ( 194 cell_origin[-1, idx[i]] + position_in_cell[-1, idx[i]] 195 ) 196 if ( 197 position_within_column < 0 198 or position_within_column > domain_top_level_index 199 ): 200 idx[i] = flag 201 healthy[0] = 0 202 203 return body 204 205 # pylint: disable=too-many-arguments 206 def flag_precipitated( 207 self, 208 *, 209 cell_origin, 210 position_in_cell, 211 water_mass, 212 multiplicity, 213 idx, 214 length, 215 healthy, 216 precipitation_counting_level_index, 217 displacement, 218 ) -> float: 219 """return a scalar value corresponding to the mass of water (all phases) that crossed 220 the bottom boundary of the entire domain""" 221 return self._flag_precipitated_body( 222 cell_origin.data, 223 position_in_cell.data, 224 water_mass.data, 225 multiplicity.data, 226 idx.data, 227 length, 228 healthy.data, 229 precipitation_counting_level_index, 230 displacement.data, 231 ) 232 233 # pylint: disable=too-many-arguments 234 def flag_out_of_column( 235 self, 236 cell_origin, 237 position_in_cell, 238 idx, 239 length, 240 healthy, 241 domain_top_level_index, 242 ): 243 self._flag_out_of_column_body( 244 cell_origin.data, 245 position_in_cell.data, 246 idx.data, 247 length, 248 healthy.data, 249 domain_top_level_index, 250 )
@staticmethod
@numba.njit(**{**conf.JIT_FLAGS, **{'parallel': False, 'cache': False}})
def
calculate_displacement_body_1d( dim, scheme, displacement, courant, cell_origin, position_in_cell, n_substeps):
28 @staticmethod 29 @numba.njit(**{**conf.JIT_FLAGS, **{"parallel": False, "cache": False}}) 30 # pylint: disable=too-many-arguments 31 def calculate_displacement_body_1d( 32 dim, scheme, displacement, courant, cell_origin, position_in_cell, n_substeps 33 ): 34 length = displacement.shape[1] 35 for droplet in numba.prange(length): # pylint: disable=not-an-iterable 36 # Arakawa-C grid 37 _l = cell_origin[0, droplet] 38 _r = cell_origin[0, droplet] + 1 39 calculate_displacement_body_common( 40 dim, 41 droplet, 42 scheme, 43 _l, 44 _r, 45 displacement, 46 courant, 47 position_in_cell, 48 n_substeps, 49 )
@staticmethod
@numba.njit(**{**conf.JIT_FLAGS, **{'parallel': False, 'cache': False}})
def
calculate_displacement_body_2d( dim, scheme, displacement, courant, cell_origin, position_in_cell, n_substeps):
51 @staticmethod 52 @numba.njit(**{**conf.JIT_FLAGS, **{"parallel": False, "cache": False}}) 53 # pylint: disable=too-many-arguments 54 def calculate_displacement_body_2d( 55 dim, scheme, displacement, courant, cell_origin, position_in_cell, n_substeps 56 ): 57 length = displacement.shape[1] 58 for droplet in numba.prange(length): # pylint: disable=not-an-iterable 59 # Arakawa-C grid 60 _l = ( 61 cell_origin[0, droplet], 62 cell_origin[1, droplet], 63 ) 64 _r = ( 65 cell_origin[0, droplet] + 1 * (dim == 0), 66 cell_origin[1, droplet] + 1 * (dim == 1), 67 ) 68 calculate_displacement_body_common( 69 dim, 70 droplet, 71 scheme, 72 _l, 73 _r, 74 displacement, 75 courant, 76 position_in_cell, 77 n_substeps, 78 )
@staticmethod
@numba.njit(**{**conf.JIT_FLAGS, **{'parallel': False, 'cache': False}})
def
calculate_displacement_body_3d( dim, scheme, displacement, courant, cell_origin, position_in_cell, n_substeps):
80 @staticmethod 81 @numba.njit(**{**conf.JIT_FLAGS, **{"parallel": False, "cache": False}}) 82 # pylint: disable=too-many-arguments 83 def calculate_displacement_body_3d( 84 dim, scheme, displacement, courant, cell_origin, position_in_cell, n_substeps 85 ): 86 n_sd = displacement.shape[1] 87 for droplet in numba.prange(n_sd): # pylint: disable=not-an-iterable 88 # Arakawa-C grid 89 _l = ( 90 cell_origin[0, droplet], 91 cell_origin[1, droplet], 92 cell_origin[2, droplet], 93 ) 94 _r = ( 95 cell_origin[0, droplet] + 1 * (dim == 0), 96 cell_origin[1, droplet] + 1 * (dim == 1), 97 cell_origin[2, droplet] + 1 * (dim == 2), 98 ) 99 calculate_displacement_body_common( 100 dim, 101 droplet, 102 scheme, 103 _l, 104 _r, 105 displacement, 106 courant, 107 position_in_cell, 108 n_substeps, 109 )
def
calculate_displacement( self, *, dim, displacement, courant, cell_origin, position_in_cell, n_substeps):
111 def calculate_displacement( 112 self, *, dim, displacement, courant, cell_origin, position_in_cell, n_substeps 113 ): 114 n_dims = len(courant.shape) 115 scheme = self.formulae.particle_advection.displacement 116 if n_dims == 1: 117 DisplacementMethods.calculate_displacement_body_1d( 118 dim, 119 scheme, 120 displacement.data, 121 courant.data, 122 cell_origin.data, 123 position_in_cell.data, 124 n_substeps, 125 ) 126 elif n_dims == 2: 127 DisplacementMethods.calculate_displacement_body_2d( 128 dim, 129 scheme, 130 displacement.data, 131 courant.data, 132 cell_origin.data, 133 position_in_cell.data, 134 n_substeps, 135 ) 136 elif n_dims == 3: 137 DisplacementMethods.calculate_displacement_body_3d( 138 dim, 139 scheme, 140 displacement.data, 141 courant.data, 142 cell_origin.data, 143 position_in_cell.data, 144 n_substeps, 145 ) 146 else: 147 raise NotImplementedError()
def
flag_precipitated( self, *, cell_origin, position_in_cell, water_mass, multiplicity, idx, length, healthy, precipitation_counting_level_index, displacement) -> float:
206 def flag_precipitated( 207 self, 208 *, 209 cell_origin, 210 position_in_cell, 211 water_mass, 212 multiplicity, 213 idx, 214 length, 215 healthy, 216 precipitation_counting_level_index, 217 displacement, 218 ) -> float: 219 """return a scalar value corresponding to the mass of water (all phases) that crossed 220 the bottom boundary of the entire domain""" 221 return self._flag_precipitated_body( 222 cell_origin.data, 223 position_in_cell.data, 224 water_mass.data, 225 multiplicity.data, 226 idx.data, 227 length, 228 healthy.data, 229 precipitation_counting_level_index, 230 displacement.data, 231 )
return a scalar value corresponding to the mass of water (all phases) that crossed the bottom boundary of the entire domain
def
flag_out_of_column( self, cell_origin, position_in_cell, idx, length, healthy, domain_top_level_index):
234 def flag_out_of_column( 235 self, 236 cell_origin, 237 position_in_cell, 238 idx, 239 length, 240 healthy, 241 domain_top_level_index, 242 ): 243 self._flag_out_of_column_body( 244 cell_origin.data, 245 position_in_cell.data, 246 idx.data, 247 length, 248 healthy.data, 249 domain_top_level_index, 250 )