PyMPDATA.boundary_conditions.polar

polar boundary condition for use in with spherical coordinates

 1"""polar boundary condition for use in with spherical coordinates"""
 2
 3from functools import lru_cache
 4
 5import numba
 6
 7from PyMPDATA.impl.enumerations import ARG_FOCUS, SIGN_LEFT, SIGN_RIGHT
 8from PyMPDATA.impl.traversals_common import (
 9    make_fill_halos_loop,
10    make_fill_halos_loop_vector,
11)
12
13
14class Polar:
15    """class which instances are to be passed in boundary_conditions tuple to the
16    `ScalarField` and `VectorField` __init__ methods"""
17
18    def __init__(self, grid, longitude_idx, latitude_idx):
19        assert SIGN_RIGHT == -1
20        assert SIGN_LEFT == +1
21
22        self.nlon = grid[longitude_idx]
23        self.nlat = grid[latitude_idx]
24        assert self.nlon % 2 == 0
25
26        self.nlon_half = self.nlon // 2
27        self.lon_idx = longitude_idx
28        self.lat_idx = latitude_idx
29
30    def make_scalar(self, indexers, halo, _, jit_flags, dimension_index):
31        """returns (lru-cached) Numba-compiled scalar halo-filling callable"""
32        nlon_half = self.nlon_half
33        nlat = self.nlat
34        lon_idx = self.lon_idx
35        lat_idx = self.lat_idx
36        left_edge_idx = halo - 1
37        right_edge_idx = nlat + halo
38        ats = indexers.ats[dimension_index]
39        set_value = indexers.set
40
41        @numba.njit(**jit_flags)
42        def fill_halos(psi, _, sign):
43            lon = psi[ARG_FOCUS][lon_idx]
44            lat = psi[ARG_FOCUS][lat_idx]
45            if lat <= left_edge_idx:
46                step = (left_edge_idx - lat) * 2 + 1
47            else:
48                step = (lat - right_edge_idx) * 2 + 1
49
50            val = nlon_half * (-1 if lon > nlon_half else 1)
51            return ats(*psi, sign * step, val)
52
53        return make_fill_halos_loop(jit_flags, set_value, fill_halos)
54
55    @staticmethod
56    def make_vector(indexers, _, __, jit_flags, dimension_index):
57        """returns (lru-cached) Numba-compiled vector halo-filling callable"""
58        return _make_vector_polar(
59            indexers.atv, indexers.set, jit_flags, dimension_index
60        )
61
62
63@lru_cache()
64def _make_vector_polar(_atv, set_value, jit_flags, dimension_index):
65    @numba.njit(**jit_flags)
66    def fill_halos_parallel(_1, _2, _3):
67        return 0  # TODO #120
68
69    @numba.njit(**jit_flags)
70    def fill_halos_normal(_1, _2, _3, _4):
71        return 0  # TODO #120
72
73    return make_fill_halos_loop_vector(
74        jit_flags, set_value, fill_halos_parallel, fill_halos_normal, dimension_index
75    )
class Polar:
15class Polar:
16    """class which instances are to be passed in boundary_conditions tuple to the
17    `ScalarField` and `VectorField` __init__ methods"""
18
19    def __init__(self, grid, longitude_idx, latitude_idx):
20        assert SIGN_RIGHT == -1
21        assert SIGN_LEFT == +1
22
23        self.nlon = grid[longitude_idx]
24        self.nlat = grid[latitude_idx]
25        assert self.nlon % 2 == 0
26
27        self.nlon_half = self.nlon // 2
28        self.lon_idx = longitude_idx
29        self.lat_idx = latitude_idx
30
31    def make_scalar(self, indexers, halo, _, jit_flags, dimension_index):
32        """returns (lru-cached) Numba-compiled scalar halo-filling callable"""
33        nlon_half = self.nlon_half
34        nlat = self.nlat
35        lon_idx = self.lon_idx
36        lat_idx = self.lat_idx
37        left_edge_idx = halo - 1
38        right_edge_idx = nlat + halo
39        ats = indexers.ats[dimension_index]
40        set_value = indexers.set
41
42        @numba.njit(**jit_flags)
43        def fill_halos(psi, _, sign):
44            lon = psi[ARG_FOCUS][lon_idx]
45            lat = psi[ARG_FOCUS][lat_idx]
46            if lat <= left_edge_idx:
47                step = (left_edge_idx - lat) * 2 + 1
48            else:
49                step = (lat - right_edge_idx) * 2 + 1
50
51            val = nlon_half * (-1 if lon > nlon_half else 1)
52            return ats(*psi, sign * step, val)
53
54        return make_fill_halos_loop(jit_flags, set_value, fill_halos)
55
56    @staticmethod
57    def make_vector(indexers, _, __, jit_flags, dimension_index):
58        """returns (lru-cached) Numba-compiled vector halo-filling callable"""
59        return _make_vector_polar(
60            indexers.atv, indexers.set, jit_flags, dimension_index
61        )

class which instances are to be passed in boundary_conditions tuple to the ScalarField and VectorField __init__ methods

Polar(grid, longitude_idx, latitude_idx)
19    def __init__(self, grid, longitude_idx, latitude_idx):
20        assert SIGN_RIGHT == -1
21        assert SIGN_LEFT == +1
22
23        self.nlon = grid[longitude_idx]
24        self.nlat = grid[latitude_idx]
25        assert self.nlon % 2 == 0
26
27        self.nlon_half = self.nlon // 2
28        self.lon_idx = longitude_idx
29        self.lat_idx = latitude_idx
nlon
nlat
nlon_half
lon_idx
lat_idx
def make_scalar(self, indexers, halo, _, jit_flags, dimension_index):
31    def make_scalar(self, indexers, halo, _, jit_flags, dimension_index):
32        """returns (lru-cached) Numba-compiled scalar halo-filling callable"""
33        nlon_half = self.nlon_half
34        nlat = self.nlat
35        lon_idx = self.lon_idx
36        lat_idx = self.lat_idx
37        left_edge_idx = halo - 1
38        right_edge_idx = nlat + halo
39        ats = indexers.ats[dimension_index]
40        set_value = indexers.set
41
42        @numba.njit(**jit_flags)
43        def fill_halos(psi, _, sign):
44            lon = psi[ARG_FOCUS][lon_idx]
45            lat = psi[ARG_FOCUS][lat_idx]
46            if lat <= left_edge_idx:
47                step = (left_edge_idx - lat) * 2 + 1
48            else:
49                step = (lat - right_edge_idx) * 2 + 1
50
51            val = nlon_half * (-1 if lon > nlon_half else 1)
52            return ats(*psi, sign * step, val)
53
54        return make_fill_halos_loop(jit_flags, set_value, fill_halos)

returns (lru-cached) Numba-compiled scalar halo-filling callable

@staticmethod
def make_vector(indexers, _, __, jit_flags, dimension_index):
56    @staticmethod
57    def make_vector(indexers, _, __, jit_flags, dimension_index):
58        """returns (lru-cached) Numba-compiled vector halo-filling callable"""
59        return _make_vector_polar(
60            indexers.atv, indexers.set, jit_flags, dimension_index
61        )

returns (lru-cached) Numba-compiled vector halo-filling callable