PySDM_examples.Shima_et_al_2009.example_timing

 1import os
 2
 3from matplotlib import pyplot as plt
 4from PySDM_examples.Shima_et_al_2009.settings import Settings
 5
 6from PySDM.backends import Numba, ThrustRTC
 7from PySDM.builder import Builder
 8from PySDM.dynamics import Coalescence
 9from PySDM.environments import Box
10from PySDM.initialisation.sampling.spectral_sampling import ConstantMultiplicity
11from PySDM.products import WallTime
12
13
14def run(settings, backend):
15    env = Box(dv=settings.dv, dt=settings.dt)
16    builder = Builder(
17        n_sd=settings.n_sd,
18        backend=backend,
19        environment=env,
20        dynamics=(Coalescence(collision_kernel=settings.kernel),),
21    )
22    attributes = {}
23    sampling = ConstantMultiplicity(settings.spectrum)
24    attributes["volume"], attributes["multiplicity"] = sampling.sample_deterministic(
25        settings.n_sd
26    )
27    particles = builder.build(attributes, products=(WallTime(),))
28
29    states = {}
30    last_wall_time = None
31    for step in settings.output_steps:
32        particles.run(step - particles.n_steps)
33        last_wall_time = particles.products["wall time"].get()
34
35    return states, last_wall_time
36
37
38def main(plot: bool):
39    settings = Settings()
40    settings.steps = [100, 3600] if "CI" not in os.environ else [1, 2]
41
42    times = {}
43    for backend in (ThrustRTC, Numba):
44        nsds = [2**n for n in range(12, 19, 3)]
45        key = backend.__name__
46        times[key] = []
47        for sd in nsds:
48            settings.n_sd = sd
49            _, wall_time = run(settings, backend())
50            times[key].append(wall_time)
51
52    for backend, t in times.items():
53        plt.plot(nsds, t, label=backend, linestyle="--", marker="o")
54    plt.ylabel("wall time [s]")
55    plt.xlabel("number of particles")
56    plt.grid()
57    plt.legend()
58    plt.loglog(base=2)
59    if plot:
60        plt.show()
61
62
63if __name__ == "__main__":
64    main(plot="CI" not in os.environ)
def run(settings, backend):
15def run(settings, backend):
16    env = Box(dv=settings.dv, dt=settings.dt)
17    builder = Builder(
18        n_sd=settings.n_sd,
19        backend=backend,
20        environment=env,
21        dynamics=(Coalescence(collision_kernel=settings.kernel),),
22    )
23    attributes = {}
24    sampling = ConstantMultiplicity(settings.spectrum)
25    attributes["volume"], attributes["multiplicity"] = sampling.sample_deterministic(
26        settings.n_sd
27    )
28    particles = builder.build(attributes, products=(WallTime(),))
29
30    states = {}
31    last_wall_time = None
32    for step in settings.output_steps:
33        particles.run(step - particles.n_steps)
34        last_wall_time = particles.products["wall time"].get()
35
36    return states, last_wall_time
def main(plot: bool):
39def main(plot: bool):
40    settings = Settings()
41    settings.steps = [100, 3600] if "CI" not in os.environ else [1, 2]
42
43    times = {}
44    for backend in (ThrustRTC, Numba):
45        nsds = [2**n for n in range(12, 19, 3)]
46        key = backend.__name__
47        times[key] = []
48        for sd in nsds:
49            settings.n_sd = sd
50            _, wall_time = run(settings, backend())
51            times[key].append(wall_time)
52
53    for backend, t in times.items():
54        plt.plot(nsds, t, label=backend, linestyle="--", marker="o")
55    plt.ylabel("wall time [s]")
56    plt.xlabel("number of particles")
57    plt.grid()
58    plt.legend()
59    plt.loglog(base=2)
60    if plot:
61        plt.show()