PySDM_examples.Bartman_2020_MasterThesis.fig_4_adaptive_sdm

 1import os
 2
 3from matplotlib import pyplot as plt
 4from PySDM_examples.Shima_et_al_2009.example import run
 5from PySDM_examples.Shima_et_al_2009.settings import Settings
 6from PySDM_examples.Shima_et_al_2009.spectrum_plotter import SpectrumPlotter
 7
 8
 9def main(plot: bool = True, save: str = None):
10    n_sds = [13, 15, 17] if "CI" not in os.environ else [13, 15]
11    dts = [10, 5, 1, "adaptive"]
12    iters = 10
13    base_time = None
14
15    plt.ioff()
16    fig, axs = plt.subplots(
17        len(dts), len(n_sds), sharex=True, sharey=True, figsize=(10, 10)
18    )
19
20    for i, dt in enumerate(dts):
21        for j, n_sd in enumerate(n_sds):
22            outputs = []
23            exec_time = 0
24            for _ in range(iters):
25                settings = Settings()
26
27                settings.n_sd = 2**n_sd
28                settings.dt = dt if dt != "adaptive" else 10
29                settings.adaptive = dt == "adaptive"
30
31                states, exec_time = run(settings)
32                outputs.append(states)
33            mean_time = exec_time / iters
34            if base_time is None:
35                base_time = mean_time
36            norm_time = mean_time / base_time
37            mean_output = {}
38            for key in outputs[0].keys():
39                mean_output[key] = sum((output[key] for output in outputs)) / len(
40                    outputs
41                )
42
43            plotter = SpectrumPlotter(settings, legend=False)
44            plotter.fig = fig
45            plotter.ax = axs[i, j]
46            plotter.smooth = True
47            for step, vals in mean_output.items():
48                plotter.plot(vals, step * settings.dt)
49
50            plotter.ylabel = (
51                r"$\bf{dt: " + str(dt) + "}$\ndm/dlnr [g/m^3/(unit dr/r)]"
52                if j == 0
53                else None
54            )
55            plotter.xlabel = (
56                "particle radius [µm]\n" + r"$\bf{n_{sd}: 2^{" + str(n_sd) + "}}$"
57                if i == len(dts) - 1
58                else None
59            )
60            plotter.title = f"norm. time: {norm_time:.2f}; " + plotter.title
61            plotter.finished = False
62            plotter.finish()
63    if save is not None:
64        n_sd = settings.n_sd
65        plotter.save(save + "/" + f"{n_sd}_shima_fig_2" + "." + plotter.format)
66    if plot:
67        plotter.show()
68
69
70if __name__ == "__main__":
71    main(plot="CI" not in os.environ, save=".")
def main(plot: bool = True, save: str = None):
10def main(plot: bool = True, save: str = None):
11    n_sds = [13, 15, 17] if "CI" not in os.environ else [13, 15]
12    dts = [10, 5, 1, "adaptive"]
13    iters = 10
14    base_time = None
15
16    plt.ioff()
17    fig, axs = plt.subplots(
18        len(dts), len(n_sds), sharex=True, sharey=True, figsize=(10, 10)
19    )
20
21    for i, dt in enumerate(dts):
22        for j, n_sd in enumerate(n_sds):
23            outputs = []
24            exec_time = 0
25            for _ in range(iters):
26                settings = Settings()
27
28                settings.n_sd = 2**n_sd
29                settings.dt = dt if dt != "adaptive" else 10
30                settings.adaptive = dt == "adaptive"
31
32                states, exec_time = run(settings)
33                outputs.append(states)
34            mean_time = exec_time / iters
35            if base_time is None:
36                base_time = mean_time
37            norm_time = mean_time / base_time
38            mean_output = {}
39            for key in outputs[0].keys():
40                mean_output[key] = sum((output[key] for output in outputs)) / len(
41                    outputs
42                )
43
44            plotter = SpectrumPlotter(settings, legend=False)
45            plotter.fig = fig
46            plotter.ax = axs[i, j]
47            plotter.smooth = True
48            for step, vals in mean_output.items():
49                plotter.plot(vals, step * settings.dt)
50
51            plotter.ylabel = (
52                r"$\bf{dt: " + str(dt) + "}$\ndm/dlnr [g/m^3/(unit dr/r)]"
53                if j == 0
54                else None
55            )
56            plotter.xlabel = (
57                "particle radius [µm]\n" + r"$\bf{n_{sd}: 2^{" + str(n_sd) + "}}$"
58                if i == len(dts) - 1
59                else None
60            )
61            plotter.title = f"norm. time: {norm_time:.2f}; " + plotter.title
62            plotter.finished = False
63            plotter.finish()
64    if save is not None:
65        n_sd = settings.n_sd
66        plotter.save(save + "/" + f"{n_sd}_shima_fig_2" + "." + plotter.format)
67    if plot:
68        plotter.show()