PyMPDATA.stepper

MPDATA iteration logic

  1"""MPDATA iteration logic"""
  2
  3import sys
  4import warnings
  5from functools import lru_cache
  6
  7import numba
  8import numpy as np
  9from numba.core.errors import NumbaExperimentalFeatureWarning
 10
 11from .impl.clock import clock
 12from .impl.enumerations import ARG_DATA, IMPL_BC, IMPL_META_AND_DATA, MAX_DIM_NUM
 13from .impl.formulae_antidiff import make_antidiff
 14from .impl.formulae_axpy import make_axpy
 15from .impl.formulae_flux import make_flux_first_pass, make_flux_subsequent
 16from .impl.formulae_laplacian import make_laplacian
 17from .impl.formulae_nonoscillatory import make_beta, make_correction, make_psi_extrema
 18from .impl.formulae_upwind import make_upwind
 19from .impl.meta import _Impl
 20from .impl.traversals import Traversals
 21from .options import Options
 22
 23
 24class Stepper:
 25    """MPDATA stepper specialised for given options, dimensionality and optionally grid
 26    (instances of Stepper can be shared among `Solver`s)"""
 27
 28    def __init__(
 29        self,
 30        *,
 31        options: Options,
 32        n_dims: (int, None) = None,
 33        non_unit_g_factor: bool = False,
 34        grid: (tuple, None) = None,
 35        n_threads: (int, None) = None,
 36        left_first: (tuple, None) = None,
 37        buffer_size: int = 0
 38    ):
 39        if n_dims is not None and grid is not None:
 40            raise ValueError()
 41        if n_dims is None and grid is None:
 42            raise ValueError()
 43        if grid is None:
 44            grid = tuple([-1] * n_dims)
 45        if n_dims is None:
 46            n_dims = len(grid)
 47        if n_dims > 1 and options.DPDC:
 48            raise NotImplementedError()
 49        if n_threads is None:
 50            n_threads = numba.get_num_threads()
 51        if left_first is None:
 52            left_first = tuple([True] * MAX_DIM_NUM)
 53
 54        self.__options = options
 55        self.__n_threads = 1 if n_dims == 1 else n_threads
 56
 57        if self.__n_threads > 1:
 58            try:
 59                numba.parfors.parfor.ensure_parallel_support()
 60            except numba.core.errors.UnsupportedParforsError:
 61                print(
 62                    "Numba ensure_parallel_support() failed, forcing n_threads=1",
 63                    file=sys.stderr,
 64                )
 65                self.__n_threads = 1
 66
 67            if not numba.config.DISABLE_JIT:  # pylint: disable=no-member
 68
 69                @numba.jit(parallel=True, nopython=True)
 70                def fill_array_with_thread_id(arr):
 71                    """writes thread id to corresponding array element"""
 72                    for i in numba.prange(  # pylint: disable=not-an-iterable
 73                        numba.get_num_threads()
 74                    ):
 75                        arr[i] = numba.get_thread_id()
 76
 77                arr = np.full(numba.get_num_threads(), -1)
 78                fill_array_with_thread_id(arr)
 79                if not max(arr) > 0:
 80                    raise ValueError(
 81                        "n_threads>1 requested, but Numba does not seem to parallelize"
 82                        " (try changing Numba threading backend?)"
 83                    )
 84
 85        self.__n_dims = n_dims
 86        self.__call, self.traversals = make_step_impl(
 87            options,
 88            non_unit_g_factor,
 89            grid,
 90            self.n_threads,
 91            left_first=left_first,
 92            buffer_size=buffer_size,
 93        )
 94
 95    @property
 96    def options(self) -> Options:
 97        """`Options` instance used"""
 98        return self.__options
 99
100    @property
101    def n_threads(self) -> int:
102        """actual n_threads used (may be different than passed to __init__ if n_dims==1
103        or if on a platform where Numba does not support threading)"""
104        return self.__n_threads
105
106    @property
107    def n_dims(self) -> int:
108        """dimensionality (1, 2 or 3)"""
109        return self.__n_dims
110
111    def __call__(self, *, n_steps, mu_coeff, post_step, post_iter, fields):
112        assert self.n_threads == 1 or numba.get_num_threads() == self.n_threads
113        with warnings.catch_warnings():
114            warnings.simplefilter("ignore", category=NumbaExperimentalFeatureWarning)
115            wall_time_per_timestep = self.__call(
116                n_steps,
117                mu_coeff,
118                post_step,
119                post_iter,
120                *(
121                    _Impl(field=v.impl[IMPL_META_AND_DATA], bc=v.impl[IMPL_BC])
122                    for v in fields.values()
123                ),
124                self.traversals.data,
125            )
126        return wall_time_per_timestep
127
128
129@lru_cache()
130# pylint: disable=too-many-locals,too-many-statements,too-many-arguments
131def make_step_impl(
132    options, non_unit_g_factor, grid, n_threads, left_first: tuple, buffer_size
133):
134    """returns (and caches) an njit-ted stepping function and a traversals pair"""
135    traversals = Traversals(
136        grid=grid,
137        halo=options.n_halo,
138        jit_flags=options.jit_flags,
139        n_threads=n_threads,
140        left_first=left_first,
141        buffer_size=buffer_size,
142    )
143
144    n_iters = options.n_iters
145    non_zero_mu_coeff = options.non_zero_mu_coeff
146    nonoscillatory = options.nonoscillatory
147
148    upwind = make_upwind(options, non_unit_g_factor, traversals)
149    flux_first_pass = make_flux_first_pass(options, traversals)
150    flux_subsequent = make_flux_subsequent(options, traversals)
151    antidiff = make_antidiff(non_unit_g_factor, options, traversals)
152    antidiff_last_pass = make_antidiff(
153        non_unit_g_factor, options, traversals, last_pass=True
154    )
155    laplacian = make_laplacian(non_unit_g_factor, options, traversals)
156    nonoscillatory_psi_extrema = make_psi_extrema(options, traversals)
157    nonoscillatory_beta = make_beta(non_unit_g_factor, options, traversals)
158    nonoscillatory_correction = make_correction(options, traversals)
159    axpy = make_axpy(options, traversals)
160
161    @numba.njit(**options.jit_flags)
162    # pylint: disable=too-many-arguments,too-many-locals,too-many-branches,unnecessary-dunder-call
163    def step(
164        n_steps,
165        mu_coeff,
166        post_step,
167        post_iter,
168        advectee,
169        advector,
170        g_factor,
171        vectmp_a,
172        vectmp_b,
173        vectmp_c,
174        psi_extrema,
175        beta,
176        null_impl,
177    ):
178        time = clock()
179        for step in range(n_steps):
180            if non_zero_mu_coeff:
181                advector_orig = advector
182                advector = vectmp_c
183            for iteration in range(n_iters):
184                if iteration == 0:
185                    if nonoscillatory:
186                        nonoscillatory_psi_extrema(null_impl, psi_extrema, advectee)
187                    if non_zero_mu_coeff:
188                        laplacian(null_impl, advector, advectee)
189                        axpy(
190                            *advector.field,
191                            mu_coeff,
192                            *advector.field,
193                            *advector_orig.field,
194                        )
195                    flux_first_pass(null_impl, vectmp_a, advector, advectee)
196                    flux = vectmp_a
197                else:
198                    if iteration == 1:
199                        advector_oscil = advector
200                        advector_nonos = vectmp_a
201                        flux = vectmp_b
202                    elif iteration % 2 == 0:
203                        advector_oscil = vectmp_a
204                        advector_nonos = vectmp_b
205                        flux = vectmp_a
206                    else:
207                        advector_oscil = vectmp_b
208                        advector_nonos = vectmp_a
209                        flux = vectmp_b
210                    if iteration < n_iters - 1:
211                        antidiff(
212                            null_impl,
213                            advector_nonos,
214                            advectee,
215                            advector_oscil,
216                            g_factor,
217                        )
218                    else:
219                        antidiff_last_pass(
220                            null_impl,
221                            advector_nonos,
222                            advectee,
223                            advector_oscil,
224                            g_factor,
225                        )
226                    flux_subsequent(null_impl, flux, advectee, advector_nonos)
227                    if nonoscillatory:
228                        nonoscillatory_beta(
229                            null_impl, beta, flux, advectee, psi_extrema, g_factor
230                        )
231                        # note: in libmpdata++, the oscillatory advector from prev iter is used
232                        nonoscillatory_correction(null_impl, advector_nonos, beta)
233                        flux_subsequent(null_impl, flux, advectee, advector_nonos)
234                upwind(null_impl, advectee, flux, g_factor)
235                post_iter.call(flux.field, g_factor.field, step, iteration)
236            if non_zero_mu_coeff:
237                advector = advector_orig
238            post_step.call(advectee.field[ARG_DATA], step)
239        return (clock() - time) / n_steps if n_steps > 0 else np.nan
240
241    return step, traversals
class Stepper:
 25class Stepper:
 26    """MPDATA stepper specialised for given options, dimensionality and optionally grid
 27    (instances of Stepper can be shared among `Solver`s)"""
 28
 29    def __init__(
 30        self,
 31        *,
 32        options: Options,
 33        n_dims: (int, None) = None,
 34        non_unit_g_factor: bool = False,
 35        grid: (tuple, None) = None,
 36        n_threads: (int, None) = None,
 37        left_first: (tuple, None) = None,
 38        buffer_size: int = 0
 39    ):
 40        if n_dims is not None and grid is not None:
 41            raise ValueError()
 42        if n_dims is None and grid is None:
 43            raise ValueError()
 44        if grid is None:
 45            grid = tuple([-1] * n_dims)
 46        if n_dims is None:
 47            n_dims = len(grid)
 48        if n_dims > 1 and options.DPDC:
 49            raise NotImplementedError()
 50        if n_threads is None:
 51            n_threads = numba.get_num_threads()
 52        if left_first is None:
 53            left_first = tuple([True] * MAX_DIM_NUM)
 54
 55        self.__options = options
 56        self.__n_threads = 1 if n_dims == 1 else n_threads
 57
 58        if self.__n_threads > 1:
 59            try:
 60                numba.parfors.parfor.ensure_parallel_support()
 61            except numba.core.errors.UnsupportedParforsError:
 62                print(
 63                    "Numba ensure_parallel_support() failed, forcing n_threads=1",
 64                    file=sys.stderr,
 65                )
 66                self.__n_threads = 1
 67
 68            if not numba.config.DISABLE_JIT:  # pylint: disable=no-member
 69
 70                @numba.jit(parallel=True, nopython=True)
 71                def fill_array_with_thread_id(arr):
 72                    """writes thread id to corresponding array element"""
 73                    for i in numba.prange(  # pylint: disable=not-an-iterable
 74                        numba.get_num_threads()
 75                    ):
 76                        arr[i] = numba.get_thread_id()
 77
 78                arr = np.full(numba.get_num_threads(), -1)
 79                fill_array_with_thread_id(arr)
 80                if not max(arr) > 0:
 81                    raise ValueError(
 82                        "n_threads>1 requested, but Numba does not seem to parallelize"
 83                        " (try changing Numba threading backend?)"
 84                    )
 85
 86        self.__n_dims = n_dims
 87        self.__call, self.traversals = make_step_impl(
 88            options,
 89            non_unit_g_factor,
 90            grid,
 91            self.n_threads,
 92            left_first=left_first,
 93            buffer_size=buffer_size,
 94        )
 95
 96    @property
 97    def options(self) -> Options:
 98        """`Options` instance used"""
 99        return self.__options
100
101    @property
102    def n_threads(self) -> int:
103        """actual n_threads used (may be different than passed to __init__ if n_dims==1
104        or if on a platform where Numba does not support threading)"""
105        return self.__n_threads
106
107    @property
108    def n_dims(self) -> int:
109        """dimensionality (1, 2 or 3)"""
110        return self.__n_dims
111
112    def __call__(self, *, n_steps, mu_coeff, post_step, post_iter, fields):
113        assert self.n_threads == 1 or numba.get_num_threads() == self.n_threads
114        with warnings.catch_warnings():
115            warnings.simplefilter("ignore", category=NumbaExperimentalFeatureWarning)
116            wall_time_per_timestep = self.__call(
117                n_steps,
118                mu_coeff,
119                post_step,
120                post_iter,
121                *(
122                    _Impl(field=v.impl[IMPL_META_AND_DATA], bc=v.impl[IMPL_BC])
123                    for v in fields.values()
124                ),
125                self.traversals.data,
126            )
127        return wall_time_per_timestep

MPDATA stepper specialised for given options, dimensionality and optionally grid (instances of Stepper can be shared among Solvers)

Stepper( *, options: PyMPDATA.options.Options, n_dims: (<class 'int'>, None) = None, non_unit_g_factor: bool = False, grid: (<class 'tuple'>, None) = None, n_threads: (<class 'int'>, None) = None, left_first: (<class 'tuple'>, None) = None, buffer_size: int = 0)
29    def __init__(
30        self,
31        *,
32        options: Options,
33        n_dims: (int, None) = None,
34        non_unit_g_factor: bool = False,
35        grid: (tuple, None) = None,
36        n_threads: (int, None) = None,
37        left_first: (tuple, None) = None,
38        buffer_size: int = 0
39    ):
40        if n_dims is not None and grid is not None:
41            raise ValueError()
42        if n_dims is None and grid is None:
43            raise ValueError()
44        if grid is None:
45            grid = tuple([-1] * n_dims)
46        if n_dims is None:
47            n_dims = len(grid)
48        if n_dims > 1 and options.DPDC:
49            raise NotImplementedError()
50        if n_threads is None:
51            n_threads = numba.get_num_threads()
52        if left_first is None:
53            left_first = tuple([True] * MAX_DIM_NUM)
54
55        self.__options = options
56        self.__n_threads = 1 if n_dims == 1 else n_threads
57
58        if self.__n_threads > 1:
59            try:
60                numba.parfors.parfor.ensure_parallel_support()
61            except numba.core.errors.UnsupportedParforsError:
62                print(
63                    "Numba ensure_parallel_support() failed, forcing n_threads=1",
64                    file=sys.stderr,
65                )
66                self.__n_threads = 1
67
68            if not numba.config.DISABLE_JIT:  # pylint: disable=no-member
69
70                @numba.jit(parallel=True, nopython=True)
71                def fill_array_with_thread_id(arr):
72                    """writes thread id to corresponding array element"""
73                    for i in numba.prange(  # pylint: disable=not-an-iterable
74                        numba.get_num_threads()
75                    ):
76                        arr[i] = numba.get_thread_id()
77
78                arr = np.full(numba.get_num_threads(), -1)
79                fill_array_with_thread_id(arr)
80                if not max(arr) > 0:
81                    raise ValueError(
82                        "n_threads>1 requested, but Numba does not seem to parallelize"
83                        " (try changing Numba threading backend?)"
84                    )
85
86        self.__n_dims = n_dims
87        self.__call, self.traversals = make_step_impl(
88            options,
89            non_unit_g_factor,
90            grid,
91            self.n_threads,
92            left_first=left_first,
93            buffer_size=buffer_size,
94        )
options: PyMPDATA.options.Options
96    @property
97    def options(self) -> Options:
98        """`Options` instance used"""
99        return self.__options

Options instance used

n_threads: int
101    @property
102    def n_threads(self) -> int:
103        """actual n_threads used (may be different than passed to __init__ if n_dims==1
104        or if on a platform where Numba does not support threading)"""
105        return self.__n_threads

actual n_threads used (may be different than passed to __init__ if n_dims==1 or if on a platform where Numba does not support threading)

n_dims: int
107    @property
108    def n_dims(self) -> int:
109        """dimensionality (1, 2 or 3)"""
110        return self.__n_dims

dimensionality (1, 2 or 3)

@lru_cache()
def make_step_impl( options, non_unit_g_factor, grid, n_threads, left_first: tuple, buffer_size):
130@lru_cache()
131# pylint: disable=too-many-locals,too-many-statements,too-many-arguments
132def make_step_impl(
133    options, non_unit_g_factor, grid, n_threads, left_first: tuple, buffer_size
134):
135    """returns (and caches) an njit-ted stepping function and a traversals pair"""
136    traversals = Traversals(
137        grid=grid,
138        halo=options.n_halo,
139        jit_flags=options.jit_flags,
140        n_threads=n_threads,
141        left_first=left_first,
142        buffer_size=buffer_size,
143    )
144
145    n_iters = options.n_iters
146    non_zero_mu_coeff = options.non_zero_mu_coeff
147    nonoscillatory = options.nonoscillatory
148
149    upwind = make_upwind(options, non_unit_g_factor, traversals)
150    flux_first_pass = make_flux_first_pass(options, traversals)
151    flux_subsequent = make_flux_subsequent(options, traversals)
152    antidiff = make_antidiff(non_unit_g_factor, options, traversals)
153    antidiff_last_pass = make_antidiff(
154        non_unit_g_factor, options, traversals, last_pass=True
155    )
156    laplacian = make_laplacian(non_unit_g_factor, options, traversals)
157    nonoscillatory_psi_extrema = make_psi_extrema(options, traversals)
158    nonoscillatory_beta = make_beta(non_unit_g_factor, options, traversals)
159    nonoscillatory_correction = make_correction(options, traversals)
160    axpy = make_axpy(options, traversals)
161
162    @numba.njit(**options.jit_flags)
163    # pylint: disable=too-many-arguments,too-many-locals,too-many-branches,unnecessary-dunder-call
164    def step(
165        n_steps,
166        mu_coeff,
167        post_step,
168        post_iter,
169        advectee,
170        advector,
171        g_factor,
172        vectmp_a,
173        vectmp_b,
174        vectmp_c,
175        psi_extrema,
176        beta,
177        null_impl,
178    ):
179        time = clock()
180        for step in range(n_steps):
181            if non_zero_mu_coeff:
182                advector_orig = advector
183                advector = vectmp_c
184            for iteration in range(n_iters):
185                if iteration == 0:
186                    if nonoscillatory:
187                        nonoscillatory_psi_extrema(null_impl, psi_extrema, advectee)
188                    if non_zero_mu_coeff:
189                        laplacian(null_impl, advector, advectee)
190                        axpy(
191                            *advector.field,
192                            mu_coeff,
193                            *advector.field,
194                            *advector_orig.field,
195                        )
196                    flux_first_pass(null_impl, vectmp_a, advector, advectee)
197                    flux = vectmp_a
198                else:
199                    if iteration == 1:
200                        advector_oscil = advector
201                        advector_nonos = vectmp_a
202                        flux = vectmp_b
203                    elif iteration % 2 == 0:
204                        advector_oscil = vectmp_a
205                        advector_nonos = vectmp_b
206                        flux = vectmp_a
207                    else:
208                        advector_oscil = vectmp_b
209                        advector_nonos = vectmp_a
210                        flux = vectmp_b
211                    if iteration < n_iters - 1:
212                        antidiff(
213                            null_impl,
214                            advector_nonos,
215                            advectee,
216                            advector_oscil,
217                            g_factor,
218                        )
219                    else:
220                        antidiff_last_pass(
221                            null_impl,
222                            advector_nonos,
223                            advectee,
224                            advector_oscil,
225                            g_factor,
226                        )
227                    flux_subsequent(null_impl, flux, advectee, advector_nonos)
228                    if nonoscillatory:
229                        nonoscillatory_beta(
230                            null_impl, beta, flux, advectee, psi_extrema, g_factor
231                        )
232                        # note: in libmpdata++, the oscillatory advector from prev iter is used
233                        nonoscillatory_correction(null_impl, advector_nonos, beta)
234                        flux_subsequent(null_impl, flux, advectee, advector_nonos)
235                upwind(null_impl, advectee, flux, g_factor)
236                post_iter.call(flux.field, g_factor.field, step, iteration)
237            if non_zero_mu_coeff:
238                advector = advector_orig
239            post_step.call(advectee.field[ARG_DATA], step)
240        return (clock() - time) / n_steps if n_steps > 0 else np.nan
241
242    return step, traversals

returns (and caches) an njit-ted stepping function and a traversals pair