PyMPDATA.stepper

MPDATA iteration logic

  1"""MPDATA iteration logic"""
  2
  3import sys
  4import warnings
  5from functools import lru_cache
  6
  7import numba
  8import numpy as np
  9from numba.core.errors import NumbaExperimentalFeatureWarning
 10
 11from .impl.clock import clock
 12from .impl.enumerations import IMPL_BC, IMPL_META_AND_DATA, MAX_DIM_NUM
 13from .impl.formulae_antidiff import make_antidiff
 14from .impl.formulae_axpy import make_axpy
 15from .impl.formulae_flux import make_flux_first_pass, make_flux_subsequent
 16from .impl.formulae_laplacian import make_laplacian
 17from .impl.formulae_nonoscillatory import make_beta, make_correction, make_psi_extrema
 18from .impl.formulae_upwind import make_upwind
 19from .impl.meta import _Impl
 20from .impl.traversals import Traversals
 21from .options import Options
 22
 23
 24class Stepper:
 25    """MPDATA stepper specialised for given options, dimensionality and optionally grid
 26    (instances of Stepper can be shared among `Solver`s)"""
 27
 28    def __init__(
 29        self,
 30        *,
 31        options: Options,
 32        n_dims: (int, None) = None,
 33        non_unit_g_factor: bool = False,
 34        grid: (tuple, None) = None,
 35        n_threads: (int, None) = None,
 36        left_first: (tuple, None) = None,
 37        buffer_size: int = 0
 38    ):
 39        if n_dims is not None and grid is not None:
 40            raise ValueError()
 41        if n_dims is None and grid is None:
 42            raise ValueError()
 43        if grid is None:
 44            grid = tuple([-1] * n_dims)
 45        if n_dims is None:
 46            n_dims = len(grid)
 47        if n_dims > 1 and options.DPDC:
 48            raise NotImplementedError()
 49        if n_threads is None:
 50            n_threads = numba.get_num_threads()
 51        if left_first is None:
 52            left_first = tuple([True] * MAX_DIM_NUM)
 53
 54        self.__options = options
 55        self.__n_threads = 1 if n_dims == 1 else n_threads
 56
 57        if self.__n_threads > 1:
 58            try:
 59                numba.parfors.parfor.ensure_parallel_support()
 60            except numba.core.errors.UnsupportedParforsError:
 61                print(
 62                    "Numba ensure_parallel_support() failed, forcing n_threads=1",
 63                    file=sys.stderr,
 64                )
 65                self.__n_threads = 1
 66
 67            if not numba.config.DISABLE_JIT:  # pylint: disable=no-member
 68
 69                @numba.jit(parallel=True, nopython=True)
 70                def fill_array_with_thread_id(arr):
 71                    """writes thread id to corresponding array element"""
 72                    for i in numba.prange(  # pylint: disable=not-an-iterable
 73                        numba.get_num_threads()
 74                    ):
 75                        arr[i] = numba.get_thread_id()
 76
 77                arr = np.full(numba.get_num_threads(), -1)
 78                fill_array_with_thread_id(arr)
 79                if not max(arr) > 0:
 80                    raise ValueError(
 81                        "n_threads>1 requested, but Numba does not seem to parallelize"
 82                        " (try changing Numba threading backend?)"
 83                    )
 84
 85        self.__n_dims = n_dims
 86        self.__call, self.traversals = make_step_impl(
 87            options,
 88            non_unit_g_factor,
 89            grid,
 90            self.n_threads,
 91            left_first=left_first,
 92            buffer_size=buffer_size,
 93        )
 94
 95    @property
 96    def options(self) -> Options:
 97        """`Options` instance used"""
 98        return self.__options
 99
100    @property
101    def n_threads(self) -> int:
102        """actual n_threads used (may be different than passed to __init__ if n_dims==1
103        or if on a platform where Numba does not support threading)"""
104        return self.__n_threads
105
106    @property
107    def n_dims(self) -> int:
108        """dimensionality (1, 2 or 3)"""
109        return self.__n_dims
110
111    def __call__(self, *, n_steps, mu_coeff, ante_step, post_step, post_iter, fields):
112        assert self.n_threads == 1 or numba.get_num_threads() == self.n_threads
113        with warnings.catch_warnings():
114            warnings.simplefilter("ignore", category=NumbaExperimentalFeatureWarning)
115            wall_time_per_timestep = self.__call(
116                n_steps,
117                mu_coeff,
118                ante_step,
119                post_step,
120                post_iter,
121                *(
122                    (
123                        _Impl(field=v.impl[IMPL_META_AND_DATA], bc=v.impl[IMPL_BC])
124                        if k != "advectee"
125                        else tuple(
126                            _Impl(
127                                field=vv.impl[IMPL_META_AND_DATA], bc=vv.impl[IMPL_BC]
128                            )
129                            for vv in v
130                        )
131                    )
132                    for k, v in fields.items()
133                ),
134                self.traversals.data,
135            )
136        return wall_time_per_timestep
137
138
139@lru_cache()
140# pylint: disable=too-many-locals,too-many-statements,too-many-arguments
141def make_step_impl(
142    options, non_unit_g_factor, grid, n_threads, left_first: tuple, buffer_size
143):
144    """returns (and caches) an njit-ted stepping function and a traversals pair"""
145    traversals = Traversals(
146        grid=grid,
147        halo=options.n_halo,
148        jit_flags=options.jit_flags,
149        n_threads=n_threads,
150        left_first=left_first,
151        buffer_size=buffer_size,
152    )
153
154    n_iters = options.n_iters
155    non_zero_mu_coeff = options.non_zero_mu_coeff
156    nonoscillatory = options.nonoscillatory
157
158    upwind = make_upwind(options, non_unit_g_factor, traversals)
159    flux_first_pass = make_flux_first_pass(options, traversals)
160    flux_subsequent = make_flux_subsequent(options, traversals)
161    antidiff = make_antidiff(non_unit_g_factor, options, traversals)
162    antidiff_last_pass = make_antidiff(
163        non_unit_g_factor, options, traversals, last_pass=True
164    )
165    laplacian = make_laplacian(non_unit_g_factor, options, traversals)
166    nonoscillatory_psi_extrema = make_psi_extrema(options, traversals)
167    nonoscillatory_beta = make_beta(non_unit_g_factor, options, traversals)
168    nonoscillatory_correction = make_correction(options, traversals)
169    axpy = make_axpy(options, traversals)
170
171    @numba.njit(**options.jit_flags)
172    # pylint: disable=too-many-arguments,too-many-locals,too-many-branches,unnecessary-dunder-call
173    def step(
174        n_steps,
175        mu_coeff,
176        ante_step,
177        post_step,
178        post_iter,
179        advectees,
180        advector,
181        g_factor,
182        vectmp_a,
183        vectmp_b,
184        vectmp_c,
185        dynamic_advector_stash_outer,
186        dynamic_advector_stash_mid3d,
187        dynamic_advector_stash_inner,
188        psi_extrema,
189        beta,
190        traversals_data,
191    ):
192        time = clock()
193        for step in range(n_steps):
194            for index, advectee in enumerate(advectees):
195                ante_step.call(
196                    traversals_data,
197                    advectees,
198                    advector,
199                    step,
200                    index,
201                    dynamic_advector_stash_outer,
202                    dynamic_advector_stash_mid3d,
203                    dynamic_advector_stash_inner,
204                )
205                if non_zero_mu_coeff:
206                    advector_orig = advector
207                    advector = vectmp_c
208                for iteration in range(n_iters):
209                    if iteration == 0:
210                        if nonoscillatory:
211                            nonoscillatory_psi_extrema(
212                                traversals_data, psi_extrema, advectee
213                            )
214                        if non_zero_mu_coeff:
215                            laplacian(traversals_data, advector, advectee)
216                            axpy(
217                                *advector.field,
218                                mu_coeff,
219                                *advector.field,
220                                *advector_orig.field,
221                            )
222                        flux_first_pass(traversals_data, vectmp_a, advector, advectee)
223                        flux = vectmp_a
224                    else:
225                        if iteration == 1:
226                            advector_oscil = advector
227                            advector_nonos = vectmp_a
228                            flux = vectmp_b
229                        elif iteration % 2 == 0:
230                            advector_oscil = vectmp_a
231                            advector_nonos = vectmp_b
232                            flux = vectmp_a
233                        else:
234                            advector_oscil = vectmp_b
235                            advector_nonos = vectmp_a
236                            flux = vectmp_b
237                        if iteration < n_iters - 1:
238                            antidiff(
239                                traversals_data,
240                                advector_nonos,
241                                advectee,
242                                advector_oscil,
243                                g_factor,
244                            )
245                        else:
246                            antidiff_last_pass(
247                                traversals_data,
248                                advector_nonos,
249                                advectee,
250                                advector_oscil,
251                                g_factor,
252                            )
253                        flux_subsequent(traversals_data, flux, advectee, advector_nonos)
254                        if nonoscillatory:
255                            nonoscillatory_beta(
256                                traversals_data,
257                                beta,
258                                flux,
259                                advectee,
260                                psi_extrema,
261                                g_factor,
262                            )
263                            # note: in libmpdata++, the oscillatory advector from prev iter is used
264                            nonoscillatory_correction(
265                                traversals_data, advector_nonos, beta
266                            )
267                            flux_subsequent(
268                                traversals_data, flux, advectee, advector_nonos
269                            )
270                    upwind(traversals_data, advectee, flux, g_factor)
271                    post_iter.call(
272                        traversals_data, flux.field, g_factor.field, step, iteration
273                    )
274                if non_zero_mu_coeff:
275                    advector = advector_orig
276                post_step.call(traversals_data, advectees, step, index)
277        return (clock() - time) / n_steps if n_steps > 0 else np.nan
278
279    return step, traversals
class Stepper:
 25class Stepper:
 26    """MPDATA stepper specialised for given options, dimensionality and optionally grid
 27    (instances of Stepper can be shared among `Solver`s)"""
 28
 29    def __init__(
 30        self,
 31        *,
 32        options: Options,
 33        n_dims: (int, None) = None,
 34        non_unit_g_factor: bool = False,
 35        grid: (tuple, None) = None,
 36        n_threads: (int, None) = None,
 37        left_first: (tuple, None) = None,
 38        buffer_size: int = 0
 39    ):
 40        if n_dims is not None and grid is not None:
 41            raise ValueError()
 42        if n_dims is None and grid is None:
 43            raise ValueError()
 44        if grid is None:
 45            grid = tuple([-1] * n_dims)
 46        if n_dims is None:
 47            n_dims = len(grid)
 48        if n_dims > 1 and options.DPDC:
 49            raise NotImplementedError()
 50        if n_threads is None:
 51            n_threads = numba.get_num_threads()
 52        if left_first is None:
 53            left_first = tuple([True] * MAX_DIM_NUM)
 54
 55        self.__options = options
 56        self.__n_threads = 1 if n_dims == 1 else n_threads
 57
 58        if self.__n_threads > 1:
 59            try:
 60                numba.parfors.parfor.ensure_parallel_support()
 61            except numba.core.errors.UnsupportedParforsError:
 62                print(
 63                    "Numba ensure_parallel_support() failed, forcing n_threads=1",
 64                    file=sys.stderr,
 65                )
 66                self.__n_threads = 1
 67
 68            if not numba.config.DISABLE_JIT:  # pylint: disable=no-member
 69
 70                @numba.jit(parallel=True, nopython=True)
 71                def fill_array_with_thread_id(arr):
 72                    """writes thread id to corresponding array element"""
 73                    for i in numba.prange(  # pylint: disable=not-an-iterable
 74                        numba.get_num_threads()
 75                    ):
 76                        arr[i] = numba.get_thread_id()
 77
 78                arr = np.full(numba.get_num_threads(), -1)
 79                fill_array_with_thread_id(arr)
 80                if not max(arr) > 0:
 81                    raise ValueError(
 82                        "n_threads>1 requested, but Numba does not seem to parallelize"
 83                        " (try changing Numba threading backend?)"
 84                    )
 85
 86        self.__n_dims = n_dims
 87        self.__call, self.traversals = make_step_impl(
 88            options,
 89            non_unit_g_factor,
 90            grid,
 91            self.n_threads,
 92            left_first=left_first,
 93            buffer_size=buffer_size,
 94        )
 95
 96    @property
 97    def options(self) -> Options:
 98        """`Options` instance used"""
 99        return self.__options
100
101    @property
102    def n_threads(self) -> int:
103        """actual n_threads used (may be different than passed to __init__ if n_dims==1
104        or if on a platform where Numba does not support threading)"""
105        return self.__n_threads
106
107    @property
108    def n_dims(self) -> int:
109        """dimensionality (1, 2 or 3)"""
110        return self.__n_dims
111
112    def __call__(self, *, n_steps, mu_coeff, ante_step, post_step, post_iter, fields):
113        assert self.n_threads == 1 or numba.get_num_threads() == self.n_threads
114        with warnings.catch_warnings():
115            warnings.simplefilter("ignore", category=NumbaExperimentalFeatureWarning)
116            wall_time_per_timestep = self.__call(
117                n_steps,
118                mu_coeff,
119                ante_step,
120                post_step,
121                post_iter,
122                *(
123                    (
124                        _Impl(field=v.impl[IMPL_META_AND_DATA], bc=v.impl[IMPL_BC])
125                        if k != "advectee"
126                        else tuple(
127                            _Impl(
128                                field=vv.impl[IMPL_META_AND_DATA], bc=vv.impl[IMPL_BC]
129                            )
130                            for vv in v
131                        )
132                    )
133                    for k, v in fields.items()
134                ),
135                self.traversals.data,
136            )
137        return wall_time_per_timestep

MPDATA stepper specialised for given options, dimensionality and optionally grid (instances of Stepper can be shared among Solvers)

Stepper( *, options: PyMPDATA.options.Options, n_dims: (<class 'int'>, None) = None, non_unit_g_factor: bool = False, grid: (<class 'tuple'>, None) = None, n_threads: (<class 'int'>, None) = None, left_first: (<class 'tuple'>, None) = None, buffer_size: int = 0)
29    def __init__(
30        self,
31        *,
32        options: Options,
33        n_dims: (int, None) = None,
34        non_unit_g_factor: bool = False,
35        grid: (tuple, None) = None,
36        n_threads: (int, None) = None,
37        left_first: (tuple, None) = None,
38        buffer_size: int = 0
39    ):
40        if n_dims is not None and grid is not None:
41            raise ValueError()
42        if n_dims is None and grid is None:
43            raise ValueError()
44        if grid is None:
45            grid = tuple([-1] * n_dims)
46        if n_dims is None:
47            n_dims = len(grid)
48        if n_dims > 1 and options.DPDC:
49            raise NotImplementedError()
50        if n_threads is None:
51            n_threads = numba.get_num_threads()
52        if left_first is None:
53            left_first = tuple([True] * MAX_DIM_NUM)
54
55        self.__options = options
56        self.__n_threads = 1 if n_dims == 1 else n_threads
57
58        if self.__n_threads > 1:
59            try:
60                numba.parfors.parfor.ensure_parallel_support()
61            except numba.core.errors.UnsupportedParforsError:
62                print(
63                    "Numba ensure_parallel_support() failed, forcing n_threads=1",
64                    file=sys.stderr,
65                )
66                self.__n_threads = 1
67
68            if not numba.config.DISABLE_JIT:  # pylint: disable=no-member
69
70                @numba.jit(parallel=True, nopython=True)
71                def fill_array_with_thread_id(arr):
72                    """writes thread id to corresponding array element"""
73                    for i in numba.prange(  # pylint: disable=not-an-iterable
74                        numba.get_num_threads()
75                    ):
76                        arr[i] = numba.get_thread_id()
77
78                arr = np.full(numba.get_num_threads(), -1)
79                fill_array_with_thread_id(arr)
80                if not max(arr) > 0:
81                    raise ValueError(
82                        "n_threads>1 requested, but Numba does not seem to parallelize"
83                        " (try changing Numba threading backend?)"
84                    )
85
86        self.__n_dims = n_dims
87        self.__call, self.traversals = make_step_impl(
88            options,
89            non_unit_g_factor,
90            grid,
91            self.n_threads,
92            left_first=left_first,
93            buffer_size=buffer_size,
94        )
options: PyMPDATA.options.Options
96    @property
97    def options(self) -> Options:
98        """`Options` instance used"""
99        return self.__options

Options instance used

n_threads: int
101    @property
102    def n_threads(self) -> int:
103        """actual n_threads used (may be different than passed to __init__ if n_dims==1
104        or if on a platform where Numba does not support threading)"""
105        return self.__n_threads

actual n_threads used (may be different than passed to __init__ if n_dims==1 or if on a platform where Numba does not support threading)

n_dims: int
107    @property
108    def n_dims(self) -> int:
109        """dimensionality (1, 2 or 3)"""
110        return self.__n_dims

dimensionality (1, 2 or 3)

@lru_cache()
def make_step_impl( options, non_unit_g_factor, grid, n_threads, left_first: tuple, buffer_size):
140@lru_cache()
141# pylint: disable=too-many-locals,too-many-statements,too-many-arguments
142def make_step_impl(
143    options, non_unit_g_factor, grid, n_threads, left_first: tuple, buffer_size
144):
145    """returns (and caches) an njit-ted stepping function and a traversals pair"""
146    traversals = Traversals(
147        grid=grid,
148        halo=options.n_halo,
149        jit_flags=options.jit_flags,
150        n_threads=n_threads,
151        left_first=left_first,
152        buffer_size=buffer_size,
153    )
154
155    n_iters = options.n_iters
156    non_zero_mu_coeff = options.non_zero_mu_coeff
157    nonoscillatory = options.nonoscillatory
158
159    upwind = make_upwind(options, non_unit_g_factor, traversals)
160    flux_first_pass = make_flux_first_pass(options, traversals)
161    flux_subsequent = make_flux_subsequent(options, traversals)
162    antidiff = make_antidiff(non_unit_g_factor, options, traversals)
163    antidiff_last_pass = make_antidiff(
164        non_unit_g_factor, options, traversals, last_pass=True
165    )
166    laplacian = make_laplacian(non_unit_g_factor, options, traversals)
167    nonoscillatory_psi_extrema = make_psi_extrema(options, traversals)
168    nonoscillatory_beta = make_beta(non_unit_g_factor, options, traversals)
169    nonoscillatory_correction = make_correction(options, traversals)
170    axpy = make_axpy(options, traversals)
171
172    @numba.njit(**options.jit_flags)
173    # pylint: disable=too-many-arguments,too-many-locals,too-many-branches,unnecessary-dunder-call
174    def step(
175        n_steps,
176        mu_coeff,
177        ante_step,
178        post_step,
179        post_iter,
180        advectees,
181        advector,
182        g_factor,
183        vectmp_a,
184        vectmp_b,
185        vectmp_c,
186        dynamic_advector_stash_outer,
187        dynamic_advector_stash_mid3d,
188        dynamic_advector_stash_inner,
189        psi_extrema,
190        beta,
191        traversals_data,
192    ):
193        time = clock()
194        for step in range(n_steps):
195            for index, advectee in enumerate(advectees):
196                ante_step.call(
197                    traversals_data,
198                    advectees,
199                    advector,
200                    step,
201                    index,
202                    dynamic_advector_stash_outer,
203                    dynamic_advector_stash_mid3d,
204                    dynamic_advector_stash_inner,
205                )
206                if non_zero_mu_coeff:
207                    advector_orig = advector
208                    advector = vectmp_c
209                for iteration in range(n_iters):
210                    if iteration == 0:
211                        if nonoscillatory:
212                            nonoscillatory_psi_extrema(
213                                traversals_data, psi_extrema, advectee
214                            )
215                        if non_zero_mu_coeff:
216                            laplacian(traversals_data, advector, advectee)
217                            axpy(
218                                *advector.field,
219                                mu_coeff,
220                                *advector.field,
221                                *advector_orig.field,
222                            )
223                        flux_first_pass(traversals_data, vectmp_a, advector, advectee)
224                        flux = vectmp_a
225                    else:
226                        if iteration == 1:
227                            advector_oscil = advector
228                            advector_nonos = vectmp_a
229                            flux = vectmp_b
230                        elif iteration % 2 == 0:
231                            advector_oscil = vectmp_a
232                            advector_nonos = vectmp_b
233                            flux = vectmp_a
234                        else:
235                            advector_oscil = vectmp_b
236                            advector_nonos = vectmp_a
237                            flux = vectmp_b
238                        if iteration < n_iters - 1:
239                            antidiff(
240                                traversals_data,
241                                advector_nonos,
242                                advectee,
243                                advector_oscil,
244                                g_factor,
245                            )
246                        else:
247                            antidiff_last_pass(
248                                traversals_data,
249                                advector_nonos,
250                                advectee,
251                                advector_oscil,
252                                g_factor,
253                            )
254                        flux_subsequent(traversals_data, flux, advectee, advector_nonos)
255                        if nonoscillatory:
256                            nonoscillatory_beta(
257                                traversals_data,
258                                beta,
259                                flux,
260                                advectee,
261                                psi_extrema,
262                                g_factor,
263                            )
264                            # note: in libmpdata++, the oscillatory advector from prev iter is used
265                            nonoscillatory_correction(
266                                traversals_data, advector_nonos, beta
267                            )
268                            flux_subsequent(
269                                traversals_data, flux, advectee, advector_nonos
270                            )
271                    upwind(traversals_data, advectee, flux, g_factor)
272                    post_iter.call(
273                        traversals_data, flux.field, g_factor.field, step, iteration
274                    )
275                if non_zero_mu_coeff:
276                    advector = advector_orig
277                post_step.call(traversals_data, advectees, step, index)
278        return (clock() - time) / n_steps if n_steps > 0 else np.nan
279
280    return step, traversals

returns (and caches) an njit-ted stepping function and a traversals pair