test(vv): Added more scripts to verify GridFire behavior
This commit is contained in:
@@ -6,6 +6,7 @@ from gridfire.policy import MainSequencePolicy
|
||||
|
||||
from gridfire.engine import GraphEngine, MultiscalePartitioningEngineView, AdaptiveEngineView
|
||||
from gridfire.engine import NetworkBuildDepth
|
||||
from fourdst.composition.utils import buildCompositionFromMassFractions
|
||||
|
||||
from scipy.signal import find_peaks
|
||||
|
||||
@@ -77,11 +78,9 @@ def rescale_composition(comp_ref : Composition, ZZs : float, Y_primordial : floa
|
||||
return newComp
|
||||
|
||||
def init_composition(ZZs : float = 0) -> Composition:
|
||||
Y_solar = [7.0262E-01, 1.7479E-06, 6.8955E-02, 2.5000E-04, 7.8554E-05, 6.0144E-04, 8.1031E-05, 2.1513E-05]
|
||||
S = ["H-1", "He-3", "He-4", "C-12", "N-14", "O-16", "Ne-20", "Mg-24"]
|
||||
return rescale_composition(Composition(S, Y_solar), ZZs)
|
||||
|
||||
|
||||
X_GS98 = [0.73395, 0.00005, 0.2490, 0.00281, 0.00101, 0.00883, 0.00149, 0.00064, 0.00066, 0.00035, 0.00008, 0.00006, 0.00107]
|
||||
S_GS98 = ["H-1", "He-3", "He-4", "C-12", "N-14", "O-16", "Ne-20", "Mg-24", "Si-28", "S-32", "Ar-36", "Ca-40", "Fe-56"]
|
||||
return buildCompositionFromMassFractions(S_GS98, X_GS98)
|
||||
def init_netIn(temp: float, rho: float, time: float, comp: Composition) -> NetIn:
|
||||
n : NetIn = NetIn()
|
||||
n.temperature = temp
|
||||
@@ -154,185 +153,37 @@ def quantify_engine_error(df_base, df_approx, r_base: NetOut, r_approx: NetOut,
|
||||
|
||||
def main(save_show):
|
||||
C = init_composition()
|
||||
netIn = init_netIn(1.5e7, 1.6e2, years_to_seconds(10e9), C)
|
||||
|
||||
netIn = init_netIn(10**7.1760912591, 10**2.2041199827, 1e17, C)
|
||||
|
||||
stepLogger = StepLogger()
|
||||
|
||||
engine_graph = GraphEngine(C, 4)
|
||||
blob = engine_graph.constructStateBlob()
|
||||
print(f"Gridfire Using: {len(engine_graph.getNetworkReactions(blob))} Reactions and {len(engine_graph.getNetworkSpecies(blob))} Species")
|
||||
|
||||
solver_ctx_graph = PointSolverContext(engine_graph.constructStateBlob())
|
||||
solver_ctx_graph.stdout_logging = True
|
||||
solver_ctx_graph = PointSolverContext(blob)
|
||||
solver_ctx_graph.stdout_logging = False
|
||||
solver_ctx_graph.callback = lambda ctx: stepLogger.log_step(ctx)
|
||||
|
||||
solver_single = PointSolver(engine_graph)
|
||||
|
||||
r_graph = solver_single.evaluate(solver_ctx_graph, netIn, False, False)
|
||||
df_graph : pd.DataFrame = stepLogger.df
|
||||
df_graph.to_csv("bbq_graph.csv", index=False)
|
||||
stepLogger.reset()
|
||||
|
||||
QSE_engine = MultiscalePartitioningEngineView(engine_graph)
|
||||
solver_ctx_graph_qse = PointSolverContext(QSE_engine.constructStateBlob(engine_graph.constructStateBlob()))
|
||||
solver_ctx_graph_qse.stdout_logging = True
|
||||
solver_ctx_graph_qse.stdout_logging = False
|
||||
solver_ctx_graph_qse.callback = lambda ctx: stepLogger.log_step(ctx)
|
||||
|
||||
solver_QSE = PointSolver(QSE_engine)
|
||||
r_qse = solver_QSE.evaluate(solver_ctx_graph_qse, netIn, False, False)
|
||||
|
||||
df_qse = stepLogger.df
|
||||
df_qse : pd.DataFrame = stepLogger.df
|
||||
df_qse.to_csv("bbq_qse.csv", index=False)
|
||||
stepLogger.reset()
|
||||
|
||||
# policy = MainSequencePolicy(C)
|
||||
# construct = policy.construct()
|
||||
# solver_AE_QSE = PointSolver(construct.engine)
|
||||
# solver_ctx_graph_qse_ae = PointSolverContext(construct.scratch_blob)
|
||||
# solver_ctx_graph_qse_ae.callback = lambda ctx: stepLogger.log_step(ctx)
|
||||
# solver_ctx_graph_qse_ae.stdout_logging = False
|
||||
#
|
||||
# r_ae_qse = solver_AE_QSE.evaluate(solver_ctx_graph_qse_ae, netIn, False, False)
|
||||
#
|
||||
# df_ae_qse = stepLogger.df
|
||||
# stepLogger.reset()
|
||||
|
||||
# fig, axs = plt.subplots(2, 1, figsize=(10, 7))
|
||||
S = ["H-1", "He-4", "C-12", "N-14", "O-16", "Mg-24"]
|
||||
t = np.logspace(7, 17.5, 5000)
|
||||
# for spID, sp in enumerate(S):
|
||||
# gf = interp1d(df_graph.t, df_graph[sp])
|
||||
# qf = interp1d(df_qse.t, df_qse[sp])
|
||||
#
|
||||
# ax = axs[0]
|
||||
# ax.loglog(t, gf(t), 'o-', color=f"C{spID}")
|
||||
# ax.loglog(t, qf(t), 'o', color=f"C{spID}", linestyle='dashed')
|
||||
#
|
||||
# ax.text(1, df_graph[sp].iloc[0]*1.1, sp, fontsize=12, color=f"C{spID}")
|
||||
#
|
||||
# ax = axs[1]
|
||||
# ax.semilogx(t, (qf(t)-gf(t))/gf(t), color=f"C{spID}")
|
||||
#
|
||||
# axs[1].set_xlabel("Time [s]", fontsize=15)
|
||||
# axs[0].set_ylabel("Molar Abundance [mol/g]", fontsize=15)
|
||||
# axs[1].set_ylabel("Relative Error", fontsize=15)
|
||||
#
|
||||
# fig, ax = plt.subplots(1, 1, figsize=(10, 7))
|
||||
# ge = interp1d(df_graph.t, df_graph.eps)
|
||||
# qe = interp1d(df_qse.t, df_qse.eps)
|
||||
# ax.loglog(t, np.abs((qe(t) - ge(t)) / ge(t)))
|
||||
|
||||
temporal_err_qse, final_err_qse = quantify_engine_error(
|
||||
df_base=df_graph,
|
||||
df_approx=df_qse,
|
||||
r_base=r_graph,
|
||||
r_approx=r_qse,
|
||||
species_list=S
|
||||
)
|
||||
|
||||
qse_rel_eps_error = (df_graph.eps.iloc[-1] - df_qse.eps.iloc[-1])/df_qse.eps.iloc[-1]
|
||||
|
||||
|
||||
fig, ax = plt.subplots(1, 1, figsize=(10, 7))
|
||||
# ax.semilogx(df_graph.t, df_graph["H-2"], 'o-', color='red')
|
||||
# ax.semilogx(df_qse.t, df_qse["H-2"], 'o', color='blue', linestyle='dashed')
|
||||
|
||||
graph_h1 = interp1d(df_graph.t, df_graph["H-1"])
|
||||
qse_h1 = interp1d(df_qse.t, df_qse["H-1"])
|
||||
graph_h2 = interp1d(df_graph.t, df_graph["H-2"])
|
||||
qse_h2 = interp1d(df_qse.t, df_qse["H-2"])
|
||||
|
||||
graph_DH = graph_h2(t)/graph_h1(t)
|
||||
qse_DH = qse_h2(t)/qse_h1(t)
|
||||
|
||||
dex_diff = np.abs(np.log10(graph_h2(t)) - np.log10(qse_h2(t)))
|
||||
dex_dh_diff = np.abs(np.log10(graph_DH) - np.log10(qse_DH))
|
||||
# ax.semilogx(t, dex_diff, color='green')
|
||||
ax.loglog(t, dex_dh_diff, color='black')
|
||||
# ax.semilogx(t, qse_h2(t)/qse_h1(t), color='green')
|
||||
ax.set_xlabel("Time [s]", fontsize=17)
|
||||
ax.set_ylabel(r"$\left|\log_{10}\left(\frac{D}{H})\right)_{graph} - \log_{10}\left(\frac{D}{H}\right)_{qse}\right|$", fontsize=17)
|
||||
|
||||
if save_show == ShowSave.SAVE:
|
||||
plt.savefig("DHErr.pdf")
|
||||
plt.close()
|
||||
else:
|
||||
plt.show()
|
||||
|
||||
sums_qse = {}
|
||||
sums_graph = {}
|
||||
symbols = {}
|
||||
|
||||
for sp, y in r_qse.composition:
|
||||
z = sp.z()
|
||||
symbols[z] = sp.el()
|
||||
|
||||
y_graph = r_graph.composition.getMolarAbundance(sp)
|
||||
|
||||
sums_qse[int(z)] = sums_qse.get(z, 0.0) + y
|
||||
sums_graph[int(z)] = sums_graph.get(z, 0.0) + y_graph
|
||||
|
||||
print(sums_qse[3])
|
||||
print(sums_graph[3])
|
||||
|
||||
z_list = sorted(sums_qse.keys())
|
||||
dex_list = []
|
||||
|
||||
symbols = [val for key, val in symbols.items()]
|
||||
|
||||
for z in z_list:
|
||||
total_qse = sums_qse[z]
|
||||
total_graph = sums_graph[z]
|
||||
|
||||
if total_graph > 1e-13 and total_qse > 1e-13:
|
||||
offset = np.log10(total_qse / total_graph)
|
||||
else:
|
||||
if z >= 14:
|
||||
offset = np.nan # Disable these for visualization, they all have abundances so small (on the order of -100 it doesnt matter)
|
||||
else:
|
||||
offset = 0.0
|
||||
|
||||
dex_list.append(offset)
|
||||
fig, ax = plt.subplots(1, 1, figsize=(10, 7))
|
||||
data = sorted(zip(z_list, symbols, dex_list), key=lambda x: x[0])
|
||||
sorted_z, sorted_symbols, sorted_dex = zip(*data)
|
||||
print(sorted_symbols)
|
||||
print(sorted_dex)
|
||||
# 2. Create the plot
|
||||
fig, ax = plt.subplots(1, 1, figsize=(12, 6))
|
||||
print(sorted_symbols)
|
||||
bars = ax.bar(sorted_symbols, sorted_dex, color='grey', edgecolor='grey', alpha=0.8)
|
||||
|
||||
# 3. Add styling and labels
|
||||
ax.axhline(0, color='black', linewidth=0.8) # Adds a clear baseline at 0 dex
|
||||
ax.set_xlabel('Element', fontsize=25)
|
||||
ax.set_ylabel('Offset [dex]', fontsize=25)
|
||||
|
||||
if save_show == ShowSave.SAVE:
|
||||
plt.savefig("DexElementalOffset.pdf")
|
||||
plt.close()
|
||||
|
||||
|
||||
e_graph = interp1d(df_graph.t, df_graph.eps)
|
||||
e_qse = interp1d(df_qse.t, df_qse.eps)
|
||||
|
||||
dex_eps_diff = np.log10(e_graph(t)) - np.log10(e_qse(t))
|
||||
fig, ax = plt.subplots(1, 1, figsize=(10, 7))
|
||||
ax.semilogx(t, dex_eps_diff, color='black')
|
||||
ax.set_xlabel("Time [s]", fontsize=25)
|
||||
ax.set_xlabel("Offset [dex]", fontsize=25)
|
||||
|
||||
if save_show == ShowSave.SAVE:
|
||||
plt.savefig("DexEpsOffset.pdf")
|
||||
plt.close()
|
||||
|
||||
|
||||
if save_show == ShowSave.SHOW:
|
||||
plt.show()
|
||||
|
||||
print("=== QSE ===")
|
||||
print(temporal_err_qse)
|
||||
print(final_err_qse)
|
||||
print(f"Relative ε error: {qse_rel_eps_error}")
|
||||
|
||||
print(f"Neutrino Loss Difference [dex]: {np.log10(r_graph.specific_neutrino_energy_loss) - np.log10(r_qse.specific_neutrino_energy_loss)}")
|
||||
|
||||
if __name__ == "__main__":
|
||||
import argparse
|
||||
|
||||
BIN
validation/ManuscriptFigures/ErrorBudget/DHErr.pdf
Normal file
BIN
validation/ManuscriptFigures/ErrorBudget/DHErr.pdf
Normal file
Binary file not shown.
BIN
validation/ManuscriptFigures/ErrorBudget/DexElementalOffset.pdf
Normal file
BIN
validation/ManuscriptFigures/ErrorBudget/DexElementalOffset.pdf
Normal file
Binary file not shown.
BIN
validation/ManuscriptFigures/ErrorBudget/DexEpsOffset.pdf
Normal file
BIN
validation/ManuscriptFigures/ErrorBudget/DexEpsOffset.pdf
Normal file
Binary file not shown.
@@ -168,7 +168,7 @@ def main(save_show):
|
||||
solver_single = PointSolver(engine_graph)
|
||||
|
||||
r_graph = solver_single.evaluate(solver_ctx_graph, netIn, False, False)
|
||||
df_graph = stepLogger.df
|
||||
df_graph : pd.DataFrame = stepLogger.df
|
||||
stepLogger.reset()
|
||||
|
||||
QSE_engine = MultiscalePartitioningEngineView(engine_graph)
|
||||
|
||||
Binary file not shown.
@@ -8,8 +8,13 @@ import sys
|
||||
|
||||
from gridfire.solver import PointSolverTimestepContext
|
||||
from gridfire._gridfire.engine.scratchpads import StateBlob
|
||||
|
||||
from fourdst.composition import Composition
|
||||
import gridfire
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
|
||||
class LogEntries(Enum):
|
||||
Step = "Step"
|
||||
t = "t"
|
||||
@@ -17,6 +22,7 @@ class LogEntries(Enum):
|
||||
eps = "eps"
|
||||
Composition = "Composition"
|
||||
ReactionContributions = "ReactionContributions"
|
||||
MassFractions = "MassFractions"
|
||||
|
||||
|
||||
class StepLogger:
|
||||
@@ -24,17 +30,40 @@ class StepLogger:
|
||||
self.num_steps : int = 0
|
||||
self.steps : List[Dict[LogEntries, Any]] = []
|
||||
|
||||
# def log_step(self, ctx: PointSolverTimestepContext):
|
||||
# comp_data: Dict[str, SupportsFloat] = {}
|
||||
# for species in ctx.engine.getNetworkSpecies(ctx.state_ctx):
|
||||
# sid = ctx.engine.getSpeciesIndex(ctx.state_ctx, species)
|
||||
# comp_data[species.name()] = ctx.state[sid]
|
||||
# entry : Dict[LogEntries, Any] = {
|
||||
# LogEntries.Step: ctx.num_steps,
|
||||
# LogEntries.t: ctx.t,
|
||||
# LogEntries.dt: ctx.dt,
|
||||
# LogEntries.eps: ctx.state[-1],
|
||||
# LogEntries.Composition: comp_data,
|
||||
# }
|
||||
# self.steps.append(entry)
|
||||
# self.num_steps += 1
|
||||
|
||||
def log_step(self, ctx: PointSolverTimestepContext):
|
||||
comp_data: Dict[str, SupportsFloat] = {}
|
||||
for species in ctx.engine.getNetworkSpecies(ctx.state_ctx):
|
||||
sid = ctx.engine.getSpeciesIndex(ctx.state_ctx, species)
|
||||
comp_data[species.name()] = ctx.state[sid]
|
||||
full_comp = ctx.composition
|
||||
|
||||
comp_data: Dict[str, float] = {}
|
||||
mass_frac: Dict[str, float] = {}
|
||||
for species in full_comp.getRegisteredSpecies():
|
||||
comp_data[species.name()] = full_comp.getMolarAbundance(species)
|
||||
mass_frac[species.name()] = full_comp.getMassFraction(species)
|
||||
|
||||
rhs_calc = ctx.engine.getMostRecentRHSCalculation(ctx.state_ctx)
|
||||
instantaneous_eps = rhs_calc.energy if rhs_calc else 0.0
|
||||
|
||||
entry : Dict[LogEntries, Any] = {
|
||||
LogEntries.Step: ctx.num_steps,
|
||||
LogEntries.t: ctx.t,
|
||||
LogEntries.dt: ctx.dt,
|
||||
LogEntries.eps: ctx.state[-1],
|
||||
LogEntries.eps: instantaneous_eps,
|
||||
LogEntries.Composition: comp_data,
|
||||
LogEntries.MassFractions: mass_frac,
|
||||
}
|
||||
self.steps.append(entry)
|
||||
self.num_steps += 1
|
||||
@@ -47,6 +76,7 @@ class StepLogger:
|
||||
LogEntries.dt.value: step[LogEntries.dt],
|
||||
LogEntries.eps.value: step[LogEntries.eps],
|
||||
LogEntries.Composition.value: step[LogEntries.Composition],
|
||||
LogEntries.MassFractions.value: step[LogEntries.MassFractions],
|
||||
}
|
||||
for step in self.steps
|
||||
]
|
||||
@@ -67,6 +97,32 @@ class StepLogger:
|
||||
with open(filename, 'w') as f:
|
||||
json.dump(out_data, f, indent=4)
|
||||
|
||||
@property
|
||||
def t(self) -> np.ndarray:
|
||||
return np.array([step[LogEntries.t] for step in self.steps])
|
||||
|
||||
@property
|
||||
def df(self) -> pd.DataFrame:
|
||||
if not self.steps:
|
||||
return pd.DataFrame()
|
||||
|
||||
flat_data = []
|
||||
for step in self.steps:
|
||||
row = {
|
||||
"Step": step[LogEntries.Step],
|
||||
"t": step[LogEntries.t],
|
||||
"dt": step[LogEntries.dt],
|
||||
"eps": step[LogEntries.eps],
|
||||
}
|
||||
X_dict = {f"X_{sp}": x for sp, x in step[LogEntries.MassFractions].items()}
|
||||
row.update(step[LogEntries.Composition])
|
||||
row.update(X_dict)
|
||||
flat_data.append(row)
|
||||
|
||||
df = pd.DataFrame(flat_data)
|
||||
|
||||
df = df.ffill().fillna(0.0)
|
||||
return df
|
||||
|
||||
def summary(self) -> Dict[str, Any]:
|
||||
if not self.steps:
|
||||
@@ -78,3 +134,8 @@ class StepLogger:
|
||||
"FinalComposition": final_step[LogEntries.Composition],
|
||||
}
|
||||
return summary_data
|
||||
|
||||
def reset(self):
|
||||
self.num_steps = 0
|
||||
self.steps : List[Dict[LogEntries, Any]] = []
|
||||
|
||||
|
||||
@@ -0,0 +1,40 @@
|
||||
savefig.facecolor: auto
|
||||
savefig.edgecolor: auto
|
||||
savefig.format: pdf
|
||||
savefig.bbox: tight
|
||||
|
||||
xtick.minor.visible : True
|
||||
ytick.minor.visible : True
|
||||
|
||||
xtick.direction : in
|
||||
ytick.direction : in
|
||||
|
||||
xtick.top : True
|
||||
ytick.right : True
|
||||
|
||||
xtick.major.size : 8
|
||||
xtick.minor.size : 4
|
||||
|
||||
ytick.major.size : 8
|
||||
ytick.minor.size : 4
|
||||
|
||||
xtick.labelsize : 18
|
||||
ytick.labelsize : 19
|
||||
|
||||
font.size : 20
|
||||
|
||||
font.family : serif
|
||||
text.usetex : True
|
||||
lines.color: C0
|
||||
patch.edgecolor: black
|
||||
text.color: black
|
||||
axes.facecolor: white
|
||||
axes.edgecolor: black
|
||||
axes.labelcolor: black
|
||||
axes.prop_cycle: cycler('color', ['1f77b4', 'ff7f0e', '2ca02c', 'd62728', '9467bd', '8c564b', 'e377c2', '7f7f7f', 'bcbd22', '17becf'])
|
||||
xtick.color: black
|
||||
ytick.color: black
|
||||
grid.color: b0b0b0
|
||||
figure.facecolor: white
|
||||
figure.edgecolor: white
|
||||
savefig.dpi: figure
|
||||
|
||||
Reference in New Issue
Block a user