7b_fit_dl_parameters.py — Fit L and R Parameters¶
This script fits the two disorder-linewidth parameters (mirrors Notebook 6):
L [nm] — grain-boundary mean free path (Casimir scattering length)
R [cm THz nm³] — defect scattering amplitude (Rayleigh-type)
using L-BFGS with strong-Wolfe line search via PyTorch autodiff.
Run after 7a_workflow_precompute_quantities.py.
The fitted parameters are saved to dl_workflow/model_parameters.hdf5
(keys: final_loss, final_model_params).
cd tutorials/disorder_linewidth
python 7b_fit_dl_parameters.py
1import sys
2
3import h5py
4import numpy as np
5import matplotlib
6import matplotlib.pyplot as plt
7from ase.io import read
8from scipy.interpolate import interp1d
9
10from tqdm import tqdm
11
12from smooth_disorder.structural import obtain_density, THzToCm, THz, Angstrom
13from smooth_disorder.vis.interactive import *
14
15from smooth_disorder.disorder_linewidth import lorentzian_numpy, lorentzian_torch
16from smooth_disorder.disorder_linewidth import prepare_fitting_inputs
17from smooth_disorder.disorder_linewidth import evaluate_linewidth_and_model_prediction
18from smooth_disorder.disorder_linewidth import PDCModel
19
20import torch
21
22
23CRYSTAL_POSCAR = "./1_graphite/POSCAR"
24DISORDERED_POSCAR = "./2_irg_t2/irg_t2_14009.vasp"
25
26WORK_DIR = "./dl_workflow"
27
28CRYSTAL_VEL_SAVE = f"{WORK_DIR}/crystal_vdos_group_vel"
29DISORDERED_VDOS_SAVE = f"{WORK_DIR}/disordered_vdos"
30SHIFTED_SAVE = f"{WORK_DIR}/reduced_density_crystal_vdos_group_vel"
31
32MODEL_PARAMETERS_SAVE = f"{WORK_DIR}/model_parameters"
33
34
35# Initial parameter values [nm], [1e-6 THz cm nm^3]
36L0, R0 = 3.3, 5.54
37
38
39
40# read the input data + setup torch arrays
41
42(density_crystal,
43density_disordered,
44freq_disordered,
45vdos_disordered,
46interp_shifted_freq_crystal,
47interp_shifted_vdos_crystal,
48interp_shifted_speed_crystal) = prepare_fitting_inputs(
49 CRYSTAL_POSCAR,
50 DISORDERED_POSCAR,
51 DISORDERED_VDOS_SAVE,
52 SHIFTED_SAVE,
53)
54
55# X and Y are the normalised crystal and disordered VDOS used in the loss
56X = torch.from_numpy(interp_shifted_vdos_crystal / density_crystal)
57Y = torch.from_numpy(vdos_disordered / density_disordered)
58
59
60# instantiate the model for fitting the linewidths and setup LBFGS search
61model = PDCModel(
62 L0, R0,
63 density_crystal, density_disordered,
64 freq_disordered,
65 interp_shifted_freq_crystal,
66 interp_shifted_vdos_crystal,
67 interp_shifted_speed_crystal,
68)
69
70optim = torch.optim.LBFGS(
71 model.parameters(),
72 lr=1.0,
73 max_iter=50,
74 line_search_fn="strong_wolfe",
75)
76
77loss_fn = torch.nn.MSELoss()
78
79losses, model_parameters_history = [], []
80
81
82# MAIN ITERATION LOOP
83def closure():
84 optim.zero_grad()
85 preds = model(X)
86 loss = loss_fn(preds, Y)
87 loss.backward()
88 print(loss.detach().cpu().numpy().copy(), model.model_params.detach().cpu().numpy().copy())
89 losses.append(loss.detach().cpu().numpy().copy())
90 model_parameters_history.append(model.model_params.detach().cpu().numpy().copy())
91 return loss
92
93loss = optim.step(closure)
94
95
96# save the model parameters to a file
97final_loss = losses[-1]
98final_model_params = model_parameters_history[-1]
99
100
101compression = "gzip"
102with h5py.File(f"{MODEL_PARAMETERS_SAVE}.hdf5", "w") as w:
103 w.create_dataset("final_loss", data=np.array([final_loss]), compression=compression)
104 w.create_dataset("final_model_params", data=final_model_params, compression=compression)
105
106