PyMPDATA_MPI.impl.boundary_condition_commons

boundary_condition common functions

  1# pylint: disable=too-many-positional-arguments,too-many-arguments
  2""" boundary_condition common functions """
  3
  4from functools import lru_cache
  5
  6import numba
  7import numba_mpi as mpi
  8from PyMPDATA.impl.enumerations import INVALID_INDEX, OUTER
  9
 10IRRELEVANT = 666
 11
 12
 13@lru_cache()
 14def make_scalar_boundary_condition(
 15    *, indexers, jit_flags, dimension_index, dtype, get_peer, mpi_dim
 16):
 17    """returns fill_halos() function for scalar boundary conditions.
 18    Provides default logic for scalar buffer filling. Notable arguments:
 19       :param get_peer: function for determining the direction of communication
 20       :type get_peer: function"""
 21
 22    @numba.njit(**jit_flags)
 23    def fill_buf(buf, psi, i_rng, k_rng, sign, _dim):
 24        for i in i_rng:
 25            for k in k_rng:
 26                buf[i - i_rng.start, k - k_rng.start] = indexers.ats[dimension_index](
 27                    (i, INVALID_INDEX, k), psi, sign
 28                )
 29
 30    send_recv = _make_send_recv(
 31        indexers.set, jit_flags, fill_buf, dtype, get_peer, mpi_dim
 32    )
 33
 34    @numba.njit(**jit_flags)
 35    def fill_halos(buffer, i_rng, j_rng, k_rng, psi, _, sign):
 36        send_recv(buffer, psi, i_rng, j_rng, k_rng, sign, IRRELEVANT, psi)
 37
 38    return fill_halos
 39
 40
 41@lru_cache()
 42def make_vector_boundary_condition(
 43    indexers, halo, jit_flags, dimension_index, dtype, get_peer, mpi_dim
 44):
 45    """returns fill_halos() function for vector boundary conditions.
 46    Provides default logic for vector buffer filling. Notable arguments:
 47       :param get_peer: function for determining the direction of communication
 48       :type get_peer: function"""
 49
 50    @numba.njit(**jit_flags)
 51    def fill_buf(buf, components, i_rng, k_rng, sign, dim):
 52        parallel = dim % len(components) == dimension_index
 53
 54        for i in i_rng:
 55            for k in k_rng:
 56                if parallel:
 57                    value = indexers.atv[dimension_index](
 58                        (i, INVALID_INDEX, k), components, sign * halo + 0.5
 59                    )
 60                else:
 61                    value = indexers.atv[dimension_index](
 62                        (i, INVALID_INDEX, k), components, sign * halo, 0.5
 63                    )
 64
 65                buf[i - i_rng.start, k - k_rng.start] = value
 66
 67    send_recv = _make_send_recv(
 68        indexers.set, jit_flags, fill_buf, dtype, get_peer, mpi_dim
 69    )
 70
 71    @numba.njit(**jit_flags)
 72    def fill_halos_loop_vector(buffer, i_rng, j_rng, k_rng, components, dim, _, sign):
 73        if i_rng.start == i_rng.stop or k_rng.start == k_rng.stop:
 74            return
 75        send_recv(buffer, components, i_rng, j_rng, k_rng, sign, dim, components[dim])
 76
 77    return fill_halos_loop_vector
 78
 79
 80def _make_send_recv(set_value, jit_flags, fill_buf, dtype, get_peer, mpi_dim):
 81
 82    @numba.njit(**jit_flags)
 83    def get_buffer_chunk(buffer, i_rng, k_rng, chunk_index):
 84        chunk_size = len(i_rng) * len(k_rng)
 85        if mpi_dim != OUTER:
 86            n_chunks = len(buffer) // (chunk_size * numba.get_num_threads())
 87            chunk_index += numba.get_thread_id() * n_chunks
 88        else:
 89            n_chunks = len(buffer) // (chunk_size * 2)
 90            chunk_index += int(numba.get_thread_id() != 0) * n_chunks
 91        return buffer.view(dtype)[
 92            chunk_index * chunk_size : (chunk_index + 1) * chunk_size
 93        ].reshape((len(i_rng), len(k_rng)))
 94
 95    @numba.njit(**jit_flags)
 96    def fill_output(output, buffer, i_rng, j_rng, k_rng):
 97        for i in i_rng:
 98            for j in j_rng:
 99                for k in k_rng:
100                    set_value(
101                        output,
102                        i,
103                        j,
104                        k,
105                        buffer[i - i_rng.start, k - k_rng.start],
106                    )
107
108    @numba.njit(**jit_flags)
109    def _send(buf, peer, fill_buf_args):
110        tag = numba.get_thread_id()
111        fill_buf(buf, *fill_buf_args)
112        mpi.send(buf, dest=peer, tag=tag)
113
114    @numba.njit(**jit_flags)
115    def _recv(buf, peer):
116        th_id = numba.get_thread_id()
117        n_th = numba.get_num_threads()
118        tag = th_id if mpi_dim != OUTER else {0: n_th - 1, n_th - 1: 0}[th_id]
119        mpi.recv(buf, source=peer, tag=tag)
120
121    @numba.njit(**jit_flags)
122    def _send_recv(buffer, psi, i_rng, j_rng, k_rng, sign, dim, output):
123        buf = get_buffer_chunk(buffer, i_rng, k_rng, chunk_index=0)
124        peer, send_first = get_peer(sign)
125        fill_buf_args = (psi, i_rng, k_rng, sign, dim)
126
127        if send_first:
128            _send(buf=buf, peer=peer, fill_buf_args=fill_buf_args)
129            _recv(buf=buf, peer=peer)
130        else:
131            _recv(buf=buf, peer=peer)
132            tmp = get_buffer_chunk(buffer, i_rng, k_rng, chunk_index=1)
133            _send(buf=tmp, peer=peer, fill_buf_args=fill_buf_args)
134
135        fill_output(output, buf, i_rng, j_rng, k_rng)
136
137    return _send_recv
IRRELEVANT = 666
@lru_cache()
def make_scalar_boundary_condition(*, indexers, jit_flags, dimension_index, dtype, get_peer, mpi_dim):
14@lru_cache()
15def make_scalar_boundary_condition(
16    *, indexers, jit_flags, dimension_index, dtype, get_peer, mpi_dim
17):
18    """returns fill_halos() function for scalar boundary conditions.
19    Provides default logic for scalar buffer filling. Notable arguments:
20       :param get_peer: function for determining the direction of communication
21       :type get_peer: function"""
22
23    @numba.njit(**jit_flags)
24    def fill_buf(buf, psi, i_rng, k_rng, sign, _dim):
25        for i in i_rng:
26            for k in k_rng:
27                buf[i - i_rng.start, k - k_rng.start] = indexers.ats[dimension_index](
28                    (i, INVALID_INDEX, k), psi, sign
29                )
30
31    send_recv = _make_send_recv(
32        indexers.set, jit_flags, fill_buf, dtype, get_peer, mpi_dim
33    )
34
35    @numba.njit(**jit_flags)
36    def fill_halos(buffer, i_rng, j_rng, k_rng, psi, _, sign):
37        send_recv(buffer, psi, i_rng, j_rng, k_rng, sign, IRRELEVANT, psi)
38
39    return fill_halos

returns fill_halos() function for scalar boundary conditions. Provides default logic for scalar buffer filling. Notable arguments: :param get_peer: function for determining the direction of communication :type get_peer: function

@lru_cache()
def make_vector_boundary_condition(indexers, halo, jit_flags, dimension_index, dtype, get_peer, mpi_dim):
42@lru_cache()
43def make_vector_boundary_condition(
44    indexers, halo, jit_flags, dimension_index, dtype, get_peer, mpi_dim
45):
46    """returns fill_halos() function for vector boundary conditions.
47    Provides default logic for vector buffer filling. Notable arguments:
48       :param get_peer: function for determining the direction of communication
49       :type get_peer: function"""
50
51    @numba.njit(**jit_flags)
52    def fill_buf(buf, components, i_rng, k_rng, sign, dim):
53        parallel = dim % len(components) == dimension_index
54
55        for i in i_rng:
56            for k in k_rng:
57                if parallel:
58                    value = indexers.atv[dimension_index](
59                        (i, INVALID_INDEX, k), components, sign * halo + 0.5
60                    )
61                else:
62                    value = indexers.atv[dimension_index](
63                        (i, INVALID_INDEX, k), components, sign * halo, 0.5
64                    )
65
66                buf[i - i_rng.start, k - k_rng.start] = value
67
68    send_recv = _make_send_recv(
69        indexers.set, jit_flags, fill_buf, dtype, get_peer, mpi_dim
70    )
71
72    @numba.njit(**jit_flags)
73    def fill_halos_loop_vector(buffer, i_rng, j_rng, k_rng, components, dim, _, sign):
74        if i_rng.start == i_rng.stop or k_rng.start == k_rng.stop:
75            return
76        send_recv(buffer, components, i_rng, j_rng, k_rng, sign, dim, components[dim])
77
78    return fill_halos_loop_vector

returns fill_halos() function for vector boundary conditions. Provides default logic for vector buffer filling. Notable arguments: :param get_peer: function for determining the direction of communication :type get_peer: function