PyMPDATA.stepper
MPDATA iteration logic
1"""MPDATA iteration logic""" 2 3import sys 4import warnings 5from functools import lru_cache 6 7import numba 8import numpy as np 9from numba.core.errors import NumbaExperimentalFeatureWarning 10 11from .impl.clock import clock 12from .impl.enumerations import IMPL_BC, IMPL_META_AND_DATA, MAX_DIM_NUM 13from .impl.formulae_antidiff import make_antidiff 14from .impl.formulae_axpy import make_axpy 15from .impl.formulae_flux import make_flux_first_pass, make_flux_subsequent 16from .impl.formulae_laplacian import make_laplacian 17from .impl.formulae_nonoscillatory import make_beta, make_correction, make_psi_extrema 18from .impl.formulae_upwind import make_upwind 19from .impl.meta import _Impl 20from .impl.traversals import Traversals 21from .options import Options 22 23 24class Stepper: 25 """MPDATA stepper specialised for given options, dimensionality and optionally grid 26 (instances of Stepper can be shared among `Solver`s)""" 27 28 def __init__( 29 self, 30 *, 31 options: Options, 32 n_dims: (int, None) = None, 33 non_unit_g_factor: bool = False, 34 grid: (tuple, None) = None, 35 n_threads: (int, None) = None, 36 left_first: (tuple, None) = None, 37 buffer_size: int = 0 38 ): 39 if n_dims is not None and grid is not None: 40 raise ValueError() 41 if n_dims is None and grid is None: 42 raise ValueError() 43 if grid is None: 44 grid = tuple([-1] * n_dims) 45 if n_dims is None: 46 n_dims = len(grid) 47 if n_dims > 1 and options.DPDC: 48 raise NotImplementedError() 49 if n_threads is None: 50 n_threads = numba.get_num_threads() 51 if left_first is None: 52 left_first = tuple([True] * MAX_DIM_NUM) 53 54 self.__options = options 55 self.__n_threads = 1 if n_dims == 1 else n_threads 56 57 if self.__n_threads > 1: 58 try: 59 numba.parfors.parfor.ensure_parallel_support() 60 except numba.core.errors.UnsupportedParforsError: 61 print( 62 "Numba ensure_parallel_support() failed, forcing n_threads=1", 63 file=sys.stderr, 64 ) 65 self.__n_threads = 1 66 67 if not numba.config.DISABLE_JIT: # pylint: disable=no-member 68 69 @numba.jit(parallel=True, nopython=True) 70 def fill_array_with_thread_id(arr): 71 """writes thread id to corresponding array element""" 72 for i in numba.prange( # pylint: disable=not-an-iterable 73 numba.get_num_threads() 74 ): 75 arr[i] = numba.get_thread_id() 76 77 arr = np.full(numba.get_num_threads(), -1) 78 fill_array_with_thread_id(arr) 79 if not max(arr) > 0: 80 raise ValueError( 81 "n_threads>1 requested, but Numba does not seem to parallelize" 82 " (try changing Numba threading backend?)" 83 ) 84 85 self.__n_dims = n_dims 86 self.__call, self.traversals = make_step_impl( 87 options, 88 non_unit_g_factor, 89 grid, 90 self.n_threads, 91 left_first=left_first, 92 buffer_size=buffer_size, 93 ) 94 95 @property 96 def options(self) -> Options: 97 """`Options` instance used""" 98 return self.__options 99 100 @property 101 def n_threads(self) -> int: 102 """actual n_threads used (may be different than passed to __init__ if n_dims==1 103 or if on a platform where Numba does not support threading)""" 104 return self.__n_threads 105 106 @property 107 def n_dims(self) -> int: 108 """dimensionality (1, 2 or 3)""" 109 return self.__n_dims 110 111 def __call__(self, *, n_steps, mu_coeff, ante_step, post_step, post_iter, fields): 112 assert self.n_threads == 1 or numba.get_num_threads() == self.n_threads 113 with warnings.catch_warnings(): 114 warnings.simplefilter("ignore", category=NumbaExperimentalFeatureWarning) 115 wall_time_per_timestep = self.__call( 116 n_steps, 117 mu_coeff, 118 ante_step, 119 post_step, 120 post_iter, 121 *( 122 ( 123 _Impl(field=v.impl[IMPL_META_AND_DATA], bc=v.impl[IMPL_BC]) 124 if k != "advectee" 125 else tuple( 126 _Impl( 127 field=vv.impl[IMPL_META_AND_DATA], bc=vv.impl[IMPL_BC] 128 ) 129 for vv in v 130 ) 131 ) 132 for k, v in fields.items() 133 ), 134 self.traversals.data, 135 ) 136 return wall_time_per_timestep 137 138 139@lru_cache() 140# pylint: disable=too-many-locals,too-many-statements,too-many-arguments 141def make_step_impl( 142 options, non_unit_g_factor, grid, n_threads, left_first: tuple, buffer_size 143): 144 """returns (and caches) an njit-ted stepping function and a traversals pair""" 145 traversals = Traversals( 146 grid=grid, 147 halo=options.n_halo, 148 jit_flags=options.jit_flags, 149 n_threads=n_threads, 150 left_first=left_first, 151 buffer_size=buffer_size, 152 ) 153 154 n_iters = options.n_iters 155 non_zero_mu_coeff = options.non_zero_mu_coeff 156 nonoscillatory = options.nonoscillatory 157 158 upwind = make_upwind(options, non_unit_g_factor, traversals) 159 flux_first_pass = make_flux_first_pass(options, traversals) 160 flux_subsequent = make_flux_subsequent(options, traversals) 161 antidiff = make_antidiff(non_unit_g_factor, options, traversals) 162 antidiff_last_pass = make_antidiff( 163 non_unit_g_factor, options, traversals, last_pass=True 164 ) 165 laplacian = make_laplacian(non_unit_g_factor, options, traversals) 166 nonoscillatory_psi_extrema = make_psi_extrema(options, traversals) 167 nonoscillatory_beta = make_beta(non_unit_g_factor, options, traversals) 168 nonoscillatory_correction = make_correction(options, traversals) 169 axpy = make_axpy(options, traversals) 170 171 @numba.njit(**options.jit_flags) 172 # pylint: disable=too-many-arguments,too-many-locals,too-many-branches,unnecessary-dunder-call 173 def step( 174 n_steps, 175 mu_coeff, 176 ante_step, 177 post_step, 178 post_iter, 179 advectees, 180 advector, 181 g_factor, 182 vectmp_a, 183 vectmp_b, 184 vectmp_c, 185 dynamic_advector_stash_outer, 186 dynamic_advector_stash_mid3d, 187 dynamic_advector_stash_inner, 188 psi_extrema, 189 beta, 190 traversals_data, 191 ): 192 time = clock() 193 for step in range(n_steps): 194 for index, advectee in enumerate(advectees): 195 ante_step.call( 196 traversals_data, 197 advectees, 198 advector, 199 step, 200 index, 201 dynamic_advector_stash_outer, 202 dynamic_advector_stash_mid3d, 203 dynamic_advector_stash_inner, 204 ) 205 if non_zero_mu_coeff: 206 advector_orig = advector 207 advector = vectmp_c 208 for iteration in range(n_iters): 209 if iteration == 0: 210 if nonoscillatory: 211 nonoscillatory_psi_extrema( 212 traversals_data, psi_extrema, advectee 213 ) 214 if non_zero_mu_coeff: 215 laplacian(traversals_data, advector, advectee) 216 axpy( 217 *advector.field, 218 mu_coeff, 219 *advector.field, 220 *advector_orig.field, 221 ) 222 flux_first_pass(traversals_data, vectmp_a, advector, advectee) 223 flux = vectmp_a 224 else: 225 if iteration == 1: 226 advector_oscil = advector 227 advector_nonos = vectmp_a 228 flux = vectmp_b 229 elif iteration % 2 == 0: 230 advector_oscil = vectmp_a 231 advector_nonos = vectmp_b 232 flux = vectmp_a 233 else: 234 advector_oscil = vectmp_b 235 advector_nonos = vectmp_a 236 flux = vectmp_b 237 if iteration < n_iters - 1: 238 antidiff( 239 traversals_data, 240 advector_nonos, 241 advectee, 242 advector_oscil, 243 g_factor, 244 ) 245 else: 246 antidiff_last_pass( 247 traversals_data, 248 advector_nonos, 249 advectee, 250 advector_oscil, 251 g_factor, 252 ) 253 flux_subsequent(traversals_data, flux, advectee, advector_nonos) 254 if nonoscillatory: 255 nonoscillatory_beta( 256 traversals_data, 257 beta, 258 flux, 259 advectee, 260 psi_extrema, 261 g_factor, 262 ) 263 # note: in libmpdata++, the oscillatory advector from prev iter is used 264 nonoscillatory_correction( 265 traversals_data, advector_nonos, beta 266 ) 267 flux_subsequent( 268 traversals_data, flux, advectee, advector_nonos 269 ) 270 upwind(traversals_data, advectee, flux, g_factor) 271 post_iter.call( 272 traversals_data, flux.field, g_factor.field, step, iteration 273 ) 274 if non_zero_mu_coeff: 275 advector = advector_orig 276 post_step.call(traversals_data, advectees, step, index) 277 return (clock() - time) / n_steps if n_steps > 0 else np.nan 278 279 return step, traversals
class
Stepper:
25class Stepper: 26 """MPDATA stepper specialised for given options, dimensionality and optionally grid 27 (instances of Stepper can be shared among `Solver`s)""" 28 29 def __init__( 30 self, 31 *, 32 options: Options, 33 n_dims: (int, None) = None, 34 non_unit_g_factor: bool = False, 35 grid: (tuple, None) = None, 36 n_threads: (int, None) = None, 37 left_first: (tuple, None) = None, 38 buffer_size: int = 0 39 ): 40 if n_dims is not None and grid is not None: 41 raise ValueError() 42 if n_dims is None and grid is None: 43 raise ValueError() 44 if grid is None: 45 grid = tuple([-1] * n_dims) 46 if n_dims is None: 47 n_dims = len(grid) 48 if n_dims > 1 and options.DPDC: 49 raise NotImplementedError() 50 if n_threads is None: 51 n_threads = numba.get_num_threads() 52 if left_first is None: 53 left_first = tuple([True] * MAX_DIM_NUM) 54 55 self.__options = options 56 self.__n_threads = 1 if n_dims == 1 else n_threads 57 58 if self.__n_threads > 1: 59 try: 60 numba.parfors.parfor.ensure_parallel_support() 61 except numba.core.errors.UnsupportedParforsError: 62 print( 63 "Numba ensure_parallel_support() failed, forcing n_threads=1", 64 file=sys.stderr, 65 ) 66 self.__n_threads = 1 67 68 if not numba.config.DISABLE_JIT: # pylint: disable=no-member 69 70 @numba.jit(parallel=True, nopython=True) 71 def fill_array_with_thread_id(arr): 72 """writes thread id to corresponding array element""" 73 for i in numba.prange( # pylint: disable=not-an-iterable 74 numba.get_num_threads() 75 ): 76 arr[i] = numba.get_thread_id() 77 78 arr = np.full(numba.get_num_threads(), -1) 79 fill_array_with_thread_id(arr) 80 if not max(arr) > 0: 81 raise ValueError( 82 "n_threads>1 requested, but Numba does not seem to parallelize" 83 " (try changing Numba threading backend?)" 84 ) 85 86 self.__n_dims = n_dims 87 self.__call, self.traversals = make_step_impl( 88 options, 89 non_unit_g_factor, 90 grid, 91 self.n_threads, 92 left_first=left_first, 93 buffer_size=buffer_size, 94 ) 95 96 @property 97 def options(self) -> Options: 98 """`Options` instance used""" 99 return self.__options 100 101 @property 102 def n_threads(self) -> int: 103 """actual n_threads used (may be different than passed to __init__ if n_dims==1 104 or if on a platform where Numba does not support threading)""" 105 return self.__n_threads 106 107 @property 108 def n_dims(self) -> int: 109 """dimensionality (1, 2 or 3)""" 110 return self.__n_dims 111 112 def __call__(self, *, n_steps, mu_coeff, ante_step, post_step, post_iter, fields): 113 assert self.n_threads == 1 or numba.get_num_threads() == self.n_threads 114 with warnings.catch_warnings(): 115 warnings.simplefilter("ignore", category=NumbaExperimentalFeatureWarning) 116 wall_time_per_timestep = self.__call( 117 n_steps, 118 mu_coeff, 119 ante_step, 120 post_step, 121 post_iter, 122 *( 123 ( 124 _Impl(field=v.impl[IMPL_META_AND_DATA], bc=v.impl[IMPL_BC]) 125 if k != "advectee" 126 else tuple( 127 _Impl( 128 field=vv.impl[IMPL_META_AND_DATA], bc=vv.impl[IMPL_BC] 129 ) 130 for vv in v 131 ) 132 ) 133 for k, v in fields.items() 134 ), 135 self.traversals.data, 136 ) 137 return wall_time_per_timestep
MPDATA stepper specialised for given options, dimensionality and optionally grid
(instances of Stepper can be shared among Solvers)
Stepper( *, options: PyMPDATA.options.Options, n_dims: (<class 'int'>, None) = None, non_unit_g_factor: bool = False, grid: (<class 'tuple'>, None) = None, n_threads: (<class 'int'>, None) = None, left_first: (<class 'tuple'>, None) = None, buffer_size: int = 0)
29 def __init__( 30 self, 31 *, 32 options: Options, 33 n_dims: (int, None) = None, 34 non_unit_g_factor: bool = False, 35 grid: (tuple, None) = None, 36 n_threads: (int, None) = None, 37 left_first: (tuple, None) = None, 38 buffer_size: int = 0 39 ): 40 if n_dims is not None and grid is not None: 41 raise ValueError() 42 if n_dims is None and grid is None: 43 raise ValueError() 44 if grid is None: 45 grid = tuple([-1] * n_dims) 46 if n_dims is None: 47 n_dims = len(grid) 48 if n_dims > 1 and options.DPDC: 49 raise NotImplementedError() 50 if n_threads is None: 51 n_threads = numba.get_num_threads() 52 if left_first is None: 53 left_first = tuple([True] * MAX_DIM_NUM) 54 55 self.__options = options 56 self.__n_threads = 1 if n_dims == 1 else n_threads 57 58 if self.__n_threads > 1: 59 try: 60 numba.parfors.parfor.ensure_parallel_support() 61 except numba.core.errors.UnsupportedParforsError: 62 print( 63 "Numba ensure_parallel_support() failed, forcing n_threads=1", 64 file=sys.stderr, 65 ) 66 self.__n_threads = 1 67 68 if not numba.config.DISABLE_JIT: # pylint: disable=no-member 69 70 @numba.jit(parallel=True, nopython=True) 71 def fill_array_with_thread_id(arr): 72 """writes thread id to corresponding array element""" 73 for i in numba.prange( # pylint: disable=not-an-iterable 74 numba.get_num_threads() 75 ): 76 arr[i] = numba.get_thread_id() 77 78 arr = np.full(numba.get_num_threads(), -1) 79 fill_array_with_thread_id(arr) 80 if not max(arr) > 0: 81 raise ValueError( 82 "n_threads>1 requested, but Numba does not seem to parallelize" 83 " (try changing Numba threading backend?)" 84 ) 85 86 self.__n_dims = n_dims 87 self.__call, self.traversals = make_step_impl( 88 options, 89 non_unit_g_factor, 90 grid, 91 self.n_threads, 92 left_first=left_first, 93 buffer_size=buffer_size, 94 )
options: PyMPDATA.options.Options
96 @property 97 def options(self) -> Options: 98 """`Options` instance used""" 99 return self.__options
Options instance used
n_threads: int
101 @property 102 def n_threads(self) -> int: 103 """actual n_threads used (may be different than passed to __init__ if n_dims==1 104 or if on a platform where Numba does not support threading)""" 105 return self.__n_threads
actual n_threads used (may be different than passed to __init__ if n_dims==1 or if on a platform where Numba does not support threading)
@lru_cache()
def
make_step_impl( options, non_unit_g_factor, grid, n_threads, left_first: tuple, buffer_size):
140@lru_cache() 141# pylint: disable=too-many-locals,too-many-statements,too-many-arguments 142def make_step_impl( 143 options, non_unit_g_factor, grid, n_threads, left_first: tuple, buffer_size 144): 145 """returns (and caches) an njit-ted stepping function and a traversals pair""" 146 traversals = Traversals( 147 grid=grid, 148 halo=options.n_halo, 149 jit_flags=options.jit_flags, 150 n_threads=n_threads, 151 left_first=left_first, 152 buffer_size=buffer_size, 153 ) 154 155 n_iters = options.n_iters 156 non_zero_mu_coeff = options.non_zero_mu_coeff 157 nonoscillatory = options.nonoscillatory 158 159 upwind = make_upwind(options, non_unit_g_factor, traversals) 160 flux_first_pass = make_flux_first_pass(options, traversals) 161 flux_subsequent = make_flux_subsequent(options, traversals) 162 antidiff = make_antidiff(non_unit_g_factor, options, traversals) 163 antidiff_last_pass = make_antidiff( 164 non_unit_g_factor, options, traversals, last_pass=True 165 ) 166 laplacian = make_laplacian(non_unit_g_factor, options, traversals) 167 nonoscillatory_psi_extrema = make_psi_extrema(options, traversals) 168 nonoscillatory_beta = make_beta(non_unit_g_factor, options, traversals) 169 nonoscillatory_correction = make_correction(options, traversals) 170 axpy = make_axpy(options, traversals) 171 172 @numba.njit(**options.jit_flags) 173 # pylint: disable=too-many-arguments,too-many-locals,too-many-branches,unnecessary-dunder-call 174 def step( 175 n_steps, 176 mu_coeff, 177 ante_step, 178 post_step, 179 post_iter, 180 advectees, 181 advector, 182 g_factor, 183 vectmp_a, 184 vectmp_b, 185 vectmp_c, 186 dynamic_advector_stash_outer, 187 dynamic_advector_stash_mid3d, 188 dynamic_advector_stash_inner, 189 psi_extrema, 190 beta, 191 traversals_data, 192 ): 193 time = clock() 194 for step in range(n_steps): 195 for index, advectee in enumerate(advectees): 196 ante_step.call( 197 traversals_data, 198 advectees, 199 advector, 200 step, 201 index, 202 dynamic_advector_stash_outer, 203 dynamic_advector_stash_mid3d, 204 dynamic_advector_stash_inner, 205 ) 206 if non_zero_mu_coeff: 207 advector_orig = advector 208 advector = vectmp_c 209 for iteration in range(n_iters): 210 if iteration == 0: 211 if nonoscillatory: 212 nonoscillatory_psi_extrema( 213 traversals_data, psi_extrema, advectee 214 ) 215 if non_zero_mu_coeff: 216 laplacian(traversals_data, advector, advectee) 217 axpy( 218 *advector.field, 219 mu_coeff, 220 *advector.field, 221 *advector_orig.field, 222 ) 223 flux_first_pass(traversals_data, vectmp_a, advector, advectee) 224 flux = vectmp_a 225 else: 226 if iteration == 1: 227 advector_oscil = advector 228 advector_nonos = vectmp_a 229 flux = vectmp_b 230 elif iteration % 2 == 0: 231 advector_oscil = vectmp_a 232 advector_nonos = vectmp_b 233 flux = vectmp_a 234 else: 235 advector_oscil = vectmp_b 236 advector_nonos = vectmp_a 237 flux = vectmp_b 238 if iteration < n_iters - 1: 239 antidiff( 240 traversals_data, 241 advector_nonos, 242 advectee, 243 advector_oscil, 244 g_factor, 245 ) 246 else: 247 antidiff_last_pass( 248 traversals_data, 249 advector_nonos, 250 advectee, 251 advector_oscil, 252 g_factor, 253 ) 254 flux_subsequent(traversals_data, flux, advectee, advector_nonos) 255 if nonoscillatory: 256 nonoscillatory_beta( 257 traversals_data, 258 beta, 259 flux, 260 advectee, 261 psi_extrema, 262 g_factor, 263 ) 264 # note: in libmpdata++, the oscillatory advector from prev iter is used 265 nonoscillatory_correction( 266 traversals_data, advector_nonos, beta 267 ) 268 flux_subsequent( 269 traversals_data, flux, advectee, advector_nonos 270 ) 271 upwind(traversals_data, advectee, flux, g_factor) 272 post_iter.call( 273 traversals_data, flux.field, g_factor.field, step, iteration 274 ) 275 if non_zero_mu_coeff: 276 advector = advector_orig 277 post_step.call(traversals_data, advectees, step, index) 278 return (clock() - time) / n_steps if n_steps > 0 else np.nan 279 280 return step, traversals
returns (and caches) an njit-ted stepping function and a traversals pair