using MLUtils, CSV
XTBTSScreener.jl
- Screening Likely Transition States with Julia and Machine Learning
This Jupyter notebook demonstrates the use of machine learning to predict if a partially-optimized initialization of a transition state, used in the study of chemical kinetics to predict rate constants, is “likely to converge” or not after further simulation with expensive Density Functional Theory simulations.
It includes a step-by-step breakdown of loading the data, training the model, and evaluating the performance. This notebook draws heavily from the Lux LSTM example found here.
Developed April 2023 by Jackson Burns for the Final Project in MIT 18.337: Parallel Computing and Scientific Machine Learning
Load the Data
The input data is saved in a CSV file, which was exported from a large SQL database of simulation results. We can load it using CSV.jl
and then partition the data into training and testing sets using MLUtils.jl
.
At this moment in time the dataset is highly unbalanced due to limitations in post-processing speed of simulations hindering retrieval of failed examples. To reduce the impact of this phenomemon on the study, we will downsample the data to a 50:50 split of failed and converged samples.
= CSV.File("data/roo_co2_full_data_augmented_expanded.csv")
csv_reader = 20501
n_samples # keep track of number of failed and converged to balance the dataset
= 0
number_converged = 0
number_failed # iterate through the entire dataset
for row in csv_reader[1:n_samples]
# check the label
if parse(Bool, "$(row.converged)")
+= 1
number_converged else
+= 1
number_failed end
end
= minimum([number_converged, number_failed])
smaller_label println(
"Failed: $number_failed\n" *
"Converged: $number_converged\n" *
"Keep: $smaller_label",
)
Failed: 3274
Converged: 17227
Keep: 3274
Each row in the CSV will have the a matrix of spatial coordinates for the molecule and the Gibbs free energy, \(E_{0}+ZPE\)(sum of electronic total energy and zero point energy), and number of simulation steps for each of the three substeps in the optimization. Initial modeling results using only the coordinate arrays were unsuccessful, with models never seeing improvement in accuracy beyond randomly guessing. These augmented descriptors have shown to improve model performance over that baseline.
Each sample in the system represents a different chemical reaction and it can therefore have a different number of atoms. The most that any molecule has will not exceed 55, so to leave room for growth and other descriptors we will zero pad each sample to a uniform length.
The exact mechanisms by which this function work are described in the inline comments, but just know that it involves a lot of string to float casting.
function get_dataloaders(smaller_label)
= CSV.File("data/roo_co2_full_data_augmented_expanded.csv")
csv_reader = 20501
n_samples = Array{Float32}(undef, 60, 6, smaller_label * 2)
x_data = Float32[]
labels = 1
iter println("Progress:")
# keep track of number of failed to balance the dataset
= 0
number_converged = 0
number_failed
for row in csv_reader[1:n_samples]
# print some updates as we go
if mod(iter, div(n_samples, 10)) == 0
println(" - row $iter of $smaller_label")
flush(stdout)
end
# get if it converged or not, add alternating samples
if parse(Bool, "$(row.converged)")
if number_converged < number_failed
push!(labels, 1.0f0)
+= 1
number_converged else
continue
end
else
push!(labels, 0.0f0)
+= 1
number_failed end
# array for descriptors for this transition state
= Array{Float32}(undef, 60, 6)
m
# pull out the augmented descriptors
= split("$(row.gibbs)")
split_gibbs = split("$(row.steps)")
split_steps = split("$(row.e0_zpe)")
split_e0_zpe = [split_gibbs, split_steps, split_e0_zpe]
split_descriptors for i in 1:3
for j in 1:3
= String(split_descriptors[i][j])
temp = replace(temp, "]" => "")
temp = replace(temp, "[" => "")
temp = replace(temp, "," => "")
temp = parse(Float32, temp)
m[i, j] end
4] = Float32(0.0)
m[i, 5] = Float32(0.0)
m[i, 6] = Float32(0.0)
m[i, end
# get the final coordinates of the atoms
= split("$(row.std_xyz)")
split_array = Int(length(split_array) / 6)
n_atoms = 4
row_counter = 1
column_counter for value in split_array
= String(value)
temp = replace(temp, "]" => "")
temp = replace(temp, "[" => "")
temp = replace(temp, "," => "")
temp = parse(Float32, temp)
m[row_counter, column_counter] += 1
column_counter if column_counter > 6
= 1
column_counter += 1
row_counter end
end
# zero-padding
for i in n_atoms+1:60
1:6] = [0, 0, 0, 0, 0, 0]
m[i, end
1:60, 1:6, iter] = m
x_data[+= 1
iter end
= length(labels)
remainingsamples println("downsampled to $remainingsamples samples, $number_converged converged $number_failed failed")
println("...loading done, partitioning data.")
= splitobs((x_data, labels); at=0.8, shuffle=true)
(x_train, y_train), (x_val, y_val) return (DataLoader(collect.((x_train, y_train)); batchsize=2^6, shuffle=true),
DataLoader(collect.((x_val, y_val)); batchsize=2^6, shuffle=false))
end
get_dataloaders (generic function with 1 method)
For ease of debugging and as a reference, the original tutorial dataloading function is included below.
function get_tutorial_dataloaders()
= 1000
dataset_size = 50
sequence_length = [MLUtils.Datasets.make_spiral(sequence_length) for _ in 1:dataset_size]
data # Get the labels
= vcat(repeat([0.0f0], dataset_size ÷ 2), repeat([1.0f0], dataset_size ÷ 2))
labels = [reshape(d[1][:, 1:sequence_length], :, sequence_length, 1)
clockwise_spirals in data[1:(dataset_size÷2)]]
for d = [reshape(d[1][:, (sequence_length+1):end], :, sequence_length,
anticlockwise_spirals 1) for d in data[((dataset_size÷2)+1):end]]
= Float32.(cat(clockwise_spirals..., anticlockwise_spirals...; dims=3))
x_data # Split the dataset
= splitobs((x_data, labels); at=0.8, shuffle=true)
(x_train, y_train), (x_val, y_val) return (DataLoader(collect.((x_train, y_train)); batchsize=128, shuffle=true),
DataLoader(collect.((x_val, y_val)); batchsize=128, shuffle=false))
end
get_tutorial_dataloaders (generic function with 1 method)
Configure the Neural Network
Following from the tutorial in the Lux documentation we write a series of functions that will create our NN.
using Lux, Random, Optimisers, Zygote, NNlib, Statistics
We seed the random number generator for consistent results.
= Random.default_rng()
rng Random.seed!(rng, 42)
TaskLocalRNG()
Define a new struct
that extends the Lux
Container
type and holds the LSTM model and the classifier.
struct StateClassifier{L,C} <:
:lstm_cell, :classifier)}
Lux.AbstractExplicitContainerLayer{(::L
lstm_cell::C
classifierend
function StateClassifier(in_dims, hidden_dims, out_dims)
return StateClassifier(LSTMCell(in_dims => hidden_dims),
Dense(hidden_dims => out_dims, sigmoid))
end
StateClassifier
function (s::StateClassifier)(x::AbstractArray{T,3}, ps::NamedTuple,
::NamedTuple) where {T}
st= Iterators.peel(eachslice(x; dims=2))
x_init, x_rest = s.lstm_cell(x_init, ps.lstm_cell, st.lstm_cell)
(y, carry), st_lstm for x in x_rest
= s.lstm_cell((x, carry), ps.lstm_cell, st_lstm)
(y, carry), st_lstm end
= s.classifier(y, ps.classifier, st.classifier)
y, st_classifier = merge(st, (classifier=st_classifier, lstm_cell=st_lstm))
st return vec(y), st
end
Define the loss function, using binarycrossentropy for simplicity.
function xlogy(x, y)
= x * log(y)
result return ifelse(iszero(x), zero(result), result)
end
function binarycrossentropy(y_pred, y_true)
= y_pred .+ eps(eltype(y_pred))
y_pred return mean(@. -xlogy(y_true, y_pred) - xlogy(1 - y_true, 1 - y_pred))
end
function compute_loss(x, y, model, ps, st)
= model(x, ps, st)
y_pred, st return binarycrossentropy(y_pred, y), y_pred, st
end
matches(y_pred, y_true) = sum((y_pred .> 0.5) .== y_true)
accuracy(y_pred, y_true) = matches(y_pred, y_true) / length(y_pred)
accuracy (generic function with 1 method)
function create_optimiser(ps)
= Optimisers.ADAM(0.0001f0)
opt return Optimisers.setup(opt, ps)
end
create_optimiser (generic function with 1 method)
Train the NN
Now with all of the boilerplate work out of the way, the NN can actually be trained!
From extensive debugging and iteration, it was determined that a very low learning rate was critical for the successful operation of the NN. The number of epochs is set to 55, which is the point at which the change in the loss substantially decreases and the model begins to overfit. The ideal batch size was determined to be 2^6, and substantially higher or lower sizes would cause parabolic loss curves or numerical instability (which can be partially blamed on the loss function).
Begin by loading the data from the file and partition it:
= get_dataloaders(smaller_label) (train_loader, val_loader)
Progress:
- row 2050 of 3274
- row 4100 of 3274
- row 6150 of 3274
downsampled to 6548 samples, 3274 converged 3274 failed
...loading done, partitioning data.
(DataLoader(::Tuple{Array{Float32, 3}, Vector{Float32}}, shuffle=true, batchsize=64), DataLoader(::Tuple{Array{Float32, 3}, Vector{Float32}}, batchsize=64))
Initialize the model and the optimizer (using ADAM):
= StateClassifier(60, 6, 1)
model = Random.default_rng()
rng Random.seed!(rng, 0)
= Lux.setup(rng, model)
ps, st = create_optimiser(ps) opt_state
(lstm_cell = (weight_i = Leaf(Adam{Float32}(0.0001, (0.9, 0.999), 1.19209f-7), (Float32[0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0; … ; 0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0], Float32[0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0; … ; 0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0], (0.9, 0.999))), weight_h = Leaf(Adam{Float32}(0.0001, (0.9, 0.999), 1.19209f-7), (Float32[0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0; … ; 0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0], Float32[0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0; … ; 0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0], (0.9, 0.999))), bias = Leaf(Adam{Float32}(0.0001, (0.9, 0.999), 1.19209f-7), (Float32[0.0; 0.0; … ; 0.0; 0.0;;], Float32[0.0; 0.0; … ; 0.0; 0.0;;], (0.9, 0.999)))), classifier = (weight = Leaf(Adam{Float32}(0.0001, (0.9, 0.999), 1.19209f-7), (Float32[0.0 0.0 … 0.0 0.0], Float32[0.0 0.0 … 0.0 0.0], (0.9, 0.999))), bias = Leaf(Adam{Float32}(0.0001, (0.9, 0.999), 1.19209f-7), (Float32[0.0;;], Float32[0.0;;], (0.9, 0.999)))))
And finally, train the model printing the occassional update.
= Float64[]
loss_vector = Float64[]
accuracy_vector for epoch in 1:55
# Train the model
= Float64[]
epoch_loss for (x, y) in train_loader
= pullback(p -> compute_loss(x, y, model, p, st), ps)
(loss, y_pred, st), back = back((one(loss), nothing, nothing))[1]
gs = Optimisers.update(opt_state, ps, gs)
opt_state, ps push!(epoch_loss, loss)
end
= mean(epoch_loss)
avg_loss push!(loss_vector, avg_loss)
# Validate the model
= Float64[]
epoch_accuracy = Lux.testmode(st)
st_ for (x, y) in val_loader
= compute_loss(x, y, model, ps, st_)
(loss, y_pred, st_) = accuracy(y_pred, y)
acc push!(epoch_accuracy, acc)
end
= mean(epoch_accuracy)
avg_accuracy if epoch == 1 || mod(epoch, 5) == 0
println("Epoch # $epoch:\n - loss of $avg_loss")
println(" - accuracy of $avg_accuracy")
flush(stdout)
end
push!(accuracy_vector, avg_accuracy)
end
Epoch # 1:
- loss of 0.7282473619391279
- accuracy of 0.5146825396825397
Epoch # 5:
- loss of 0.7089660269458119
- accuracy of 0.5236111111111111
Epoch # 10:
- loss of 0.6943259777092352
- accuracy of 0.5287202380952382
Epoch # 15:
- loss of 0.686766859961719
- accuracy of 0.5387896825396825
Epoch # 20:
- loss of 0.6815697768839394
- accuracy of 0.5472718253968254
Epoch # 25:
- loss of 0.6774497729976002
- accuracy of 0.5570436507936508
Epoch # 30:
- loss of 0.6739311807039308
- accuracy of 0.5617063492063492
Epoch # 35:
- loss of 0.6705420882236667
- accuracy of 0.5556547619047618
Epoch # 40:
- loss of 0.6674438722249938
- accuracy of 0.5660714285714286
Epoch # 45:
- loss of 0.6641433820491884
- accuracy of 0.5712797619047618
Epoch # 50:
- loss of 0.661274130751447
- accuracy of 0.5734126984126984
Epoch # 55:
- loss of 0.6587260307335272
- accuracy of 0.5696924603174602
using Plots
plot(loss_vector, label="loss", legend=:bottom, color=:red, rightmargin=1.5Plots.cm, bottommargin=0.5Plots.cm, box=:on, fmt=:png)
plot!(twinx(), accuracy_vector, label="accuracy", legend=:top, xlabel="epoch", rightmargin=1.5Plots.cm, bottommargin=0.5Plots.cm, box=:on)
While this final accuracy is not outstanding, it is worthwhile for this use case since every successful prediction could save potentially weeks of time. If taken more as a proof of concept, providing a model with only the molecular coordinates and then augmenting the data with electronic descriptors retrived from the same calculations was able to dramatically improve the results. In the future, more descriptors could be added to further increase the performance, or alternative network architectures could be explored.
using Dates
= now()
timestamp savefig("results/result-$timestamp.png")