PyMPDATA_MPI.impl.boundary_condition_commons

boundary_condition common functions

  1# pylint: disable=too-many-arguments
  2"""boundary_condition common functions"""
  3
  4from functools import lru_cache
  5
  6import numba
  7import numba_mpi as mpi
  8from mpi4py import MPI
  9
 10from PyMPDATA.impl.enumerations import INVALID_INDEX, OUTER
 11
 12IRRELEVANT = 666
 13
 14
 15@lru_cache()
 16def make_scalar_boundary_condition(
 17    *, indexers, jit_flags, dimension_index, dtype, get_peer, mpi_dim
 18):
 19    """returns fill_halos() function for scalar boundary conditions.
 20    Provides default logic for scalar buffer filling. Notable arguments:
 21       :param get_peer: function for determining the direction of communication
 22       :type get_peer: function"""
 23
 24    @numba.njit(**jit_flags)
 25    def fill_buf(buf, psi, i_rng, k_rng, sign, _dim):
 26        for i in i_rng:
 27            for k in k_rng:
 28                buf[i - i_rng.start, k - k_rng.start] = indexers.ats[dimension_index](
 29                    (i, INVALID_INDEX, k), psi, sign
 30                )
 31
 32    send_recv = _make_send_recv(
 33        indexers.set, jit_flags, fill_buf, dtype, get_peer, mpi_dim
 34    )
 35
 36    @numba.njit(**jit_flags)
 37    def fill_halos(buffer, i_rng, j_rng, k_rng, psi, _, sign):
 38        send_recv(buffer, psi, i_rng, j_rng, k_rng, sign, IRRELEVANT, psi)
 39
 40    return fill_halos
 41
 42
 43@lru_cache()
 44def make_vector_boundary_condition(
 45    indexers, halo, jit_flags, dimension_index, dtype, get_peer, mpi_dim
 46):
 47    """returns fill_halos() function for vector boundary conditions.
 48    Provides default logic for vector buffer filling. Notable arguments:
 49       :param get_peer: function for determining the direction of communication
 50       :type get_peer: function"""
 51
 52    @numba.njit(**jit_flags)
 53    def fill_buf(buf, components, i_rng, k_rng, sign, dim):
 54        parallel = dim % len(components) == dimension_index
 55
 56        for i in i_rng:
 57            for k in k_rng:
 58                if parallel:
 59                    value = indexers.atv[dimension_index](
 60                        (i, INVALID_INDEX, k), components, sign * halo + 0.5
 61                    )
 62                else:
 63                    value = indexers.atv[dimension_index](
 64                        (i, INVALID_INDEX, k), components, sign * halo, 0.5
 65                    )
 66
 67                buf[i - i_rng.start, k - k_rng.start] = value
 68
 69    send_recv = _make_send_recv(
 70        indexers.set, jit_flags, fill_buf, dtype, get_peer, mpi_dim
 71    )
 72
 73    @numba.njit(**jit_flags)
 74    def fill_halos_loop_vector(buffer, i_rng, j_rng, k_rng, components, dim, _, sign):
 75        if i_rng.start == i_rng.stop or k_rng.start == k_rng.stop:
 76            return
 77        send_recv(buffer, components, i_rng, j_rng, k_rng, sign, dim, components[dim])
 78
 79    return fill_halos_loop_vector
 80
 81
 82def _make_send_recv(set_value, jit_flags, fill_buf, dtype, get_peer, mpi_dim):
 83
 84    assert MPI.Query_thread() == MPI.THREAD_MULTIPLE
 85
 86    @numba.njit(**jit_flags)
 87    def get_buffer_chunk(buffer, i_rng, k_rng, chunk_index):
 88        chunk_size = len(i_rng) * len(k_rng)
 89        if mpi_dim != OUTER:
 90            n_chunks = len(buffer) // (chunk_size * numba.get_num_threads())
 91            chunk_index += numba.get_thread_id() * n_chunks
 92        else:
 93            n_chunks = len(buffer) // (chunk_size * 2)
 94            chunk_index += int(numba.get_thread_id() != 0) * n_chunks
 95        return buffer.view(dtype)[
 96            chunk_index * chunk_size : (chunk_index + 1) * chunk_size
 97        ].reshape((len(i_rng), len(k_rng)))
 98
 99    @numba.njit(**jit_flags)
100    def fill_output(output, buffer, i_rng, j_rng, k_rng):
101        for i in i_rng:
102            for j in j_rng:
103                for k in k_rng:
104                    set_value(
105                        output,
106                        i,
107                        j,
108                        k,
109                        buffer[i - i_rng.start, k - k_rng.start],
110                    )
111
112    @numba.njit(**jit_flags)
113    def _send(buf, peer, fill_buf_args):
114        tag = numba.get_thread_id()
115        fill_buf(buf, *fill_buf_args)
116        mpi.send(buf, dest=peer, tag=tag)
117
118    @numba.njit(**jit_flags)
119    def _recv(buf, peer):
120        th_id = numba.get_thread_id()
121        n_th = numba.get_num_threads()
122        tag = th_id if mpi_dim != OUTER else {0: n_th - 1, n_th - 1: 0}[th_id]
123        mpi.recv(buf, source=peer, tag=tag)
124
125    @numba.njit(**jit_flags)
126    def _send_recv(buffer, psi, i_rng, j_rng, k_rng, sign, dim, output):
127        buf = get_buffer_chunk(buffer, i_rng, k_rng, chunk_index=0)
128        peer, send_first = get_peer(sign)
129        fill_buf_args = (psi, i_rng, k_rng, sign, dim)
130
131        if send_first:
132            _send(buf=buf, peer=peer, fill_buf_args=fill_buf_args)
133            _recv(buf=buf, peer=peer)
134        else:
135            _recv(buf=buf, peer=peer)
136            tmp = get_buffer_chunk(buffer, i_rng, k_rng, chunk_index=1)
137            _send(buf=tmp, peer=peer, fill_buf_args=fill_buf_args)
138
139        fill_output(output, buf, i_rng, j_rng, k_rng)
140
141    return _send_recv
IRRELEVANT = 666
@lru_cache()
def make_scalar_boundary_condition(*, indexers, jit_flags, dimension_index, dtype, get_peer, mpi_dim):
16@lru_cache()
17def make_scalar_boundary_condition(
18    *, indexers, jit_flags, dimension_index, dtype, get_peer, mpi_dim
19):
20    """returns fill_halos() function for scalar boundary conditions.
21    Provides default logic for scalar buffer filling. Notable arguments:
22       :param get_peer: function for determining the direction of communication
23       :type get_peer: function"""
24
25    @numba.njit(**jit_flags)
26    def fill_buf(buf, psi, i_rng, k_rng, sign, _dim):
27        for i in i_rng:
28            for k in k_rng:
29                buf[i - i_rng.start, k - k_rng.start] = indexers.ats[dimension_index](
30                    (i, INVALID_INDEX, k), psi, sign
31                )
32
33    send_recv = _make_send_recv(
34        indexers.set, jit_flags, fill_buf, dtype, get_peer, mpi_dim
35    )
36
37    @numba.njit(**jit_flags)
38    def fill_halos(buffer, i_rng, j_rng, k_rng, psi, _, sign):
39        send_recv(buffer, psi, i_rng, j_rng, k_rng, sign, IRRELEVANT, psi)
40
41    return fill_halos

returns fill_halos() function for scalar boundary conditions. Provides default logic for scalar buffer filling. Notable arguments: :param get_peer: function for determining the direction of communication :type get_peer: function

@lru_cache()
def make_vector_boundary_condition(indexers, halo, jit_flags, dimension_index, dtype, get_peer, mpi_dim):
44@lru_cache()
45def make_vector_boundary_condition(
46    indexers, halo, jit_flags, dimension_index, dtype, get_peer, mpi_dim
47):
48    """returns fill_halos() function for vector boundary conditions.
49    Provides default logic for vector buffer filling. Notable arguments:
50       :param get_peer: function for determining the direction of communication
51       :type get_peer: function"""
52
53    @numba.njit(**jit_flags)
54    def fill_buf(buf, components, i_rng, k_rng, sign, dim):
55        parallel = dim % len(components) == dimension_index
56
57        for i in i_rng:
58            for k in k_rng:
59                if parallel:
60                    value = indexers.atv[dimension_index](
61                        (i, INVALID_INDEX, k), components, sign * halo + 0.5
62                    )
63                else:
64                    value = indexers.atv[dimension_index](
65                        (i, INVALID_INDEX, k), components, sign * halo, 0.5
66                    )
67
68                buf[i - i_rng.start, k - k_rng.start] = value
69
70    send_recv = _make_send_recv(
71        indexers.set, jit_flags, fill_buf, dtype, get_peer, mpi_dim
72    )
73
74    @numba.njit(**jit_flags)
75    def fill_halos_loop_vector(buffer, i_rng, j_rng, k_rng, components, dim, _, sign):
76        if i_rng.start == i_rng.stop or k_rng.start == k_rng.stop:
77            return
78        send_recv(buffer, components, i_rng, j_rng, k_rng, sign, dim, components[dim])
79
80    return fill_halos_loop_vector

returns fill_halos() function for vector boundary conditions. Provides default logic for vector buffer filling. Notable arguments: :param get_peer: function for determining the direction of communication :type get_peer: function