test(vv): Added more scripts to verify GridFire behavior

This commit is contained in:
2026-04-13 07:18:08 -04:00
parent 84ff182717
commit c311e4afbd
16 changed files with 1968 additions and 181 deletions

View File

@@ -6,6 +6,7 @@ from gridfire.policy import MainSequencePolicy
from gridfire.engine import GraphEngine, MultiscalePartitioningEngineView, AdaptiveEngineView
from gridfire.engine import NetworkBuildDepth
from fourdst.composition.utils import buildCompositionFromMassFractions
from scipy.signal import find_peaks
@@ -77,11 +78,9 @@ def rescale_composition(comp_ref : Composition, ZZs : float, Y_primordial : floa
return newComp
def init_composition(ZZs : float = 0) -> Composition:
Y_solar = [7.0262E-01, 1.7479E-06, 6.8955E-02, 2.5000E-04, 7.8554E-05, 6.0144E-04, 8.1031E-05, 2.1513E-05]
S = ["H-1", "He-3", "He-4", "C-12", "N-14", "O-16", "Ne-20", "Mg-24"]
return rescale_composition(Composition(S, Y_solar), ZZs)
X_GS98 = [0.73395, 0.00005, 0.2490, 0.00281, 0.00101, 0.00883, 0.00149, 0.00064, 0.00066, 0.00035, 0.00008, 0.00006, 0.00107]
S_GS98 = ["H-1", "He-3", "He-4", "C-12", "N-14", "O-16", "Ne-20", "Mg-24", "Si-28", "S-32", "Ar-36", "Ca-40", "Fe-56"]
return buildCompositionFromMassFractions(S_GS98, X_GS98)
def init_netIn(temp: float, rho: float, time: float, comp: Composition) -> NetIn:
n : NetIn = NetIn()
n.temperature = temp
@@ -154,185 +153,37 @@ def quantify_engine_error(df_base, df_approx, r_base: NetOut, r_approx: NetOut,
def main(save_show):
C = init_composition()
netIn = init_netIn(1.5e7, 1.6e2, years_to_seconds(10e9), C)
netIn = init_netIn(10**7.1760912591, 10**2.2041199827, 1e17, C)
stepLogger = StepLogger()
engine_graph = GraphEngine(C, 4)
blob = engine_graph.constructStateBlob()
print(f"Gridfire Using: {len(engine_graph.getNetworkReactions(blob))} Reactions and {len(engine_graph.getNetworkSpecies(blob))} Species")
solver_ctx_graph = PointSolverContext(engine_graph.constructStateBlob())
solver_ctx_graph.stdout_logging = True
solver_ctx_graph = PointSolverContext(blob)
solver_ctx_graph.stdout_logging = False
solver_ctx_graph.callback = lambda ctx: stepLogger.log_step(ctx)
solver_single = PointSolver(engine_graph)
r_graph = solver_single.evaluate(solver_ctx_graph, netIn, False, False)
df_graph : pd.DataFrame = stepLogger.df
df_graph.to_csv("bbq_graph.csv", index=False)
stepLogger.reset()
QSE_engine = MultiscalePartitioningEngineView(engine_graph)
solver_ctx_graph_qse = PointSolverContext(QSE_engine.constructStateBlob(engine_graph.constructStateBlob()))
solver_ctx_graph_qse.stdout_logging = True
solver_ctx_graph_qse.stdout_logging = False
solver_ctx_graph_qse.callback = lambda ctx: stepLogger.log_step(ctx)
solver_QSE = PointSolver(QSE_engine)
r_qse = solver_QSE.evaluate(solver_ctx_graph_qse, netIn, False, False)
df_qse = stepLogger.df
df_qse : pd.DataFrame = stepLogger.df
df_qse.to_csv("bbq_qse.csv", index=False)
stepLogger.reset()
# policy = MainSequencePolicy(C)
# construct = policy.construct()
# solver_AE_QSE = PointSolver(construct.engine)
# solver_ctx_graph_qse_ae = PointSolverContext(construct.scratch_blob)
# solver_ctx_graph_qse_ae.callback = lambda ctx: stepLogger.log_step(ctx)
# solver_ctx_graph_qse_ae.stdout_logging = False
#
# r_ae_qse = solver_AE_QSE.evaluate(solver_ctx_graph_qse_ae, netIn, False, False)
#
# df_ae_qse = stepLogger.df
# stepLogger.reset()
# fig, axs = plt.subplots(2, 1, figsize=(10, 7))
S = ["H-1", "He-4", "C-12", "N-14", "O-16", "Mg-24"]
t = np.logspace(7, 17.5, 5000)
# for spID, sp in enumerate(S):
# gf = interp1d(df_graph.t, df_graph[sp])
# qf = interp1d(df_qse.t, df_qse[sp])
#
# ax = axs[0]
# ax.loglog(t, gf(t), 'o-', color=f"C{spID}")
# ax.loglog(t, qf(t), 'o', color=f"C{spID}", linestyle='dashed')
#
# ax.text(1, df_graph[sp].iloc[0]*1.1, sp, fontsize=12, color=f"C{spID}")
#
# ax = axs[1]
# ax.semilogx(t, (qf(t)-gf(t))/gf(t), color=f"C{spID}")
#
# axs[1].set_xlabel("Time [s]", fontsize=15)
# axs[0].set_ylabel("Molar Abundance [mol/g]", fontsize=15)
# axs[1].set_ylabel("Relative Error", fontsize=15)
#
# fig, ax = plt.subplots(1, 1, figsize=(10, 7))
# ge = interp1d(df_graph.t, df_graph.eps)
# qe = interp1d(df_qse.t, df_qse.eps)
# ax.loglog(t, np.abs((qe(t) - ge(t)) / ge(t)))
temporal_err_qse, final_err_qse = quantify_engine_error(
df_base=df_graph,
df_approx=df_qse,
r_base=r_graph,
r_approx=r_qse,
species_list=S
)
qse_rel_eps_error = (df_graph.eps.iloc[-1] - df_qse.eps.iloc[-1])/df_qse.eps.iloc[-1]
fig, ax = plt.subplots(1, 1, figsize=(10, 7))
# ax.semilogx(df_graph.t, df_graph["H-2"], 'o-', color='red')
# ax.semilogx(df_qse.t, df_qse["H-2"], 'o', color='blue', linestyle='dashed')
graph_h1 = interp1d(df_graph.t, df_graph["H-1"])
qse_h1 = interp1d(df_qse.t, df_qse["H-1"])
graph_h2 = interp1d(df_graph.t, df_graph["H-2"])
qse_h2 = interp1d(df_qse.t, df_qse["H-2"])
graph_DH = graph_h2(t)/graph_h1(t)
qse_DH = qse_h2(t)/qse_h1(t)
dex_diff = np.abs(np.log10(graph_h2(t)) - np.log10(qse_h2(t)))
dex_dh_diff = np.abs(np.log10(graph_DH) - np.log10(qse_DH))
# ax.semilogx(t, dex_diff, color='green')
ax.loglog(t, dex_dh_diff, color='black')
# ax.semilogx(t, qse_h2(t)/qse_h1(t), color='green')
ax.set_xlabel("Time [s]", fontsize=17)
ax.set_ylabel(r"$\left|\log_{10}\left(\frac{D}{H})\right)_{graph} - \log_{10}\left(\frac{D}{H}\right)_{qse}\right|$", fontsize=17)
if save_show == ShowSave.SAVE:
plt.savefig("DHErr.pdf")
plt.close()
else:
plt.show()
sums_qse = {}
sums_graph = {}
symbols = {}
for sp, y in r_qse.composition:
z = sp.z()
symbols[z] = sp.el()
y_graph = r_graph.composition.getMolarAbundance(sp)
sums_qse[int(z)] = sums_qse.get(z, 0.0) + y
sums_graph[int(z)] = sums_graph.get(z, 0.0) + y_graph
print(sums_qse[3])
print(sums_graph[3])
z_list = sorted(sums_qse.keys())
dex_list = []
symbols = [val for key, val in symbols.items()]
for z in z_list:
total_qse = sums_qse[z]
total_graph = sums_graph[z]
if total_graph > 1e-13 and total_qse > 1e-13:
offset = np.log10(total_qse / total_graph)
else:
if z >= 14:
offset = np.nan # Disable these for visualization, they all have abundances so small (on the order of -100 it doesnt matter)
else:
offset = 0.0
dex_list.append(offset)
fig, ax = plt.subplots(1, 1, figsize=(10, 7))
data = sorted(zip(z_list, symbols, dex_list), key=lambda x: x[0])
sorted_z, sorted_symbols, sorted_dex = zip(*data)
print(sorted_symbols)
print(sorted_dex)
# 2. Create the plot
fig, ax = plt.subplots(1, 1, figsize=(12, 6))
print(sorted_symbols)
bars = ax.bar(sorted_symbols, sorted_dex, color='grey', edgecolor='grey', alpha=0.8)
# 3. Add styling and labels
ax.axhline(0, color='black', linewidth=0.8) # Adds a clear baseline at 0 dex
ax.set_xlabel('Element', fontsize=25)
ax.set_ylabel('Offset [dex]', fontsize=25)
if save_show == ShowSave.SAVE:
plt.savefig("DexElementalOffset.pdf")
plt.close()
e_graph = interp1d(df_graph.t, df_graph.eps)
e_qse = interp1d(df_qse.t, df_qse.eps)
dex_eps_diff = np.log10(e_graph(t)) - np.log10(e_qse(t))
fig, ax = plt.subplots(1, 1, figsize=(10, 7))
ax.semilogx(t, dex_eps_diff, color='black')
ax.set_xlabel("Time [s]", fontsize=25)
ax.set_xlabel("Offset [dex]", fontsize=25)
if save_show == ShowSave.SAVE:
plt.savefig("DexEpsOffset.pdf")
plt.close()
if save_show == ShowSave.SHOW:
plt.show()
print("=== QSE ===")
print(temporal_err_qse)
print(final_err_qse)
print(f"Relative ε error: {qse_rel_eps_error}")
print(f"Neutrino Loss Difference [dex]: {np.log10(r_graph.specific_neutrino_energy_loss) - np.log10(r_qse.specific_neutrino_energy_loss)}")
if __name__ == "__main__":
import argparse

Binary file not shown.

View File

@@ -168,7 +168,7 @@ def main(save_show):
solver_single = PointSolver(engine_graph)
r_graph = solver_single.evaluate(solver_ctx_graph, netIn, False, False)
df_graph = stepLogger.df
df_graph : pd.DataFrame = stepLogger.df
stepLogger.reset()
QSE_engine = MultiscalePartitioningEngineView(engine_graph)

View File

@@ -8,8 +8,13 @@ import sys
from gridfire.solver import PointSolverTimestepContext
from gridfire._gridfire.engine.scratchpads import StateBlob
from fourdst.composition import Composition
import gridfire
import numpy as np
import pandas as pd
class LogEntries(Enum):
Step = "Step"
t = "t"
@@ -17,6 +22,7 @@ class LogEntries(Enum):
eps = "eps"
Composition = "Composition"
ReactionContributions = "ReactionContributions"
MassFractions = "MassFractions"
class StepLogger:
@@ -24,17 +30,40 @@ class StepLogger:
self.num_steps : int = 0
self.steps : List[Dict[LogEntries, Any]] = []
# def log_step(self, ctx: PointSolverTimestepContext):
# comp_data: Dict[str, SupportsFloat] = {}
# for species in ctx.engine.getNetworkSpecies(ctx.state_ctx):
# sid = ctx.engine.getSpeciesIndex(ctx.state_ctx, species)
# comp_data[species.name()] = ctx.state[sid]
# entry : Dict[LogEntries, Any] = {
# LogEntries.Step: ctx.num_steps,
# LogEntries.t: ctx.t,
# LogEntries.dt: ctx.dt,
# LogEntries.eps: ctx.state[-1],
# LogEntries.Composition: comp_data,
# }
# self.steps.append(entry)
# self.num_steps += 1
def log_step(self, ctx: PointSolverTimestepContext):
comp_data: Dict[str, SupportsFloat] = {}
for species in ctx.engine.getNetworkSpecies(ctx.state_ctx):
sid = ctx.engine.getSpeciesIndex(ctx.state_ctx, species)
comp_data[species.name()] = ctx.state[sid]
full_comp = ctx.composition
comp_data: Dict[str, float] = {}
mass_frac: Dict[str, float] = {}
for species in full_comp.getRegisteredSpecies():
comp_data[species.name()] = full_comp.getMolarAbundance(species)
mass_frac[species.name()] = full_comp.getMassFraction(species)
rhs_calc = ctx.engine.getMostRecentRHSCalculation(ctx.state_ctx)
instantaneous_eps = rhs_calc.energy if rhs_calc else 0.0
entry : Dict[LogEntries, Any] = {
LogEntries.Step: ctx.num_steps,
LogEntries.t: ctx.t,
LogEntries.dt: ctx.dt,
LogEntries.eps: ctx.state[-1],
LogEntries.eps: instantaneous_eps,
LogEntries.Composition: comp_data,
LogEntries.MassFractions: mass_frac,
}
self.steps.append(entry)
self.num_steps += 1
@@ -47,6 +76,7 @@ class StepLogger:
LogEntries.dt.value: step[LogEntries.dt],
LogEntries.eps.value: step[LogEntries.eps],
LogEntries.Composition.value: step[LogEntries.Composition],
LogEntries.MassFractions.value: step[LogEntries.MassFractions],
}
for step in self.steps
]
@@ -67,6 +97,32 @@ class StepLogger:
with open(filename, 'w') as f:
json.dump(out_data, f, indent=4)
@property
def t(self) -> np.ndarray:
return np.array([step[LogEntries.t] for step in self.steps])
@property
def df(self) -> pd.DataFrame:
if not self.steps:
return pd.DataFrame()
flat_data = []
for step in self.steps:
row = {
"Step": step[LogEntries.Step],
"t": step[LogEntries.t],
"dt": step[LogEntries.dt],
"eps": step[LogEntries.eps],
}
X_dict = {f"X_{sp}": x for sp, x in step[LogEntries.MassFractions].items()}
row.update(step[LogEntries.Composition])
row.update(X_dict)
flat_data.append(row)
df = pd.DataFrame(flat_data)
df = df.ffill().fillna(0.0)
return df
def summary(self) -> Dict[str, Any]:
if not self.steps:
@@ -78,3 +134,8 @@ class StepLogger:
"FinalComposition": final_step[LogEntries.Composition],
}
return summary_data
def reset(self):
self.num_steps = 0
self.steps : List[Dict[LogEntries, Any]] = []

View File

@@ -0,0 +1,40 @@
savefig.facecolor: auto
savefig.edgecolor: auto
savefig.format: pdf
savefig.bbox: tight
xtick.minor.visible : True
ytick.minor.visible : True
xtick.direction : in
ytick.direction : in
xtick.top : True
ytick.right : True
xtick.major.size : 8
xtick.minor.size : 4
ytick.major.size : 8
ytick.minor.size : 4
xtick.labelsize : 18
ytick.labelsize : 19
font.size : 20
font.family : serif
text.usetex : True
lines.color: C0
patch.edgecolor: black
text.color: black
axes.facecolor: white
axes.edgecolor: black
axes.labelcolor: black
axes.prop_cycle: cycler('color', ['1f77b4', 'ff7f0e', '2ca02c', 'd62728', '9467bd', '8c564b', 'e377c2', '7f7f7f', 'bcbd22', '17becf'])
xtick.color: black
ytick.color: black
grid.color: b0b0b0
figure.facecolor: white
figure.edgecolor: white
savefig.dpi: figure