2b_DL_fit_params.py — Fit L and R Parameters

Fits the two disorder-linewidth model parameters (mirrors Notebook 6):

  • L [nm] — grain-boundary size (Casimir scattering length)

  • R [1e-6 THz cm nm³] — defect scattering amplitude (Rayleigh-type)

using L-BFGS with strong-Wolfe line search via PyTorch autodiff. Starting values are taken from config.L0 and config.R0; the optimizer runs for up to config.MAX_ITER iterations.

The fitted parameters are saved to dl_workflow/model_parameters.hdf5 (keys: final_loss, final_model_params).

Run after 2a_DL_workflow_precompute.py.

cd workflows
python 2b_DL_fit_params.py
  1import sys
  2import pathlib
  3sys.path.insert(0, str(pathlib.Path(__file__).parent))
  4
  5import h5py
  6import numpy as np
  7import matplotlib
  8import matplotlib.pyplot as plt
  9from ase.io import read
 10from scipy.interpolate import interp1d
 11
 12from tqdm import tqdm
 13
 14from smooth_disorder.structural import obtain_density, THzToCm, THz, Angstrom
 15from smooth_disorder.vis.interactive import *
 16
 17from smooth_disorder.disorder_linewidth import lorentzian_numpy, lorentzian_torch
 18from smooth_disorder.disorder_linewidth import prepare_fitting_inputs
 19from smooth_disorder.disorder_linewidth import evaluate_linewidth_and_model_prediction
 20from smooth_disorder.disorder_linewidth import PDCModel
 21
 22import torch
 23
 24from config import (
 25    CRYSTAL_POSCAR, DISORDERED_POSCAR,
 26    DL_WORK_DIR,
 27    L0, R0, LR, MAX_ITER, LINE_SEARCH_FN,
 28)
 29
 30WORK_DIR = DL_WORK_DIR
 31
 32CRYSTAL_VEL_SAVE     = f"{WORK_DIR}/crystal_vdos_group_vel"
 33DISORDERED_VDOS_SAVE = f"{WORK_DIR}/disordered_vdos"
 34SHIFTED_SAVE         = f"{WORK_DIR}/reduced_density_crystal_vdos_group_vel"
 35
 36MODEL_PARAMETERS_SAVE = f"{WORK_DIR}/model_parameters"
 37
 38
 39
 40# read the input data + setup torch arrays
 41
 42(density_crystal,
 43density_disordered,
 44freq_disordered,
 45vdos_disordered,
 46interp_shifted_freq_crystal,
 47interp_shifted_vdos_crystal,
 48interp_shifted_speed_crystal) = prepare_fitting_inputs(
 49    CRYSTAL_POSCAR,
 50    DISORDERED_POSCAR,
 51    DISORDERED_VDOS_SAVE,
 52    SHIFTED_SAVE,
 53)
 54
 55# X and Y are the normalised crystal and disordered VDOS used in the loss
 56X = torch.from_numpy(interp_shifted_vdos_crystal / density_crystal)
 57Y = torch.from_numpy(vdos_disordered / density_disordered)
 58
 59
 60# instantiate the model for fitting the linewidths and setup LBFGS search
 61model = PDCModel(
 62    L0, R0,
 63    density_crystal, density_disordered,
 64    freq_disordered,
 65    interp_shifted_freq_crystal,
 66    interp_shifted_vdos_crystal,
 67    interp_shifted_speed_crystal,
 68)
 69
 70optim = torch.optim.LBFGS(
 71    model.parameters(),
 72    lr=LR,
 73    max_iter=MAX_ITER,
 74    line_search_fn=LINE_SEARCH_FN,
 75)
 76
 77loss_fn = torch.nn.MSELoss()
 78
 79losses, model_parameters_history = [], []
 80
 81
 82# MAIN ITERATION LOOP
 83def closure():
 84    optim.zero_grad()
 85    preds = model(X)
 86    loss = loss_fn(preds, Y)
 87    loss.backward()
 88    print(loss.detach().cpu().numpy().copy(), model.model_params.detach().cpu().numpy().copy())
 89    losses.append(loss.detach().cpu().numpy().copy())
 90    model_parameters_history.append(model.model_params.detach().cpu().numpy().copy())
 91    return loss
 92
 93loss = optim.step(closure)
 94
 95
 96# save the model parameters to a file
 97final_loss = losses[-1]
 98final_model_params = model_parameters_history[-1]
 99
100
101compression = "gzip"
102with h5py.File(f"{MODEL_PARAMETERS_SAVE}.hdf5", "w") as w:
103    w.create_dataset("final_loss",      data=np.array([final_loss]),      compression=compression)
104    w.create_dataset("final_model_params", data=final_model_params, compression=compression)
105
106