PyMPDATA_MPI.impl.boundary_condition_commons
boundary_condition common functions
1# pylint: disable=too-many-arguments 2"""boundary_condition common functions""" 3 4from functools import lru_cache 5 6import numba 7import numba_mpi as mpi 8from mpi4py import MPI 9 10from PyMPDATA.impl.enumerations import INVALID_INDEX, OUTER 11 12IRRELEVANT = 666 13 14 15@lru_cache() 16def make_scalar_boundary_condition( 17 *, indexers, jit_flags, dimension_index, dtype, get_peer, mpi_dim 18): 19 """returns fill_halos() function for scalar boundary conditions. 20 Provides default logic for scalar buffer filling. Notable arguments: 21 :param get_peer: function for determining the direction of communication 22 :type get_peer: function""" 23 24 @numba.njit(**jit_flags) 25 def fill_buf(buf, psi, i_rng, k_rng, sign, _dim): 26 for i in i_rng: 27 for k in k_rng: 28 buf[i - i_rng.start, k - k_rng.start] = indexers.ats[dimension_index]( 29 (i, INVALID_INDEX, k), psi, sign 30 ) 31 32 send_recv = _make_send_recv( 33 indexers.set, jit_flags, fill_buf, dtype, get_peer, mpi_dim 34 ) 35 36 @numba.njit(**jit_flags) 37 def fill_halos(buffer, i_rng, j_rng, k_rng, psi, _, sign): 38 send_recv(buffer, psi, i_rng, j_rng, k_rng, sign, IRRELEVANT, psi) 39 40 return fill_halos 41 42 43@lru_cache() 44def make_vector_boundary_condition( 45 indexers, halo, jit_flags, dimension_index, dtype, get_peer, mpi_dim 46): 47 """returns fill_halos() function for vector boundary conditions. 48 Provides default logic for vector buffer filling. Notable arguments: 49 :param get_peer: function for determining the direction of communication 50 :type get_peer: function""" 51 52 @numba.njit(**jit_flags) 53 def fill_buf(buf, components, i_rng, k_rng, sign, dim): 54 parallel = dim % len(components) == dimension_index 55 56 for i in i_rng: 57 for k in k_rng: 58 if parallel: 59 value = indexers.atv[dimension_index]( 60 (i, INVALID_INDEX, k), components, sign * halo + 0.5 61 ) 62 else: 63 value = indexers.atv[dimension_index]( 64 (i, INVALID_INDEX, k), components, sign * halo, 0.5 65 ) 66 67 buf[i - i_rng.start, k - k_rng.start] = value 68 69 send_recv = _make_send_recv( 70 indexers.set, jit_flags, fill_buf, dtype, get_peer, mpi_dim 71 ) 72 73 @numba.njit(**jit_flags) 74 def fill_halos_loop_vector(buffer, i_rng, j_rng, k_rng, components, dim, _, sign): 75 if i_rng.start == i_rng.stop or k_rng.start == k_rng.stop: 76 return 77 send_recv(buffer, components, i_rng, j_rng, k_rng, sign, dim, components[dim]) 78 79 return fill_halos_loop_vector 80 81 82def _make_send_recv(set_value, jit_flags, fill_buf, dtype, get_peer, mpi_dim): 83 84 assert MPI.Query_thread() == MPI.THREAD_MULTIPLE 85 86 @numba.njit(**jit_flags) 87 def get_buffer_chunk(buffer, i_rng, k_rng, chunk_index): 88 chunk_size = len(i_rng) * len(k_rng) 89 if mpi_dim != OUTER: 90 n_chunks = len(buffer) // (chunk_size * numba.get_num_threads()) 91 chunk_index += numba.get_thread_id() * n_chunks 92 else: 93 n_chunks = len(buffer) // (chunk_size * 2) 94 chunk_index += int(numba.get_thread_id() != 0) * n_chunks 95 return buffer.view(dtype)[ 96 chunk_index * chunk_size : (chunk_index + 1) * chunk_size 97 ].reshape((len(i_rng), len(k_rng))) 98 99 @numba.njit(**jit_flags) 100 def fill_output(output, buffer, i_rng, j_rng, k_rng): 101 for i in i_rng: 102 for j in j_rng: 103 for k in k_rng: 104 set_value( 105 output, 106 i, 107 j, 108 k, 109 buffer[i - i_rng.start, k - k_rng.start], 110 ) 111 112 @numba.njit(**jit_flags) 113 def _send(buf, peer, fill_buf_args): 114 tag = numba.get_thread_id() 115 fill_buf(buf, *fill_buf_args) 116 mpi.send(buf, dest=peer, tag=tag) 117 118 @numba.njit(**jit_flags) 119 def _recv(buf, peer): 120 th_id = numba.get_thread_id() 121 n_th = numba.get_num_threads() 122 tag = th_id if mpi_dim != OUTER else {0: n_th - 1, n_th - 1: 0}[th_id] 123 mpi.recv(buf, source=peer, tag=tag) 124 125 @numba.njit(**jit_flags) 126 def _send_recv(buffer, psi, i_rng, j_rng, k_rng, sign, dim, output): 127 buf = get_buffer_chunk(buffer, i_rng, k_rng, chunk_index=0) 128 peer, send_first = get_peer(sign) 129 fill_buf_args = (psi, i_rng, k_rng, sign, dim) 130 131 if send_first: 132 _send(buf=buf, peer=peer, fill_buf_args=fill_buf_args) 133 _recv(buf=buf, peer=peer) 134 else: 135 _recv(buf=buf, peer=peer) 136 tmp = get_buffer_chunk(buffer, i_rng, k_rng, chunk_index=1) 137 _send(buf=tmp, peer=peer, fill_buf_args=fill_buf_args) 138 139 fill_output(output, buf, i_rng, j_rng, k_rng) 140 141 return _send_recv
IRRELEVANT =
666
@lru_cache()
def
make_scalar_boundary_condition(*, indexers, jit_flags, dimension_index, dtype, get_peer, mpi_dim):
16@lru_cache() 17def make_scalar_boundary_condition( 18 *, indexers, jit_flags, dimension_index, dtype, get_peer, mpi_dim 19): 20 """returns fill_halos() function for scalar boundary conditions. 21 Provides default logic for scalar buffer filling. Notable arguments: 22 :param get_peer: function for determining the direction of communication 23 :type get_peer: function""" 24 25 @numba.njit(**jit_flags) 26 def fill_buf(buf, psi, i_rng, k_rng, sign, _dim): 27 for i in i_rng: 28 for k in k_rng: 29 buf[i - i_rng.start, k - k_rng.start] = indexers.ats[dimension_index]( 30 (i, INVALID_INDEX, k), psi, sign 31 ) 32 33 send_recv = _make_send_recv( 34 indexers.set, jit_flags, fill_buf, dtype, get_peer, mpi_dim 35 ) 36 37 @numba.njit(**jit_flags) 38 def fill_halos(buffer, i_rng, j_rng, k_rng, psi, _, sign): 39 send_recv(buffer, psi, i_rng, j_rng, k_rng, sign, IRRELEVANT, psi) 40 41 return fill_halos
returns fill_halos() function for scalar boundary conditions. Provides default logic for scalar buffer filling. Notable arguments: :param get_peer: function for determining the direction of communication :type get_peer: function
@lru_cache()
def
make_vector_boundary_condition(indexers, halo, jit_flags, dimension_index, dtype, get_peer, mpi_dim):
44@lru_cache() 45def make_vector_boundary_condition( 46 indexers, halo, jit_flags, dimension_index, dtype, get_peer, mpi_dim 47): 48 """returns fill_halos() function for vector boundary conditions. 49 Provides default logic for vector buffer filling. Notable arguments: 50 :param get_peer: function for determining the direction of communication 51 :type get_peer: function""" 52 53 @numba.njit(**jit_flags) 54 def fill_buf(buf, components, i_rng, k_rng, sign, dim): 55 parallel = dim % len(components) == dimension_index 56 57 for i in i_rng: 58 for k in k_rng: 59 if parallel: 60 value = indexers.atv[dimension_index]( 61 (i, INVALID_INDEX, k), components, sign * halo + 0.5 62 ) 63 else: 64 value = indexers.atv[dimension_index]( 65 (i, INVALID_INDEX, k), components, sign * halo, 0.5 66 ) 67 68 buf[i - i_rng.start, k - k_rng.start] = value 69 70 send_recv = _make_send_recv( 71 indexers.set, jit_flags, fill_buf, dtype, get_peer, mpi_dim 72 ) 73 74 @numba.njit(**jit_flags) 75 def fill_halos_loop_vector(buffer, i_rng, j_rng, k_rng, components, dim, _, sign): 76 if i_rng.start == i_rng.stop or k_rng.start == k_rng.stop: 77 return 78 send_recv(buffer, components, i_rng, j_rng, k_rng, sign, dim, components[dim]) 79 80 return fill_halos_loop_vector
returns fill_halos() function for vector boundary conditions. Provides default logic for vector buffer filling. Notable arguments: :param get_peer: function for determining the direction of communication :type get_peer: function