PyMPDATA_MPI.impl.boundary_condition_commons
boundary_condition common functions
1# pylint: disable=too-many-positional-arguments,too-many-arguments 2"""boundary_condition common functions""" 3 4from functools import lru_cache 5 6import numba 7import numba_mpi as mpi 8from mpi4py import MPI 9from PyMPDATA.impl.enumerations import INVALID_INDEX, OUTER 10 11IRRELEVANT = 666 12 13 14@lru_cache() 15def make_scalar_boundary_condition( 16 *, indexers, jit_flags, dimension_index, dtype, get_peer, mpi_dim 17): 18 """returns fill_halos() function for scalar boundary conditions. 19 Provides default logic for scalar buffer filling. Notable arguments: 20 :param get_peer: function for determining the direction of communication 21 :type get_peer: function""" 22 23 @numba.njit(**jit_flags) 24 def fill_buf(buf, psi, i_rng, k_rng, sign, _dim): 25 for i in i_rng: 26 for k in k_rng: 27 buf[i - i_rng.start, k - k_rng.start] = indexers.ats[dimension_index]( 28 (i, INVALID_INDEX, k), psi, sign 29 ) 30 31 send_recv = _make_send_recv( 32 indexers.set, jit_flags, fill_buf, dtype, get_peer, mpi_dim 33 ) 34 35 @numba.njit(**jit_flags) 36 def fill_halos(buffer, i_rng, j_rng, k_rng, psi, _, sign): 37 send_recv(buffer, psi, i_rng, j_rng, k_rng, sign, IRRELEVANT, psi) 38 39 return fill_halos 40 41 42@lru_cache() 43def make_vector_boundary_condition( 44 indexers, halo, jit_flags, dimension_index, dtype, get_peer, mpi_dim 45): 46 """returns fill_halos() function for vector boundary conditions. 47 Provides default logic for vector buffer filling. Notable arguments: 48 :param get_peer: function for determining the direction of communication 49 :type get_peer: function""" 50 51 @numba.njit(**jit_flags) 52 def fill_buf(buf, components, i_rng, k_rng, sign, dim): 53 parallel = dim % len(components) == dimension_index 54 55 for i in i_rng: 56 for k in k_rng: 57 if parallel: 58 value = indexers.atv[dimension_index]( 59 (i, INVALID_INDEX, k), components, sign * halo + 0.5 60 ) 61 else: 62 value = indexers.atv[dimension_index]( 63 (i, INVALID_INDEX, k), components, sign * halo, 0.5 64 ) 65 66 buf[i - i_rng.start, k - k_rng.start] = value 67 68 send_recv = _make_send_recv( 69 indexers.set, jit_flags, fill_buf, dtype, get_peer, mpi_dim 70 ) 71 72 @numba.njit(**jit_flags) 73 def fill_halos_loop_vector(buffer, i_rng, j_rng, k_rng, components, dim, _, sign): 74 if i_rng.start == i_rng.stop or k_rng.start == k_rng.stop: 75 return 76 send_recv(buffer, components, i_rng, j_rng, k_rng, sign, dim, components[dim]) 77 78 return fill_halos_loop_vector 79 80 81def _make_send_recv(set_value, jit_flags, fill_buf, dtype, get_peer, mpi_dim): 82 83 assert MPI.Query_thread() == MPI.THREAD_MULTIPLE 84 85 @numba.njit(**jit_flags) 86 def get_buffer_chunk(buffer, i_rng, k_rng, chunk_index): 87 chunk_size = len(i_rng) * len(k_rng) 88 if mpi_dim != OUTER: 89 n_chunks = len(buffer) // (chunk_size * numba.get_num_threads()) 90 chunk_index += numba.get_thread_id() * n_chunks 91 else: 92 n_chunks = len(buffer) // (chunk_size * 2) 93 chunk_index += int(numba.get_thread_id() != 0) * n_chunks 94 return buffer.view(dtype)[ 95 chunk_index * chunk_size : (chunk_index + 1) * chunk_size 96 ].reshape((len(i_rng), len(k_rng))) 97 98 @numba.njit(**jit_flags) 99 def fill_output(output, buffer, i_rng, j_rng, k_rng): 100 for i in i_rng: 101 for j in j_rng: 102 for k in k_rng: 103 set_value( 104 output, 105 i, 106 j, 107 k, 108 buffer[i - i_rng.start, k - k_rng.start], 109 ) 110 111 @numba.njit(**jit_flags) 112 def _send(buf, peer, fill_buf_args): 113 tag = numba.get_thread_id() 114 fill_buf(buf, *fill_buf_args) 115 mpi.send(buf, dest=peer, tag=tag) 116 117 @numba.njit(**jit_flags) 118 def _recv(buf, peer): 119 th_id = numba.get_thread_id() 120 n_th = numba.get_num_threads() 121 tag = th_id if mpi_dim != OUTER else {0: n_th - 1, n_th - 1: 0}[th_id] 122 mpi.recv(buf, source=peer, tag=tag) 123 124 @numba.njit(**jit_flags) 125 def _send_recv(buffer, psi, i_rng, j_rng, k_rng, sign, dim, output): 126 buf = get_buffer_chunk(buffer, i_rng, k_rng, chunk_index=0) 127 peer, send_first = get_peer(sign) 128 fill_buf_args = (psi, i_rng, k_rng, sign, dim) 129 130 if send_first: 131 _send(buf=buf, peer=peer, fill_buf_args=fill_buf_args) 132 _recv(buf=buf, peer=peer) 133 else: 134 _recv(buf=buf, peer=peer) 135 tmp = get_buffer_chunk(buffer, i_rng, k_rng, chunk_index=1) 136 _send(buf=tmp, peer=peer, fill_buf_args=fill_buf_args) 137 138 fill_output(output, buf, i_rng, j_rng, k_rng) 139 140 return _send_recv
IRRELEVANT =
666
@lru_cache()
def
make_scalar_boundary_condition(*, indexers, jit_flags, dimension_index, dtype, get_peer, mpi_dim):
15@lru_cache() 16def make_scalar_boundary_condition( 17 *, indexers, jit_flags, dimension_index, dtype, get_peer, mpi_dim 18): 19 """returns fill_halos() function for scalar boundary conditions. 20 Provides default logic for scalar buffer filling. Notable arguments: 21 :param get_peer: function for determining the direction of communication 22 :type get_peer: function""" 23 24 @numba.njit(**jit_flags) 25 def fill_buf(buf, psi, i_rng, k_rng, sign, _dim): 26 for i in i_rng: 27 for k in k_rng: 28 buf[i - i_rng.start, k - k_rng.start] = indexers.ats[dimension_index]( 29 (i, INVALID_INDEX, k), psi, sign 30 ) 31 32 send_recv = _make_send_recv( 33 indexers.set, jit_flags, fill_buf, dtype, get_peer, mpi_dim 34 ) 35 36 @numba.njit(**jit_flags) 37 def fill_halos(buffer, i_rng, j_rng, k_rng, psi, _, sign): 38 send_recv(buffer, psi, i_rng, j_rng, k_rng, sign, IRRELEVANT, psi) 39 40 return fill_halos
returns fill_halos() function for scalar boundary conditions. Provides default logic for scalar buffer filling. Notable arguments: :param get_peer: function for determining the direction of communication :type get_peer: function
@lru_cache()
def
make_vector_boundary_condition(indexers, halo, jit_flags, dimension_index, dtype, get_peer, mpi_dim):
43@lru_cache() 44def make_vector_boundary_condition( 45 indexers, halo, jit_flags, dimension_index, dtype, get_peer, mpi_dim 46): 47 """returns fill_halos() function for vector boundary conditions. 48 Provides default logic for vector buffer filling. Notable arguments: 49 :param get_peer: function for determining the direction of communication 50 :type get_peer: function""" 51 52 @numba.njit(**jit_flags) 53 def fill_buf(buf, components, i_rng, k_rng, sign, dim): 54 parallel = dim % len(components) == dimension_index 55 56 for i in i_rng: 57 for k in k_rng: 58 if parallel: 59 value = indexers.atv[dimension_index]( 60 (i, INVALID_INDEX, k), components, sign * halo + 0.5 61 ) 62 else: 63 value = indexers.atv[dimension_index]( 64 (i, INVALID_INDEX, k), components, sign * halo, 0.5 65 ) 66 67 buf[i - i_rng.start, k - k_rng.start] = value 68 69 send_recv = _make_send_recv( 70 indexers.set, jit_flags, fill_buf, dtype, get_peer, mpi_dim 71 ) 72 73 @numba.njit(**jit_flags) 74 def fill_halos_loop_vector(buffer, i_rng, j_rng, k_rng, components, dim, _, sign): 75 if i_rng.start == i_rng.stop or k_rng.start == k_rng.stop: 76 return 77 send_recv(buffer, components, i_rng, j_rng, k_rng, sign, dim, components[dim]) 78 79 return fill_halos_loop_vector
returns fill_halos() function for vector boundary conditions. Provides default logic for vector buffer filling. Notable arguments: :param get_peer: function for determining the direction of communication :type get_peer: function