feat(GridFire): Added a number of python hooks

python hooks to make getting base composition more reliable; further, a number of small changes made to aid in my analysis in response to ref report 1
This commit is contained in:
2026-04-13 07:17:14 -04:00
parent 65297852e5
commit 84ff182717
44 changed files with 1676 additions and 2964 deletions

View File

@@ -0,0 +1,343 @@
import numpy as np
import pandas as pd
from IPython.core.pylabtools import figsize
from gridfire.solver import PointSolver, PointSolverContext
from gridfire.policy import MainSequencePolicy
from gridfire.engine import GraphEngine, MultiscalePartitioningEngineView, AdaptiveEngineView
from gridfire.engine import NetworkBuildDepth
from scipy.signal import find_peaks
from gridfire.config import GridFireConfig
from fourdst.composition import Composition
from scipy.integrate import trapezoid
from fourdst.composition import CanonicalComposition
from fourdst.atomic import Species
from gridfire.type import NetIn, NetOut
import matplotlib.pyplot as plt
## Note that my default style uses tex rendering. If you do not have tex installed
## simply comment out this line
plt.style.use("../utils/pub.mplstyle")
from scipy.interpolate import interp1d, CubicSpline
from enum import Enum
import sys
import os
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "../utils")))
from logger import StepLogger
class ShowSave(Enum):
SHOW="SHOW"
SAVE="SAVE"
def __str__(self):
return self.value
def rescale_composition(comp_ref : Composition, ZZs : float, Y_primordial : float = 0.248) -> Composition:
CC : CanonicalComposition = comp_ref.getCanonicalComposition()
dY_dZ = (CC.Y - Y_primordial) / CC.Z
Z_new = CC.Z * (10**ZZs)
Y_bulk_new = Y_primordial + (dY_dZ * Z_new)
X_new = 1.0 - Z_new - Y_bulk_new
if X_new < 0: raise ValueError(f"ZZs={ZZs} yields unphysical composition (X < 0)")
ratio_H = X_new / CC.X if CC.X > 0 else 0
ratio_He = Y_bulk_new / CC.Y if CC.Y > 0 else 0
ratio_Z = Z_new / CC.Z if CC.Z > 0 else 0
Y_new_list = []
newComp : Composition = Composition()
s: Species
for s in comp_ref.getRegisteredSpecies():
Xi_ref = comp_ref.getMassFraction(s)
if s.el() == "H":
Xi_new = Xi_ref * ratio_H
elif s.el() == "He":
Xi_new = Xi_ref * ratio_He
else:
Xi_new = Xi_ref * ratio_Z
Y = Xi_new / s.mass()
newComp.registerSpecies(s)
newComp.setMolarAbundance(s, Y)
return newComp
def init_composition(ZZs : float = 0) -> Composition:
Y_solar = [7.0262E-01, 1.7479E-06, 6.8955E-02, 2.5000E-04, 7.8554E-05, 6.0144E-04, 8.1031E-05, 2.1513E-05]
S = ["H-1", "He-3", "He-4", "C-12", "N-14", "O-16", "Ne-20", "Mg-24"]
return rescale_composition(Composition(S, Y_solar), ZZs)
def init_netIn(temp: float, rho: float, time: float, comp: Composition) -> NetIn:
n : NetIn = NetIn()
n.temperature = temp
n.density = rho
n.tMax = time
n.dt0 = 1e-12
n.composition = comp
return n
def years_to_seconds(years: float) -> float:
return years * 3.1536e7
def quantify_engine_error(df_base, df_approx, r_base: NetOut, r_approx: NetOut, species_list, floor_val=1e-30):
temporal_results = {}
final_state_results = {}
t_base = df_base['t'].values
tracking_cols = ['eps'] + species_list
for col in tracking_cols:
if col not in df_base.columns or col not in df_approx.columns:
continue
y_base = df_base[col].values
interpolator = interp1d(
df_approx['t'],
df_approx[col],
kind='linear',
bounds_error=False,
fill_value=(df_approx[col].iloc[0], df_approx[col].iloc[-1])
)
y_approx_interp = interpolator(t_base)
abs_diff = np.abs(y_approx_interp - y_base)
rel_diff = abs_diff / np.maximum(np.abs(y_base), floor_val)
l2_diff = np.sqrt(trapezoid(abs_diff**2, x=t_base))
l2_base = np.sqrt(trapezoid(y_base**2, x=t_base))
temporal_results[col] = {
'Max Rel Error (Temporal)': np.max(rel_diff),
'L2 Rel Error (Temporal)': l2_diff / max(l2_base, floor_val)
}
def calc_rel_err(val_approx, val_base):
return abs(val_approx - val_base) / max(abs(val_base), floor_val)
final_state_results['Energy'] = {
'Final Rel Error': calc_rel_err(r_approx.energy, r_base.energy)
}
final_state_results['Neutrino Loss'] = {
'Final Rel Error': calc_rel_err(r_approx.specific_neutrino_energy_loss, r_base.specific_neutrino_energy_loss)
}
for sp in species_list:
try:
val_base = r_base.composition[sp]
val_approx = r_approx.composition[sp]
final_state_results[f"Final {sp}"] = {
'Final Rel Error': calc_rel_err(val_approx, val_base)
}
except (KeyError, TypeError, AttributeError):
pass
return pd.DataFrame(temporal_results).T, pd.DataFrame(final_state_results).T
def main(save_show):
C = init_composition()
netIn = init_netIn(1.5e7, 1.6e2, years_to_seconds(10e9), C)
stepLogger = StepLogger()
engine_graph = GraphEngine(C, 4)
solver_ctx_graph = PointSolverContext(engine_graph.constructStateBlob())
solver_ctx_graph.stdout_logging = True
solver_ctx_graph.callback = lambda ctx: stepLogger.log_step(ctx)
solver_single = PointSolver(engine_graph)
r_graph = solver_single.evaluate(solver_ctx_graph, netIn, False, False)
df_graph : pd.DataFrame = stepLogger.df
stepLogger.reset()
QSE_engine = MultiscalePartitioningEngineView(engine_graph)
solver_ctx_graph_qse = PointSolverContext(QSE_engine.constructStateBlob(engine_graph.constructStateBlob()))
solver_ctx_graph_qse.stdout_logging = True
solver_ctx_graph_qse.callback = lambda ctx: stepLogger.log_step(ctx)
solver_QSE = PointSolver(QSE_engine)
r_qse = solver_QSE.evaluate(solver_ctx_graph_qse, netIn, False, False)
df_qse = stepLogger.df
stepLogger.reset()
# policy = MainSequencePolicy(C)
# construct = policy.construct()
# solver_AE_QSE = PointSolver(construct.engine)
# solver_ctx_graph_qse_ae = PointSolverContext(construct.scratch_blob)
# solver_ctx_graph_qse_ae.callback = lambda ctx: stepLogger.log_step(ctx)
# solver_ctx_graph_qse_ae.stdout_logging = False
#
# r_ae_qse = solver_AE_QSE.evaluate(solver_ctx_graph_qse_ae, netIn, False, False)
#
# df_ae_qse = stepLogger.df
# stepLogger.reset()
# fig, axs = plt.subplots(2, 1, figsize=(10, 7))
S = ["H-1", "He-4", "C-12", "N-14", "O-16", "Mg-24"]
t = np.logspace(7, 17.5, 5000)
# for spID, sp in enumerate(S):
# gf = interp1d(df_graph.t, df_graph[sp])
# qf = interp1d(df_qse.t, df_qse[sp])
#
# ax = axs[0]
# ax.loglog(t, gf(t), 'o-', color=f"C{spID}")
# ax.loglog(t, qf(t), 'o', color=f"C{spID}", linestyle='dashed')
#
# ax.text(1, df_graph[sp].iloc[0]*1.1, sp, fontsize=12, color=f"C{spID}")
#
# ax = axs[1]
# ax.semilogx(t, (qf(t)-gf(t))/gf(t), color=f"C{spID}")
#
# axs[1].set_xlabel("Time [s]", fontsize=15)
# axs[0].set_ylabel("Molar Abundance [mol/g]", fontsize=15)
# axs[1].set_ylabel("Relative Error", fontsize=15)
#
# fig, ax = plt.subplots(1, 1, figsize=(10, 7))
# ge = interp1d(df_graph.t, df_graph.eps)
# qe = interp1d(df_qse.t, df_qse.eps)
# ax.loglog(t, np.abs((qe(t) - ge(t)) / ge(t)))
temporal_err_qse, final_err_qse = quantify_engine_error(
df_base=df_graph,
df_approx=df_qse,
r_base=r_graph,
r_approx=r_qse,
species_list=S
)
qse_rel_eps_error = (df_graph.eps.iloc[-1] - df_qse.eps.iloc[-1])/df_qse.eps.iloc[-1]
fig, ax = plt.subplots(1, 1, figsize=(10, 7))
# ax.semilogx(df_graph.t, df_graph["H-2"], 'o-', color='red')
# ax.semilogx(df_qse.t, df_qse["H-2"], 'o', color='blue', linestyle='dashed')
graph_h1 = interp1d(df_graph.t, df_graph["H-1"])
qse_h1 = interp1d(df_qse.t, df_qse["H-1"])
graph_h2 = interp1d(df_graph.t, df_graph["H-2"])
qse_h2 = interp1d(df_qse.t, df_qse["H-2"])
graph_DH = graph_h2(t)/graph_h1(t)
qse_DH = qse_h2(t)/qse_h1(t)
dex_diff = np.abs(np.log10(graph_h2(t)) - np.log10(qse_h2(t)))
dex_dh_diff = np.abs(np.log10(graph_DH) - np.log10(qse_DH))
# ax.semilogx(t, dex_diff, color='green')
ax.loglog(t, dex_dh_diff, color='black')
# ax.semilogx(t, qse_h2(t)/qse_h1(t), color='green')
ax.set_xlabel("Time [s]", fontsize=17)
ax.set_ylabel(r"$\left|\log_{10}\left(\frac{D}{H})\right)_{graph} - \log_{10}\left(\frac{D}{H}\right)_{qse}\right|$", fontsize=17)
if save_show == ShowSave.SAVE:
plt.savefig("DHErr.pdf")
plt.close()
else:
plt.show()
sums_qse = {}
sums_graph = {}
symbols = {}
for sp, y in r_qse.composition:
z = sp.z()
symbols[z] = sp.el()
y_graph = r_graph.composition.getMolarAbundance(sp)
sums_qse[int(z)] = sums_qse.get(z, 0.0) + y
sums_graph[int(z)] = sums_graph.get(z, 0.0) + y_graph
print(sums_qse[3])
print(sums_graph[3])
z_list = sorted(sums_qse.keys())
dex_list = []
symbols = [val for key, val in symbols.items()]
for z in z_list:
total_qse = sums_qse[z]
total_graph = sums_graph[z]
if total_graph > 1e-13 and total_qse > 1e-13:
offset = np.log10(total_qse / total_graph)
else:
if z >= 14:
offset = np.nan # Disable these for visualization, they all have abundances so small (on the order of -100 it doesnt matter)
else:
offset = 0.0
dex_list.append(offset)
fig, ax = plt.subplots(1, 1, figsize=(10, 7))
data = sorted(zip(z_list, symbols, dex_list), key=lambda x: x[0])
sorted_z, sorted_symbols, sorted_dex = zip(*data)
print(sorted_symbols)
print(sorted_dex)
# 2. Create the plot
fig, ax = plt.subplots(1, 1, figsize=(12, 6))
print(sorted_symbols)
bars = ax.bar(sorted_symbols, sorted_dex, color='grey', edgecolor='grey', alpha=0.8)
# 3. Add styling and labels
ax.axhline(0, color='black', linewidth=0.8) # Adds a clear baseline at 0 dex
ax.set_xlabel('Element', fontsize=25)
ax.set_ylabel('Offset [dex]', fontsize=25)
if save_show == ShowSave.SAVE:
plt.savefig("DexElementalOffset.pdf")
plt.close()
e_graph = interp1d(df_graph.t, df_graph.eps)
e_qse = interp1d(df_qse.t, df_qse.eps)
dex_eps_diff = np.log10(e_graph(t)) - np.log10(e_qse(t))
fig, ax = plt.subplots(1, 1, figsize=(10, 7))
ax.semilogx(t, dex_eps_diff, color='black')
ax.set_xlabel("Time [s]", fontsize=25)
ax.set_xlabel("Offset [dex]", fontsize=25)
if save_show == ShowSave.SAVE:
plt.savefig("DexEpsOffset.pdf")
plt.close()
if save_show == ShowSave.SHOW:
plt.show()
print("=== QSE ===")
print(temporal_err_qse)
print(final_err_qse)
print(f"Relative ε error: {qse_rel_eps_error}")
print(f"Neutrino Loss Difference [dex]: {np.log10(r_graph.specific_neutrino_energy_loss) - np.log10(r_qse.specific_neutrino_energy_loss)}")
if __name__ == "__main__":
import argparse
app = argparse.ArgumentParser(prog="Derivative Smoothness", description="Generate of view plots of derivative smoothness")
app.add_argument("-s", type=ShowSave, default=ShowSave.SHOW, choices=list(ShowSave), help="Whether to show or save the generated plot")
args = app.parse_args()
main(args.s)

View File

@@ -0,0 +1,343 @@
import numpy as np
import pandas as pd
from IPython.core.pylabtools import figsize
from gridfire.solver import PointSolver, PointSolverContext
from gridfire.policy import MainSequencePolicy
from gridfire.engine import GraphEngine, MultiscalePartitioningEngineView, AdaptiveEngineView
from gridfire.engine import NetworkBuildDepth
from scipy.signal import find_peaks
from gridfire.config import GridFireConfig
from fourdst.composition import Composition
from scipy.integrate import trapezoid
from fourdst.composition import CanonicalComposition
from fourdst.atomic import Species
from gridfire.type import NetIn, NetOut
import matplotlib.pyplot as plt
## Note that my default style uses tex rendering. If you do not have tex installed
## simply comment out this line
plt.style.use("../utils/pub.mplstyle")
from scipy.interpolate import interp1d, CubicSpline
from enum import Enum
import sys
import os
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "../utils")))
from logger import StepLogger
class ShowSave(Enum):
SHOW="SHOW"
SAVE="SAVE"
def __str__(self):
return self.value
def rescale_composition(comp_ref : Composition, ZZs : float, Y_primordial : float = 0.248) -> Composition:
CC : CanonicalComposition = comp_ref.getCanonicalComposition()
dY_dZ = (CC.Y - Y_primordial) / CC.Z
Z_new = CC.Z * (10**ZZs)
Y_bulk_new = Y_primordial + (dY_dZ * Z_new)
X_new = 1.0 - Z_new - Y_bulk_new
if X_new < 0: raise ValueError(f"ZZs={ZZs} yields unphysical composition (X < 0)")
ratio_H = X_new / CC.X if CC.X > 0 else 0
ratio_He = Y_bulk_new / CC.Y if CC.Y > 0 else 0
ratio_Z = Z_new / CC.Z if CC.Z > 0 else 0
Y_new_list = []
newComp : Composition = Composition()
s: Species
for s in comp_ref.getRegisteredSpecies():
Xi_ref = comp_ref.getMassFraction(s)
if s.el() == "H":
Xi_new = Xi_ref * ratio_H
elif s.el() == "He":
Xi_new = Xi_ref * ratio_He
else:
Xi_new = Xi_ref * ratio_Z
Y = Xi_new / s.mass()
newComp.registerSpecies(s)
newComp.setMolarAbundance(s, Y)
return newComp
def init_composition(ZZs : float = 0) -> Composition:
Y_solar = [7.0262E-01, 1.7479E-06, 6.8955E-02, 2.5000E-04, 7.8554E-05, 6.0144E-04, 8.1031E-05, 2.1513E-05]
S = ["H-1", "He-3", "He-4", "C-12", "N-14", "O-16", "Ne-20", "Mg-24"]
return rescale_composition(Composition(S, Y_solar), ZZs)
def init_netIn(temp: float, rho: float, time: float, comp: Composition) -> NetIn:
n : NetIn = NetIn()
n.temperature = temp
n.density = rho
n.tMax = time
n.dt0 = 1e-12
n.composition = comp
return n
def years_to_seconds(years: float) -> float:
return years * 3.1536e7
def quantify_engine_error(df_base, df_approx, r_base: NetOut, r_approx: NetOut, species_list, floor_val=1e-30):
temporal_results = {}
final_state_results = {}
t_base = df_base['t'].values
tracking_cols = ['eps'] + species_list
for col in tracking_cols:
if col not in df_base.columns or col not in df_approx.columns:
continue
y_base = df_base[col].values
interpolator = interp1d(
df_approx['t'],
df_approx[col],
kind='linear',
bounds_error=False,
fill_value=(df_approx[col].iloc[0], df_approx[col].iloc[-1])
)
y_approx_interp = interpolator(t_base)
abs_diff = np.abs(y_approx_interp - y_base)
rel_diff = abs_diff / np.maximum(np.abs(y_base), floor_val)
l2_diff = np.sqrt(trapezoid(abs_diff**2, x=t_base))
l2_base = np.sqrt(trapezoid(y_base**2, x=t_base))
temporal_results[col] = {
'Max Rel Error (Temporal)': np.max(rel_diff),
'L2 Rel Error (Temporal)': l2_diff / max(l2_base, floor_val)
}
def calc_rel_err(val_approx, val_base):
return abs(val_approx - val_base) / max(abs(val_base), floor_val)
final_state_results['Energy'] = {
'Final Rel Error': calc_rel_err(r_approx.energy, r_base.energy)
}
final_state_results['Neutrino Loss'] = {
'Final Rel Error': calc_rel_err(r_approx.specific_neutrino_energy_loss, r_base.specific_neutrino_energy_loss)
}
for sp in species_list:
try:
val_base = r_base.composition[sp]
val_approx = r_approx.composition[sp]
final_state_results[f"Final {sp}"] = {
'Final Rel Error': calc_rel_err(val_approx, val_base)
}
except (KeyError, TypeError, AttributeError):
pass
return pd.DataFrame(temporal_results).T, pd.DataFrame(final_state_results).T
def main(save_show):
C = init_composition()
netIn = init_netIn(1.5e7, 1.6e2, years_to_seconds(10e9), C)
stepLogger = StepLogger()
engine_graph = GraphEngine(C, 4)
solver_ctx_graph = PointSolverContext(engine_graph.constructStateBlob())
solver_ctx_graph.stdout_logging = True
solver_ctx_graph.callback = lambda ctx: stepLogger.log_step(ctx)
solver_single = PointSolver(engine_graph)
r_graph = solver_single.evaluate(solver_ctx_graph, netIn, False, False)
df_graph = stepLogger.df
stepLogger.reset()
QSE_engine = MultiscalePartitioningEngineView(engine_graph)
solver_ctx_graph_qse = PointSolverContext(QSE_engine.constructStateBlob(engine_graph.constructStateBlob()))
solver_ctx_graph_qse.stdout_logging = True
solver_ctx_graph_qse.callback = lambda ctx: stepLogger.log_step(ctx)
solver_QSE = PointSolver(QSE_engine)
r_qse = solver_QSE.evaluate(solver_ctx_graph_qse, netIn, False, False)
df_qse = stepLogger.df
stepLogger.reset()
# policy = MainSequencePolicy(C)
# construct = policy.construct()
# solver_AE_QSE = PointSolver(construct.engine)
# solver_ctx_graph_qse_ae = PointSolverContext(construct.scratch_blob)
# solver_ctx_graph_qse_ae.callback = lambda ctx: stepLogger.log_step(ctx)
# solver_ctx_graph_qse_ae.stdout_logging = False
#
# r_ae_qse = solver_AE_QSE.evaluate(solver_ctx_graph_qse_ae, netIn, False, False)
#
# df_ae_qse = stepLogger.df
# stepLogger.reset()
# fig, axs = plt.subplots(2, 1, figsize=(10, 7))
S = ["H-1", "He-4", "C-12", "N-14", "O-16", "Mg-24"]
t = np.logspace(7, 17.5, 5000)
# for spID, sp in enumerate(S):
# gf = interp1d(df_graph.t, df_graph[sp])
# qf = interp1d(df_qse.t, df_qse[sp])
#
# ax = axs[0]
# ax.loglog(t, gf(t), 'o-', color=f"C{spID}")
# ax.loglog(t, qf(t), 'o', color=f"C{spID}", linestyle='dashed')
#
# ax.text(1, df_graph[sp].iloc[0]*1.1, sp, fontsize=12, color=f"C{spID}")
#
# ax = axs[1]
# ax.semilogx(t, (qf(t)-gf(t))/gf(t), color=f"C{spID}")
#
# axs[1].set_xlabel("Time [s]", fontsize=15)
# axs[0].set_ylabel("Molar Abundance [mol/g]", fontsize=15)
# axs[1].set_ylabel("Relative Error", fontsize=15)
#
# fig, ax = plt.subplots(1, 1, figsize=(10, 7))
# ge = interp1d(df_graph.t, df_graph.eps)
# qe = interp1d(df_qse.t, df_qse.eps)
# ax.loglog(t, np.abs((qe(t) - ge(t)) / ge(t)))
temporal_err_qse, final_err_qse = quantify_engine_error(
df_base=df_graph,
df_approx=df_qse,
r_base=r_graph,
r_approx=r_qse,
species_list=S
)
qse_rel_eps_error = (df_graph.eps.iloc[-1] - df_qse.eps.iloc[-1])/df_qse.eps.iloc[-1]
fig, ax = plt.subplots(1, 1, figsize=(10, 7))
# ax.semilogx(df_graph.t, df_graph["H-2"], 'o-', color='red')
# ax.semilogx(df_qse.t, df_qse["H-2"], 'o', color='blue', linestyle='dashed')
graph_h1 = interp1d(df_graph.t, df_graph["H-1"])
qse_h1 = interp1d(df_qse.t, df_qse["H-1"])
graph_h2 = interp1d(df_graph.t, df_graph["H-2"])
qse_h2 = interp1d(df_qse.t, df_qse["H-2"])
graph_DH = graph_h2(t)/graph_h1(t)
qse_DH = qse_h2(t)/qse_h1(t)
dex_diff = np.abs(np.log10(graph_h2(t)) - np.log10(qse_h2(t)))
dex_dh_diff = np.abs(np.log10(graph_DH) - np.log10(qse_DH))
# ax.semilogx(t, dex_diff, color='green')
ax.loglog(t, dex_dh_diff, color='black')
# ax.semilogx(t, qse_h2(t)/qse_h1(t), color='green')
ax.set_xlabel("Time [s]", fontsize=17)
ax.set_ylabel(r"$\left|\log_{10}\left(\frac{D}{H})\right)_{graph} - \log_{10}\left(\frac{D}{H}\right)_{qse}\right|$", fontsize=17)
if save_show == ShowSave.SAVE:
plt.savefig("DHErr.pdf")
plt.close()
else:
plt.show()
sums_qse = {}
sums_graph = {}
symbols = {}
for sp, y in r_qse.composition:
z = sp.z()
symbols[z] = sp.el()
y_graph = r_graph.composition.getMolarAbundance(sp)
sums_qse[int(z)] = sums_qse.get(z, 0.0) + y
sums_graph[int(z)] = sums_graph.get(z, 0.0) + y_graph
print(sums_qse[3])
print(sums_graph[3])
z_list = sorted(sums_qse.keys())
dex_list = []
symbols = [val for key, val in symbols.items()]
for z in z_list:
total_qse = sums_qse[z]
total_graph = sums_graph[z]
if total_graph > 1e-13 and total_qse > 1e-13:
offset = np.log10(total_qse / total_graph)
else:
if z >= 14:
offset = np.nan # Disable these for visualization, they all have abundances so small (on the order of -100 it doesnt matter)
else:
offset = 0.0
dex_list.append(offset)
fig, ax = plt.subplots(1, 1, figsize=(10, 7))
data = sorted(zip(z_list, symbols, dex_list), key=lambda x: x[0])
sorted_z, sorted_symbols, sorted_dex = zip(*data)
print(sorted_symbols)
print(sorted_dex)
# 2. Create the plot
fig, ax = plt.subplots(1, 1, figsize=(12, 6))
print(sorted_symbols)
bars = ax.bar(sorted_symbols, sorted_dex, color='grey', edgecolor='grey', alpha=0.8)
# 3. Add styling and labels
ax.axhline(0, color='black', linewidth=0.8) # Adds a clear baseline at 0 dex
ax.set_xlabel('Element', fontsize=25)
ax.set_ylabel('Offset [dex]', fontsize=25)
if save_show == ShowSave.SAVE:
plt.savefig("DexElementalOffset.pdf")
plt.close()
e_graph = interp1d(df_graph.t, df_graph.eps)
e_qse = interp1d(df_qse.t, df_qse.eps)
dex_eps_diff = np.log10(e_graph(t)) - np.log10(e_qse(t))
fig, ax = plt.subplots(1, 1, figsize=(10, 7))
ax.semilogx(t, dex_eps_diff, color='black')
ax.set_xlabel("Time [s]", fontsize=25)
ax.set_xlabel("Offset [dex]", fontsize=25)
if save_show == ShowSave.SAVE:
plt.savefig("DexEpsOffset.pdf")
plt.close()
if save_show == ShowSave.SHOW:
plt.show()
print("=== QSE ===")
print(temporal_err_qse)
print(final_err_qse)
print(f"Relative ε error: {qse_rel_eps_error}")
print(f"Neutrino Loss Difference [dex]: {np.log10(r_graph.specific_neutrino_energy_loss) - np.log10(r_qse.specific_neutrino_energy_loss)}")
if __name__ == "__main__":
import argparse
app = argparse.ArgumentParser(prog="Derivative Smoothness", description="Generate of view plots of derivative smoothness")
app.add_argument("-s", type=ShowSave, default=ShowSave.SHOW, choices=list(ShowSave), help="Whether to show or save the generated plot")
args = app.parse_args()
main(args.s)

View File

@@ -0,0 +1,227 @@
import numpy as np
from IPython.core.pylabtools import figsize
from gridfire.solver import PointSolver, PointSolverContext
from gridfire.policy import MainSequencePolicy
from scipy.signal import find_peaks
from gridfire.config import GridFireConfig
from fourdst.composition import Composition
from scipy.integrate import trapezoid
from fourdst.composition import CanonicalComposition
from fourdst.atomic import Species
from gridfire.type import NetIn
import matplotlib.pyplot as plt
## Note that my default style uses tex rendering. If you do not have tex installed
## simply comment out this line
plt.style.use("../utils/pub.mplstyle")
from scipy.interpolate import interp1d, CubicSpline
from enum import Enum
import sys
import os
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "../utils")))
from logger import StepLogger
class ShowSave(Enum):
SHOW="SHOW"
SAVE="SAVE"
def __str__(self):
return self.value
def rescale_composition(comp_ref : Composition, ZZs : float, Y_primordial : float = 0.248) -> Composition:
CC : CanonicalComposition = comp_ref.getCanonicalComposition()
dY_dZ = (CC.Y - Y_primordial) / CC.Z
Z_new = CC.Z * (10**ZZs)
Y_bulk_new = Y_primordial + (dY_dZ * Z_new)
X_new = 1.0 - Z_new - Y_bulk_new
if X_new < 0: raise ValueError(f"ZZs={ZZs} yields unphysical composition (X < 0)")
ratio_H = X_new / CC.X if CC.X > 0 else 0
ratio_He = Y_bulk_new / CC.Y if CC.Y > 0 else 0
ratio_Z = Z_new / CC.Z if CC.Z > 0 else 0
Y_new_list = []
newComp : Composition = Composition()
s: Species
for s in comp_ref.getRegisteredSpecies():
Xi_ref = comp_ref.getMassFraction(s)
if s.el() == "H":
Xi_new = Xi_ref * ratio_H
elif s.el() == "He":
Xi_new = Xi_ref * ratio_He
else:
Xi_new = Xi_ref * ratio_Z
Y = Xi_new / s.mass()
newComp.registerSpecies(s)
newComp.setMolarAbundance(s, Y)
return newComp
def init_composition(ZZs : float = 0) -> Composition:
Y_solar = [7.0262E-01, 9.7479E-06, 6.8955E-02, 2.5000E-04, 7.8554E-05, 6.0144E-04, 8.1031E-05, 2.1513E-05]
S = ["H-1", "He-3", "He-4", "C-12", "N-14", "O-16", "Ne-20", "Mg-24"]
return rescale_composition(Composition(S, Y_solar), ZZs)
def init_netIn(temp: float, rho: float, time: float, comp: Composition) -> NetIn:
n : NetIn = NetIn()
n.temperature = temp
n.density = rho
n.tMax = time
n.dt0 = 1e-12
n.composition = comp
return n
def years_to_seconds(years: float) -> float:
return years * 3.1536e7
def main(save_show):
C = init_composition()
netIn = init_netIn(1.5e7, 160, years_to_seconds(10e9), C)
policy = MainSequencePolicy(C)
construct = policy.construct()
# 3e-8 and 1e-24 are the default tolerances we adopt as testing indicates it works well for
# main sequence evolution. We encorage researchers to trial various relative and
# absolute thresholds
# config = GridFireConfig()
# config.solver.pointSolver.trigger.boundaryFlux.relativeThreshold = 3e-8
# config.solver.pointSolver.trigger.boundaryFlux.absoluteThreshold = 1e-24
# solver = PointSolver(construct.engine, config)
solver = PointSolver(construct.engine)
solver_ctx = PointSolverContext(construct.scratch_blob)
stepLogger = StepLogger()
solver_ctx.callback = lambda ctx: stepLogger.log_step(ctx);
solver.evaluate(solver_ctx, netIn, False, False)
df = stepLogger.df
fig, axs = plt.subplots(2, 1, figsize=(17, 10))
t = np.linspace(df.t.min(), df.t.max(), 1000)
# Note we are not plotting Ne-20 as its molar abundance is so close to N-14 that it makes it hard to
# distinguish that species
PlottingSpecies = ["H-1", "He-3", "He-4", "C-12", "N-14", "O-16", "Mg-24"]
stable_index = 10
for sp in PlottingSpecies:
x = df.t[stable_index:]
y = df[sp][stable_index:]
axs[0].loglog(x, y)
axs[1].semilogx(x, np.gradient(y, x))
axs[0].text(x.iloc[0], y.iloc[0]*1.1, sp, fontsize=12)
axs[0].set_ylabel("$Y$ [mol/g]", fontsize=23)
axs[1].set_ylabel(r"$\frac{dY}{dt}$ [mol/g/s]", fontsize=23)
axs[1].set_xlabel("Time [s]")
ax_eps = axs[0].twinx()
ax_deps = axs[1].twinx()
ax_eps.set_ylabel(r"$\epsilon$ [erg/g/s]", rotation=270, labelpad=25, fontsize=23)
ax_deps.set_ylabel(r"$\frac{d\epsilon}{dt}$ [erg/g/s$^2$]", rotation=270, labelpad=25, fontsize=23)
ax_eps.axvline(1.008e+15, color='grey', linestyle='dashed')
ax_deps.axvline(1.008e+15, color='grey', linestyle='dashed')
ax_eps.loglog(df.t[stable_index:], df.eps[stable_index:], color='red', linestyle='dashed')
ax_eps.text(df.t[stable_index:].iloc[0]*1.05, df.eps[stable_index:].iloc[0]*3, r"$\epsilon$", rotation=25, fontsize=20)
ax_deps.semilogx(df.t[stable_index:], np.gradient(df.eps[stable_index:], df.t[stable_index:]), color='red', linestyle='dashed')
if save_show == ShowSave.SHOW:
plt.show()
else:
plt.savefig("smoothness_plot.pdf")
plt.close()
t = df.t.values
eps = df.eps.values
# Use this plot to determine the index to test removal of
# fig, ax = plt.subplots(1, 1, figsize=(10, 7))
# ax.plot(np.gradient(eps, t))
# ax.grid()
# plt.show()
idx = 156
t1 = t
eps1 = eps
t2 = np.delete(t, idx)
eps2 = np.delete(eps, idx)
f_deps_1 = interp1d(t1, np.gradient(eps1, t1))
f_deps_2 = interp1d(t2, np.gradient(eps2, t2))
int_deps_1 = trapezoid(f_deps_1(t), t)
int_deps_2 = trapezoid(f_deps_2(t), t)
rel_err = (int_deps_1 - int_deps_2) / int_deps_2
print(f"Rel Error: {rel_err:+0.3E}")
window = 10
indices = np.arange(idx - window, idx + window + 1)
indices_no_gap = np.delete(indices, window)
clean_t = t[indices_no_gap]
clean_eps = eps[indices_no_gap]
spline = CubicSpline(clean_t, clean_eps)
eps_predicted = spline(t[idx])
eps_actual = eps[idx]
absolute_jump = np.abs(eps_actual - eps_predicted)
relative_jump = absolute_jump / eps_actual
print(f"Local Discontinuity at index {idx}: {relative_jump:.3%}")
E_actual = trapezoid(eps, t)
t_clean = np.delete(t, idx)
eps_clean_points = np.delete(eps, idx)
spline = CubicSpline(t_clean, eps_clean_points)
eps_smooth = np.copy(eps)
eps_smooth[idx] = spline(t[idx])
E_smooth = trapezoid(eps_smooth, t)
total_rel_error = (E_actual - E_smooth) / E_smooth
print(f"Total Relative Energy Error: {total_rel_error:+0.12E}")
if __name__ == "__main__":
import argparse
app = argparse.ArgumentParser(prog="Derivative Smoothness", description="Generate of view plots of derivative smoothness")
app.add_argument("-s", type=ShowSave, default=ShowSave.SHOW, choices=list(ShowSave), help="Whether to show or save the generated plot")
args = app.parse_args()
main(args.s)

View File

@@ -0,0 +1,80 @@
from enum import Enum
from typing import Dict, List, Any, SupportsFloat
import json
from datetime import datetime
import os
import sys
from gridfire.solver import PointSolverTimestepContext
from gridfire._gridfire.engine.scratchpads import StateBlob
import gridfire
class LogEntries(Enum):
Step = "Step"
t = "t"
dt = "dt"
eps = "eps"
Composition = "Composition"
ReactionContributions = "ReactionContributions"
class StepLogger:
def __init__(self):
self.num_steps : int = 0
self.steps : List[Dict[LogEntries, Any]] = []
def log_step(self, ctx: PointSolverTimestepContext):
comp_data: Dict[str, SupportsFloat] = {}
for species in ctx.engine.getNetworkSpecies(ctx.state_ctx):
sid = ctx.engine.getSpeciesIndex(ctx.state_ctx, species)
comp_data[species.name()] = ctx.state[sid]
entry : Dict[LogEntries, Any] = {
LogEntries.Step: ctx.num_steps,
LogEntries.t: ctx.t,
LogEntries.dt: ctx.dt,
LogEntries.eps: ctx.state[-1],
LogEntries.Composition: comp_data,
}
self.steps.append(entry)
self.num_steps += 1
def to_json(self, filename: str, **kwargs):
serializable_steps : List[Dict[str, Any]] = [
{
LogEntries.Step.value: step[LogEntries.Step],
LogEntries.t.value: step[LogEntries.t],
LogEntries.dt.value: step[LogEntries.dt],
LogEntries.eps.value: step[LogEntries.eps],
LogEntries.Composition.value: step[LogEntries.Composition],
}
for step in self.steps
]
out_data : Dict[str, Any] = {
"Metadata": {
"NumSteps": self.num_steps,
**kwargs,
"DateCreated": datetime.now().isoformat(),
"GridFireVersion": gridfire.__version__,
"Author": "Emily M. Boudreaux",
"OS": os.uname().sysname,
"ClangVersion": os.popen("clang --version").read().strip(),
"GccVersion": os.popen("gcc --version").read().strip(),
"PythonVersion": sys.version,
},
"Steps": serializable_steps
}
with open(filename, 'w') as f:
json.dump(out_data, f, indent=4)
def summary(self) -> Dict[str, Any]:
if not self.steps:
return {}
final_step = self.steps[-1]
summary_data : Dict[str, Any] = {
"TotalSteps": self.num_steps,
"FinalTime": final_step[LogEntries.t],
"FinalComposition": final_step[LogEntries.Composition],
}
return summary_data