PyMPDATA_MPI.mpi_polar

polar boundary condition logic

 1""" polar boundary condition logic """
 2
 3from functools import lru_cache
 4
 5import numba
 6import numba_mpi as mpi
 7from PyMPDATA.boundary_conditions import Polar
 8from PyMPDATA.impl.enumerations import INNER, OUTER
 9
10from PyMPDATA_MPI.impl import MPIBoundaryCondition
11
12
13class MPIPolar(MPIBoundaryCondition):
14    """class which instances are to be passed in boundary_conditions tuple to the
15    `PyMPDATA.scalar_field.ScalarField` and
16    `PyMPDATA.vector_field.VectorField` __init__ methods"""
17
18    def __init__(self, mpi_grid, grid, mpi_dim):
19        self.worker_pool_size = grid[mpi_dim] // mpi_grid[mpi_dim]
20        self.__mpi_size_one = self.worker_pool_size == 1
21
22        if not self.__mpi_size_one:
23            only_one_peer_per_subdomain = self.worker_pool_size % 2 == 0
24            assert only_one_peer_per_subdomain
25
26        super().__init__(
27            size=self.worker_pool_size,
28            base=(
29                Polar(grid=grid, longitude_idx=OUTER, latitude_idx=INNER)
30                if self.__mpi_size_one
31                else None
32            ),
33            mpi_dim=mpi_dim,
34        )
35
36    @staticmethod
37    def make_vector(indexers, halo, dtype, jit_flags, dimension_index):
38        """returns (lru-cached) Numba-compiled vector halo-filling callable"""
39        return Polar.make_vector(indexers, halo, dtype, jit_flags, dimension_index)
40
41    @staticmethod
42    @lru_cache
43    def make_get_peer(jit_flags, size):
44        """returns a numba-compiled callable."""
45
46        @numba.njit(**jit_flags)
47        def get_peer(_):
48            rank = mpi.rank()
49            peer = (rank + size // 2) % size
50            return peer, peer < size // 2
51
52        return get_peer
14class MPIPolar(MPIBoundaryCondition):
15    """class which instances are to be passed in boundary_conditions tuple to the
16    `PyMPDATA.scalar_field.ScalarField` and
17    `PyMPDATA.vector_field.VectorField` __init__ methods"""
18
19    def __init__(self, mpi_grid, grid, mpi_dim):
20        self.worker_pool_size = grid[mpi_dim] // mpi_grid[mpi_dim]
21        self.__mpi_size_one = self.worker_pool_size == 1
22
23        if not self.__mpi_size_one:
24            only_one_peer_per_subdomain = self.worker_pool_size % 2 == 0
25            assert only_one_peer_per_subdomain
26
27        super().__init__(
28            size=self.worker_pool_size,
29            base=(
30                Polar(grid=grid, longitude_idx=OUTER, latitude_idx=INNER)
31                if self.__mpi_size_one
32                else None
33            ),
34            mpi_dim=mpi_dim,
35        )
36
37    @staticmethod
38    def make_vector(indexers, halo, dtype, jit_flags, dimension_index):
39        """returns (lru-cached) Numba-compiled vector halo-filling callable"""
40        return Polar.make_vector(indexers, halo, dtype, jit_flags, dimension_index)
41
42    @staticmethod
43    @lru_cache
44    def make_get_peer(jit_flags, size):
45        """returns a numba-compiled callable."""
46
47        @numba.njit(**jit_flags)
48        def get_peer(_):
49            rank = mpi.rank()
50            peer = (rank + size // 2) % size
51            return peer, peer < size // 2
52
53        return get_peer

class which instances are to be passed in boundary_conditions tuple to the PyMPDATA.scalar_field.ScalarField and PyMPDATA.vector_field.VectorField __init__ methods

MPIPolar(mpi_grid, grid, mpi_dim)
19    def __init__(self, mpi_grid, grid, mpi_dim):
20        self.worker_pool_size = grid[mpi_dim] // mpi_grid[mpi_dim]
21        self.__mpi_size_one = self.worker_pool_size == 1
22
23        if not self.__mpi_size_one:
24            only_one_peer_per_subdomain = self.worker_pool_size % 2 == 0
25            assert only_one_peer_per_subdomain
26
27        super().__init__(
28            size=self.worker_pool_size,
29            base=(
30                Polar(grid=grid, longitude_idx=OUTER, latitude_idx=INNER)
31                if self.__mpi_size_one
32                else None
33            ),
34            mpi_dim=mpi_dim,
35        )
worker_pool_size
@staticmethod
def make_vector(indexers, halo, dtype, jit_flags, dimension_index):
37    @staticmethod
38    def make_vector(indexers, halo, dtype, jit_flags, dimension_index):
39        """returns (lru-cached) Numba-compiled vector halo-filling callable"""
40        return Polar.make_vector(indexers, halo, dtype, jit_flags, dimension_index)

returns (lru-cached) Numba-compiled vector halo-filling callable

@staticmethod
@lru_cache
def make_get_peer(jit_flags, size):
42    @staticmethod
43    @lru_cache
44    def make_get_peer(jit_flags, size):
45        """returns a numba-compiled callable."""
46
47        @numba.njit(**jit_flags)
48        def get_peer(_):
49            rank = mpi.rank()
50            peer = (rank + size // 2) % size
51            return peer, peer < size // 2
52
53        return get_peer

returns a numba-compiled callable.