PySDM_examples.Szumowski_et_al_1998.mpdata_2d

  1import inspect
  2from functools import cached_property
  3from threading import Thread
  4
  5import numpy as np
  6from PyMPDATA import Options, ScalarField, Solver, Stepper, VectorField
  7from PyMPDATA.boundary_conditions import Periodic
  8from PySDM_examples.Szumowski_et_al_1998.fields import (
  9    nondivergent_vector_field_2d,
 10    x_vec_coord,
 11    z_vec_coord,
 12)
 13
 14from PySDM.backends.impl_numba import conf
 15from PySDM.impl.arakawa_c import make_rhod
 16
 17
 18class MPDATA_2D:
 19    def __init__(
 20        self,
 21        *,
 22        advectees,
 23        stream_function,
 24        rhod_of_zZ,
 25        dt,
 26        grid,
 27        size,
 28        n_iters=2,
 29        infinite_gauge=True,
 30        nonoscillatory=True,
 31        third_order_terms=False
 32    ):
 33        self._grid = grid
 34        self.size = size
 35        self.dt = dt
 36        self.stream_function = stream_function
 37        self.stream_function_time_dependent = (
 38            "t" in inspect.signature(stream_function).parameters
 39        )
 40        self.asynchronous = False
 41        self.thread: (Thread, None) = None
 42        self.t = 0
 43        self.advectees = advectees
 44
 45        self._options = Options(
 46            n_iters=n_iters,
 47            infinite_gauge=infinite_gauge,
 48            nonoscillatory=nonoscillatory,
 49            third_order_terms=third_order_terms,
 50        )
 51
 52        self.g_factor = make_rhod(grid, rhod_of_zZ)
 53        self.g_factor_vec = (
 54            rhod_of_zZ(zZ=x_vec_coord(grid)[-1]),
 55            rhod_of_zZ(zZ=z_vec_coord(grid)[-1]),
 56        )
 57
 58    @cached_property
 59    def mpdatas(self):
 60        disable_threads_if_needed = {}
 61        if not conf.JIT_FLAGS["parallel"]:
 62            disable_threads_if_needed["n_threads"] = 1
 63
 64        stepper = Stepper(
 65            options=self._options,
 66            grid=self._grid,
 67            non_unit_g_factor=True,
 68            **disable_threads_if_needed
 69        )
 70
 71        advector_impl = VectorField(
 72            (
 73                np.full((self._grid[0] + 1, self._grid[1]), np.nan),
 74                np.full((self._grid[0], self._grid[1] + 1), np.nan),
 75            ),
 76            halo=self._options.n_halo,
 77            boundary_conditions=(Periodic(), Periodic()),
 78        )
 79
 80        g_factor_impl = ScalarField(
 81            self.g_factor.astype(dtype=self._options.dtype),
 82            halo=self._options.n_halo,
 83            boundary_conditions=(Periodic(), Periodic()),
 84        )
 85
 86        mpdatas = {}
 87        for k, v in self.advectees.items():
 88            advectee_impl = ScalarField(
 89                np.asarray(v, dtype=self._options.dtype),
 90                halo=self._options.n_halo,
 91                boundary_conditions=(Periodic(), Periodic()),
 92            )
 93            mpdatas[k] = Solver(
 94                stepper=stepper,
 95                advectee=advectee_impl,
 96                advector=advector_impl,
 97                g_factor=g_factor_impl,
 98            )
 99        return mpdatas
100
101    def __getitem__(self, key: str):
102        if "mpdatas" in self.__dict__:
103            return self.mpdatas[key].advectee.get()
104        return self.advectees[key]
105
106    def __call__(self, displacement):
107        if self.asynchronous:
108            self.thread = Thread(target=self.step, args=())
109            self.thread.start()
110        else:
111            self.step(displacement)
112
113    def wait(self):
114        if self.asynchronous:
115            if self.thread is not None:
116                self.thread.join()
117
118    def refresh_advector(self, displacement):
119        for mpdata in self.mpdatas.values():
120            advector = nondivergent_vector_field_2d(
121                self._grid, self.size, self.dt, self.stream_function, t=self.t
122            )
123            for d in range(len(self._grid)):
124                np.testing.assert_array_less(np.abs(advector[d]), 1)
125                mpdata.advector.get_component(d)[:] = advector[d]
126            if displacement is not None:
127                for d in range(len(self._grid)):
128                    advector[d] /= self.g_factor_vec[d]
129                displacement.upload_courant_field(advector)
130            break  # the advector field is shared
131
132    def step(self, displacement):
133        if not self.stream_function_time_dependent and self.t == 0:
134            self.refresh_advector(displacement)
135
136        self.t += 0.5 * self.dt
137        if self.stream_function_time_dependent:
138            self.refresh_advector(displacement)
139        for mpdata in self.mpdatas.values():
140            mpdata.advance(1)
141        self.t += 0.5 * self.dt
class MPDATA_2D:
 19class MPDATA_2D:
 20    def __init__(
 21        self,
 22        *,
 23        advectees,
 24        stream_function,
 25        rhod_of_zZ,
 26        dt,
 27        grid,
 28        size,
 29        n_iters=2,
 30        infinite_gauge=True,
 31        nonoscillatory=True,
 32        third_order_terms=False
 33    ):
 34        self._grid = grid
 35        self.size = size
 36        self.dt = dt
 37        self.stream_function = stream_function
 38        self.stream_function_time_dependent = (
 39            "t" in inspect.signature(stream_function).parameters
 40        )
 41        self.asynchronous = False
 42        self.thread: (Thread, None) = None
 43        self.t = 0
 44        self.advectees = advectees
 45
 46        self._options = Options(
 47            n_iters=n_iters,
 48            infinite_gauge=infinite_gauge,
 49            nonoscillatory=nonoscillatory,
 50            third_order_terms=third_order_terms,
 51        )
 52
 53        self.g_factor = make_rhod(grid, rhod_of_zZ)
 54        self.g_factor_vec = (
 55            rhod_of_zZ(zZ=x_vec_coord(grid)[-1]),
 56            rhod_of_zZ(zZ=z_vec_coord(grid)[-1]),
 57        )
 58
 59    @cached_property
 60    def mpdatas(self):
 61        disable_threads_if_needed = {}
 62        if not conf.JIT_FLAGS["parallel"]:
 63            disable_threads_if_needed["n_threads"] = 1
 64
 65        stepper = Stepper(
 66            options=self._options,
 67            grid=self._grid,
 68            non_unit_g_factor=True,
 69            **disable_threads_if_needed
 70        )
 71
 72        advector_impl = VectorField(
 73            (
 74                np.full((self._grid[0] + 1, self._grid[1]), np.nan),
 75                np.full((self._grid[0], self._grid[1] + 1), np.nan),
 76            ),
 77            halo=self._options.n_halo,
 78            boundary_conditions=(Periodic(), Periodic()),
 79        )
 80
 81        g_factor_impl = ScalarField(
 82            self.g_factor.astype(dtype=self._options.dtype),
 83            halo=self._options.n_halo,
 84            boundary_conditions=(Periodic(), Periodic()),
 85        )
 86
 87        mpdatas = {}
 88        for k, v in self.advectees.items():
 89            advectee_impl = ScalarField(
 90                np.asarray(v, dtype=self._options.dtype),
 91                halo=self._options.n_halo,
 92                boundary_conditions=(Periodic(), Periodic()),
 93            )
 94            mpdatas[k] = Solver(
 95                stepper=stepper,
 96                advectee=advectee_impl,
 97                advector=advector_impl,
 98                g_factor=g_factor_impl,
 99            )
100        return mpdatas
101
102    def __getitem__(self, key: str):
103        if "mpdatas" in self.__dict__:
104            return self.mpdatas[key].advectee.get()
105        return self.advectees[key]
106
107    def __call__(self, displacement):
108        if self.asynchronous:
109            self.thread = Thread(target=self.step, args=())
110            self.thread.start()
111        else:
112            self.step(displacement)
113
114    def wait(self):
115        if self.asynchronous:
116            if self.thread is not None:
117                self.thread.join()
118
119    def refresh_advector(self, displacement):
120        for mpdata in self.mpdatas.values():
121            advector = nondivergent_vector_field_2d(
122                self._grid, self.size, self.dt, self.stream_function, t=self.t
123            )
124            for d in range(len(self._grid)):
125                np.testing.assert_array_less(np.abs(advector[d]), 1)
126                mpdata.advector.get_component(d)[:] = advector[d]
127            if displacement is not None:
128                for d in range(len(self._grid)):
129                    advector[d] /= self.g_factor_vec[d]
130                displacement.upload_courant_field(advector)
131            break  # the advector field is shared
132
133    def step(self, displacement):
134        if not self.stream_function_time_dependent and self.t == 0:
135            self.refresh_advector(displacement)
136
137        self.t += 0.5 * self.dt
138        if self.stream_function_time_dependent:
139            self.refresh_advector(displacement)
140        for mpdata in self.mpdatas.values():
141            mpdata.advance(1)
142        self.t += 0.5 * self.dt
MPDATA_2D( *, advectees, stream_function, rhod_of_zZ, dt, grid, size, n_iters=2, infinite_gauge=True, nonoscillatory=True, third_order_terms=False)
20    def __init__(
21        self,
22        *,
23        advectees,
24        stream_function,
25        rhod_of_zZ,
26        dt,
27        grid,
28        size,
29        n_iters=2,
30        infinite_gauge=True,
31        nonoscillatory=True,
32        third_order_terms=False
33    ):
34        self._grid = grid
35        self.size = size
36        self.dt = dt
37        self.stream_function = stream_function
38        self.stream_function_time_dependent = (
39            "t" in inspect.signature(stream_function).parameters
40        )
41        self.asynchronous = False
42        self.thread: (Thread, None) = None
43        self.t = 0
44        self.advectees = advectees
45
46        self._options = Options(
47            n_iters=n_iters,
48            infinite_gauge=infinite_gauge,
49            nonoscillatory=nonoscillatory,
50            third_order_terms=third_order_terms,
51        )
52
53        self.g_factor = make_rhod(grid, rhod_of_zZ)
54        self.g_factor_vec = (
55            rhod_of_zZ(zZ=x_vec_coord(grid)[-1]),
56            rhod_of_zZ(zZ=z_vec_coord(grid)[-1]),
57        )
size
dt
stream_function
stream_function_time_dependent
asynchronous
thread: (<class 'threading.Thread'>, None)
t
advectees
g_factor
g_factor_vec
mpdatas
 59    @cached_property
 60    def mpdatas(self):
 61        disable_threads_if_needed = {}
 62        if not conf.JIT_FLAGS["parallel"]:
 63            disable_threads_if_needed["n_threads"] = 1
 64
 65        stepper = Stepper(
 66            options=self._options,
 67            grid=self._grid,
 68            non_unit_g_factor=True,
 69            **disable_threads_if_needed
 70        )
 71
 72        advector_impl = VectorField(
 73            (
 74                np.full((self._grid[0] + 1, self._grid[1]), np.nan),
 75                np.full((self._grid[0], self._grid[1] + 1), np.nan),
 76            ),
 77            halo=self._options.n_halo,
 78            boundary_conditions=(Periodic(), Periodic()),
 79        )
 80
 81        g_factor_impl = ScalarField(
 82            self.g_factor.astype(dtype=self._options.dtype),
 83            halo=self._options.n_halo,
 84            boundary_conditions=(Periodic(), Periodic()),
 85        )
 86
 87        mpdatas = {}
 88        for k, v in self.advectees.items():
 89            advectee_impl = ScalarField(
 90                np.asarray(v, dtype=self._options.dtype),
 91                halo=self._options.n_halo,
 92                boundary_conditions=(Periodic(), Periodic()),
 93            )
 94            mpdatas[k] = Solver(
 95                stepper=stepper,
 96                advectee=advectee_impl,
 97                advector=advector_impl,
 98                g_factor=g_factor_impl,
 99            )
100        return mpdatas
def wait(self):
114    def wait(self):
115        if self.asynchronous:
116            if self.thread is not None:
117                self.thread.join()
def refresh_advector(self, displacement):
119    def refresh_advector(self, displacement):
120        for mpdata in self.mpdatas.values():
121            advector = nondivergent_vector_field_2d(
122                self._grid, self.size, self.dt, self.stream_function, t=self.t
123            )
124            for d in range(len(self._grid)):
125                np.testing.assert_array_less(np.abs(advector[d]), 1)
126                mpdata.advector.get_component(d)[:] = advector[d]
127            if displacement is not None:
128                for d in range(len(self._grid)):
129                    advector[d] /= self.g_factor_vec[d]
130                displacement.upload_courant_field(advector)
131            break  # the advector field is shared
def step(self, displacement):
133    def step(self, displacement):
134        if not self.stream_function_time_dependent and self.t == 0:
135            self.refresh_advector(displacement)
136
137        self.t += 0.5 * self.dt
138        if self.stream_function_time_dependent:
139            self.refresh_advector(displacement)
140        for mpdata in self.mpdatas.values():
141            mpdata.advance(1)
142        self.t += 0.5 * self.dt