2b_DL_fit_params.py — Fit L and R Parameters¶
Fits the two disorder-linewidth model parameters (mirrors Notebook 6):
L [nm] — grain-boundary size (Casimir scattering length)
R [1e-6 THz cm nm³] — defect scattering amplitude (Rayleigh-type)
using L-BFGS with strong-Wolfe line search via PyTorch autodiff. Starting
values are taken from config.L0 and config.R0; the optimizer runs for
up to config.MAX_ITER iterations.
The fitted parameters are saved to dl_workflow/model_parameters.hdf5
(keys: final_loss, final_model_params).
Run after 2a_DL_workflow_precompute.py.
cd workflows
python 2b_DL_fit_params.py
1import sys
2import pathlib
3sys.path.insert(0, str(pathlib.Path(__file__).parent))
4
5import h5py
6import numpy as np
7import matplotlib
8import matplotlib.pyplot as plt
9from ase.io import read
10from scipy.interpolate import interp1d
11
12from tqdm import tqdm
13
14from smooth_disorder.structural import obtain_density, THzToCm, THz, Angstrom
15from smooth_disorder.vis.interactive import *
16
17from smooth_disorder.disorder_linewidth import lorentzian_numpy, lorentzian_torch
18from smooth_disorder.disorder_linewidth import prepare_fitting_inputs
19from smooth_disorder.disorder_linewidth import evaluate_linewidth_and_model_prediction
20from smooth_disorder.disorder_linewidth import PDCModel
21
22import torch
23
24from config import (
25 CRYSTAL_POSCAR, DISORDERED_POSCAR,
26 DL_WORK_DIR,
27 L0, R0, LR, MAX_ITER, LINE_SEARCH_FN,
28)
29
30WORK_DIR = DL_WORK_DIR
31
32CRYSTAL_VEL_SAVE = f"{WORK_DIR}/crystal_vdos_group_vel"
33DISORDERED_VDOS_SAVE = f"{WORK_DIR}/disordered_vdos"
34SHIFTED_SAVE = f"{WORK_DIR}/reduced_density_crystal_vdos_group_vel"
35
36MODEL_PARAMETERS_SAVE = f"{WORK_DIR}/model_parameters"
37
38
39
40# read the input data + setup torch arrays
41
42(density_crystal,
43density_disordered,
44freq_disordered,
45vdos_disordered,
46interp_shifted_freq_crystal,
47interp_shifted_vdos_crystal,
48interp_shifted_speed_crystal) = prepare_fitting_inputs(
49 CRYSTAL_POSCAR,
50 DISORDERED_POSCAR,
51 DISORDERED_VDOS_SAVE,
52 SHIFTED_SAVE,
53)
54
55# X and Y are the normalised crystal and disordered VDOS used in the loss
56X = torch.from_numpy(interp_shifted_vdos_crystal / density_crystal)
57Y = torch.from_numpy(vdos_disordered / density_disordered)
58
59
60# instantiate the model for fitting the linewidths and setup LBFGS search
61model = PDCModel(
62 L0, R0,
63 density_crystal, density_disordered,
64 freq_disordered,
65 interp_shifted_freq_crystal,
66 interp_shifted_vdos_crystal,
67 interp_shifted_speed_crystal,
68)
69
70optim = torch.optim.LBFGS(
71 model.parameters(),
72 lr=LR,
73 max_iter=MAX_ITER,
74 line_search_fn=LINE_SEARCH_FN,
75)
76
77loss_fn = torch.nn.MSELoss()
78
79losses, model_parameters_history = [], []
80
81
82# MAIN ITERATION LOOP
83def closure():
84 optim.zero_grad()
85 preds = model(X)
86 loss = loss_fn(preds, Y)
87 loss.backward()
88 print(loss.detach().cpu().numpy().copy(), model.model_params.detach().cpu().numpy().copy())
89 losses.append(loss.detach().cpu().numpy().copy())
90 model_parameters_history.append(model.model_params.detach().cpu().numpy().copy())
91 return loss
92
93loss = optim.step(closure)
94
95
96# save the model parameters to a file
97final_loss = losses[-1]
98final_model_params = model_parameters_history[-1]
99
100
101compression = "gzip"
102with h5py.File(f"{MODEL_PARAMETERS_SAVE}.hdf5", "w") as w:
103 w.create_dataset("final_loss", data=np.array([final_loss]), compression=compression)
104 w.create_dataset("final_model_params", data=final_model_params, compression=compression)
105
106