PyMPDATA.stepper
MPDATA iteration logic
1"""MPDATA iteration logic""" 2 3import sys 4import warnings 5from functools import lru_cache 6 7import numba 8import numpy as np 9from numba.core.errors import NumbaExperimentalFeatureWarning 10 11from .impl.clock import clock 12from .impl.enumerations import ARG_DATA, IMPL_BC, IMPL_META_AND_DATA, MAX_DIM_NUM 13from .impl.formulae_antidiff import make_antidiff 14from .impl.formulae_axpy import make_axpy 15from .impl.formulae_flux import make_flux_first_pass, make_flux_subsequent 16from .impl.formulae_laplacian import make_laplacian 17from .impl.formulae_nonoscillatory import make_beta, make_correction, make_psi_extrema 18from .impl.formulae_upwind import make_upwind 19from .impl.meta import _Impl 20from .impl.traversals import Traversals 21from .options import Options 22 23 24class Stepper: 25 """MPDATA stepper specialised for given options, dimensionality and optionally grid 26 (instances of Stepper can be shared among `Solver`s)""" 27 28 def __init__( 29 self, 30 *, 31 options: Options, 32 n_dims: (int, None) = None, 33 non_unit_g_factor: bool = False, 34 grid: (tuple, None) = None, 35 n_threads: (int, None) = None, 36 left_first: (tuple, None) = None, 37 buffer_size: int = 0 38 ): 39 if n_dims is not None and grid is not None: 40 raise ValueError() 41 if n_dims is None and grid is None: 42 raise ValueError() 43 if grid is None: 44 grid = tuple([-1] * n_dims) 45 if n_dims is None: 46 n_dims = len(grid) 47 if n_dims > 1 and options.DPDC: 48 raise NotImplementedError() 49 if n_threads is None: 50 n_threads = numba.get_num_threads() 51 if left_first is None: 52 left_first = tuple([True] * MAX_DIM_NUM) 53 54 self.__options = options 55 self.__n_threads = 1 if n_dims == 1 else n_threads 56 57 if self.__n_threads > 1: 58 try: 59 numba.parfors.parfor.ensure_parallel_support() 60 except numba.core.errors.UnsupportedParforsError: 61 print( 62 "Numba ensure_parallel_support() failed, forcing n_threads=1", 63 file=sys.stderr, 64 ) 65 self.__n_threads = 1 66 67 if not numba.config.DISABLE_JIT: # pylint: disable=no-member 68 69 @numba.jit(parallel=True, nopython=True) 70 def fill_array_with_thread_id(arr): 71 """writes thread id to corresponding array element""" 72 for i in numba.prange( # pylint: disable=not-an-iterable 73 numba.get_num_threads() 74 ): 75 arr[i] = numba.get_thread_id() 76 77 arr = np.full(numba.get_num_threads(), -1) 78 fill_array_with_thread_id(arr) 79 if not max(arr) > 0: 80 raise ValueError( 81 "n_threads>1 requested, but Numba does not seem to parallelize" 82 " (try changing Numba threading backend?)" 83 ) 84 85 self.__n_dims = n_dims 86 self.__call, self.traversals = make_step_impl( 87 options, 88 non_unit_g_factor, 89 grid, 90 self.n_threads, 91 left_first=left_first, 92 buffer_size=buffer_size, 93 ) 94 95 @property 96 def options(self) -> Options: 97 """`Options` instance used""" 98 return self.__options 99 100 @property 101 def n_threads(self) -> int: 102 """actual n_threads used (may be different than passed to __init__ if n_dims==1 103 or if on a platform where Numba does not support threading)""" 104 return self.__n_threads 105 106 @property 107 def n_dims(self) -> int: 108 """dimensionality (1, 2 or 3)""" 109 return self.__n_dims 110 111 def __call__(self, *, n_steps, mu_coeff, post_step, post_iter, fields): 112 assert self.n_threads == 1 or numba.get_num_threads() == self.n_threads 113 with warnings.catch_warnings(): 114 warnings.simplefilter("ignore", category=NumbaExperimentalFeatureWarning) 115 wall_time_per_timestep = self.__call( 116 n_steps, 117 mu_coeff, 118 post_step, 119 post_iter, 120 *( 121 _Impl(field=v.impl[IMPL_META_AND_DATA], bc=v.impl[IMPL_BC]) 122 for v in fields.values() 123 ), 124 self.traversals.data, 125 ) 126 return wall_time_per_timestep 127 128 129@lru_cache() 130# pylint: disable=too-many-locals,too-many-statements,too-many-arguments 131def make_step_impl( 132 options, non_unit_g_factor, grid, n_threads, left_first: tuple, buffer_size 133): 134 """returns (and caches) an njit-ted stepping function and a traversals pair""" 135 traversals = Traversals( 136 grid=grid, 137 halo=options.n_halo, 138 jit_flags=options.jit_flags, 139 n_threads=n_threads, 140 left_first=left_first, 141 buffer_size=buffer_size, 142 ) 143 144 n_iters = options.n_iters 145 non_zero_mu_coeff = options.non_zero_mu_coeff 146 nonoscillatory = options.nonoscillatory 147 148 upwind = make_upwind(options, non_unit_g_factor, traversals) 149 flux_first_pass = make_flux_first_pass(options, traversals) 150 flux_subsequent = make_flux_subsequent(options, traversals) 151 antidiff = make_antidiff(non_unit_g_factor, options, traversals) 152 antidiff_last_pass = make_antidiff( 153 non_unit_g_factor, options, traversals, last_pass=True 154 ) 155 laplacian = make_laplacian(non_unit_g_factor, options, traversals) 156 nonoscillatory_psi_extrema = make_psi_extrema(options, traversals) 157 nonoscillatory_beta = make_beta(non_unit_g_factor, options, traversals) 158 nonoscillatory_correction = make_correction(options, traversals) 159 axpy = make_axpy(options, traversals) 160 161 @numba.njit(**options.jit_flags) 162 # pylint: disable=too-many-arguments,too-many-locals,too-many-branches,unnecessary-dunder-call 163 def step( 164 n_steps, 165 mu_coeff, 166 post_step, 167 post_iter, 168 advectee, 169 advector, 170 g_factor, 171 vectmp_a, 172 vectmp_b, 173 vectmp_c, 174 psi_extrema, 175 beta, 176 null_impl, 177 ): 178 time = clock() 179 for step in range(n_steps): 180 if non_zero_mu_coeff: 181 advector_orig = advector 182 advector = vectmp_c 183 for iteration in range(n_iters): 184 if iteration == 0: 185 if nonoscillatory: 186 nonoscillatory_psi_extrema(null_impl, psi_extrema, advectee) 187 if non_zero_mu_coeff: 188 laplacian(null_impl, advector, advectee) 189 axpy( 190 *advector.field, 191 mu_coeff, 192 *advector.field, 193 *advector_orig.field, 194 ) 195 flux_first_pass(null_impl, vectmp_a, advector, advectee) 196 flux = vectmp_a 197 else: 198 if iteration == 1: 199 advector_oscil = advector 200 advector_nonos = vectmp_a 201 flux = vectmp_b 202 elif iteration % 2 == 0: 203 advector_oscil = vectmp_a 204 advector_nonos = vectmp_b 205 flux = vectmp_a 206 else: 207 advector_oscil = vectmp_b 208 advector_nonos = vectmp_a 209 flux = vectmp_b 210 if iteration < n_iters - 1: 211 antidiff( 212 null_impl, 213 advector_nonos, 214 advectee, 215 advector_oscil, 216 g_factor, 217 ) 218 else: 219 antidiff_last_pass( 220 null_impl, 221 advector_nonos, 222 advectee, 223 advector_oscil, 224 g_factor, 225 ) 226 flux_subsequent(null_impl, flux, advectee, advector_nonos) 227 if nonoscillatory: 228 nonoscillatory_beta( 229 null_impl, beta, flux, advectee, psi_extrema, g_factor 230 ) 231 # note: in libmpdata++, the oscillatory advector from prev iter is used 232 nonoscillatory_correction(null_impl, advector_nonos, beta) 233 flux_subsequent(null_impl, flux, advectee, advector_nonos) 234 upwind(null_impl, advectee, flux, g_factor) 235 post_iter.call(flux.field, g_factor.field, step, iteration) 236 if non_zero_mu_coeff: 237 advector = advector_orig 238 post_step.call(advectee.field[ARG_DATA], step) 239 return (clock() - time) / n_steps if n_steps > 0 else np.nan 240 241 return step, traversals
class
Stepper:
25class Stepper: 26 """MPDATA stepper specialised for given options, dimensionality and optionally grid 27 (instances of Stepper can be shared among `Solver`s)""" 28 29 def __init__( 30 self, 31 *, 32 options: Options, 33 n_dims: (int, None) = None, 34 non_unit_g_factor: bool = False, 35 grid: (tuple, None) = None, 36 n_threads: (int, None) = None, 37 left_first: (tuple, None) = None, 38 buffer_size: int = 0 39 ): 40 if n_dims is not None and grid is not None: 41 raise ValueError() 42 if n_dims is None and grid is None: 43 raise ValueError() 44 if grid is None: 45 grid = tuple([-1] * n_dims) 46 if n_dims is None: 47 n_dims = len(grid) 48 if n_dims > 1 and options.DPDC: 49 raise NotImplementedError() 50 if n_threads is None: 51 n_threads = numba.get_num_threads() 52 if left_first is None: 53 left_first = tuple([True] * MAX_DIM_NUM) 54 55 self.__options = options 56 self.__n_threads = 1 if n_dims == 1 else n_threads 57 58 if self.__n_threads > 1: 59 try: 60 numba.parfors.parfor.ensure_parallel_support() 61 except numba.core.errors.UnsupportedParforsError: 62 print( 63 "Numba ensure_parallel_support() failed, forcing n_threads=1", 64 file=sys.stderr, 65 ) 66 self.__n_threads = 1 67 68 if not numba.config.DISABLE_JIT: # pylint: disable=no-member 69 70 @numba.jit(parallel=True, nopython=True) 71 def fill_array_with_thread_id(arr): 72 """writes thread id to corresponding array element""" 73 for i in numba.prange( # pylint: disable=not-an-iterable 74 numba.get_num_threads() 75 ): 76 arr[i] = numba.get_thread_id() 77 78 arr = np.full(numba.get_num_threads(), -1) 79 fill_array_with_thread_id(arr) 80 if not max(arr) > 0: 81 raise ValueError( 82 "n_threads>1 requested, but Numba does not seem to parallelize" 83 " (try changing Numba threading backend?)" 84 ) 85 86 self.__n_dims = n_dims 87 self.__call, self.traversals = make_step_impl( 88 options, 89 non_unit_g_factor, 90 grid, 91 self.n_threads, 92 left_first=left_first, 93 buffer_size=buffer_size, 94 ) 95 96 @property 97 def options(self) -> Options: 98 """`Options` instance used""" 99 return self.__options 100 101 @property 102 def n_threads(self) -> int: 103 """actual n_threads used (may be different than passed to __init__ if n_dims==1 104 or if on a platform where Numba does not support threading)""" 105 return self.__n_threads 106 107 @property 108 def n_dims(self) -> int: 109 """dimensionality (1, 2 or 3)""" 110 return self.__n_dims 111 112 def __call__(self, *, n_steps, mu_coeff, post_step, post_iter, fields): 113 assert self.n_threads == 1 or numba.get_num_threads() == self.n_threads 114 with warnings.catch_warnings(): 115 warnings.simplefilter("ignore", category=NumbaExperimentalFeatureWarning) 116 wall_time_per_timestep = self.__call( 117 n_steps, 118 mu_coeff, 119 post_step, 120 post_iter, 121 *( 122 _Impl(field=v.impl[IMPL_META_AND_DATA], bc=v.impl[IMPL_BC]) 123 for v in fields.values() 124 ), 125 self.traversals.data, 126 ) 127 return wall_time_per_timestep
MPDATA stepper specialised for given options, dimensionality and optionally grid
(instances of Stepper can be shared among Solver
s)
Stepper( *, options: PyMPDATA.options.Options, n_dims: (<class 'int'>, None) = None, non_unit_g_factor: bool = False, grid: (<class 'tuple'>, None) = None, n_threads: (<class 'int'>, None) = None, left_first: (<class 'tuple'>, None) = None, buffer_size: int = 0)
29 def __init__( 30 self, 31 *, 32 options: Options, 33 n_dims: (int, None) = None, 34 non_unit_g_factor: bool = False, 35 grid: (tuple, None) = None, 36 n_threads: (int, None) = None, 37 left_first: (tuple, None) = None, 38 buffer_size: int = 0 39 ): 40 if n_dims is not None and grid is not None: 41 raise ValueError() 42 if n_dims is None and grid is None: 43 raise ValueError() 44 if grid is None: 45 grid = tuple([-1] * n_dims) 46 if n_dims is None: 47 n_dims = len(grid) 48 if n_dims > 1 and options.DPDC: 49 raise NotImplementedError() 50 if n_threads is None: 51 n_threads = numba.get_num_threads() 52 if left_first is None: 53 left_first = tuple([True] * MAX_DIM_NUM) 54 55 self.__options = options 56 self.__n_threads = 1 if n_dims == 1 else n_threads 57 58 if self.__n_threads > 1: 59 try: 60 numba.parfors.parfor.ensure_parallel_support() 61 except numba.core.errors.UnsupportedParforsError: 62 print( 63 "Numba ensure_parallel_support() failed, forcing n_threads=1", 64 file=sys.stderr, 65 ) 66 self.__n_threads = 1 67 68 if not numba.config.DISABLE_JIT: # pylint: disable=no-member 69 70 @numba.jit(parallel=True, nopython=True) 71 def fill_array_with_thread_id(arr): 72 """writes thread id to corresponding array element""" 73 for i in numba.prange( # pylint: disable=not-an-iterable 74 numba.get_num_threads() 75 ): 76 arr[i] = numba.get_thread_id() 77 78 arr = np.full(numba.get_num_threads(), -1) 79 fill_array_with_thread_id(arr) 80 if not max(arr) > 0: 81 raise ValueError( 82 "n_threads>1 requested, but Numba does not seem to parallelize" 83 " (try changing Numba threading backend?)" 84 ) 85 86 self.__n_dims = n_dims 87 self.__call, self.traversals = make_step_impl( 88 options, 89 non_unit_g_factor, 90 grid, 91 self.n_threads, 92 left_first=left_first, 93 buffer_size=buffer_size, 94 )
options: PyMPDATA.options.Options
96 @property 97 def options(self) -> Options: 98 """`Options` instance used""" 99 return self.__options
Options
instance used
n_threads: int
101 @property 102 def n_threads(self) -> int: 103 """actual n_threads used (may be different than passed to __init__ if n_dims==1 104 or if on a platform where Numba does not support threading)""" 105 return self.__n_threads
actual n_threads used (may be different than passed to __init__ if n_dims==1 or if on a platform where Numba does not support threading)
@lru_cache()
def
make_step_impl( options, non_unit_g_factor, grid, n_threads, left_first: tuple, buffer_size):
130@lru_cache() 131# pylint: disable=too-many-locals,too-many-statements,too-many-arguments 132def make_step_impl( 133 options, non_unit_g_factor, grid, n_threads, left_first: tuple, buffer_size 134): 135 """returns (and caches) an njit-ted stepping function and a traversals pair""" 136 traversals = Traversals( 137 grid=grid, 138 halo=options.n_halo, 139 jit_flags=options.jit_flags, 140 n_threads=n_threads, 141 left_first=left_first, 142 buffer_size=buffer_size, 143 ) 144 145 n_iters = options.n_iters 146 non_zero_mu_coeff = options.non_zero_mu_coeff 147 nonoscillatory = options.nonoscillatory 148 149 upwind = make_upwind(options, non_unit_g_factor, traversals) 150 flux_first_pass = make_flux_first_pass(options, traversals) 151 flux_subsequent = make_flux_subsequent(options, traversals) 152 antidiff = make_antidiff(non_unit_g_factor, options, traversals) 153 antidiff_last_pass = make_antidiff( 154 non_unit_g_factor, options, traversals, last_pass=True 155 ) 156 laplacian = make_laplacian(non_unit_g_factor, options, traversals) 157 nonoscillatory_psi_extrema = make_psi_extrema(options, traversals) 158 nonoscillatory_beta = make_beta(non_unit_g_factor, options, traversals) 159 nonoscillatory_correction = make_correction(options, traversals) 160 axpy = make_axpy(options, traversals) 161 162 @numba.njit(**options.jit_flags) 163 # pylint: disable=too-many-arguments,too-many-locals,too-many-branches,unnecessary-dunder-call 164 def step( 165 n_steps, 166 mu_coeff, 167 post_step, 168 post_iter, 169 advectee, 170 advector, 171 g_factor, 172 vectmp_a, 173 vectmp_b, 174 vectmp_c, 175 psi_extrema, 176 beta, 177 null_impl, 178 ): 179 time = clock() 180 for step in range(n_steps): 181 if non_zero_mu_coeff: 182 advector_orig = advector 183 advector = vectmp_c 184 for iteration in range(n_iters): 185 if iteration == 0: 186 if nonoscillatory: 187 nonoscillatory_psi_extrema(null_impl, psi_extrema, advectee) 188 if non_zero_mu_coeff: 189 laplacian(null_impl, advector, advectee) 190 axpy( 191 *advector.field, 192 mu_coeff, 193 *advector.field, 194 *advector_orig.field, 195 ) 196 flux_first_pass(null_impl, vectmp_a, advector, advectee) 197 flux = vectmp_a 198 else: 199 if iteration == 1: 200 advector_oscil = advector 201 advector_nonos = vectmp_a 202 flux = vectmp_b 203 elif iteration % 2 == 0: 204 advector_oscil = vectmp_a 205 advector_nonos = vectmp_b 206 flux = vectmp_a 207 else: 208 advector_oscil = vectmp_b 209 advector_nonos = vectmp_a 210 flux = vectmp_b 211 if iteration < n_iters - 1: 212 antidiff( 213 null_impl, 214 advector_nonos, 215 advectee, 216 advector_oscil, 217 g_factor, 218 ) 219 else: 220 antidiff_last_pass( 221 null_impl, 222 advector_nonos, 223 advectee, 224 advector_oscil, 225 g_factor, 226 ) 227 flux_subsequent(null_impl, flux, advectee, advector_nonos) 228 if nonoscillatory: 229 nonoscillatory_beta( 230 null_impl, beta, flux, advectee, psi_extrema, g_factor 231 ) 232 # note: in libmpdata++, the oscillatory advector from prev iter is used 233 nonoscillatory_correction(null_impl, advector_nonos, beta) 234 flux_subsequent(null_impl, flux, advectee, advector_nonos) 235 upwind(null_impl, advectee, flux, g_factor) 236 post_iter.call(flux.field, g_factor.field, step, iteration) 237 if non_zero_mu_coeff: 238 advector = advector_orig 239 post_step.call(advectee.field[ARG_DATA], step) 240 return (clock() - time) / n_steps if n_steps > 0 else np.nan 241 242 return step, traversals
returns (and caches) an njit-ted stepping function and a traversals pair