PyMPDATA_MPI.mpi_polar
polar boundary condition logic
1""" polar boundary condition logic """ 2 3from functools import lru_cache 4 5import numba 6import numba_mpi as mpi 7from PyMPDATA.boundary_conditions import Polar 8from PyMPDATA.impl.enumerations import INNER, OUTER 9 10from PyMPDATA_MPI.impl import MPIBoundaryCondition 11 12 13class MPIPolar(MPIBoundaryCondition): 14 """class which instances are to be passed in boundary_conditions tuple to the 15 `PyMPDATA.scalar_field.ScalarField` and 16 `PyMPDATA.vector_field.VectorField` __init__ methods""" 17 18 def __init__(self, mpi_grid, grid, mpi_dim): 19 self.worker_pool_size = grid[mpi_dim] // mpi_grid[mpi_dim] 20 self.__mpi_size_one = self.worker_pool_size == 1 21 22 if not self.__mpi_size_one: 23 only_one_peer_per_subdomain = self.worker_pool_size % 2 == 0 24 assert only_one_peer_per_subdomain 25 26 super().__init__( 27 size=self.worker_pool_size, 28 base=( 29 Polar(grid=grid, longitude_idx=OUTER, latitude_idx=INNER) 30 if self.__mpi_size_one 31 else None 32 ), 33 mpi_dim=mpi_dim, 34 ) 35 36 @staticmethod 37 def make_vector(indexers, halo, dtype, jit_flags, dimension_index): 38 """returns (lru-cached) Numba-compiled vector halo-filling callable""" 39 return Polar.make_vector(indexers, halo, dtype, jit_flags, dimension_index) 40 41 @staticmethod 42 @lru_cache 43 def make_get_peer(jit_flags, size): 44 """returns a numba-compiled callable.""" 45 46 @numba.njit(**jit_flags) 47 def get_peer(_): 48 rank = mpi.rank() 49 peer = (rank + size // 2) % size 50 return peer, peer < size // 2 51 52 return get_peer
14class MPIPolar(MPIBoundaryCondition): 15 """class which instances are to be passed in boundary_conditions tuple to the 16 `PyMPDATA.scalar_field.ScalarField` and 17 `PyMPDATA.vector_field.VectorField` __init__ methods""" 18 19 def __init__(self, mpi_grid, grid, mpi_dim): 20 self.worker_pool_size = grid[mpi_dim] // mpi_grid[mpi_dim] 21 self.__mpi_size_one = self.worker_pool_size == 1 22 23 if not self.__mpi_size_one: 24 only_one_peer_per_subdomain = self.worker_pool_size % 2 == 0 25 assert only_one_peer_per_subdomain 26 27 super().__init__( 28 size=self.worker_pool_size, 29 base=( 30 Polar(grid=grid, longitude_idx=OUTER, latitude_idx=INNER) 31 if self.__mpi_size_one 32 else None 33 ), 34 mpi_dim=mpi_dim, 35 ) 36 37 @staticmethod 38 def make_vector(indexers, halo, dtype, jit_flags, dimension_index): 39 """returns (lru-cached) Numba-compiled vector halo-filling callable""" 40 return Polar.make_vector(indexers, halo, dtype, jit_flags, dimension_index) 41 42 @staticmethod 43 @lru_cache 44 def make_get_peer(jit_flags, size): 45 """returns a numba-compiled callable.""" 46 47 @numba.njit(**jit_flags) 48 def get_peer(_): 49 rank = mpi.rank() 50 peer = (rank + size // 2) % size 51 return peer, peer < size // 2 52 53 return get_peer
class which instances are to be passed in boundary_conditions tuple to the
PyMPDATA.scalar_field.ScalarField
and
PyMPDATA.vector_field.VectorField
__init__ methods
MPIPolar(mpi_grid, grid, mpi_dim)
19 def __init__(self, mpi_grid, grid, mpi_dim): 20 self.worker_pool_size = grid[mpi_dim] // mpi_grid[mpi_dim] 21 self.__mpi_size_one = self.worker_pool_size == 1 22 23 if not self.__mpi_size_one: 24 only_one_peer_per_subdomain = self.worker_pool_size % 2 == 0 25 assert only_one_peer_per_subdomain 26 27 super().__init__( 28 size=self.worker_pool_size, 29 base=( 30 Polar(grid=grid, longitude_idx=OUTER, latitude_idx=INNER) 31 if self.__mpi_size_one 32 else None 33 ), 34 mpi_dim=mpi_dim, 35 )
@staticmethod
def
make_vector(indexers, halo, dtype, jit_flags, dimension_index):
37 @staticmethod 38 def make_vector(indexers, halo, dtype, jit_flags, dimension_index): 39 """returns (lru-cached) Numba-compiled vector halo-filling callable""" 40 return Polar.make_vector(indexers, halo, dtype, jit_flags, dimension_index)
returns (lru-cached) Numba-compiled vector halo-filling callable
@staticmethod
@lru_cache
def
make_get_peer(jit_flags, size):
42 @staticmethod 43 @lru_cache 44 def make_get_peer(jit_flags, size): 45 """returns a numba-compiled callable.""" 46 47 @numba.njit(**jit_flags) 48 def get_peer(_): 49 rank = mpi.rank() 50 peer = (rank + size // 2) % size 51 return peer, peer < size // 2 52 53 return get_peer
returns a numba-compiled callable.