PyMPDATA_MPI.impl.boundary_condition_commons
boundary_condition common functions
1# pylint: disable=too-many-positional-arguments,too-many-arguments 2""" boundary_condition common functions """ 3 4from functools import lru_cache 5 6import numba 7import numba_mpi as mpi 8from PyMPDATA.impl.enumerations import INVALID_INDEX, OUTER 9 10IRRELEVANT = 666 11 12 13@lru_cache() 14def make_scalar_boundary_condition( 15 *, indexers, jit_flags, dimension_index, dtype, get_peer, mpi_dim 16): 17 """returns fill_halos() function for scalar boundary conditions. 18 Provides default logic for scalar buffer filling. Notable arguments: 19 :param get_peer: function for determining the direction of communication 20 :type get_peer: function""" 21 22 @numba.njit(**jit_flags) 23 def fill_buf(buf, psi, i_rng, k_rng, sign, _dim): 24 for i in i_rng: 25 for k in k_rng: 26 buf[i - i_rng.start, k - k_rng.start] = indexers.ats[dimension_index]( 27 (i, INVALID_INDEX, k), psi, sign 28 ) 29 30 send_recv = _make_send_recv( 31 indexers.set, jit_flags, fill_buf, dtype, get_peer, mpi_dim 32 ) 33 34 @numba.njit(**jit_flags) 35 def fill_halos(buffer, i_rng, j_rng, k_rng, psi, _, sign): 36 send_recv(buffer, psi, i_rng, j_rng, k_rng, sign, IRRELEVANT, psi) 37 38 return fill_halos 39 40 41@lru_cache() 42def make_vector_boundary_condition( 43 indexers, halo, jit_flags, dimension_index, dtype, get_peer, mpi_dim 44): 45 """returns fill_halos() function for vector boundary conditions. 46 Provides default logic for vector buffer filling. Notable arguments: 47 :param get_peer: function for determining the direction of communication 48 :type get_peer: function""" 49 50 @numba.njit(**jit_flags) 51 def fill_buf(buf, components, i_rng, k_rng, sign, dim): 52 parallel = dim % len(components) == dimension_index 53 54 for i in i_rng: 55 for k in k_rng: 56 if parallel: 57 value = indexers.atv[dimension_index]( 58 (i, INVALID_INDEX, k), components, sign * halo + 0.5 59 ) 60 else: 61 value = indexers.atv[dimension_index]( 62 (i, INVALID_INDEX, k), components, sign * halo, 0.5 63 ) 64 65 buf[i - i_rng.start, k - k_rng.start] = value 66 67 send_recv = _make_send_recv( 68 indexers.set, jit_flags, fill_buf, dtype, get_peer, mpi_dim 69 ) 70 71 @numba.njit(**jit_flags) 72 def fill_halos_loop_vector(buffer, i_rng, j_rng, k_rng, components, dim, _, sign): 73 if i_rng.start == i_rng.stop or k_rng.start == k_rng.stop: 74 return 75 send_recv(buffer, components, i_rng, j_rng, k_rng, sign, dim, components[dim]) 76 77 return fill_halos_loop_vector 78 79 80def _make_send_recv(set_value, jit_flags, fill_buf, dtype, get_peer, mpi_dim): 81 82 @numba.njit(**jit_flags) 83 def get_buffer_chunk(buffer, i_rng, k_rng, chunk_index): 84 chunk_size = len(i_rng) * len(k_rng) 85 if mpi_dim != OUTER: 86 n_chunks = len(buffer) // (chunk_size * numba.get_num_threads()) 87 chunk_index += numba.get_thread_id() * n_chunks 88 else: 89 n_chunks = len(buffer) // (chunk_size * 2) 90 chunk_index += int(numba.get_thread_id() != 0) * n_chunks 91 return buffer.view(dtype)[ 92 chunk_index * chunk_size : (chunk_index + 1) * chunk_size 93 ].reshape((len(i_rng), len(k_rng))) 94 95 @numba.njit(**jit_flags) 96 def fill_output(output, buffer, i_rng, j_rng, k_rng): 97 for i in i_rng: 98 for j in j_rng: 99 for k in k_rng: 100 set_value( 101 output, 102 i, 103 j, 104 k, 105 buffer[i - i_rng.start, k - k_rng.start], 106 ) 107 108 @numba.njit(**jit_flags) 109 def _send(buf, peer, fill_buf_args): 110 tag = numba.get_thread_id() 111 fill_buf(buf, *fill_buf_args) 112 mpi.send(buf, dest=peer, tag=tag) 113 114 @numba.njit(**jit_flags) 115 def _recv(buf, peer): 116 th_id = numba.get_thread_id() 117 n_th = numba.get_num_threads() 118 tag = th_id if mpi_dim != OUTER else {0: n_th - 1, n_th - 1: 0}[th_id] 119 mpi.recv(buf, source=peer, tag=tag) 120 121 @numba.njit(**jit_flags) 122 def _send_recv(buffer, psi, i_rng, j_rng, k_rng, sign, dim, output): 123 buf = get_buffer_chunk(buffer, i_rng, k_rng, chunk_index=0) 124 peer, send_first = get_peer(sign) 125 fill_buf_args = (psi, i_rng, k_rng, sign, dim) 126 127 if send_first: 128 _send(buf=buf, peer=peer, fill_buf_args=fill_buf_args) 129 _recv(buf=buf, peer=peer) 130 else: 131 _recv(buf=buf, peer=peer) 132 tmp = get_buffer_chunk(buffer, i_rng, k_rng, chunk_index=1) 133 _send(buf=tmp, peer=peer, fill_buf_args=fill_buf_args) 134 135 fill_output(output, buf, i_rng, j_rng, k_rng) 136 137 return _send_recv
IRRELEVANT =
666
@lru_cache()
def
make_scalar_boundary_condition(*, indexers, jit_flags, dimension_index, dtype, get_peer, mpi_dim):
14@lru_cache() 15def make_scalar_boundary_condition( 16 *, indexers, jit_flags, dimension_index, dtype, get_peer, mpi_dim 17): 18 """returns fill_halos() function for scalar boundary conditions. 19 Provides default logic for scalar buffer filling. Notable arguments: 20 :param get_peer: function for determining the direction of communication 21 :type get_peer: function""" 22 23 @numba.njit(**jit_flags) 24 def fill_buf(buf, psi, i_rng, k_rng, sign, _dim): 25 for i in i_rng: 26 for k in k_rng: 27 buf[i - i_rng.start, k - k_rng.start] = indexers.ats[dimension_index]( 28 (i, INVALID_INDEX, k), psi, sign 29 ) 30 31 send_recv = _make_send_recv( 32 indexers.set, jit_flags, fill_buf, dtype, get_peer, mpi_dim 33 ) 34 35 @numba.njit(**jit_flags) 36 def fill_halos(buffer, i_rng, j_rng, k_rng, psi, _, sign): 37 send_recv(buffer, psi, i_rng, j_rng, k_rng, sign, IRRELEVANT, psi) 38 39 return fill_halos
returns fill_halos() function for scalar boundary conditions. Provides default logic for scalar buffer filling. Notable arguments: :param get_peer: function for determining the direction of communication :type get_peer: function
@lru_cache()
def
make_vector_boundary_condition(indexers, halo, jit_flags, dimension_index, dtype, get_peer, mpi_dim):
42@lru_cache() 43def make_vector_boundary_condition( 44 indexers, halo, jit_flags, dimension_index, dtype, get_peer, mpi_dim 45): 46 """returns fill_halos() function for vector boundary conditions. 47 Provides default logic for vector buffer filling. Notable arguments: 48 :param get_peer: function for determining the direction of communication 49 :type get_peer: function""" 50 51 @numba.njit(**jit_flags) 52 def fill_buf(buf, components, i_rng, k_rng, sign, dim): 53 parallel = dim % len(components) == dimension_index 54 55 for i in i_rng: 56 for k in k_rng: 57 if parallel: 58 value = indexers.atv[dimension_index]( 59 (i, INVALID_INDEX, k), components, sign * halo + 0.5 60 ) 61 else: 62 value = indexers.atv[dimension_index]( 63 (i, INVALID_INDEX, k), components, sign * halo, 0.5 64 ) 65 66 buf[i - i_rng.start, k - k_rng.start] = value 67 68 send_recv = _make_send_recv( 69 indexers.set, jit_flags, fill_buf, dtype, get_peer, mpi_dim 70 ) 71 72 @numba.njit(**jit_flags) 73 def fill_halos_loop_vector(buffer, i_rng, j_rng, k_rng, components, dim, _, sign): 74 if i_rng.start == i_rng.stop or k_rng.start == k_rng.stop: 75 return 76 send_recv(buffer, components, i_rng, j_rng, k_rng, sign, dim, components[dim]) 77 78 return fill_halos_loop_vector
returns fill_halos() function for vector boundary conditions. Provides default logic for vector buffer filling. Notable arguments: :param get_peer: function for determining the direction of communication :type get_peer: function