PyMPDATA.boundary_conditions.polar
polar boundary condition for use in with spherical coordinates
1"""polar boundary condition for use in with spherical coordinates""" 2 3from functools import lru_cache 4 5import numba 6 7from PyMPDATA.impl.enumerations import ARG_FOCUS, SIGN_LEFT, SIGN_RIGHT 8from PyMPDATA.impl.traversals_common import ( 9 make_fill_halos_loop, 10 make_fill_halos_loop_vector, 11) 12 13 14class Polar: 15 """class which instances are to be passed in boundary_conditions tuple to the 16 `ScalarField` and `VectorField` __init__ methods""" 17 18 def __init__(self, grid, longitude_idx, latitude_idx): 19 assert SIGN_RIGHT == -1 20 assert SIGN_LEFT == +1 21 22 self.nlon = grid[longitude_idx] 23 self.nlat = grid[latitude_idx] 24 assert self.nlon % 2 == 0 25 26 self.nlon_half = self.nlon // 2 27 self.lon_idx = longitude_idx 28 self.lat_idx = latitude_idx 29 30 def make_scalar(self, indexers, halo, _, jit_flags, dimension_index): 31 """returns (lru-cached) Numba-compiled scalar halo-filling callable""" 32 nlon_half = self.nlon_half 33 nlat = self.nlat 34 lon_idx = self.lon_idx 35 lat_idx = self.lat_idx 36 left_edge_idx = halo - 1 37 right_edge_idx = nlat + halo 38 ats = indexers.ats[dimension_index] 39 set_value = indexers.set 40 41 @numba.njit(**jit_flags) 42 def fill_halos(psi, _, sign): 43 lon = psi[ARG_FOCUS][lon_idx] 44 lat = psi[ARG_FOCUS][lat_idx] 45 if lat <= left_edge_idx: 46 step = (left_edge_idx - lat) * 2 + 1 47 else: 48 step = (lat - right_edge_idx) * 2 + 1 49 50 val = nlon_half * (-1 if lon > nlon_half else 1) 51 return ats(*psi, sign * step, val) 52 53 return make_fill_halos_loop(jit_flags, set_value, fill_halos) 54 55 @staticmethod 56 def make_vector(indexers, _, __, jit_flags, dimension_index): 57 """returns (lru-cached) Numba-compiled vector halo-filling callable""" 58 return _make_vector_polar( 59 indexers.atv, indexers.set, jit_flags, dimension_index 60 ) 61 62 63@lru_cache() 64def _make_vector_polar(_atv, set_value, jit_flags, dimension_index): 65 @numba.njit(**jit_flags) 66 def fill_halos_parallel(_1, _2, _3): 67 return 0 # TODO #120 68 69 @numba.njit(**jit_flags) 70 def fill_halos_normal(_1, _2, _3, _4): 71 return 0 # TODO #120 72 73 return make_fill_halos_loop_vector( 74 jit_flags, set_value, fill_halos_parallel, fill_halos_normal, dimension_index 75 )
class
Polar:
15class Polar: 16 """class which instances are to be passed in boundary_conditions tuple to the 17 `ScalarField` and `VectorField` __init__ methods""" 18 19 def __init__(self, grid, longitude_idx, latitude_idx): 20 assert SIGN_RIGHT == -1 21 assert SIGN_LEFT == +1 22 23 self.nlon = grid[longitude_idx] 24 self.nlat = grid[latitude_idx] 25 assert self.nlon % 2 == 0 26 27 self.nlon_half = self.nlon // 2 28 self.lon_idx = longitude_idx 29 self.lat_idx = latitude_idx 30 31 def make_scalar(self, indexers, halo, _, jit_flags, dimension_index): 32 """returns (lru-cached) Numba-compiled scalar halo-filling callable""" 33 nlon_half = self.nlon_half 34 nlat = self.nlat 35 lon_idx = self.lon_idx 36 lat_idx = self.lat_idx 37 left_edge_idx = halo - 1 38 right_edge_idx = nlat + halo 39 ats = indexers.ats[dimension_index] 40 set_value = indexers.set 41 42 @numba.njit(**jit_flags) 43 def fill_halos(psi, _, sign): 44 lon = psi[ARG_FOCUS][lon_idx] 45 lat = psi[ARG_FOCUS][lat_idx] 46 if lat <= left_edge_idx: 47 step = (left_edge_idx - lat) * 2 + 1 48 else: 49 step = (lat - right_edge_idx) * 2 + 1 50 51 val = nlon_half * (-1 if lon > nlon_half else 1) 52 return ats(*psi, sign * step, val) 53 54 return make_fill_halos_loop(jit_flags, set_value, fill_halos) 55 56 @staticmethod 57 def make_vector(indexers, _, __, jit_flags, dimension_index): 58 """returns (lru-cached) Numba-compiled vector halo-filling callable""" 59 return _make_vector_polar( 60 indexers.atv, indexers.set, jit_flags, dimension_index 61 )
class which instances are to be passed in boundary_conditions tuple to the
ScalarField
and VectorField
__init__ methods
Polar(grid, longitude_idx, latitude_idx)
19 def __init__(self, grid, longitude_idx, latitude_idx): 20 assert SIGN_RIGHT == -1 21 assert SIGN_LEFT == +1 22 23 self.nlon = grid[longitude_idx] 24 self.nlat = grid[latitude_idx] 25 assert self.nlon % 2 == 0 26 27 self.nlon_half = self.nlon // 2 28 self.lon_idx = longitude_idx 29 self.lat_idx = latitude_idx
def
make_scalar(self, indexers, halo, _, jit_flags, dimension_index):
31 def make_scalar(self, indexers, halo, _, jit_flags, dimension_index): 32 """returns (lru-cached) Numba-compiled scalar halo-filling callable""" 33 nlon_half = self.nlon_half 34 nlat = self.nlat 35 lon_idx = self.lon_idx 36 lat_idx = self.lat_idx 37 left_edge_idx = halo - 1 38 right_edge_idx = nlat + halo 39 ats = indexers.ats[dimension_index] 40 set_value = indexers.set 41 42 @numba.njit(**jit_flags) 43 def fill_halos(psi, _, sign): 44 lon = psi[ARG_FOCUS][lon_idx] 45 lat = psi[ARG_FOCUS][lat_idx] 46 if lat <= left_edge_idx: 47 step = (left_edge_idx - lat) * 2 + 1 48 else: 49 step = (lat - right_edge_idx) * 2 + 1 50 51 val = nlon_half * (-1 if lon > nlon_half else 1) 52 return ats(*psi, sign * step, val) 53 54 return make_fill_halos_loop(jit_flags, set_value, fill_halos)
returns (lru-cached) Numba-compiled scalar halo-filling callable
@staticmethod
def
make_vector(indexers, _, __, jit_flags, dimension_index):
56 @staticmethod 57 def make_vector(indexers, _, __, jit_flags, dimension_index): 58 """returns (lru-cached) Numba-compiled vector halo-filling callable""" 59 return _make_vector_polar( 60 indexers.atv, indexers.set, jit_flags, dimension_index 61 )
returns (lru-cached) Numba-compiled vector halo-filling callable