feat(validation): added more of the scripts to make paper figures
This commit is contained in:
@@ -0,0 +1,472 @@
|
||||
import numpy as np
|
||||
import matplotlib.pyplot as plt
|
||||
import scipy.integrate
|
||||
import pynucastro as pyna
|
||||
import os
|
||||
import sys
|
||||
import importlib.util
|
||||
import time
|
||||
import matplotlib.lines as mlines
|
||||
import re
|
||||
import json
|
||||
import argparse
|
||||
|
||||
from fourdst.composition import Composition
|
||||
from gridfire.type import NetIn
|
||||
from gridfire.engine import GraphEngine
|
||||
from gridfire.solver import PointSolver, PointSolverContext
|
||||
|
||||
from tqdm import tqdm
|
||||
|
||||
from fourdst.composition.utils import buildCompositionFromMassFractions
|
||||
|
||||
|
||||
|
||||
def T9(age):
|
||||
return 10.0 / np.sqrt(age)
|
||||
|
||||
def get_density(age):
|
||||
return 4e-5 * (T9(age) ** 3)
|
||||
|
||||
def get_pyna_rate(my_rate_str, library):
|
||||
match = re.match(r"([a-zA-Z0-9]+)\(([^,]*),([^)]*)\)(.*)", my_rate_str)
|
||||
|
||||
if not match:
|
||||
print(f"Could not parse string format: {my_rate_str}")
|
||||
return []
|
||||
|
||||
target = match.group(1)
|
||||
projectile = match.group(2)
|
||||
ejectiles = match.group(3)
|
||||
product = match.group(4)
|
||||
|
||||
def expand_species(s_str):
|
||||
if not s_str or s_str.strip() == "":
|
||||
return []
|
||||
|
||||
parts = s_str.split()
|
||||
expanded = []
|
||||
|
||||
for p in parts:
|
||||
if p == 'g':
|
||||
continue
|
||||
|
||||
mult_match = re.match(r"^(\d+)([a-zA-Z0-9]+)$", p)
|
||||
if mult_match:
|
||||
count = int(mult_match.group(1))
|
||||
spec = mult_match.group(2)
|
||||
else:
|
||||
count = 1
|
||||
spec = p
|
||||
|
||||
if spec == 'g':
|
||||
continue
|
||||
if spec == 'a': spec = 'he4'
|
||||
|
||||
expanded.extend([spec] * count)
|
||||
return expanded
|
||||
|
||||
reactants_str = [target] + expand_species(projectile)
|
||||
products_str = expand_species(ejectiles) + [product]
|
||||
|
||||
try:
|
||||
r_nuc = [pyna.Nucleus(r) for r in reactants_str]
|
||||
p_nuc = [pyna.Nucleus(p) for p in products_str]
|
||||
except Exception as e:
|
||||
print(f"Error converting nuclei for {my_rate_str}: {e}")
|
||||
return []
|
||||
|
||||
rates = library.get_rate_by_nuclei(r_nuc, p_nuc)
|
||||
|
||||
if rates:
|
||||
if not isinstance(rates, list):
|
||||
return [rates]
|
||||
return rates
|
||||
|
||||
r_nuc_names = sorted([str(n) for n in r_nuc])
|
||||
p_nuc_names = sorted([str(n) for n in p_nuc])
|
||||
|
||||
ignore_list = ['e-', 'e+', 'g', 'nu', 'anu']
|
||||
|
||||
matched_rates = []
|
||||
for rate in library.get_rates():
|
||||
lib_r_names = sorted([str(n) for n in rate.reactants if str(n) not in ignore_list])
|
||||
lib_p_names = sorted([str(n) for n in rate.products if str(n) not in ignore_list])
|
||||
|
||||
if r_nuc_names == lib_r_names and p_nuc_names == lib_p_names:
|
||||
matched_rates.append(rate)
|
||||
|
||||
return matched_rates
|
||||
|
||||
|
||||
def load_network_module(filepath):
|
||||
module_name = os.path.basename(filepath).replace(".py", "")
|
||||
if module_name in sys.modules:
|
||||
del sys.modules[module_name]
|
||||
|
||||
spec = importlib.util.spec_from_file_location(module_name, filepath)
|
||||
if spec is None:
|
||||
raise FileNotFoundError(f"Error: could not find module at {filepath}")
|
||||
|
||||
network_module = importlib.util.module_from_spec(spec)
|
||||
sys.modules[module_name] = network_module
|
||||
spec.loader.exec_module(network_module)
|
||||
return network_module
|
||||
|
||||
def main(args):
|
||||
tMax = 3600.0
|
||||
h = 0.01
|
||||
current_time = 180.0
|
||||
|
||||
XpXn = 7.17
|
||||
Xn = 1.0 / (1.0 + XpXn)
|
||||
Xp = 1.0 - Xn
|
||||
comp: Composition = buildCompositionFromMassFractions(["H-1", "n-1"], [Xp, Xn])
|
||||
|
||||
netIn = NetIn()
|
||||
netIn.composition = comp
|
||||
netIn.dt0 = 1e-12
|
||||
|
||||
if args.depth is not None:
|
||||
print(f"Initializing GridFire GraphEngine with restricted depth = {args.depth}")
|
||||
engine = GraphEngine(comp, args.depth)
|
||||
else:
|
||||
print("Initializing full-depth GridFire GraphEngine (Note: pynucastro may take a long time to run JIT, set NUMBA_DISABLE_JIT=1 as an eviromental variable to disable JIT, this makes per timestep time increase but may still be faster for large networks due to the lack of upfront compilation time)")
|
||||
engine = GraphEngine(comp)
|
||||
|
||||
blob = engine.constructStateBlob()
|
||||
solver_ctx = PointSolverContext(blob)
|
||||
solver_ctx.stdout_logging = False
|
||||
solver = PointSolver(engine)
|
||||
|
||||
gf_initial_Y = {}
|
||||
for sp in engine.getNetworkSpecies(solver_ctx.engine_ctx):
|
||||
if comp.contains(sp):
|
||||
gf_initial_Y[sp.name()] = comp.getMolarAbundance(sp)
|
||||
else:
|
||||
gf_initial_Y[sp.name()] = 0.0
|
||||
|
||||
gf_time = []
|
||||
gf_results = {}
|
||||
step_conditions = []
|
||||
|
||||
gf_start_time = time.time()
|
||||
|
||||
gf_current_time = current_time
|
||||
total_steps = int(np.ceil(np.log(tMax / current_time) / np.log(1 + h)))
|
||||
with tqdm(total=total_steps, desc="GridFire BBN", unit="step") as pbar:
|
||||
while gf_current_time < tMax:
|
||||
current_dt = h * gf_current_time
|
||||
next_time = gf_current_time + current_dt
|
||||
|
||||
burn_temp = (T9(gf_current_time) + T9(next_time)) / 2.0 * 1e9
|
||||
burn_density = (get_density(gf_current_time) + get_density(next_time)) / 2.0
|
||||
|
||||
netIn.temperature = burn_temp
|
||||
netIn.density = burn_density
|
||||
netIn.tMax = current_dt
|
||||
|
||||
netOut = solver.evaluate(solver_ctx, netIn)
|
||||
netIn.composition = netOut.composition
|
||||
|
||||
pbar.update(1)
|
||||
pbar.set_postfix(t=f"{gf_current_time:.2e}", T=f"{burn_temp:.2e}", rho=f"{burn_density:.2e}")
|
||||
|
||||
|
||||
step_conditions.append({
|
||||
"dt": current_dt,
|
||||
"T": burn_temp,
|
||||
"rho": burn_density,
|
||||
"t": gf_current_time
|
||||
})
|
||||
|
||||
gf_time.append(gf_current_time)
|
||||
|
||||
for sp in engine.getNetworkSpecies(solver_ctx.engine_ctx):
|
||||
name = sp.name()
|
||||
if name not in gf_results:
|
||||
gf_results[name] = []
|
||||
gf_results[name].append(netOut.composition.getMolarAbundance(sp))
|
||||
|
||||
gf_current_time += current_dt
|
||||
|
||||
gf_end_time = time.time()
|
||||
print(f"GridFire integration finished in {gf_end_time - gf_start_time:.4f} seconds.")
|
||||
|
||||
|
||||
print("Building Pynucastro BBN Network...")
|
||||
reaclib_library = pyna.ReacLibLibrary()
|
||||
|
||||
rate_names = [r.id().replace("e+","").replace("e-","").replace(", ", ",") for r in engine.getNetworkReactions(solver_ctx.engine_ctx)]
|
||||
|
||||
goodRates = []
|
||||
missingRates = []
|
||||
skipped_photo_rates = 0
|
||||
pyna_rate_mapping = {}
|
||||
import io
|
||||
import contextlib
|
||||
|
||||
for r_str in rate_names:
|
||||
pyna_rates_for_reaction = []
|
||||
|
||||
with contextlib.redirect_stdout(io.StringIO()), contextlib.redirect_stderr(io.StringIO()):
|
||||
try:
|
||||
res = reaclib_library.get_rate_by_name(r_str)
|
||||
if res is not None:
|
||||
if isinstance(res, list):
|
||||
pyna_rates_for_reaction.extend(res)
|
||||
else:
|
||||
pyna_rates_for_reaction.append(res)
|
||||
except:
|
||||
pass
|
||||
|
||||
if not pyna_rates_for_reaction:
|
||||
res_nuc = get_pyna_rate(r_str, reaclib_library)
|
||||
if res_nuc:
|
||||
if isinstance(res_nuc, list):
|
||||
pyna_rates_for_reaction.extend(res_nuc)
|
||||
else:
|
||||
pyna_rates_for_reaction.append(res_nuc)
|
||||
|
||||
if pyna_rates_for_reaction:
|
||||
pyna_rate_mapping[r_str] = pyna_rates_for_reaction
|
||||
for rate in pyna_rates_for_reaction:
|
||||
if args.filter_photo:
|
||||
is_photo_rate = any(str(r).lower() in ['g', 'gamma'] for r in rate.reactants)
|
||||
if is_photo_rate:
|
||||
skipped_photo_rates += 1
|
||||
continue
|
||||
|
||||
goodRates.append(rate)
|
||||
else:
|
||||
missingRates.append(r_str)
|
||||
|
||||
if missingRates:
|
||||
print(f"Warning: Could not map {len(missingRates)} rates to Pynucastro (likely absent from default ReacLib).")
|
||||
print(f"Missing sample: {missingRates[:10]}...")
|
||||
|
||||
if args.filter_photo:
|
||||
print(f"Info: Skipped {skipped_photo_rates} photodisintegration rates due to --filter-photo flag.")
|
||||
|
||||
print("--- Evaluating reaction rates over all temperatures ---")
|
||||
gf_rates_history = {}
|
||||
py_rates_history = {}
|
||||
gf_rate_labels = {}
|
||||
py_rate_labels = {}
|
||||
|
||||
for reaction in engine.getNetworkReactions(solver_ctx.engine_ctx):
|
||||
r_str = reaction.id().replace("e+","").replace("e-","").replace(", ", ",")
|
||||
gf_rates_history[r_str] = []
|
||||
py_rates_history[r_str] = []
|
||||
|
||||
try:
|
||||
gf_rate_labels[r_str] = reaction.sources()
|
||||
except AttributeError:
|
||||
try:
|
||||
gf_rate_labels[r_str] = reaction.sourceLabel()
|
||||
except AttributeError:
|
||||
gf_rate_labels[r_str] = "Unknown"
|
||||
|
||||
if r_str in pyna_rate_mapping:
|
||||
py_rate_labels[r_str] = [getattr(pr, 'label', 'Unknown') for pr in pyna_rate_mapping[r_str]]
|
||||
else:
|
||||
py_rate_labels[r_str] = []
|
||||
|
||||
for step in tqdm(step_conditions, desc="Calculating Rates", unit="step"):
|
||||
T9_val = step["T"] / 1e9
|
||||
T_K = step["T"]
|
||||
|
||||
for reaction in engine.getNetworkReactions(solver_ctx.engine_ctx):
|
||||
r_str = reaction.id().replace("e+","").replace("e-","").replace(", ", ",")
|
||||
|
||||
gf_rate_val = 0.0
|
||||
try:
|
||||
gf_rate_val = reaction.calculate_rate(T9_val, 0, [])
|
||||
except:
|
||||
try:
|
||||
gf_rate_val = reaction.calculate_rate(T9_val, 0, 0, 0, [], dict())
|
||||
except Exception as e:
|
||||
pass
|
||||
gf_rates_history[r_str].append(gf_rate_val)
|
||||
|
||||
py_rate_val = 0.0
|
||||
if r_str in pyna_rate_mapping:
|
||||
for pr in pyna_rate_mapping[r_str]:
|
||||
py_rate_val += pr.eval(T_K)
|
||||
py_rates_history[r_str].append(py_rate_val)
|
||||
|
||||
print("--- Rate Comparison Summary ---")
|
||||
threshold = 1e-4
|
||||
mismatches = {}
|
||||
for r_str in gf_rates_history:
|
||||
gf_arr = np.array(gf_rates_history[r_str])
|
||||
py_arr = np.array(py_rates_history[r_str])
|
||||
|
||||
with np.errstate(divide='ignore', invalid='ignore'):
|
||||
denom = np.where(py_arr != 0, py_arr, gf_arr)
|
||||
denom = np.where(denom == 0, 1e-30, denom)
|
||||
rel_diffs = np.abs(gf_arr - py_arr) / denom
|
||||
|
||||
max_diff = np.max(rel_diffs)
|
||||
if max_diff > threshold:
|
||||
max_idx = np.argmax(rel_diffs)
|
||||
mismatches[r_str] = {
|
||||
"max_diff": max_diff,
|
||||
"temp": step_conditions[max_idx]["T"],
|
||||
"gf_val": gf_arr[max_idx],
|
||||
"py_val": py_arr[max_idx]
|
||||
}
|
||||
|
||||
if mismatches:
|
||||
print(f"Found {len(mismatches)} rates with differences > {threshold:.2%}")
|
||||
for r_str, info in mismatches.items():
|
||||
gf_lbl = gf_rate_labels.get(r_str, 'Unknown')
|
||||
py_lbl = py_rate_labels.get(r_str, [])
|
||||
print(f"{r_str:20}: Max Diff = {info['max_diff']:.2%}, at T = {info['temp']:.2e} K")
|
||||
print(f" GF = {info['gf_val']:.4e} (Source: {gf_lbl})")
|
||||
print(f" Py = {info['py_val']:.4e} (Sources: {py_lbl})")
|
||||
else:
|
||||
print(f"All rates match within the {threshold:.2%} threshold across all temperatures.")
|
||||
print("-------------------------------")
|
||||
|
||||
pynet = pyna.PythonNetwork(rates=goodRates)
|
||||
network_file = "pynuc_bbn_network.py"
|
||||
pynet.write_network(network_file)
|
||||
|
||||
net = load_network_module(network_file)
|
||||
|
||||
mapping = {
|
||||
"H-1": ("p", "tab:blue"),
|
||||
"n-1": ("n", "tab:orange"),
|
||||
"He-4": ("he4", "tab:green"),
|
||||
"H-2": ("d", "tab:red"),
|
||||
"H-3": ("t", "tab:purple"),
|
||||
"He-3": ("he3", "tab:brown"),
|
||||
"Li-7": ("li7", "tab:pink"),
|
||||
"Be-7": ("be7", "tab:gray")
|
||||
}
|
||||
|
||||
Y0 = np.zeros(net.nnuc)
|
||||
|
||||
for i, nuc in enumerate(pynet.get_nuclei()):
|
||||
nuc_name = str(nuc)
|
||||
gf_name = None
|
||||
for gf, (py, _) in mapping.items():
|
||||
if py == nuc_name:
|
||||
gf_name = gf
|
||||
break
|
||||
|
||||
if not gf_name:
|
||||
match = re.match(r"([a-zA-Z]+)(\d+)", nuc_name)
|
||||
if match:
|
||||
gf_name = f"{match.group(1).capitalize()}-{match.group(2)}"
|
||||
|
||||
if gf_name and gf_name in gf_initial_Y:
|
||||
Y0[i] = gf_initial_Y[gf_name]
|
||||
|
||||
pyna_time = []
|
||||
|
||||
pyna_nuc_names = [str(n) for n in pynet.get_nuclei()]
|
||||
pyna_results = {nuc: [] for nuc in pyna_nuc_names}
|
||||
|
||||
pyna_start_time = time.time()
|
||||
|
||||
for step in tqdm(step_conditions, unit="step", desc="pynucastro Integration"):
|
||||
sol = scipy.integrate.solve_ivp(
|
||||
net.rhs,
|
||||
[0, step["dt"]],
|
||||
Y0,
|
||||
args=(step["rho"], step["T"]),
|
||||
method="Radau",
|
||||
jac=net.jacobian,
|
||||
rtol=1e-8,
|
||||
atol=1e-20
|
||||
)
|
||||
|
||||
Y0 = sol.y[:, -1]
|
||||
pyna_time.append(step["t"])
|
||||
|
||||
for j in range(net.nnuc):
|
||||
nuc_name = str(pynet.get_nuclei()[j])
|
||||
if nuc_name in pyna_results:
|
||||
pyna_results[nuc_name].append(Y0[j])
|
||||
|
||||
pyna_end_time = time.time()
|
||||
print(f"Pynucastro integration finished in {pyna_end_time - pyna_start_time:.4f} seconds.")
|
||||
|
||||
|
||||
export_data = {
|
||||
"metadata": {
|
||||
"tMax": tMax,
|
||||
"h": h,
|
||||
"initial_time": current_time,
|
||||
"initial_XpXn_ratio": XpXn,
|
||||
"initial_mass_fractions": {
|
||||
"Xp": Xp,
|
||||
"Xn": Xn
|
||||
},
|
||||
"execution_times_seconds": {
|
||||
"gridfire": gf_end_time - gf_start_time,
|
||||
"pynucastro": pyna_end_time - pyna_start_time
|
||||
},
|
||||
"missing_pynucastro_rates": missingRates,
|
||||
"skipped_photodisintegration_rates": skipped_photo_rates if args.filter_photo else 0,
|
||||
"rate_labels": {
|
||||
"gridfire": gf_rate_labels,
|
||||
"pynucastro": py_rate_labels
|
||||
}
|
||||
},
|
||||
"thermodynamic_conditions": step_conditions,
|
||||
"data": {
|
||||
"gridfire": {
|
||||
"time": gf_time,
|
||||
"molar_abundances": gf_results,
|
||||
"reaction_rates": gf_rates_history
|
||||
},
|
||||
"pynucastro": {
|
||||
"time": pyna_time,
|
||||
"molar_abundances": pyna_results,
|
||||
"reaction_rates": py_rates_history
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
json_out_file = "bbn_simulation_data.json"
|
||||
with open(json_out_file, "w") as f:
|
||||
json.dump(export_data, f, indent=4)
|
||||
plt.style.use("default")
|
||||
fig, ax = plt.subplots(figsize=(10, 7))
|
||||
|
||||
for gf_name, (pyna_name, color) in mapping.items():
|
||||
if gf_name in gf_results:
|
||||
ax.plot(gf_time, gf_results[gf_name], color=color, linestyle="-", linewidth=2.5, label=f"GF {gf_name}")
|
||||
if pyna_name in pyna_results:
|
||||
ax.plot(pyna_time, pyna_results[pyna_name], color=color, linestyle="--", linewidth=1.5, label=f"Pyna {pyna_name}")
|
||||
|
||||
ax.set_xscale("log")
|
||||
ax.set_yscale("log")
|
||||
ax.set_ylim(1e-12, 2)
|
||||
ax.set_xlabel("Time (s)", fontsize=14)
|
||||
ax.set_ylabel("Molar Abundance (Y)", fontsize=14)
|
||||
|
||||
line_gf = mlines.Line2D([], [], color='black', linestyle='-', linewidth=2.5, label='GridFire')
|
||||
line_py = mlines.Line2D([], [], color='black', linestyle='--', linewidth=1.5, label='Pynucastro')
|
||||
|
||||
sp_handles = []
|
||||
for gf_name, (pyna_name, color) in mapping.items():
|
||||
sp_handles.append(mlines.Line2D([], [], color=color, linestyle='-', linewidth=2, label=gf_name))
|
||||
|
||||
ax.legend(handles=[line_gf, line_py] + sp_handles, loc='center left', bbox_to_anchor=(1.02, 0.5), fontsize=12)
|
||||
|
||||
out_file = "bbn_comparison.pdf"
|
||||
plt.savefig(out_file)
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="GridFire vs Pynucastro BBN Comparison")
|
||||
parser.add_argument("--filter-photo", action="store_true",
|
||||
help="Filter out photodisintegration (reverse) rates to mimic GridFire's forward-only mechanics.")
|
||||
parser.add_argument("--depth", type=int, default=None,
|
||||
help="Limit the assembly depth of GridFire's GraphEngine. E.g., setting '--depth 3' shrinks the network size from 5000+ reactions to ~100, which reduces Pynucastro's Numba JIT compile time from hours to seconds.")
|
||||
args = parser.parse_args()
|
||||
|
||||
main(args)
|
||||
|
||||
Reference in New Issue
Block a user