PyMPDATA_MPI.impl.boundary_condition_commons

boundary_condition common functions

  1# pylint: disable=too-many-positional-arguments,too-many-arguments
  2"""boundary_condition common functions"""
  3
  4from functools import lru_cache
  5
  6import numba
  7import numba_mpi as mpi
  8from mpi4py import MPI
  9from PyMPDATA.impl.enumerations import INVALID_INDEX, OUTER
 10
 11IRRELEVANT = 666
 12
 13
 14@lru_cache()
 15def make_scalar_boundary_condition(
 16    *, indexers, jit_flags, dimension_index, dtype, get_peer, mpi_dim
 17):
 18    """returns fill_halos() function for scalar boundary conditions.
 19    Provides default logic for scalar buffer filling. Notable arguments:
 20       :param get_peer: function for determining the direction of communication
 21       :type get_peer: function"""
 22
 23    @numba.njit(**jit_flags)
 24    def fill_buf(buf, psi, i_rng, k_rng, sign, _dim):
 25        for i in i_rng:
 26            for k in k_rng:
 27                buf[i - i_rng.start, k - k_rng.start] = indexers.ats[dimension_index](
 28                    (i, INVALID_INDEX, k), psi, sign
 29                )
 30
 31    send_recv = _make_send_recv(
 32        indexers.set, jit_flags, fill_buf, dtype, get_peer, mpi_dim
 33    )
 34
 35    @numba.njit(**jit_flags)
 36    def fill_halos(buffer, i_rng, j_rng, k_rng, psi, _, sign):
 37        send_recv(buffer, psi, i_rng, j_rng, k_rng, sign, IRRELEVANT, psi)
 38
 39    return fill_halos
 40
 41
 42@lru_cache()
 43def make_vector_boundary_condition(
 44    indexers, halo, jit_flags, dimension_index, dtype, get_peer, mpi_dim
 45):
 46    """returns fill_halos() function for vector boundary conditions.
 47    Provides default logic for vector buffer filling. Notable arguments:
 48       :param get_peer: function for determining the direction of communication
 49       :type get_peer: function"""
 50
 51    @numba.njit(**jit_flags)
 52    def fill_buf(buf, components, i_rng, k_rng, sign, dim):
 53        parallel = dim % len(components) == dimension_index
 54
 55        for i in i_rng:
 56            for k in k_rng:
 57                if parallel:
 58                    value = indexers.atv[dimension_index](
 59                        (i, INVALID_INDEX, k), components, sign * halo + 0.5
 60                    )
 61                else:
 62                    value = indexers.atv[dimension_index](
 63                        (i, INVALID_INDEX, k), components, sign * halo, 0.5
 64                    )
 65
 66                buf[i - i_rng.start, k - k_rng.start] = value
 67
 68    send_recv = _make_send_recv(
 69        indexers.set, jit_flags, fill_buf, dtype, get_peer, mpi_dim
 70    )
 71
 72    @numba.njit(**jit_flags)
 73    def fill_halos_loop_vector(buffer, i_rng, j_rng, k_rng, components, dim, _, sign):
 74        if i_rng.start == i_rng.stop or k_rng.start == k_rng.stop:
 75            return
 76        send_recv(buffer, components, i_rng, j_rng, k_rng, sign, dim, components[dim])
 77
 78    return fill_halos_loop_vector
 79
 80
 81def _make_send_recv(set_value, jit_flags, fill_buf, dtype, get_peer, mpi_dim):
 82
 83    assert MPI.Query_thread() == MPI.THREAD_MULTIPLE
 84
 85    @numba.njit(**jit_flags)
 86    def get_buffer_chunk(buffer, i_rng, k_rng, chunk_index):
 87        chunk_size = len(i_rng) * len(k_rng)
 88        if mpi_dim != OUTER:
 89            n_chunks = len(buffer) // (chunk_size * numba.get_num_threads())
 90            chunk_index += numba.get_thread_id() * n_chunks
 91        else:
 92            n_chunks = len(buffer) // (chunk_size * 2)
 93            chunk_index += int(numba.get_thread_id() != 0) * n_chunks
 94        return buffer.view(dtype)[
 95            chunk_index * chunk_size : (chunk_index + 1) * chunk_size
 96        ].reshape((len(i_rng), len(k_rng)))
 97
 98    @numba.njit(**jit_flags)
 99    def fill_output(output, buffer, i_rng, j_rng, k_rng):
100        for i in i_rng:
101            for j in j_rng:
102                for k in k_rng:
103                    set_value(
104                        output,
105                        i,
106                        j,
107                        k,
108                        buffer[i - i_rng.start, k - k_rng.start],
109                    )
110
111    @numba.njit(**jit_flags)
112    def _send(buf, peer, fill_buf_args):
113        tag = numba.get_thread_id()
114        fill_buf(buf, *fill_buf_args)
115        mpi.send(buf, dest=peer, tag=tag)
116
117    @numba.njit(**jit_flags)
118    def _recv(buf, peer):
119        th_id = numba.get_thread_id()
120        n_th = numba.get_num_threads()
121        tag = th_id if mpi_dim != OUTER else {0: n_th - 1, n_th - 1: 0}[th_id]
122        mpi.recv(buf, source=peer, tag=tag)
123
124    @numba.njit(**jit_flags)
125    def _send_recv(buffer, psi, i_rng, j_rng, k_rng, sign, dim, output):
126        buf = get_buffer_chunk(buffer, i_rng, k_rng, chunk_index=0)
127        peer, send_first = get_peer(sign)
128        fill_buf_args = (psi, i_rng, k_rng, sign, dim)
129
130        if send_first:
131            _send(buf=buf, peer=peer, fill_buf_args=fill_buf_args)
132            _recv(buf=buf, peer=peer)
133        else:
134            _recv(buf=buf, peer=peer)
135            tmp = get_buffer_chunk(buffer, i_rng, k_rng, chunk_index=1)
136            _send(buf=tmp, peer=peer, fill_buf_args=fill_buf_args)
137
138        fill_output(output, buf, i_rng, j_rng, k_rng)
139
140    return _send_recv
IRRELEVANT = 666
@lru_cache()
def make_scalar_boundary_condition(*, indexers, jit_flags, dimension_index, dtype, get_peer, mpi_dim):
15@lru_cache()
16def make_scalar_boundary_condition(
17    *, indexers, jit_flags, dimension_index, dtype, get_peer, mpi_dim
18):
19    """returns fill_halos() function for scalar boundary conditions.
20    Provides default logic for scalar buffer filling. Notable arguments:
21       :param get_peer: function for determining the direction of communication
22       :type get_peer: function"""
23
24    @numba.njit(**jit_flags)
25    def fill_buf(buf, psi, i_rng, k_rng, sign, _dim):
26        for i in i_rng:
27            for k in k_rng:
28                buf[i - i_rng.start, k - k_rng.start] = indexers.ats[dimension_index](
29                    (i, INVALID_INDEX, k), psi, sign
30                )
31
32    send_recv = _make_send_recv(
33        indexers.set, jit_flags, fill_buf, dtype, get_peer, mpi_dim
34    )
35
36    @numba.njit(**jit_flags)
37    def fill_halos(buffer, i_rng, j_rng, k_rng, psi, _, sign):
38        send_recv(buffer, psi, i_rng, j_rng, k_rng, sign, IRRELEVANT, psi)
39
40    return fill_halos

returns fill_halos() function for scalar boundary conditions. Provides default logic for scalar buffer filling. Notable arguments: :param get_peer: function for determining the direction of communication :type get_peer: function

@lru_cache()
def make_vector_boundary_condition(indexers, halo, jit_flags, dimension_index, dtype, get_peer, mpi_dim):
43@lru_cache()
44def make_vector_boundary_condition(
45    indexers, halo, jit_flags, dimension_index, dtype, get_peer, mpi_dim
46):
47    """returns fill_halos() function for vector boundary conditions.
48    Provides default logic for vector buffer filling. Notable arguments:
49       :param get_peer: function for determining the direction of communication
50       :type get_peer: function"""
51
52    @numba.njit(**jit_flags)
53    def fill_buf(buf, components, i_rng, k_rng, sign, dim):
54        parallel = dim % len(components) == dimension_index
55
56        for i in i_rng:
57            for k in k_rng:
58                if parallel:
59                    value = indexers.atv[dimension_index](
60                        (i, INVALID_INDEX, k), components, sign * halo + 0.5
61                    )
62                else:
63                    value = indexers.atv[dimension_index](
64                        (i, INVALID_INDEX, k), components, sign * halo, 0.5
65                    )
66
67                buf[i - i_rng.start, k - k_rng.start] = value
68
69    send_recv = _make_send_recv(
70        indexers.set, jit_flags, fill_buf, dtype, get_peer, mpi_dim
71    )
72
73    @numba.njit(**jit_flags)
74    def fill_halos_loop_vector(buffer, i_rng, j_rng, k_rng, components, dim, _, sign):
75        if i_rng.start == i_rng.stop or k_rng.start == k_rng.stop:
76            return
77        send_recv(buffer, components, i_rng, j_rng, k_rng, sign, dim, components[dim])
78
79    return fill_halos_loop_vector

returns fill_halos() function for vector boundary conditions. Provides default logic for vector buffer filling. Notable arguments: :param get_peer: function for determining the direction of communication :type get_peer: function