import numpy as np import matplotlib.pyplot as plt import scipy.integrate import pynucastro as pyna import os import sys import importlib.util import time import matplotlib.lines as mlines import re import json import argparse from fourdst.composition import Composition from gridfire.type import NetIn from gridfire.engine import GraphEngine from gridfire.solver import PointSolver, PointSolverContext from tqdm import tqdm from fourdst.composition.utils import buildCompositionFromMassFractions def T9(age): return 10.0 / np.sqrt(age) def get_density(age): return 4e-5 * (T9(age) ** 3) def get_pyna_rate(my_rate_str, library): match = re.match(r"([a-zA-Z0-9]+)\(([^,]*),([^)]*)\)(.*)", my_rate_str) if not match: print(f"Could not parse string format: {my_rate_str}") return [] target = match.group(1) projectile = match.group(2) ejectiles = match.group(3) product = match.group(4) def expand_species(s_str): if not s_str or s_str.strip() == "": return [] parts = s_str.split() expanded = [] for p in parts: if p == 'g': continue mult_match = re.match(r"^(\d+)([a-zA-Z0-9]+)$", p) if mult_match: count = int(mult_match.group(1)) spec = mult_match.group(2) else: count = 1 spec = p if spec == 'g': continue if spec == 'a': spec = 'he4' expanded.extend([spec] * count) return expanded reactants_str = [target] + expand_species(projectile) products_str = expand_species(ejectiles) + [product] try: r_nuc = [pyna.Nucleus(r) for r in reactants_str] p_nuc = [pyna.Nucleus(p) for p in products_str] except Exception as e: print(f"Error converting nuclei for {my_rate_str}: {e}") return [] rates = library.get_rate_by_nuclei(r_nuc, p_nuc) if rates: if not isinstance(rates, list): return [rates] return rates r_nuc_names = sorted([str(n) for n in r_nuc]) p_nuc_names = sorted([str(n) for n in p_nuc]) ignore_list = ['e-', 'e+', 'g', 'nu', 'anu'] matched_rates = [] for rate in library.get_rates(): lib_r_names = sorted([str(n) for n in rate.reactants if str(n) not in ignore_list]) lib_p_names = sorted([str(n) for n in rate.products if str(n) not in ignore_list]) if r_nuc_names == lib_r_names and p_nuc_names == lib_p_names: matched_rates.append(rate) return matched_rates def load_network_module(filepath): module_name = os.path.basename(filepath).replace(".py", "") if module_name in sys.modules: del sys.modules[module_name] spec = importlib.util.spec_from_file_location(module_name, filepath) if spec is None: raise FileNotFoundError(f"Error: could not find module at {filepath}") network_module = importlib.util.module_from_spec(spec) sys.modules[module_name] = network_module spec.loader.exec_module(network_module) return network_module def main(args): tMax = 3600.0 h = 0.01 current_time = 180.0 XpXn = 7.17 Xn = 1.0 / (1.0 + XpXn) Xp = 1.0 - Xn comp: Composition = buildCompositionFromMassFractions(["H-1", "n-1"], [Xp, Xn]) netIn = NetIn() netIn.composition = comp netIn.dt0 = 1e-12 if args.depth is not None: print(f"Initializing GridFire GraphEngine with restricted depth = {args.depth}") engine = GraphEngine(comp, args.depth) else: print("Initializing full-depth GridFire GraphEngine (Note: pynucastro may take a long time to run JIT, set NUMBA_DISABLE_JIT=1 as an eviromental variable to disable JIT, this makes per timestep time increase but may still be faster for large networks due to the lack of upfront compilation time)") engine = GraphEngine(comp) blob = engine.constructStateBlob() solver_ctx = PointSolverContext(blob) solver_ctx.stdout_logging = False solver = PointSolver(engine) gf_initial_Y = {} for sp in engine.getNetworkSpecies(solver_ctx.engine_ctx): if comp.contains(sp): gf_initial_Y[sp.name()] = comp.getMolarAbundance(sp) else: gf_initial_Y[sp.name()] = 0.0 gf_time = [] gf_results = {} step_conditions = [] gf_start_time = time.time() gf_current_time = current_time total_steps = int(np.ceil(np.log(tMax / current_time) / np.log(1 + h))) with tqdm(total=total_steps, desc="GridFire BBN", unit="step") as pbar: while gf_current_time < tMax: current_dt = h * gf_current_time next_time = gf_current_time + current_dt burn_temp = (T9(gf_current_time) + T9(next_time)) / 2.0 * 1e9 burn_density = (get_density(gf_current_time) + get_density(next_time)) / 2.0 netIn.temperature = burn_temp netIn.density = burn_density netIn.tMax = current_dt netOut = solver.evaluate(solver_ctx, netIn) netIn.composition = netOut.composition pbar.update(1) pbar.set_postfix(t=f"{gf_current_time:.2e}", T=f"{burn_temp:.2e}", rho=f"{burn_density:.2e}") step_conditions.append({ "dt": current_dt, "T": burn_temp, "rho": burn_density, "t": gf_current_time }) gf_time.append(gf_current_time) for sp in engine.getNetworkSpecies(solver_ctx.engine_ctx): name = sp.name() if name not in gf_results: gf_results[name] = [] gf_results[name].append(netOut.composition.getMolarAbundance(sp)) gf_current_time += current_dt gf_end_time = time.time() print(f"GridFire integration finished in {gf_end_time - gf_start_time:.4f} seconds.") print("Building Pynucastro BBN Network...") reaclib_library = pyna.ReacLibLibrary() rate_names = [r.id().replace("e+","").replace("e-","").replace(", ", ",") for r in engine.getNetworkReactions(solver_ctx.engine_ctx)] goodRates = [] missingRates = [] skipped_photo_rates = 0 pyna_rate_mapping = {} import io import contextlib for r_str in rate_names: pyna_rates_for_reaction = [] with contextlib.redirect_stdout(io.StringIO()), contextlib.redirect_stderr(io.StringIO()): try: res = reaclib_library.get_rate_by_name(r_str) if res is not None: if isinstance(res, list): pyna_rates_for_reaction.extend(res) else: pyna_rates_for_reaction.append(res) except: pass if not pyna_rates_for_reaction: res_nuc = get_pyna_rate(r_str, reaclib_library) if res_nuc: if isinstance(res_nuc, list): pyna_rates_for_reaction.extend(res_nuc) else: pyna_rates_for_reaction.append(res_nuc) if pyna_rates_for_reaction: pyna_rate_mapping[r_str] = pyna_rates_for_reaction for rate in pyna_rates_for_reaction: if args.filter_photo: is_photo_rate = any(str(r).lower() in ['g', 'gamma'] for r in rate.reactants) if is_photo_rate: skipped_photo_rates += 1 continue goodRates.append(rate) else: missingRates.append(r_str) if missingRates: print(f"Warning: Could not map {len(missingRates)} rates to Pynucastro (likely absent from default ReacLib).") print(f"Missing sample: {missingRates[:10]}...") if args.filter_photo: print(f"Info: Skipped {skipped_photo_rates} photodisintegration rates due to --filter-photo flag.") print("--- Evaluating reaction rates over all temperatures ---") gf_rates_history = {} py_rates_history = {} gf_rate_labels = {} py_rate_labels = {} for reaction in engine.getNetworkReactions(solver_ctx.engine_ctx): r_str = reaction.id().replace("e+","").replace("e-","").replace(", ", ",") gf_rates_history[r_str] = [] py_rates_history[r_str] = [] try: gf_rate_labels[r_str] = reaction.sources() except AttributeError: try: gf_rate_labels[r_str] = reaction.sourceLabel() except AttributeError: gf_rate_labels[r_str] = "Unknown" if r_str in pyna_rate_mapping: py_rate_labels[r_str] = [getattr(pr, 'label', 'Unknown') for pr in pyna_rate_mapping[r_str]] else: py_rate_labels[r_str] = [] for step in tqdm(step_conditions, desc="Calculating Rates", unit="step"): T9_val = step["T"] / 1e9 T_K = step["T"] for reaction in engine.getNetworkReactions(solver_ctx.engine_ctx): r_str = reaction.id().replace("e+","").replace("e-","").replace(", ", ",") gf_rate_val = 0.0 try: gf_rate_val = reaction.calculate_rate(T9_val, 0, []) except: try: gf_rate_val = reaction.calculate_rate(T9_val, 0, 0, 0, [], dict()) except Exception as e: pass gf_rates_history[r_str].append(gf_rate_val) py_rate_val = 0.0 if r_str in pyna_rate_mapping: for pr in pyna_rate_mapping[r_str]: py_rate_val += pr.eval(T_K) py_rates_history[r_str].append(py_rate_val) print("--- Rate Comparison Summary ---") threshold = 1e-4 mismatches = {} for r_str in gf_rates_history: gf_arr = np.array(gf_rates_history[r_str]) py_arr = np.array(py_rates_history[r_str]) with np.errstate(divide='ignore', invalid='ignore'): denom = np.where(py_arr != 0, py_arr, gf_arr) denom = np.where(denom == 0, 1e-30, denom) rel_diffs = np.abs(gf_arr - py_arr) / denom max_diff = np.max(rel_diffs) if max_diff > threshold: max_idx = np.argmax(rel_diffs) mismatches[r_str] = { "max_diff": max_diff, "temp": step_conditions[max_idx]["T"], "gf_val": gf_arr[max_idx], "py_val": py_arr[max_idx] } if mismatches: print(f"Found {len(mismatches)} rates with differences > {threshold:.2%}") for r_str, info in mismatches.items(): gf_lbl = gf_rate_labels.get(r_str, 'Unknown') py_lbl = py_rate_labels.get(r_str, []) print(f"{r_str:20}: Max Diff = {info['max_diff']:.2%}, at T = {info['temp']:.2e} K") print(f" GF = {info['gf_val']:.4e} (Source: {gf_lbl})") print(f" Py = {info['py_val']:.4e} (Sources: {py_lbl})") else: print(f"All rates match within the {threshold:.2%} threshold across all temperatures.") print("-------------------------------") pynet = pyna.PythonNetwork(rates=goodRates) network_file = "pynuc_bbn_network.py" pynet.write_network(network_file) net = load_network_module(network_file) mapping = { "H-1": ("p", "tab:blue"), "n-1": ("n", "tab:orange"), "He-4": ("he4", "tab:green"), "H-2": ("d", "tab:red"), "H-3": ("t", "tab:purple"), "He-3": ("he3", "tab:brown"), "Li-7": ("li7", "tab:pink"), "Be-7": ("be7", "tab:gray") } Y0 = np.zeros(net.nnuc) for i, nuc in enumerate(pynet.get_nuclei()): nuc_name = str(nuc) gf_name = None for gf, (py, _) in mapping.items(): if py == nuc_name: gf_name = gf break if not gf_name: match = re.match(r"([a-zA-Z]+)(\d+)", nuc_name) if match: gf_name = f"{match.group(1).capitalize()}-{match.group(2)}" if gf_name and gf_name in gf_initial_Y: Y0[i] = gf_initial_Y[gf_name] pyna_time = [] pyna_nuc_names = [str(n) for n in pynet.get_nuclei()] pyna_results = {nuc: [] for nuc in pyna_nuc_names} pyna_start_time = time.time() for step in tqdm(step_conditions, unit="step", desc="pynucastro Integration"): sol = scipy.integrate.solve_ivp( net.rhs, [0, step["dt"]], Y0, args=(step["rho"], step["T"]), method="Radau", jac=net.jacobian, rtol=1e-8, atol=1e-20 ) Y0 = sol.y[:, -1] pyna_time.append(step["t"]) for j in range(net.nnuc): nuc_name = str(pynet.get_nuclei()[j]) if nuc_name in pyna_results: pyna_results[nuc_name].append(Y0[j]) pyna_end_time = time.time() print(f"Pynucastro integration finished in {pyna_end_time - pyna_start_time:.4f} seconds.") export_data = { "metadata": { "tMax": tMax, "h": h, "initial_time": current_time, "initial_XpXn_ratio": XpXn, "initial_mass_fractions": { "Xp": Xp, "Xn": Xn }, "execution_times_seconds": { "gridfire": gf_end_time - gf_start_time, "pynucastro": pyna_end_time - pyna_start_time }, "missing_pynucastro_rates": missingRates, "skipped_photodisintegration_rates": skipped_photo_rates if args.filter_photo else 0, "rate_labels": { "gridfire": gf_rate_labels, "pynucastro": py_rate_labels } }, "thermodynamic_conditions": step_conditions, "data": { "gridfire": { "time": gf_time, "molar_abundances": gf_results, "reaction_rates": gf_rates_history }, "pynucastro": { "time": pyna_time, "molar_abundances": pyna_results, "reaction_rates": py_rates_history } } } json_out_file = "bbn_simulation_data.json" with open(json_out_file, "w") as f: json.dump(export_data, f, indent=4) plt.style.use("default") fig, ax = plt.subplots(figsize=(10, 7)) for gf_name, (pyna_name, color) in mapping.items(): if gf_name in gf_results: ax.plot(gf_time, gf_results[gf_name], color=color, linestyle="-", linewidth=2.5, label=f"GF {gf_name}") if pyna_name in pyna_results: ax.plot(pyna_time, pyna_results[pyna_name], color=color, linestyle="--", linewidth=1.5, label=f"Pyna {pyna_name}") ax.set_xscale("log") ax.set_yscale("log") ax.set_ylim(1e-12, 2) ax.set_xlabel("Time (s)", fontsize=14) ax.set_ylabel("Molar Abundance (Y)", fontsize=14) line_gf = mlines.Line2D([], [], color='black', linestyle='-', linewidth=2.5, label='GridFire') line_py = mlines.Line2D([], [], color='black', linestyle='--', linewidth=1.5, label='Pynucastro') sp_handles = [] for gf_name, (pyna_name, color) in mapping.items(): sp_handles.append(mlines.Line2D([], [], color=color, linestyle='-', linewidth=2, label=gf_name)) ax.legend(handles=[line_gf, line_py] + sp_handles, loc='center left', bbox_to_anchor=(1.02, 0.5), fontsize=12) out_file = "bbn_comparison.pdf" plt.savefig(out_file) if __name__ == "__main__": parser = argparse.ArgumentParser(description="GridFire vs Pynucastro BBN Comparison") parser.add_argument("--filter-photo", action="store_true", help="Filter out photodisintegration (reverse) rates to mimic GridFire's forward-only mechanics.") parser.add_argument("--depth", type=int, default=None, help="Limit the assembly depth of GridFire's GraphEngine. E.g., setting '--depth 3' shrinks the network size from 5000+ reactions to ~100, which reduces Pynucastro's Numba JIT compile time from hours to seconds.") args = parser.parse_args() main(args)