test(vv): Added more scripts to verify GridFire behavior

This commit is contained in:
2026-04-13 07:18:08 -04:00
parent 84ff182717
commit c311e4afbd
16 changed files with 1968 additions and 181 deletions

View File

@@ -6,6 +6,7 @@ from gridfire.policy import MainSequencePolicy
from gridfire.engine import GraphEngine, MultiscalePartitioningEngineView, AdaptiveEngineView
from gridfire.engine import NetworkBuildDepth
from fourdst.composition.utils import buildCompositionFromMassFractions
from scipy.signal import find_peaks
@@ -77,11 +78,9 @@ def rescale_composition(comp_ref : Composition, ZZs : float, Y_primordial : floa
return newComp
def init_composition(ZZs : float = 0) -> Composition:
Y_solar = [7.0262E-01, 1.7479E-06, 6.8955E-02, 2.5000E-04, 7.8554E-05, 6.0144E-04, 8.1031E-05, 2.1513E-05]
S = ["H-1", "He-3", "He-4", "C-12", "N-14", "O-16", "Ne-20", "Mg-24"]
return rescale_composition(Composition(S, Y_solar), ZZs)
X_GS98 = [0.73395, 0.00005, 0.2490, 0.00281, 0.00101, 0.00883, 0.00149, 0.00064, 0.00066, 0.00035, 0.00008, 0.00006, 0.00107]
S_GS98 = ["H-1", "He-3", "He-4", "C-12", "N-14", "O-16", "Ne-20", "Mg-24", "Si-28", "S-32", "Ar-36", "Ca-40", "Fe-56"]
return buildCompositionFromMassFractions(S_GS98, X_GS98)
def init_netIn(temp: float, rho: float, time: float, comp: Composition) -> NetIn:
n : NetIn = NetIn()
n.temperature = temp
@@ -154,185 +153,37 @@ def quantify_engine_error(df_base, df_approx, r_base: NetOut, r_approx: NetOut,
def main(save_show):
C = init_composition()
netIn = init_netIn(1.5e7, 1.6e2, years_to_seconds(10e9), C)
netIn = init_netIn(10**7.1760912591, 10**2.2041199827, 1e17, C)
stepLogger = StepLogger()
engine_graph = GraphEngine(C, 4)
blob = engine_graph.constructStateBlob()
print(f"Gridfire Using: {len(engine_graph.getNetworkReactions(blob))} Reactions and {len(engine_graph.getNetworkSpecies(blob))} Species")
solver_ctx_graph = PointSolverContext(engine_graph.constructStateBlob())
solver_ctx_graph.stdout_logging = True
solver_ctx_graph = PointSolverContext(blob)
solver_ctx_graph.stdout_logging = False
solver_ctx_graph.callback = lambda ctx: stepLogger.log_step(ctx)
solver_single = PointSolver(engine_graph)
r_graph = solver_single.evaluate(solver_ctx_graph, netIn, False, False)
df_graph : pd.DataFrame = stepLogger.df
df_graph.to_csv("bbq_graph.csv", index=False)
stepLogger.reset()
QSE_engine = MultiscalePartitioningEngineView(engine_graph)
solver_ctx_graph_qse = PointSolverContext(QSE_engine.constructStateBlob(engine_graph.constructStateBlob()))
solver_ctx_graph_qse.stdout_logging = True
solver_ctx_graph_qse.stdout_logging = False
solver_ctx_graph_qse.callback = lambda ctx: stepLogger.log_step(ctx)
solver_QSE = PointSolver(QSE_engine)
r_qse = solver_QSE.evaluate(solver_ctx_graph_qse, netIn, False, False)
df_qse = stepLogger.df
df_qse : pd.DataFrame = stepLogger.df
df_qse.to_csv("bbq_qse.csv", index=False)
stepLogger.reset()
# policy = MainSequencePolicy(C)
# construct = policy.construct()
# solver_AE_QSE = PointSolver(construct.engine)
# solver_ctx_graph_qse_ae = PointSolverContext(construct.scratch_blob)
# solver_ctx_graph_qse_ae.callback = lambda ctx: stepLogger.log_step(ctx)
# solver_ctx_graph_qse_ae.stdout_logging = False
#
# r_ae_qse = solver_AE_QSE.evaluate(solver_ctx_graph_qse_ae, netIn, False, False)
#
# df_ae_qse = stepLogger.df
# stepLogger.reset()
# fig, axs = plt.subplots(2, 1, figsize=(10, 7))
S = ["H-1", "He-4", "C-12", "N-14", "O-16", "Mg-24"]
t = np.logspace(7, 17.5, 5000)
# for spID, sp in enumerate(S):
# gf = interp1d(df_graph.t, df_graph[sp])
# qf = interp1d(df_qse.t, df_qse[sp])
#
# ax = axs[0]
# ax.loglog(t, gf(t), 'o-', color=f"C{spID}")
# ax.loglog(t, qf(t), 'o', color=f"C{spID}", linestyle='dashed')
#
# ax.text(1, df_graph[sp].iloc[0]*1.1, sp, fontsize=12, color=f"C{spID}")
#
# ax = axs[1]
# ax.semilogx(t, (qf(t)-gf(t))/gf(t), color=f"C{spID}")
#
# axs[1].set_xlabel("Time [s]", fontsize=15)
# axs[0].set_ylabel("Molar Abundance [mol/g]", fontsize=15)
# axs[1].set_ylabel("Relative Error", fontsize=15)
#
# fig, ax = plt.subplots(1, 1, figsize=(10, 7))
# ge = interp1d(df_graph.t, df_graph.eps)
# qe = interp1d(df_qse.t, df_qse.eps)
# ax.loglog(t, np.abs((qe(t) - ge(t)) / ge(t)))
temporal_err_qse, final_err_qse = quantify_engine_error(
df_base=df_graph,
df_approx=df_qse,
r_base=r_graph,
r_approx=r_qse,
species_list=S
)
qse_rel_eps_error = (df_graph.eps.iloc[-1] - df_qse.eps.iloc[-1])/df_qse.eps.iloc[-1]
fig, ax = plt.subplots(1, 1, figsize=(10, 7))
# ax.semilogx(df_graph.t, df_graph["H-2"], 'o-', color='red')
# ax.semilogx(df_qse.t, df_qse["H-2"], 'o', color='blue', linestyle='dashed')
graph_h1 = interp1d(df_graph.t, df_graph["H-1"])
qse_h1 = interp1d(df_qse.t, df_qse["H-1"])
graph_h2 = interp1d(df_graph.t, df_graph["H-2"])
qse_h2 = interp1d(df_qse.t, df_qse["H-2"])
graph_DH = graph_h2(t)/graph_h1(t)
qse_DH = qse_h2(t)/qse_h1(t)
dex_diff = np.abs(np.log10(graph_h2(t)) - np.log10(qse_h2(t)))
dex_dh_diff = np.abs(np.log10(graph_DH) - np.log10(qse_DH))
# ax.semilogx(t, dex_diff, color='green')
ax.loglog(t, dex_dh_diff, color='black')
# ax.semilogx(t, qse_h2(t)/qse_h1(t), color='green')
ax.set_xlabel("Time [s]", fontsize=17)
ax.set_ylabel(r"$\left|\log_{10}\left(\frac{D}{H})\right)_{graph} - \log_{10}\left(\frac{D}{H}\right)_{qse}\right|$", fontsize=17)
if save_show == ShowSave.SAVE:
plt.savefig("DHErr.pdf")
plt.close()
else:
plt.show()
sums_qse = {}
sums_graph = {}
symbols = {}
for sp, y in r_qse.composition:
z = sp.z()
symbols[z] = sp.el()
y_graph = r_graph.composition.getMolarAbundance(sp)
sums_qse[int(z)] = sums_qse.get(z, 0.0) + y
sums_graph[int(z)] = sums_graph.get(z, 0.0) + y_graph
print(sums_qse[3])
print(sums_graph[3])
z_list = sorted(sums_qse.keys())
dex_list = []
symbols = [val for key, val in symbols.items()]
for z in z_list:
total_qse = sums_qse[z]
total_graph = sums_graph[z]
if total_graph > 1e-13 and total_qse > 1e-13:
offset = np.log10(total_qse / total_graph)
else:
if z >= 14:
offset = np.nan # Disable these for visualization, they all have abundances so small (on the order of -100 it doesnt matter)
else:
offset = 0.0
dex_list.append(offset)
fig, ax = plt.subplots(1, 1, figsize=(10, 7))
data = sorted(zip(z_list, symbols, dex_list), key=lambda x: x[0])
sorted_z, sorted_symbols, sorted_dex = zip(*data)
print(sorted_symbols)
print(sorted_dex)
# 2. Create the plot
fig, ax = plt.subplots(1, 1, figsize=(12, 6))
print(sorted_symbols)
bars = ax.bar(sorted_symbols, sorted_dex, color='grey', edgecolor='grey', alpha=0.8)
# 3. Add styling and labels
ax.axhline(0, color='black', linewidth=0.8) # Adds a clear baseline at 0 dex
ax.set_xlabel('Element', fontsize=25)
ax.set_ylabel('Offset [dex]', fontsize=25)
if save_show == ShowSave.SAVE:
plt.savefig("DexElementalOffset.pdf")
plt.close()
e_graph = interp1d(df_graph.t, df_graph.eps)
e_qse = interp1d(df_qse.t, df_qse.eps)
dex_eps_diff = np.log10(e_graph(t)) - np.log10(e_qse(t))
fig, ax = plt.subplots(1, 1, figsize=(10, 7))
ax.semilogx(t, dex_eps_diff, color='black')
ax.set_xlabel("Time [s]", fontsize=25)
ax.set_xlabel("Offset [dex]", fontsize=25)
if save_show == ShowSave.SAVE:
plt.savefig("DexEpsOffset.pdf")
plt.close()
if save_show == ShowSave.SHOW:
plt.show()
print("=== QSE ===")
print(temporal_err_qse)
print(final_err_qse)
print(f"Relative ε error: {qse_rel_eps_error}")
print(f"Neutrino Loss Difference [dex]: {np.log10(r_graph.specific_neutrino_energy_loss) - np.log10(r_qse.specific_neutrino_energy_loss)}")
if __name__ == "__main__":
import argparse