PyMPDATA_MPI.domain_decomposition

MPI-aware domain decomposition utilities

 1"""MPI-aware domain decomposition utilities"""
 2
 3import numpy as np
 4
 5from PyMPDATA.impl.domain_decomposition import make_subdomain
 6
 7subdomain = make_subdomain(jit_flags={})
 8
 9
10def mpi_indices(*, grid, rank, size, mpi_dim):
11    """returns a mapping from rank-local indices to domain-wide indices,
12    (subdomain-aware equivalent of np.indices)"""
13    start, stop = subdomain(grid[mpi_dim], rank, size)
14    indices_arg = list(grid)
15    indices_arg[mpi_dim] = stop - start
16    xyi = np.indices(tuple(indices_arg), dtype=float)
17    xyi[mpi_dim] += start
18    return xyi
@numba.njit(**jit_flags)
def subdomain(span, rank, size):
13    @numba.njit(**jit_flags)
14    def subdomain(span, rank, size):
15        if rank >= size:
16            raise ValueError("rank >= size")
17
18        n_max = math.ceil(span / size)
19        start = n_max * rank
20        stop = start + (n_max if start + n_max <= span else span - start)
21        return start, stop
def mpi_indices(*, grid, rank, size, mpi_dim):
11def mpi_indices(*, grid, rank, size, mpi_dim):
12    """returns a mapping from rank-local indices to domain-wide indices,
13    (subdomain-aware equivalent of np.indices)"""
14    start, stop = subdomain(grid[mpi_dim], rank, size)
15    indices_arg = list(grid)
16    indices_arg[mpi_dim] = stop - start
17    xyi = np.indices(tuple(indices_arg), dtype=float)
18    xyi[mpi_dim] += start
19    return xyi

returns a mapping from rank-local indices to domain-wide indices, (subdomain-aware equivalent of np.indices)