Files

351 lines
14 KiB
Python

import argparse
import json
import os
import sys
import math
import numpy as np
import matplotlib.pyplot as plt
from scipy.interpolate import interp1d
from scipy.integrate import cumulative_trapezoid
from astropy import constants as const
from astropy import units as u
from enum import Enum
from typing import List, Dict, Any, Tuple
from fourdst.atomic import species as spdict
EXTERNAL_STYLE_PATH = "../ManuscriptFigures/utils/pub.mplstyle"
class PlotVariable(Enum):
COMPOSITION = "composition"
EPS = "eps"
DT = "dt"
class OutputFormat(Enum):
INTERACTIVE = "interactive"
PDF = "pdf"
PNG = "png"
JPEG = "jpeg"
def discover_runs(base_dir: str) -> List[str]:
runs = set()
gf_ok_dir = os.path.join(base_dir, "GridFire", "Ok")
if os.path.exists(gf_ok_dir):
for fname in os.listdir(gf_ok_dir):
if fname.endswith("_OKAY.json"):
runs.add(fname.replace("_OKAY.json", ""))
return sorted(list(runs))
def load_run_data(base_dir: str, run_name: str) -> Tuple[Dict[str, Any], Dict[str, Any]]:
gf_data = {}
pynuc_data = {}
gf_ok_path = os.path.join(base_dir, "GridFire", "Ok", f"{run_name}_OKAY.json")
gf_err_path = os.path.join(base_dir, "GridFire", "Err", f"{run_name}_FAIL.json")
pynuc_path = os.path.join(base_dir, "pynucastro", f"{run_name}_pynucastro.json")
if os.path.exists(gf_ok_path):
with open(gf_ok_path, 'r') as f:
gf_data = json.load(f)
elif os.path.exists(gf_err_path):
with open(gf_err_path, 'r') as f:
gf_data = json.load(f)
if os.path.exists(pynuc_path):
with open(pynuc_path, 'r') as f:
pynuc_data = json.load(f)
return gf_data, pynuc_data
def get_pynuc_eps(steps: List[Dict[str, Any]]) -> Tuple[np.ndarray, np.ndarray]:
c_sq = (const.c.cgs.value)**2
Na = const.N_A.cgs.value
amu_to_g = const.u.cgs.value
eps_history = []
time_history = []
last_Y = {}
last_t = None
for step in steps:
t = step["t"]
current_Y = step["Composition"]
if last_t is None:
last_t = t
last_Y = current_Y.copy()
eps_history.append(0.0)
time_history.append(t)
continue
dt = t - last_t
if dt > 0:
dm_dt = 0.0
all_species = set(current_Y.keys()) | set(last_Y.keys())
for sp in all_species:
y_curr = current_Y.get(sp, 0.0)
y_prev = last_Y.get(sp, 0.0)
dy = y_curr - y_prev
if sp in spdict:
mass_g = spdict[sp].mass() * amu_to_g
dm_dt += mass_g * (dy / dt)
rate = -dm_dt * Na * c_sq
eps_history.append(rate)
time_history.append(t)
last_t = t
last_Y = current_Y.copy()
return np.array(time_history), np.array(eps_history)
def _setup_axes(ax_main: plt.Axes, ax_diff: plt.Axes, var: PlotVariable, fig_opts: dict):
ax_diff.set_xlabel("Time (s)")
ax_main.set_xscale(fig_opts['x_scale'])
ax_diff.set_xscale(fig_opts['x_scale'])
ax_main.set_yscale(fig_opts['y_scale'])
if var == PlotVariable.EPS:
ax_main.set_ylabel("Cumulative Energy (eps)")
elif var == PlotVariable.DT:
ax_main.set_ylabel("Timestep Size (dt)")
elif var == PlotVariable.COMPOSITION:
ax_main.set_ylabel("Mass Fraction (X_i)")
ax_diff.set_ylabel(r"$\Delta \log_{10}$")
def _plot_single_run(ax_main: plt.Axes, ax_diff: plt.Axes, run_name: str, var: PlotVariable,
base_dir: str, compare_pynuc: bool):
gf_data, pynuc_data = load_run_data(base_dir, run_name)
if not gf_data or gf_data.get("Metadata", {}).get("Status") == "Error":
return
gf_steps = gf_data.get("Steps", [])
if not gf_steps:
return
if compare_pynuc and not pynuc_data:
print(f"Warning: PyNucastro comparison requested but data not found for '{run_name}'.")
t_gf = np.array([s["t"] for s in gf_steps])
if var == PlotVariable.COMPOSITION:
final_comp = gf_steps[-1]["Composition"]
top_species = sorted(final_comp, key=final_comp.get, reverse=True)[:3]
for spec in top_species:
y_gf = np.array([s["Composition"].get(spec, 1e-30) for s in gf_steps])
line, = ax_main.plot(t_gf, y_gf, label=f"{run_name} {spec} (GF)")
if compare_pynuc and pynuc_data:
pynuc_steps = pynuc_data.get("Steps", [])
if not pynuc_steps:
continue
t_pynuc = np.array([s["t"] for s in pynuc_steps])
y_pynuc = np.array([s["Composition"].get(spec, 1e-30) for s in pynuc_steps])
ax_main.plot(t_pynuc, y_pynuc, '--', color=line.get_color(), label=f"{run_name} {spec} (PyNuc)")
if len(t_pynuc) > 1:
f_interp = interp1d(t_pynuc, y_pynuc, kind='linear', bounds_error=False, fill_value=(y_pynuc[0], y_pynuc[-1]))
y_pynuc_interp = f_interp(t_gf)
log_diff = np.abs(np.log10(np.maximum(y_gf, 1e-30)) - np.log10(np.maximum(y_pynuc_interp, 1e-30)))
ax_diff.plot(t_gf, log_diff, color=line.get_color(), linestyle=':', label=f"Δ {spec}")
elif var == PlotVariable.EPS:
y_gf = np.array([s["eps"] for s in gf_steps])
line, = ax_main.plot(t_gf, y_gf, label=f"{run_name} (GF)")
if compare_pynuc and pynuc_data:
pynuc_steps = pynuc_data.get("Steps", [])
if pynuc_steps:
s_t, s_e = get_pynuc_eps(pynuc_steps)
if len(s_t) > 1:
s_cumE = cumulative_trapezoid(s_e, s_t, initial=0)
ax_main.plot(s_t, s_cumE, '--', color=line.get_color(), label=f"{run_name} (PyNuc)")
f_pynuc_interp = interp1d(s_t[np.isfinite(s_cumE)], s_cumE[np.isfinite(s_cumE)])
f_gf_interp = interp1d(t_gf, y_gf)
t_safe = np.logspace(
8,
np.log10(min(s_t.max(), t_gf.max())),
1000
)
y_pynuc_interp = f_pynuc_interp(t_safe)
y_gf_interp = f_gf_interp(t_safe)
pynuc_safe = np.maximum(np.abs(y_pynuc_interp), 1e-30)
gf_safe = np.maximum(np.abs(y_gf_interp), 1e-30)
log_diff = np.log10(gf_safe) - np.log10(pynuc_safe)
ax_diff.plot(t_safe, log_diff, color=line.get_color(), linestyle=':', label=f"Δ eps")
ax_main.set_xlim(1e8)
elif var == PlotVariable.DT:
y_gf = np.array([s["dt"] for s in gf_steps])
ax_main.plot(t_gf, y_gf, label=f"{run_name} (GF)")
def _finalize_plot(fig: plt.Figure, ax_main: plt.Axes, ax_diff: plt.Axes, format_opt: OutputFormat, filename_base: str, is_subfigure: bool = False):
ax_main.legend(loc='best', fontsize='small')
if len(ax_diff.lines) > 0:
ax_diff.legend(loc='best', fontsize='x-small')
if is_subfigure:
return
fig.tight_layout()
if format_opt != OutputFormat.INTERACTIVE:
out_name = f"{filename_base}.{format_opt.value}"
fig.savefig(out_name, format=format_opt.value, bbox_inches='tight')
print(f"Saved figure: {out_name}")
plt.close(fig)
def plot_data(runs: List[str], plot_vars: List[PlotVariable], base_dir: str,
compare_pynuc: bool, format_opt: OutputFormat, fig_opts: dict):
if not runs:
print("No valid runs to plot.")
return
if fig_opts['use_ext_style']:
try:
plt.style.use(EXTERNAL_STYLE_PATH)
print(f"Using external style sheet: {EXTERNAL_STYLE_PATH}")
except Exception as e:
print(f"Warning: Failed to load external style sheet. Error: {e}")
elif fig_opts['style']:
try:
plt.style.use(fig_opts['style'])
except OSError:
print(f"Warning: Style '{fig_opts['style']}' not found. Using default.")
plt.rcParams["figure.figsize"] = fig_opts['figsize']
plt.rcParams["figure.dpi"] = fig_opts['dpi']
for var in plot_vars:
if fig_opts['merge_runs']:
fig, (ax_main, ax_diff) = plt.subplots(2, 1, sharex=True, gridspec_kw={'height_ratios': [3, 1]})
ax_main.set_title(f"Comparison of {var.value.upper()} (Merged Runs)")
_setup_axes(ax_main, ax_diff, var, fig_opts)
for run_name in runs:
_plot_single_run(ax_main, ax_diff, run_name, var, base_dir, compare_pynuc)
_finalize_plot(fig, ax_main, ax_diff, format_opt, f"ValidationPlot_Merged_{var.value}")
else:
num_runs = len(runs)
cols = math.ceil(math.sqrt(num_runs))
rows = math.ceil(num_runs / cols)
base_w, base_h = fig_opts['figsize']
fig = plt.figure(figsize=(base_w * cols, base_h * rows), layout='constrained')
fig.suptitle(f"{var.value.upper()} Comparison", fontsize=16, fontweight='bold')
subfigs_raw = fig.subfigures(rows, cols)
if hasattr(subfigs_raw, 'flatten'):
subfigs = subfigs_raw.flatten()
else:
subfigs = [subfigs_raw]
for i, run_name in enumerate(runs):
subfig = subfigs[i]
subfig.suptitle(f"{run_name}", fontsize=12)
axes = subfig.subplots(2, 1, sharex=True, gridspec_kw={'height_ratios': [3, 1]})
ax_main, ax_diff = axes[0], axes[1]
_setup_axes(ax_main, ax_diff, var, fig_opts)
_plot_single_run(ax_main, ax_diff, run_name, var, base_dir, compare_pynuc)
_finalize_plot(fig, ax_main, ax_diff, format_opt, "", is_subfigure=True)
for j in range(num_runs, len(subfigs)):
subfigs[j].set_visible(False)
if format_opt != OutputFormat.INTERACTIVE:
out_name = f"ValidationPlot_Grid_{var.value}.{format_opt.value}"
fig.savefig(out_name, format=format_opt.value, bbox_inches='tight')
print(f"Saved grid figure: {out_name}")
plt.close(fig)
if format_opt == OutputFormat.INTERACTIVE:
plt.show()
def main():
parser = argparse.ArgumentParser(description="GridFire Validation Suite Output Parser and Plotter")
parser.add_argument("-d", "--data-dir", type=str, default="GF_Validation_Output",
help="Path to the directory containing the JSON output folders.")
parser.add_argument("--runs", nargs="+", type=str, required=False,
help="Which validation runs to analyze. Use 'all' to process all available runs.")
parser.add_argument("--plot", nargs="*", type=lambda x: PlotVariable[x.upper()], choices=list(PlotVariable), default=[],
help="Variables to plot. Leave empty to skip plotting.")
parser.add_argument("--compare-pynucastro", action="store_true",
help="Include pynucastro data and calculate log residuals.")
parser.add_argument("--merge-runs", action="store_true",
help="Merge all specified runs onto a single figure per variable. (Default: Grid layout of subfigures)")
parser.add_argument("--x-scale", type=str, choices=["log", "linear"], default="log",
help="Scale for the x-axis (time). Default is 'log'.")
parser.add_argument("--y-scale", type=str, choices=["log", "linear"], default="log",
help="Scale for the y-axis (main plots). Default is 'log'.")
parser.add_argument("--format", type=lambda x: OutputFormat[x.upper()], choices=list(OutputFormat), default=OutputFormat.INTERACTIVE,
help="Output format for the plots. Default is interactive window.")
parser.add_argument("--use-external-style", action="store_true",
help="Load the custom style sheet defined in EXTERNAL_STYLE_PATH.")
parser.add_argument("--style", type=str, default=None,
help="Built-in Matplotlib stylesheet name (e.g., 'seaborn-v0_8-whitegrid'). Ignored if --use-external-style is set.")
parser.add_argument("--figsize", nargs=2, type=float, default=[8.0, 6.0],
metavar=("WIDTH", "HEIGHT"), help="Base figure size in inches per subfigure (e.g., --figsize 10 8).")
parser.add_argument("--dpi", type=int, default=150, help="DPI resolution for saved figures.")
parser.add_argument("--list", action="store_true", default=False, help="list available runs")
args = parser.parse_args()
available_runs = discover_runs(args.data_dir)
if not available_runs:
print(f"Error: No successful run data found in {args.data_dir} (Checked GridFire/Ok/).")
sys.exit(1)
if args.list:
for run in available_runs:
print(f"==> {run}")
exit()
if "all" in [r.lower() for r in args.runs]:
target_runs = available_runs
else:
target_runs = [r for r in args.runs if r in available_runs]
missing = [r for r in args.runs if r.lower() != "all" and r not in target_runs]
if missing:
print(f"Warning: The following runs were skipped because they failed or weren't found: {', '.join(missing)}")
if args.plot and target_runs:
fig_opts = {
"figsize": tuple(args.figsize),
"dpi": args.dpi,
"style": args.style,
"use_ext_style": args.use_external_style,
"merge_runs": args.merge_runs,
"x_scale": args.x_scale,
"y_scale": args.y_scale
}
plot_data(target_runs, args.plot, args.data_dir, args.compare_pynucastro, args.format, fig_opts)
elif args.plot:
print("Error: No valid runs matched your selection.")
if __name__ == "__main__":
main()