7b_fit_dl_parameters.py — Fit L and R Parameters

This script fits the two disorder-linewidth parameters (mirrors Notebook 6):

  • L [nm] — grain-boundary mean free path (Casimir scattering length)

  • R [cm THz nm³] — defect scattering amplitude (Rayleigh-type)

using L-BFGS with strong-Wolfe line search via PyTorch autodiff.

Run after 7a_workflow_precompute_quantities.py.

The fitted parameters are saved to dl_workflow/model_parameters.hdf5 (keys: final_loss, final_model_params).

cd tutorials/disorder_linewidth
python 7b_fit_dl_parameters.py
  1import sys
  2
  3import h5py
  4import numpy as np
  5import matplotlib
  6import matplotlib.pyplot as plt
  7from ase.io import read
  8from scipy.interpolate import interp1d
  9
 10from tqdm import tqdm
 11
 12from smooth_disorder.structural import obtain_density, THzToCm, THz, Angstrom
 13from smooth_disorder.vis.interactive import *
 14
 15from smooth_disorder.disorder_linewidth import lorentzian_numpy, lorentzian_torch
 16from smooth_disorder.disorder_linewidth import prepare_fitting_inputs
 17from smooth_disorder.disorder_linewidth import evaluate_linewidth_and_model_prediction
 18from smooth_disorder.disorder_linewidth import PDCModel
 19
 20import torch
 21
 22
 23CRYSTAL_POSCAR    = "./1_graphite/POSCAR"
 24DISORDERED_POSCAR = "./2_irg_t2/irg_t2_14009.vasp"
 25
 26WORK_DIR = "./dl_workflow"
 27
 28CRYSTAL_VEL_SAVE     = f"{WORK_DIR}/crystal_vdos_group_vel"
 29DISORDERED_VDOS_SAVE = f"{WORK_DIR}/disordered_vdos"
 30SHIFTED_SAVE         = f"{WORK_DIR}/reduced_density_crystal_vdos_group_vel"
 31
 32MODEL_PARAMETERS_SAVE = f"{WORK_DIR}/model_parameters"
 33
 34
 35# Initial parameter values [nm], [1e-6 THz cm nm^3]
 36L0, R0 = 3.3, 5.54
 37
 38
 39
 40# read the input data + setup torch arrays
 41
 42(density_crystal,
 43density_disordered,
 44freq_disordered,
 45vdos_disordered,
 46interp_shifted_freq_crystal,
 47interp_shifted_vdos_crystal,
 48interp_shifted_speed_crystal) = prepare_fitting_inputs(
 49    CRYSTAL_POSCAR,
 50    DISORDERED_POSCAR,
 51    DISORDERED_VDOS_SAVE,
 52    SHIFTED_SAVE,
 53)
 54
 55# X and Y are the normalised crystal and disordered VDOS used in the loss
 56X = torch.from_numpy(interp_shifted_vdos_crystal / density_crystal)
 57Y = torch.from_numpy(vdos_disordered / density_disordered)
 58
 59
 60# instantiate the model for fitting the linewidths and setup LBFGS search
 61model = PDCModel(
 62    L0, R0,
 63    density_crystal, density_disordered,
 64    freq_disordered,
 65    interp_shifted_freq_crystal,
 66    interp_shifted_vdos_crystal,
 67    interp_shifted_speed_crystal,
 68)
 69
 70optim = torch.optim.LBFGS(
 71    model.parameters(),
 72    lr=1.0,
 73    max_iter=50,
 74    line_search_fn="strong_wolfe",
 75)
 76
 77loss_fn = torch.nn.MSELoss()
 78
 79losses, model_parameters_history = [], []
 80
 81
 82# MAIN ITERATION LOOP
 83def closure():
 84    optim.zero_grad()
 85    preds = model(X)
 86    loss = loss_fn(preds, Y)
 87    loss.backward()
 88    print(loss.detach().cpu().numpy().copy(), model.model_params.detach().cpu().numpy().copy())
 89    losses.append(loss.detach().cpu().numpy().copy())
 90    model_parameters_history.append(model.model_params.detach().cpu().numpy().copy())
 91    return loss
 92
 93loss = optim.step(closure)
 94
 95
 96# save the model parameters to a file
 97final_loss = losses[-1]
 98final_model_params = model_parameters_history[-1]
 99
100
101compression = "gzip"
102with h5py.File(f"{MODEL_PARAMETERS_SAVE}.hdf5", "w") as w:
103    w.create_dataset("final_loss",      data=np.array([final_loss]),      compression=compression)
104    w.create_dataset("final_model_params", data=final_model_params, compression=compression)
105
106