import argparse import json import os import sys import math import numpy as np import matplotlib.pyplot as plt from scipy.interpolate import interp1d from scipy.integrate import cumulative_trapezoid from astropy import constants as const from astropy import units as u from enum import Enum from typing import List, Dict, Any, Tuple from fourdst.atomic import species as spdict EXTERNAL_STYLE_PATH = "../ManuscriptFigures/utils/pub.mplstyle" class PlotVariable(Enum): COMPOSITION = "composition" EPS = "eps" DT = "dt" class OutputFormat(Enum): INTERACTIVE = "interactive" PDF = "pdf" PNG = "png" JPEG = "jpeg" def discover_runs(base_dir: str) -> List[str]: runs = set() gf_ok_dir = os.path.join(base_dir, "GridFire", "Ok") if os.path.exists(gf_ok_dir): for fname in os.listdir(gf_ok_dir): if fname.endswith("_OKAY.json"): runs.add(fname.replace("_OKAY.json", "")) return sorted(list(runs)) def load_run_data(base_dir: str, run_name: str) -> Tuple[Dict[str, Any], Dict[str, Any]]: gf_data = {} pynuc_data = {} gf_ok_path = os.path.join(base_dir, "GridFire", "Ok", f"{run_name}_OKAY.json") gf_err_path = os.path.join(base_dir, "GridFire", "Err", f"{run_name}_FAIL.json") pynuc_path = os.path.join(base_dir, "pynucastro", f"{run_name}_pynucastro.json") if os.path.exists(gf_ok_path): with open(gf_ok_path, 'r') as f: gf_data = json.load(f) elif os.path.exists(gf_err_path): with open(gf_err_path, 'r') as f: gf_data = json.load(f) if os.path.exists(pynuc_path): with open(pynuc_path, 'r') as f: pynuc_data = json.load(f) return gf_data, pynuc_data def get_pynuc_eps(steps: List[Dict[str, Any]]) -> Tuple[np.ndarray, np.ndarray]: c_sq = (const.c.cgs.value)**2 Na = const.N_A.cgs.value amu_to_g = const.u.cgs.value eps_history = [] time_history = [] last_Y = {} last_t = None for step in steps: t = step["t"] current_Y = step["Composition"] if last_t is None: last_t = t last_Y = current_Y.copy() eps_history.append(0.0) time_history.append(t) continue dt = t - last_t if dt > 0: dm_dt = 0.0 all_species = set(current_Y.keys()) | set(last_Y.keys()) for sp in all_species: y_curr = current_Y.get(sp, 0.0) y_prev = last_Y.get(sp, 0.0) dy = y_curr - y_prev if sp in spdict: mass_g = spdict[sp].mass() * amu_to_g dm_dt += mass_g * (dy / dt) rate = -dm_dt * Na * c_sq eps_history.append(rate) time_history.append(t) last_t = t last_Y = current_Y.copy() return np.array(time_history), np.array(eps_history) def _setup_axes(ax_main: plt.Axes, ax_diff: plt.Axes, var: PlotVariable, fig_opts: dict): ax_diff.set_xlabel("Time (s)") ax_main.set_xscale(fig_opts['x_scale']) ax_diff.set_xscale(fig_opts['x_scale']) ax_main.set_yscale(fig_opts['y_scale']) if var == PlotVariable.EPS: ax_main.set_ylabel("Cumulative Energy (eps)") elif var == PlotVariable.DT: ax_main.set_ylabel("Timestep Size (dt)") elif var == PlotVariable.COMPOSITION: ax_main.set_ylabel("Mass Fraction (X_i)") ax_diff.set_ylabel(r"$\Delta \log_{10}$") def _plot_single_run(ax_main: plt.Axes, ax_diff: plt.Axes, run_name: str, var: PlotVariable, base_dir: str, compare_pynuc: bool): gf_data, pynuc_data = load_run_data(base_dir, run_name) if not gf_data or gf_data.get("Metadata", {}).get("Status") == "Error": return gf_steps = gf_data.get("Steps", []) if not gf_steps: return if compare_pynuc and not pynuc_data: print(f"Warning: PyNucastro comparison requested but data not found for '{run_name}'.") t_gf = np.array([s["t"] for s in gf_steps]) if var == PlotVariable.COMPOSITION: final_comp = gf_steps[-1]["Composition"] top_species = sorted(final_comp, key=final_comp.get, reverse=True)[:3] for spec in top_species: y_gf = np.array([s["Composition"].get(spec, 1e-30) for s in gf_steps]) line, = ax_main.plot(t_gf, y_gf, label=f"{run_name} {spec} (GF)") if compare_pynuc and pynuc_data: pynuc_steps = pynuc_data.get("Steps", []) if not pynuc_steps: continue t_pynuc = np.array([s["t"] for s in pynuc_steps]) y_pynuc = np.array([s["Composition"].get(spec, 1e-30) for s in pynuc_steps]) ax_main.plot(t_pynuc, y_pynuc, '--', color=line.get_color(), label=f"{run_name} {spec} (PyNuc)") if len(t_pynuc) > 1: f_interp = interp1d(t_pynuc, y_pynuc, kind='linear', bounds_error=False, fill_value=(y_pynuc[0], y_pynuc[-1])) y_pynuc_interp = f_interp(t_gf) log_diff = np.abs(np.log10(np.maximum(y_gf, 1e-30)) - np.log10(np.maximum(y_pynuc_interp, 1e-30))) ax_diff.plot(t_gf, log_diff, color=line.get_color(), linestyle=':', label=f"Δ {spec}") elif var == PlotVariable.EPS: y_gf = np.array([s["eps"] for s in gf_steps]) line, = ax_main.plot(t_gf, y_gf, label=f"{run_name} (GF)") if compare_pynuc and pynuc_data: pynuc_steps = pynuc_data.get("Steps", []) if pynuc_steps: s_t, s_e = get_pynuc_eps(pynuc_steps) if len(s_t) > 1: s_cumE = cumulative_trapezoid(s_e, s_t, initial=0) ax_main.plot(s_t, s_cumE, '--', color=line.get_color(), label=f"{run_name} (PyNuc)") f_pynuc_interp = interp1d(s_t[np.isfinite(s_cumE)], s_cumE[np.isfinite(s_cumE)]) f_gf_interp = interp1d(t_gf, y_gf) t_safe = np.logspace( 8, np.log10(min(s_t.max(), t_gf.max())), 1000 ) y_pynuc_interp = f_pynuc_interp(t_safe) y_gf_interp = f_gf_interp(t_safe) pynuc_safe = np.maximum(np.abs(y_pynuc_interp), 1e-30) gf_safe = np.maximum(np.abs(y_gf_interp), 1e-30) log_diff = np.log10(gf_safe) - np.log10(pynuc_safe) ax_diff.plot(t_safe, log_diff, color=line.get_color(), linestyle=':', label=f"Δ eps") ax_main.set_xlim(1e8) elif var == PlotVariable.DT: y_gf = np.array([s["dt"] for s in gf_steps]) ax_main.plot(t_gf, y_gf, label=f"{run_name} (GF)") def _finalize_plot(fig: plt.Figure, ax_main: plt.Axes, ax_diff: plt.Axes, format_opt: OutputFormat, filename_base: str, is_subfigure: bool = False): ax_main.legend(loc='best', fontsize='small') if len(ax_diff.lines) > 0: ax_diff.legend(loc='best', fontsize='x-small') if is_subfigure: return fig.tight_layout() if format_opt != OutputFormat.INTERACTIVE: out_name = f"{filename_base}.{format_opt.value}" fig.savefig(out_name, format=format_opt.value, bbox_inches='tight') print(f"Saved figure: {out_name}") plt.close(fig) def plot_data(runs: List[str], plot_vars: List[PlotVariable], base_dir: str, compare_pynuc: bool, format_opt: OutputFormat, fig_opts: dict): if not runs: print("No valid runs to plot.") return if fig_opts['use_ext_style']: try: plt.style.use(EXTERNAL_STYLE_PATH) print(f"Using external style sheet: {EXTERNAL_STYLE_PATH}") except Exception as e: print(f"Warning: Failed to load external style sheet. Error: {e}") elif fig_opts['style']: try: plt.style.use(fig_opts['style']) except OSError: print(f"Warning: Style '{fig_opts['style']}' not found. Using default.") plt.rcParams["figure.figsize"] = fig_opts['figsize'] plt.rcParams["figure.dpi"] = fig_opts['dpi'] for var in plot_vars: if fig_opts['merge_runs']: fig, (ax_main, ax_diff) = plt.subplots(2, 1, sharex=True, gridspec_kw={'height_ratios': [3, 1]}) ax_main.set_title(f"Comparison of {var.value.upper()} (Merged Runs)") _setup_axes(ax_main, ax_diff, var, fig_opts) for run_name in runs: _plot_single_run(ax_main, ax_diff, run_name, var, base_dir, compare_pynuc) _finalize_plot(fig, ax_main, ax_diff, format_opt, f"ValidationPlot_Merged_{var.value}") else: num_runs = len(runs) cols = math.ceil(math.sqrt(num_runs)) rows = math.ceil(num_runs / cols) base_w, base_h = fig_opts['figsize'] fig = plt.figure(figsize=(base_w * cols, base_h * rows), layout='constrained') fig.suptitle(f"{var.value.upper()} Comparison", fontsize=16, fontweight='bold') subfigs_raw = fig.subfigures(rows, cols) if hasattr(subfigs_raw, 'flatten'): subfigs = subfigs_raw.flatten() else: subfigs = [subfigs_raw] for i, run_name in enumerate(runs): subfig = subfigs[i] subfig.suptitle(f"{run_name}", fontsize=12) axes = subfig.subplots(2, 1, sharex=True, gridspec_kw={'height_ratios': [3, 1]}) ax_main, ax_diff = axes[0], axes[1] _setup_axes(ax_main, ax_diff, var, fig_opts) _plot_single_run(ax_main, ax_diff, run_name, var, base_dir, compare_pynuc) _finalize_plot(fig, ax_main, ax_diff, format_opt, "", is_subfigure=True) for j in range(num_runs, len(subfigs)): subfigs[j].set_visible(False) if format_opt != OutputFormat.INTERACTIVE: out_name = f"ValidationPlot_Grid_{var.value}.{format_opt.value}" fig.savefig(out_name, format=format_opt.value, bbox_inches='tight') print(f"Saved grid figure: {out_name}") plt.close(fig) if format_opt == OutputFormat.INTERACTIVE: plt.show() def main(): parser = argparse.ArgumentParser(description="GridFire Validation Suite Output Parser and Plotter") parser.add_argument("-d", "--data-dir", type=str, default="GF_Validation_Output", help="Path to the directory containing the JSON output folders.") parser.add_argument("--runs", nargs="+", type=str, required=False, help="Which validation runs to analyze. Use 'all' to process all available runs.") parser.add_argument("--plot", nargs="*", type=lambda x: PlotVariable[x.upper()], choices=list(PlotVariable), default=[], help="Variables to plot. Leave empty to skip plotting.") parser.add_argument("--compare-pynucastro", action="store_true", help="Include pynucastro data and calculate log residuals.") parser.add_argument("--merge-runs", action="store_true", help="Merge all specified runs onto a single figure per variable. (Default: Grid layout of subfigures)") parser.add_argument("--x-scale", type=str, choices=["log", "linear"], default="log", help="Scale for the x-axis (time). Default is 'log'.") parser.add_argument("--y-scale", type=str, choices=["log", "linear"], default="log", help="Scale for the y-axis (main plots). Default is 'log'.") parser.add_argument("--format", type=lambda x: OutputFormat[x.upper()], choices=list(OutputFormat), default=OutputFormat.INTERACTIVE, help="Output format for the plots. Default is interactive window.") parser.add_argument("--use-external-style", action="store_true", help="Load the custom style sheet defined in EXTERNAL_STYLE_PATH.") parser.add_argument("--style", type=str, default=None, help="Built-in Matplotlib stylesheet name (e.g., 'seaborn-v0_8-whitegrid'). Ignored if --use-external-style is set.") parser.add_argument("--figsize", nargs=2, type=float, default=[8.0, 6.0], metavar=("WIDTH", "HEIGHT"), help="Base figure size in inches per subfigure (e.g., --figsize 10 8).") parser.add_argument("--dpi", type=int, default=150, help="DPI resolution for saved figures.") parser.add_argument("--list", action="store_true", default=False, help="list available runs") args = parser.parse_args() available_runs = discover_runs(args.data_dir) if not available_runs: print(f"Error: No successful run data found in {args.data_dir} (Checked GridFire/Ok/).") sys.exit(1) if args.list: for run in available_runs: print(f"==> {run}") exit() if "all" in [r.lower() for r in args.runs]: target_runs = available_runs else: target_runs = [r for r in args.runs if r in available_runs] missing = [r for r in args.runs if r.lower() != "all" and r not in target_runs] if missing: print(f"Warning: The following runs were skipped because they failed or weren't found: {', '.join(missing)}") if args.plot and target_runs: fig_opts = { "figsize": tuple(args.figsize), "dpi": args.dpi, "style": args.style, "use_ext_style": args.use_external_style, "merge_runs": args.merge_runs, "x_scale": args.x_scale, "y_scale": args.y_scale } plot_data(target_runs, args.plot, args.data_dir, args.compare_pynucastro, args.format, fig_opts) elif args.plot: print("Error: No valid runs matched your selection.") if __name__ == "__main__": main()