PyMPDATA_MPI.mpi_periodic

periodic/cyclic boundary condition logic

 1""" periodic/cyclic boundary condition logic """
 2
 3from functools import lru_cache
 4
 5import numba
 6import numba_mpi as mpi
 7from PyMPDATA.boundary_conditions import Periodic
 8from PyMPDATA.impl.enumerations import SIGN_LEFT, SIGN_RIGHT
 9
10from PyMPDATA_MPI.impl import MPIBoundaryCondition
11from PyMPDATA_MPI.impl.boundary_condition_commons import make_vector_boundary_condition
12
13
14class MPIPeriodic(MPIBoundaryCondition):
15    """class which instances are to be passed in boundary_conditions tuple to the
16    `PyMPDATA.scalar_field.ScalarField` and
17    `PyMPDATA.vector_field.VectorField` __init__ methods"""
18
19    def __init__(self, size, mpi_dim):
20        # passing size insead of using mpi.size() because lack of support for non-default
21        # MPI communicators. https://github.com/numba-mpi/numba-mpi/issues/64
22        assert SIGN_RIGHT == -1
23        assert SIGN_LEFT == +1
24
25        super().__init__(size=size, base=Periodic, mpi_dim=mpi_dim)
26
27    # pylint: disable=too-many-positional-arguments,too-many-arguments
28    def make_vector(self, indexers, halo, dtype, jit_flags, dimension_index):
29        """returns (lru-cached) Numba-compiled vector halo-filling callable"""
30        if self.worker_pool_size == 1:
31            return Periodic.make_vector(
32                indexers, halo, dtype, jit_flags, dimension_index
33            )
34        return make_vector_boundary_condition(
35            indexers,
36            halo,
37            jit_flags,
38            dimension_index,
39            dtype,
40            self.make_get_peer(jit_flags, self.worker_pool_size),
41            self.mpi_dim,
42        )
43
44    @staticmethod
45    @lru_cache
46    def make_get_peer(jit_flags, size):
47        """returns (lru-cached) numba-compiled callable."""
48
49        @numba.njit(**jit_flags)
50        def get_peers(sign):
51            rank = mpi.rank()
52            left_peer = (rank - 1) % size
53            right_peer = (rank + 1) % size
54            peers = (-1, left_peer, right_peer)
55            return peers[sign], SIGN_LEFT == sign
56
57        return get_peers
15class MPIPeriodic(MPIBoundaryCondition):
16    """class which instances are to be passed in boundary_conditions tuple to the
17    `PyMPDATA.scalar_field.ScalarField` and
18    `PyMPDATA.vector_field.VectorField` __init__ methods"""
19
20    def __init__(self, size, mpi_dim):
21        # passing size insead of using mpi.size() because lack of support for non-default
22        # MPI communicators. https://github.com/numba-mpi/numba-mpi/issues/64
23        assert SIGN_RIGHT == -1
24        assert SIGN_LEFT == +1
25
26        super().__init__(size=size, base=Periodic, mpi_dim=mpi_dim)
27
28    # pylint: disable=too-many-positional-arguments,too-many-arguments
29    def make_vector(self, indexers, halo, dtype, jit_flags, dimension_index):
30        """returns (lru-cached) Numba-compiled vector halo-filling callable"""
31        if self.worker_pool_size == 1:
32            return Periodic.make_vector(
33                indexers, halo, dtype, jit_flags, dimension_index
34            )
35        return make_vector_boundary_condition(
36            indexers,
37            halo,
38            jit_flags,
39            dimension_index,
40            dtype,
41            self.make_get_peer(jit_flags, self.worker_pool_size),
42            self.mpi_dim,
43        )
44
45    @staticmethod
46    @lru_cache
47    def make_get_peer(jit_flags, size):
48        """returns (lru-cached) numba-compiled callable."""
49
50        @numba.njit(**jit_flags)
51        def get_peers(sign):
52            rank = mpi.rank()
53            left_peer = (rank - 1) % size
54            right_peer = (rank + 1) % size
55            peers = (-1, left_peer, right_peer)
56            return peers[sign], SIGN_LEFT == sign
57
58        return get_peers

class which instances are to be passed in boundary_conditions tuple to the PyMPDATA.scalar_field.ScalarField and PyMPDATA.vector_field.VectorField __init__ methods

MPIPeriodic(size, mpi_dim)
20    def __init__(self, size, mpi_dim):
21        # passing size insead of using mpi.size() because lack of support for non-default
22        # MPI communicators. https://github.com/numba-mpi/numba-mpi/issues/64
23        assert SIGN_RIGHT == -1
24        assert SIGN_LEFT == +1
25
26        super().__init__(size=size, base=Periodic, mpi_dim=mpi_dim)
def make_vector(self, indexers, halo, dtype, jit_flags, dimension_index):
29    def make_vector(self, indexers, halo, dtype, jit_flags, dimension_index):
30        """returns (lru-cached) Numba-compiled vector halo-filling callable"""
31        if self.worker_pool_size == 1:
32            return Periodic.make_vector(
33                indexers, halo, dtype, jit_flags, dimension_index
34            )
35        return make_vector_boundary_condition(
36            indexers,
37            halo,
38            jit_flags,
39            dimension_index,
40            dtype,
41            self.make_get_peer(jit_flags, self.worker_pool_size),
42            self.mpi_dim,
43        )

returns (lru-cached) Numba-compiled vector halo-filling callable

@staticmethod
@lru_cache
def make_get_peer(jit_flags, size):
45    @staticmethod
46    @lru_cache
47    def make_get_peer(jit_flags, size):
48        """returns (lru-cached) numba-compiled callable."""
49
50        @numba.njit(**jit_flags)
51        def get_peers(sign):
52            rank = mpi.rank()
53            left_peer = (rank - 1) % size
54            right_peer = (rank + 1) % size
55            peers = (-1, left_peer, right_peer)
56            return peers[sign], SIGN_LEFT == sign
57
58        return get_peers

returns (lru-cached) numba-compiled callable.