PyMPDATA_MPI.domain_decomposition

MPI-aware domain decomposition utilities

 1""" MPI-aware domain decomposition utilities """
 2
 3import numpy as np
 4from PyMPDATA.impl.domain_decomposition import make_subdomain
 5
 6subdomain = make_subdomain(jit_flags={})
 7
 8
 9def mpi_indices(*, grid, rank, size, mpi_dim):
10    """returns a mapping from rank-local indices to domain-wide indices,
11    (subdomain-aware equivalent of np.indices)"""
12    start, stop = subdomain(grid[mpi_dim], rank, size)
13    indices_arg = list(grid)
14    indices_arg[mpi_dim] = stop - start
15    xyi = np.indices(tuple(indices_arg), dtype=float)
16    xyi[mpi_dim] += start
17    return xyi
@numba.njit(**jit_flags)
def subdomain(span, rank, size):
12    @numba.njit(**jit_flags)
13    def subdomain(span, rank, size):
14        if rank >= size:
15            raise ValueError("rank >= size")
16
17        n_max = math.ceil(span / size)
18        start = n_max * rank
19        stop = start + (n_max if start + n_max <= span else span - start)
20        return start, stop
def mpi_indices(*, grid, rank, size, mpi_dim):
10def mpi_indices(*, grid, rank, size, mpi_dim):
11    """returns a mapping from rank-local indices to domain-wide indices,
12    (subdomain-aware equivalent of np.indices)"""
13    start, stop = subdomain(grid[mpi_dim], rank, size)
14    indices_arg = list(grid)
15    indices_arg[mpi_dim] = stop - start
16    xyi = np.indices(tuple(indices_arg), dtype=float)
17    xyi[mpi_dim] += start
18    return xyi

returns a mapping from rank-local indices to domain-wide indices, (subdomain-aware equivalent of np.indices)