PyMPDATA_examples.Jarecka_et_al_2015.simulation

  1import numba
  2import numpy as np
  3from PyMPDATA_examples.Jarecka_et_al_2015 import formulae
  4
  5from PyMPDATA import ScalarField, Solver, Stepper, VectorField
  6from PyMPDATA.boundary_conditions import Constant
  7from PyMPDATA.impl.enumerations import INNER, OUTER
  8from PyMPDATA.impl.formulae_divide import make_divide_or_zero
  9
 10
 11def make_hooks(*, traversals, options, grid_step, time_step):
 12
 13    divide_or_zero = make_divide_or_zero(options, traversals)
 14    interpolate = formulae.make_interpolate(options, traversals)
 15    rhs_x = formulae.make_rhs(grid_step, time_step, OUTER, options, traversals)
 16    rhs_y = formulae.make_rhs(grid_step, time_step, INNER, options, traversals)
 17
 18    @numba.experimental.jitclass([])
 19    class AnteStep:  # pylint:disable=too-few-public-methods
 20        def __init__(self):
 21            pass
 22
 23        def call(
 24            self,
 25            traversals_data,
 26            advectees,
 27            advector,
 28            _,
 29            index,
 30            todo_outer,
 31            todo_mid3d,
 32            todo_inner,
 33        ):
 34            if index == 0:
 35                divide_or_zero(
 36                    *todo_outer.field,
 37                    *todo_mid3d.field,
 38                    *todo_inner.field,
 39                    *advectees[1].field,
 40                    *todo_mid3d.field,
 41                    *advectees[2].field,
 42                    *advectees[0].field,
 43                    time_step,
 44                    grid_step
 45                )
 46                interpolate(traversals_data, todo_outer, todo_inner, advector)
 47            elif index == 1:
 48                rhs_x(traversals_data, advectees[index], advectees[0])
 49            else:
 50                rhs_y(traversals_data, advectees[index], advectees[0])
 51
 52    @numba.experimental.jitclass([])
 53    class PostStep:  # pylint:disable=too-few-public-methods
 54        def __init__(self):
 55            pass
 56
 57        def call(self, traversals_data, advectees, _, index):
 58            if index == 0:
 59                pass
 60            if index == 1:
 61                rhs_x(traversals_data, advectees[index], advectees[0])
 62            else:
 63                rhs_y(traversals_data, advectees[index], advectees[0])
 64
 65    return AnteStep(), PostStep()
 66
 67
 68class Simulation:
 69    # pylint: disable=too-few-public-methods
 70    def __init__(self, settings):
 71        self.settings = settings
 72        s = settings
 73
 74        halo = settings.options.n_halo
 75        grid = (s.nx, s.ny)
 76        bcs = [Constant(value=0)] * len(grid)
 77
 78        self.advector = VectorField(
 79            (np.zeros((s.nx + 1, s.ny)), np.zeros((s.nx, s.ny + 1))), halo, bcs
 80        )
 81
 82        xi, yi = np.indices(grid, dtype=float)
 83        xi -= (s.nx - 1) / 2
 84        yi -= (s.ny - 1) / 2
 85        x = xi * s.dx
 86        y = yi * s.dy
 87        h0 = formulae.amplitude(x, y, s.lx0, s.ly0)
 88
 89        self.advectees = {
 90            "h": ScalarField(h0, halo, bcs),
 91            "uh": ScalarField(np.zeros(grid), halo, bcs),
 92            "vh": ScalarField(np.zeros(grid), halo, bcs),
 93        }
 94
 95        stepper = Stepper(options=s.options, grid=grid)
 96
 97        self.ante_step, self.post_step = make_hooks(
 98            traversals=stepper.traversals,
 99            options=settings.options,
100            grid_step=(s.dx, None, s.dy),
101            time_step=s.dt,
102        )
103
104        self.solver = Solver(stepper, self.advectees, self.advector)
105
106    def run(self):
107        s = self.settings
108        output = []
109        for it in range(s.nt + 1):
110            if it != 0:
111                self.solver.advance(
112                    1, ante_step=self.ante_step, post_step=self.post_step
113                )
114            if it % s.outfreq == 0:
115                output.append(
116                    {
117                        k: self.solver.advectee[k].get().copy()
118                        for k in self.advectees.keys()  # pylint:disable=consider-iterating-dictionary
119                    }
120                )
121        return output
def make_hooks(*, traversals, options, grid_step, time_step):
12def make_hooks(*, traversals, options, grid_step, time_step):
13
14    divide_or_zero = make_divide_or_zero(options, traversals)
15    interpolate = formulae.make_interpolate(options, traversals)
16    rhs_x = formulae.make_rhs(grid_step, time_step, OUTER, options, traversals)
17    rhs_y = formulae.make_rhs(grid_step, time_step, INNER, options, traversals)
18
19    @numba.experimental.jitclass([])
20    class AnteStep:  # pylint:disable=too-few-public-methods
21        def __init__(self):
22            pass
23
24        def call(
25            self,
26            traversals_data,
27            advectees,
28            advector,
29            _,
30            index,
31            todo_outer,
32            todo_mid3d,
33            todo_inner,
34        ):
35            if index == 0:
36                divide_or_zero(
37                    *todo_outer.field,
38                    *todo_mid3d.field,
39                    *todo_inner.field,
40                    *advectees[1].field,
41                    *todo_mid3d.field,
42                    *advectees[2].field,
43                    *advectees[0].field,
44                    time_step,
45                    grid_step
46                )
47                interpolate(traversals_data, todo_outer, todo_inner, advector)
48            elif index == 1:
49                rhs_x(traversals_data, advectees[index], advectees[0])
50            else:
51                rhs_y(traversals_data, advectees[index], advectees[0])
52
53    @numba.experimental.jitclass([])
54    class PostStep:  # pylint:disable=too-few-public-methods
55        def __init__(self):
56            pass
57
58        def call(self, traversals_data, advectees, _, index):
59            if index == 0:
60                pass
61            if index == 1:
62                rhs_x(traversals_data, advectees[index], advectees[0])
63            else:
64                rhs_y(traversals_data, advectees[index], advectees[0])
65
66    return AnteStep(), PostStep()
class Simulation:
 69class Simulation:
 70    # pylint: disable=too-few-public-methods
 71    def __init__(self, settings):
 72        self.settings = settings
 73        s = settings
 74
 75        halo = settings.options.n_halo
 76        grid = (s.nx, s.ny)
 77        bcs = [Constant(value=0)] * len(grid)
 78
 79        self.advector = VectorField(
 80            (np.zeros((s.nx + 1, s.ny)), np.zeros((s.nx, s.ny + 1))), halo, bcs
 81        )
 82
 83        xi, yi = np.indices(grid, dtype=float)
 84        xi -= (s.nx - 1) / 2
 85        yi -= (s.ny - 1) / 2
 86        x = xi * s.dx
 87        y = yi * s.dy
 88        h0 = formulae.amplitude(x, y, s.lx0, s.ly0)
 89
 90        self.advectees = {
 91            "h": ScalarField(h0, halo, bcs),
 92            "uh": ScalarField(np.zeros(grid), halo, bcs),
 93            "vh": ScalarField(np.zeros(grid), halo, bcs),
 94        }
 95
 96        stepper = Stepper(options=s.options, grid=grid)
 97
 98        self.ante_step, self.post_step = make_hooks(
 99            traversals=stepper.traversals,
100            options=settings.options,
101            grid_step=(s.dx, None, s.dy),
102            time_step=s.dt,
103        )
104
105        self.solver = Solver(stepper, self.advectees, self.advector)
106
107    def run(self):
108        s = self.settings
109        output = []
110        for it in range(s.nt + 1):
111            if it != 0:
112                self.solver.advance(
113                    1, ante_step=self.ante_step, post_step=self.post_step
114                )
115            if it % s.outfreq == 0:
116                output.append(
117                    {
118                        k: self.solver.advectee[k].get().copy()
119                        for k in self.advectees.keys()  # pylint:disable=consider-iterating-dictionary
120                    }
121                )
122        return output
Simulation(settings)
 71    def __init__(self, settings):
 72        self.settings = settings
 73        s = settings
 74
 75        halo = settings.options.n_halo
 76        grid = (s.nx, s.ny)
 77        bcs = [Constant(value=0)] * len(grid)
 78
 79        self.advector = VectorField(
 80            (np.zeros((s.nx + 1, s.ny)), np.zeros((s.nx, s.ny + 1))), halo, bcs
 81        )
 82
 83        xi, yi = np.indices(grid, dtype=float)
 84        xi -= (s.nx - 1) / 2
 85        yi -= (s.ny - 1) / 2
 86        x = xi * s.dx
 87        y = yi * s.dy
 88        h0 = formulae.amplitude(x, y, s.lx0, s.ly0)
 89
 90        self.advectees = {
 91            "h": ScalarField(h0, halo, bcs),
 92            "uh": ScalarField(np.zeros(grid), halo, bcs),
 93            "vh": ScalarField(np.zeros(grid), halo, bcs),
 94        }
 95
 96        stepper = Stepper(options=s.options, grid=grid)
 97
 98        self.ante_step, self.post_step = make_hooks(
 99            traversals=stepper.traversals,
100            options=settings.options,
101            grid_step=(s.dx, None, s.dy),
102            time_step=s.dt,
103        )
104
105        self.solver = Solver(stepper, self.advectees, self.advector)
settings
advector
advectees
solver
def run(self):
107    def run(self):
108        s = self.settings
109        output = []
110        for it in range(s.nt + 1):
111            if it != 0:
112                self.solver.advance(
113                    1, ante_step=self.ante_step, post_step=self.post_step
114                )
115            if it % s.outfreq == 0:
116                output.append(
117                    {
118                        k: self.solver.advectee[k].get().copy()
119                        for k in self.advectees.keys()  # pylint:disable=consider-iterating-dictionary
120                    }
121                )
122        return output