PySDM_examples.Szumowski_et_al_1998.mpdata_2d
1import inspect 2from functools import cached_property 3from threading import Thread 4 5import numpy as np 6from PyMPDATA import Options, ScalarField, Solver, Stepper, VectorField 7from PyMPDATA.boundary_conditions import Periodic 8from PySDM_examples.Szumowski_et_al_1998.fields import ( 9 nondivergent_vector_field_2d, 10 x_vec_coord, 11 z_vec_coord, 12) 13 14from PySDM.backends.impl_numba import conf 15from PySDM.impl.arakawa_c import make_rhod 16 17 18class MPDATA_2D: 19 def __init__( 20 self, 21 *, 22 advectees, 23 stream_function, 24 rhod_of_zZ, 25 dt, 26 grid, 27 size, 28 n_iters=2, 29 infinite_gauge=True, 30 nonoscillatory=True, 31 third_order_terms=False 32 ): 33 self._grid = grid 34 self.size = size 35 self.dt = dt 36 self.stream_function = stream_function 37 self.stream_function_time_dependent = ( 38 "t" in inspect.signature(stream_function).parameters 39 ) 40 self.asynchronous = False 41 self.thread: (Thread, None) = None 42 self.t = 0 43 self.advectees = advectees 44 45 self._options = Options( 46 n_iters=n_iters, 47 infinite_gauge=infinite_gauge, 48 nonoscillatory=nonoscillatory, 49 third_order_terms=third_order_terms, 50 ) 51 52 self.g_factor = make_rhod(grid, rhod_of_zZ) 53 self.g_factor_vec = ( 54 rhod_of_zZ(zZ=x_vec_coord(grid)[-1]), 55 rhod_of_zZ(zZ=z_vec_coord(grid)[-1]), 56 ) 57 58 @cached_property 59 def mpdatas(self): 60 disable_threads_if_needed = {} 61 if not conf.JIT_FLAGS["parallel"]: 62 disable_threads_if_needed["n_threads"] = 1 63 64 stepper = Stepper( 65 options=self._options, 66 grid=self._grid, 67 non_unit_g_factor=True, 68 **disable_threads_if_needed 69 ) 70 71 advector_impl = VectorField( 72 ( 73 np.full((self._grid[0] + 1, self._grid[1]), np.nan), 74 np.full((self._grid[0], self._grid[1] + 1), np.nan), 75 ), 76 halo=self._options.n_halo, 77 boundary_conditions=(Periodic(), Periodic()), 78 ) 79 80 g_factor_impl = ScalarField( 81 self.g_factor.astype(dtype=self._options.dtype), 82 halo=self._options.n_halo, 83 boundary_conditions=(Periodic(), Periodic()), 84 ) 85 86 mpdatas = {} 87 for k, v in self.advectees.items(): 88 advectee_impl = ScalarField( 89 np.asarray(v, dtype=self._options.dtype), 90 halo=self._options.n_halo, 91 boundary_conditions=(Periodic(), Periodic()), 92 ) 93 mpdatas[k] = Solver( 94 stepper=stepper, 95 advectee=advectee_impl, 96 advector=advector_impl, 97 g_factor=g_factor_impl, 98 ) 99 return mpdatas 100 101 def __getitem__(self, key: str): 102 if "mpdatas" in self.__dict__: 103 return self.mpdatas[key].advectee.get() 104 return self.advectees[key] 105 106 def __call__(self, displacement): 107 if self.asynchronous: 108 self.thread = Thread(target=self.step, args=()) 109 self.thread.start() 110 else: 111 self.step(displacement) 112 113 def wait(self): 114 if self.asynchronous: 115 if self.thread is not None: 116 self.thread.join() 117 118 def refresh_advector(self, displacement): 119 for mpdata in self.mpdatas.values(): 120 advector = nondivergent_vector_field_2d( 121 self._grid, self.size, self.dt, self.stream_function, t=self.t 122 ) 123 for d in range(len(self._grid)): 124 np.testing.assert_array_less(np.abs(advector[d]), 1) 125 mpdata.advector.get_component(d)[:] = advector[d] 126 if displacement is not None: 127 for d in range(len(self._grid)): 128 advector[d] /= self.g_factor_vec[d] 129 displacement.upload_courant_field(advector) 130 break # the advector field is shared 131 132 def step(self, displacement): 133 if not self.stream_function_time_dependent and self.t == 0: 134 self.refresh_advector(displacement) 135 136 self.t += 0.5 * self.dt 137 if self.stream_function_time_dependent: 138 self.refresh_advector(displacement) 139 for mpdata in self.mpdatas.values(): 140 mpdata.advance(1) 141 self.t += 0.5 * self.dt
class
MPDATA_2D:
19class MPDATA_2D: 20 def __init__( 21 self, 22 *, 23 advectees, 24 stream_function, 25 rhod_of_zZ, 26 dt, 27 grid, 28 size, 29 n_iters=2, 30 infinite_gauge=True, 31 nonoscillatory=True, 32 third_order_terms=False 33 ): 34 self._grid = grid 35 self.size = size 36 self.dt = dt 37 self.stream_function = stream_function 38 self.stream_function_time_dependent = ( 39 "t" in inspect.signature(stream_function).parameters 40 ) 41 self.asynchronous = False 42 self.thread: (Thread, None) = None 43 self.t = 0 44 self.advectees = advectees 45 46 self._options = Options( 47 n_iters=n_iters, 48 infinite_gauge=infinite_gauge, 49 nonoscillatory=nonoscillatory, 50 third_order_terms=third_order_terms, 51 ) 52 53 self.g_factor = make_rhod(grid, rhod_of_zZ) 54 self.g_factor_vec = ( 55 rhod_of_zZ(zZ=x_vec_coord(grid)[-1]), 56 rhod_of_zZ(zZ=z_vec_coord(grid)[-1]), 57 ) 58 59 @cached_property 60 def mpdatas(self): 61 disable_threads_if_needed = {} 62 if not conf.JIT_FLAGS["parallel"]: 63 disable_threads_if_needed["n_threads"] = 1 64 65 stepper = Stepper( 66 options=self._options, 67 grid=self._grid, 68 non_unit_g_factor=True, 69 **disable_threads_if_needed 70 ) 71 72 advector_impl = VectorField( 73 ( 74 np.full((self._grid[0] + 1, self._grid[1]), np.nan), 75 np.full((self._grid[0], self._grid[1] + 1), np.nan), 76 ), 77 halo=self._options.n_halo, 78 boundary_conditions=(Periodic(), Periodic()), 79 ) 80 81 g_factor_impl = ScalarField( 82 self.g_factor.astype(dtype=self._options.dtype), 83 halo=self._options.n_halo, 84 boundary_conditions=(Periodic(), Periodic()), 85 ) 86 87 mpdatas = {} 88 for k, v in self.advectees.items(): 89 advectee_impl = ScalarField( 90 np.asarray(v, dtype=self._options.dtype), 91 halo=self._options.n_halo, 92 boundary_conditions=(Periodic(), Periodic()), 93 ) 94 mpdatas[k] = Solver( 95 stepper=stepper, 96 advectee=advectee_impl, 97 advector=advector_impl, 98 g_factor=g_factor_impl, 99 ) 100 return mpdatas 101 102 def __getitem__(self, key: str): 103 if "mpdatas" in self.__dict__: 104 return self.mpdatas[key].advectee.get() 105 return self.advectees[key] 106 107 def __call__(self, displacement): 108 if self.asynchronous: 109 self.thread = Thread(target=self.step, args=()) 110 self.thread.start() 111 else: 112 self.step(displacement) 113 114 def wait(self): 115 if self.asynchronous: 116 if self.thread is not None: 117 self.thread.join() 118 119 def refresh_advector(self, displacement): 120 for mpdata in self.mpdatas.values(): 121 advector = nondivergent_vector_field_2d( 122 self._grid, self.size, self.dt, self.stream_function, t=self.t 123 ) 124 for d in range(len(self._grid)): 125 np.testing.assert_array_less(np.abs(advector[d]), 1) 126 mpdata.advector.get_component(d)[:] = advector[d] 127 if displacement is not None: 128 for d in range(len(self._grid)): 129 advector[d] /= self.g_factor_vec[d] 130 displacement.upload_courant_field(advector) 131 break # the advector field is shared 132 133 def step(self, displacement): 134 if not self.stream_function_time_dependent and self.t == 0: 135 self.refresh_advector(displacement) 136 137 self.t += 0.5 * self.dt 138 if self.stream_function_time_dependent: 139 self.refresh_advector(displacement) 140 for mpdata in self.mpdatas.values(): 141 mpdata.advance(1) 142 self.t += 0.5 * self.dt
MPDATA_2D( *, advectees, stream_function, rhod_of_zZ, dt, grid, size, n_iters=2, infinite_gauge=True, nonoscillatory=True, third_order_terms=False)
20 def __init__( 21 self, 22 *, 23 advectees, 24 stream_function, 25 rhod_of_zZ, 26 dt, 27 grid, 28 size, 29 n_iters=2, 30 infinite_gauge=True, 31 nonoscillatory=True, 32 third_order_terms=False 33 ): 34 self._grid = grid 35 self.size = size 36 self.dt = dt 37 self.stream_function = stream_function 38 self.stream_function_time_dependent = ( 39 "t" in inspect.signature(stream_function).parameters 40 ) 41 self.asynchronous = False 42 self.thread: (Thread, None) = None 43 self.t = 0 44 self.advectees = advectees 45 46 self._options = Options( 47 n_iters=n_iters, 48 infinite_gauge=infinite_gauge, 49 nonoscillatory=nonoscillatory, 50 third_order_terms=third_order_terms, 51 ) 52 53 self.g_factor = make_rhod(grid, rhod_of_zZ) 54 self.g_factor_vec = ( 55 rhod_of_zZ(zZ=x_vec_coord(grid)[-1]), 56 rhod_of_zZ(zZ=z_vec_coord(grid)[-1]), 57 )
mpdatas
59 @cached_property 60 def mpdatas(self): 61 disable_threads_if_needed = {} 62 if not conf.JIT_FLAGS["parallel"]: 63 disable_threads_if_needed["n_threads"] = 1 64 65 stepper = Stepper( 66 options=self._options, 67 grid=self._grid, 68 non_unit_g_factor=True, 69 **disable_threads_if_needed 70 ) 71 72 advector_impl = VectorField( 73 ( 74 np.full((self._grid[0] + 1, self._grid[1]), np.nan), 75 np.full((self._grid[0], self._grid[1] + 1), np.nan), 76 ), 77 halo=self._options.n_halo, 78 boundary_conditions=(Periodic(), Periodic()), 79 ) 80 81 g_factor_impl = ScalarField( 82 self.g_factor.astype(dtype=self._options.dtype), 83 halo=self._options.n_halo, 84 boundary_conditions=(Periodic(), Periodic()), 85 ) 86 87 mpdatas = {} 88 for k, v in self.advectees.items(): 89 advectee_impl = ScalarField( 90 np.asarray(v, dtype=self._options.dtype), 91 halo=self._options.n_halo, 92 boundary_conditions=(Periodic(), Periodic()), 93 ) 94 mpdatas[k] = Solver( 95 stepper=stepper, 96 advectee=advectee_impl, 97 advector=advector_impl, 98 g_factor=g_factor_impl, 99 ) 100 return mpdatas
def
refresh_advector(self, displacement):
119 def refresh_advector(self, displacement): 120 for mpdata in self.mpdatas.values(): 121 advector = nondivergent_vector_field_2d( 122 self._grid, self.size, self.dt, self.stream_function, t=self.t 123 ) 124 for d in range(len(self._grid)): 125 np.testing.assert_array_less(np.abs(advector[d]), 1) 126 mpdata.advector.get_component(d)[:] = advector[d] 127 if displacement is not None: 128 for d in range(len(self._grid)): 129 advector[d] /= self.g_factor_vec[d] 130 displacement.upload_courant_field(advector) 131 break # the advector field is shared
def
step(self, displacement):
133 def step(self, displacement): 134 if not self.stream_function_time_dependent and self.t == 0: 135 self.refresh_advector(displacement) 136 137 self.t += 0.5 * self.dt 138 if self.stream_function_time_dependent: 139 self.refresh_advector(displacement) 140 for mpdata in self.mpdatas.values(): 141 mpdata.advance(1) 142 self.t += 0.5 * self.dt