test(validation): Added validation suite

Added the framework and some basic tests for a validation suite which
automatically tests against pynucastro results
This commit is contained in:
2025-11-27 10:05:51 -05:00
parent b7f8724e13
commit ef53575c0d
4 changed files with 476 additions and 0 deletions

View File

@@ -0,0 +1,102 @@
from gridfire.policy import MainSequencePolicy, NetworkPolicy
from gridfire.engine import DynamicEngine, GraphEngine
from gridfire.type import NetIn
from fourdst.composition import Composition
from testsuite import TestSuite
from utils import init_netIn, init_composition, years_to_seconds
from enum import Enum
class SolarLikeStar_QSE_Suite(TestSuite):
def __init__(self):
initialComposition : Composition = init_composition()
super().__init__(
name="SolarLikeStar_QSE",
description="GridFire simulation of a roughly solar like star over 10Gyr with QSE enabled.",
temp=1.5e7,
density=1.5e2,
tMax=years_to_seconds(1e10),
composition=initialComposition,
notes="Thermodynamically Static, MultiscalePartitioning Engine View"
)
def __call__(self):
policy : MainSequencePolicy = MainSequencePolicy(self.composition)
engine : DynamicEngine = policy.construct()
netIn : NetIn = init_netIn(self.temperature, self.density, self.tMax, self.composition)
self.evolve(engine, netIn)
class MetalEnhancedSolarLikeStar_QSE_Suite(TestSuite):
def __init__(self):
initialComposition : Composition = init_composition(ZZs=1)
super().__init__(
name="MetalEnhancedSolarLikeStar_QSE",
description="GridFire simulation of a star with solar core temp and density but enhanced by 1 dex in Z.",
temp=0.8 * 1.5e7,
density=1.5e2,
tMax=years_to_seconds(1e10),
composition=initialComposition,
notes="Thermodynamically Static, MultiscalePartitioning Engine View, Z enhanced by 1 dex, temperature reduced to 80% of solar core"
)
def __call__(self):
policy : MainSequencePolicy = MainSequencePolicy(self.composition)
engine : GraphEngine = policy.construct()
netIn : NetIn = init_netIn(self.temperature, self.density, self.tMax, self.composition)
self.evolve(engine, netIn)
class MetalDepletedSolarLikeStar_QSE_Suite(TestSuite):
def __init__(self):
initialComposition : Composition = init_composition(ZZs=-1)
super().__init__(
name="MetalDepletedSolarLikeStar_QSE",
description="GridFire simulation of a star with solar core temp and density but depleted by 1 dex in Z.",
temp=1.2 * 1.5e7,
density=1.5e2,
tMax=years_to_seconds(1e10),
composition=initialComposition,
notes="Thermodynamically Static, MultiscalePartitioning Engine View, Z depleted by 1 dex, temperature increased to 120% of solar core"
)
def __call__(self):
policy : MainSequencePolicy = MainSequencePolicy(self.composition)
engine : GraphEngine = policy.construct()
netIn : NetIn = init_netIn(self.temperature, self.density, self.tMax, self.composition)
self.evolve(engine, netIn)
class SolarLikeStar_No_QSE_Suite(TestSuite):
def __init__(self):
initialComposition : Composition = init_composition()
super().__init__(
name="SolarLikeStar_No_QSE",
description="GridFire simulation of a roughly solar like star over 10Gyr with QSE disabled.",
temp=1.5e7,
density=1.5e2,
tMax=years_to_seconds(1e10),
composition=initialComposition,
notes="Thermodynamically Static, No MultiscalePartitioning Engine View"
)
def __call__(self):
engine : GraphEngine = GraphEngine(self.composition, 3)
netIn : NetIn = init_netIn(self.temperature, self.density, self.tMax, self.composition)
self.evolve(engine, netIn)
class ValidationSuites(Enum):
SolarLikeStar_QSE = SolarLikeStar_QSE_Suite
SolarLikeStar_No_QSE = SolarLikeStar_No_QSE_Suite
MetalDepletedSolarLikeStar_QSE = MetalDepletedSolarLikeStar_QSE_Suite
MetalEnhancedSolarLikeStar_QSE = MetalEnhancedSolarLikeStar_QSE_Suite
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser(description="Run some subset of the GridFire validation suite.")
parser.add_argument('--suite', type=str, choices=[suite.name for suite in ValidationSuites], nargs="+", help="The validation suite to run.")
args = parser.parse_args()
for suite_name in args.suite:
suite = ValidationSuites[suite_name]
instance : TestSuite = suite.value()
instance()

77
validation/vv/logger.py Normal file
View File

@@ -0,0 +1,77 @@
from enum import Enum
from typing import Dict, List, Any, SupportsFloat
import json
from datetime import datetime
import os
import sys
from gridfire.solver import CVODETimestepContext
import gridfire
class LogEntries(Enum):
Step = "Step"
t = "t"
dt = "dt"
eps = "eps"
Composition = "Composition"
ReactionContributions = "ReactionContributions"
class StepLogger:
def __init__(self):
self.num_steps : int = 0
self.steps : List[Dict[LogEntries, Any]] = []
def log_step(self, ctx : CVODETimestepContext):
comp_data: Dict[str, SupportsFloat] = {}
for species in ctx.engine.getNetworkSpecies():
sid = ctx.engine.getSpeciesIndex(species)
comp_data[species.name()] = ctx.state[sid]
entry : Dict[LogEntries, Any] = {
LogEntries.Step: ctx.num_steps,
LogEntries.t: ctx.t,
LogEntries.dt: ctx.dt,
LogEntries.Composition: comp_data,
}
self.steps.append(entry)
self.num_steps += 1
def to_json(self, filename: str, **kwargs):
serializable_steps : List[Dict[str, Any]] = [
{
LogEntries.Step.value: step[LogEntries.Step],
LogEntries.t.value: step[LogEntries.t],
LogEntries.dt.value: step[LogEntries.dt],
LogEntries.Composition.value: step[LogEntries.Composition],
}
for step in self.steps
]
out_data : Dict[str, Any] = {
"Metadata": {
"NumSteps": self.num_steps,
**kwargs,
"DateCreated": datetime.now().isoformat(),
"GridFireVersion": gridfire.__version__,
"Author": "Emily M. Boudreaux",
"OS": os.uname().sysname,
"ClangVersion": os.popen("clang --version").read().strip(),
"GccVersion": os.popen("gcc --version").read().strip(),
"PythonVersion": sys.version,
},
"Steps": serializable_steps
}
with open(filename, 'w') as f:
json.dump(out_data, f, indent=4)
def summary(self) -> Dict[str, Any]:
if not self.steps:
return {}
final_step = self.steps[-1]
summary_data : Dict[str, Any] = {
"TotalSteps": self.num_steps,
"FinalTime": final_step[LogEntries.t],
"FinalComposition": final_step[LogEntries.Composition],
}
return summary_data

241
validation/vv/testsuite.py Normal file
View File

@@ -0,0 +1,241 @@
from abc import ABC, abstractmethod
import fourdst.atomic
import scipy.integrate
from fourdst.composition import Composition
from gridfire.engine import DynamicEngine, GraphEngine
from gridfire.type import NetIn, NetOut
from gridfire.exceptions import GridFireError
from gridfire.solver import CVODESolverStrategy
from logger import StepLogger
from typing import List
import re
from typing import Dict, Tuple, Any, Union
from datetime import datetime
import pynucastro as pyna
import os
import importlib.util
import sys
import numpy as np
import json
import time
def load_network_module(filepath):
module_name = os.path.basename(filepath).replace(".py", "")
if module_name in sys.modules: # clear any existing module with the same name
del sys.modules[module_name]
spec = importlib.util.spec_from_file_location(module_name, filepath)
if spec is None:
raise FileNotFoundError(f"Could not find module at {filepath}")
network_module = importlib.util.module_from_spec(spec)
sys.modules[module_name] = network_module
spec.loader.exec_module(network_module)
return network_module
def get_pyna_rate(my_rate_str, library):
match = re.match(r"([a-zA-Z0-9]+)\(([^,]+),([^)]*)\)(.*)", my_rate_str)
if not match:
print(f"Could not parse string format: {my_rate_str}")
return None
target = match.group(1)
projectile = match.group(2)
ejectiles = match.group(3)
product = match.group(4)
def expand_species(s_str):
if not s_str or s_str.strip() == "":
return []
# Split by space (handling "p a" or "2p a")
parts = s_str.split()
expanded = []
for p in parts:
# Check for multipliers like 2p, 3a
mult_match = re.match(r"(\d+)([a-zA-Z0-9]+)", p)
if mult_match:
count = int(mult_match.group(1))
spec = mult_match.group(2)
# Map common aliases if necessary (though pyna handles most)
if spec == 'a': spec = 'he4'
expanded.extend([spec] * count)
else:
spec = p
if spec == 'a': spec = 'he4'
expanded.append(spec)
return expanded
reactants_str = [target] + expand_species(projectile)
products_str = expand_species(ejectiles) + [product]
# Convert strings to pyna.Nucleus objects
try:
r_nuc = [pyna.Nucleus(r) for r in reactants_str]
p_nuc = [pyna.Nucleus(p) for p in products_str]
except Exception as e:
print(f"Error converting nuclei for {my_rate_str}: {e}")
return None
rates = library.get_rate_by_nuclei(r_nuc, p_nuc)
if rates:
if isinstance(rates, list):
return rates[0] # Return the first match
return rates
else:
return None
class TestSuite(ABC):
def __init__(self, name: str, description: str, temp: float, density: float, tMax: float, composition: Composition, notes: str = ""):
self.name : str = name
self.description : str = description
self.temperature : float = temp
self.density : float = density
self.tMax : float = tMax
self.composition : Composition = composition
self.notes : str = notes
def evolve_pynucastro(self, engine: GraphEngine):
print("Evolution complete. Now building equivalent pynucastro network...")
# Build equivalent pynucastro network for comparison
reaclib_library : pyna.ReacLibLibrary = pyna.ReacLibLibrary()
rate_names = [r.id().replace("e+","").replace("e-","").replace(", ", ",") for r in engine.getNetworkReactions()]
goodRates : List[pyna.rates.reaclib_rate.ReacLibRate] = []
missingRates = []
for r_str in rate_names:
# Try the exact name match first (fastest)
try:
pyna_rate = reaclib_library.get_rate_by_name(r_str)
if isinstance(pyna_rate, list):
goodRates.append(pyna_rate[0])
else:
goodRates.append(pyna_rate)
except:
# Fallback to the smart parser
pyna_rate = get_pyna_rate(r_str, reaclib_library)
if pyna_rate:
goodRates.append(pyna_rate)
else:
missingRates.append(r_str)
pynet : pyna.PythonNetwork = pyna.PythonNetwork(rates=goodRates)
pynet.write_network(f"{self.name}_pynucastro_network.py")
net = load_network_module(f"{self.name}_pynucastro_network.py")
Y0 = np.zeros(net.nnuc)
Y0[net.jp] = self.composition.getMolarAbundance("H-1")
Y0[net.jhe3] = self.composition.getMolarAbundance("He-3")
Y0[net.jhe4] = self.composition.getMolarAbundance("He-4")
Y0[net.jc12] = self.composition.getMolarAbundance("C-12")
Y0[net.jn14] = self.composition.getMolarAbundance("N-14")
Y0[net.jo16] = self.composition.getMolarAbundance("O-16")
Y0[net.jne20] = self.composition.getMolarAbundance("Ne-20")
Y0[net.jmg24] = self.composition.getMolarAbundance("Mg-24")
print("Starting pynucastro integration...")
startTime = time.time()
sol = scipy.integrate.solve_ivp(
net.rhs,
[0, self.tMax],
Y0,
args=(self.density, self.temperature),
method="BDF",
jac=net.jacobian,
rtol=1e-5,
atol=1e-8
)
endTime = time.time()
print("Pynucastro integration complete. Writing results to JSON...")
data: List[Dict[str, Union[float, Dict[str, float]]]] = []
for time_step, t in enumerate(sol.t):
data.append({"t": t, "Composition": {}})
for j in range(net.nnuc):
A = net.A[j]
Z = net.Z[j]
species: str
try:
species = fourdst.atomic.az_to_species(A, Z).name()
except:
species = f"SP-A_{A}_Z_{Z}"
data[-1]["Composition"][species] = sol.y[j, time_step]
pynucastro_json : Dict[str, Any] = {
"Metadata": {
"Name": f"{self.name}_pynucastro",
"Description": f"pynucastro simulation equivalent to GridFire validation suite: {self.description}",
"Status": "Success",
"Notes": self.notes,
"Temperature": self.temperature,
"Density": self.density,
"tMax": self.tMax,
"ElapsedTime": endTime - startTime,
"DateCreated": datetime.now().isoformat()
},
"Steps": data
}
with open(f"GridFireValidationSuite_{self.name}_pynucastro.json", "w") as f:
json.dump(pynucastro_json, f, indent=4)
def evolve(self, engine: GraphEngine, netIn: NetIn, pynucastro_compare: bool = True):
solver : CVODESolverStrategy = CVODESolverStrategy(engine)
stepLogger : StepLogger = StepLogger()
solver.set_callback(lambda ctx: stepLogger.log_step(ctx))
startTime = time.time()
try:
startTime = time.time()
netOut : NetOut = solver.evaluate(netIn)
endTime = time.time()
stepLogger.to_json(
f"GridFireValidationSuite_{self.name}_OKAY.json",
Name = f"{self.name}_Success",
Description=self.description,
Status="Success",
Notes=self.notes,
Temperature=netIn.temperature,
Density=netIn.density,
tMax=netIn.tMax,
FinalEps = netOut.energy,
FinaldEpsdT = netOut.dEps_dT,
FinaldEpsdRho = netOut.dEps_dRho,
ElapsedTime = endTime - startTime
)
except GridFireError as e:
endTime = time.time()
stepLogger.to_json(
f"GridFireValidationSuite_{self.name}_FAIL.json",
Name = f"{self.name}_Failure",
Description=self.description,
Status=f"Error",
ErrorMessage=str(e),
Notes=self.notes,
Temperature=netIn.temperature,
Density=netIn.density,
tMax=netIn.tMax,
ElapsedTime = endTime - startTime
)
if pynucastro_compare:
self.evolve_pynucastro(engine)
@abstractmethod
def __call__(self):
pass

56
validation/vv/utils.py Normal file
View File

@@ -0,0 +1,56 @@
from fourdst.composition import Composition
from fourdst.composition import CanonicalComposition
from fourdst.atomic import Species
from gridfire.type import NetIn
def rescale_composition(comp_ref : Composition, ZZs : float, Y_primordial : float = 0.248) -> Composition:
CC : CanonicalComposition = comp_ref.getCanonicalComposition()
dY_dZ = (CC.Y - Y_primordial) / CC.Z
Z_new = CC.Z * (10**ZZs)
Y_bulk_new = Y_primordial + (dY_dZ * Z_new)
X_new = 1.0 - Z_new - Y_bulk_new
if X_new < 0: raise ValueError(f"ZZs={ZZs} yields unphysical composition (X < 0)")
ratio_H = X_new / CC.X if CC.X > 0 else 0
ratio_He = Y_bulk_new / CC.Y if CC.Y > 0 else 0
ratio_Z = Z_new / CC.Z if CC.Z > 0 else 0
Y_new_list = []
newComp : Composition = Composition()
s: Species
for s in comp_ref.getRegisteredSpecies():
Xi_ref = comp_ref.getMassFraction(s)
if s.el() == "H":
Xi_new = Xi_ref * ratio_H
elif s.el() == "He":
Xi_new = Xi_ref * ratio_He
else:
Xi_new = Xi_ref * ratio_Z
Y = Xi_new / s.mass()
newComp.registerSpecies(s)
newComp.setMolarAbundance(s, Y)
return newComp
def init_composition(ZZs : float = 0) -> Composition:
Y_solar = [7.0262E-01, 9.7479E-06, 6.8955E-02, 2.5000E-04, 7.8554E-05, 6.0144E-04, 8.1031E-05, 2.1513E-05]
S = ["H-1", "He-3", "He-4", "C-12", "N-14", "O-16", "Ne-20", "Mg-24"]
return rescale_composition(Composition(S, Y_solar), ZZs)
def init_netIn(temp: float, rho: float, time: float, comp: Composition) -> NetIn:
n : NetIn = NetIn()
n.temperature = temp
n.density = rho
n.tMax = time
n.dt0 = 1e-12
n.composition = comp
return n
def years_to_seconds(years: float) -> float:
return years * 3.1536e7