using MLUtils, CSVXTBTSScreener.jl - Screening Likely Transition States with Julia and Machine Learning
This Jupyter notebook demonstrates the use of machine learning to predict if a partially-optimized initialization of a transition state, used in the study of chemical kinetics to predict rate constants, is “likely to converge” or not after further simulation with expensive Density Functional Theory simulations.
It includes a step-by-step breakdown of loading the data, training the model, and evaluating the performance. This notebook draws heavily from the Lux LSTM example found here.
Developed April 2023 by Jackson Burns for the Final Project in MIT 18.337: Parallel Computing and Scientific Machine Learning
Load the Data
The input data is saved in a CSV file, which was exported from a large SQL database of simulation results. We can load it using CSV.jl and then partition the data into training and testing sets using MLUtils.jl.
At this moment in time the dataset is highly unbalanced due to limitations in post-processing speed of simulations hindering retrieval of failed examples. To reduce the impact of this phenomemon on the study, we will downsample the data to a 50:50 split of failed and converged samples.
csv_reader = CSV.File("data/roo_co2_full_data_augmented_expanded.csv")
n_samples = 20501
# keep track of number of failed and converged to balance the dataset
number_converged = 0
number_failed = 0
# iterate through the entire dataset
for row in csv_reader[1:n_samples]
# check the label
if parse(Bool, "$(row.converged)")
number_converged += 1
else
number_failed += 1
end
end
smaller_label = minimum([number_converged, number_failed])
println(
"Failed: $number_failed\n" *
"Converged: $number_converged\n" *
"Keep: $smaller_label",
)Failed: 3274
Converged: 17227
Keep: 3274
Each row in the CSV will have the a matrix of spatial coordinates for the molecule and the Gibbs free energy, \(E_{0}+ZPE\)(sum of electronic total energy and zero point energy), and number of simulation steps for each of the three substeps in the optimization. Initial modeling results using only the coordinate arrays were unsuccessful, with models never seeing improvement in accuracy beyond randomly guessing. These augmented descriptors have shown to improve model performance over that baseline.
Each sample in the system represents a different chemical reaction and it can therefore have a different number of atoms. The most that any molecule has will not exceed 55, so to leave room for growth and other descriptors we will zero pad each sample to a uniform length.
The exact mechanisms by which this function work are described in the inline comments, but just know that it involves a lot of string to float casting.
function get_dataloaders(smaller_label)
csv_reader = CSV.File("data/roo_co2_full_data_augmented_expanded.csv")
n_samples = 20501
x_data = Array{Float32}(undef, 60, 6, smaller_label * 2)
labels = Float32[]
iter = 1
println("Progress:")
# keep track of number of failed to balance the dataset
number_converged = 0
number_failed = 0
for row in csv_reader[1:n_samples]
# print some updates as we go
if mod(iter, div(n_samples, 10)) == 0
println(" - row $iter of $smaller_label")
flush(stdout)
end
# get if it converged or not, add alternating samples
if parse(Bool, "$(row.converged)")
if number_converged < number_failed
push!(labels, 1.0f0)
number_converged += 1
else
continue
end
else
push!(labels, 0.0f0)
number_failed += 1
end
# array for descriptors for this transition state
m = Array{Float32}(undef, 60, 6)
# pull out the augmented descriptors
split_gibbs = split("$(row.gibbs)")
split_steps = split("$(row.steps)")
split_e0_zpe = split("$(row.e0_zpe)")
split_descriptors = [split_gibbs, split_steps, split_e0_zpe]
for i in 1:3
for j in 1:3
temp = String(split_descriptors[i][j])
temp = replace(temp, "]" => "")
temp = replace(temp, "[" => "")
temp = replace(temp, "," => "")
m[i, j] = parse(Float32, temp)
end
m[i, 4] = Float32(0.0)
m[i, 5] = Float32(0.0)
m[i, 6] = Float32(0.0)
end
# get the final coordinates of the atoms
split_array = split("$(row.std_xyz)")
n_atoms = Int(length(split_array) / 6)
row_counter = 4
column_counter = 1
for value in split_array
temp = String(value)
temp = replace(temp, "]" => "")
temp = replace(temp, "[" => "")
temp = replace(temp, "," => "")
m[row_counter, column_counter] = parse(Float32, temp)
column_counter += 1
if column_counter > 6
column_counter = 1
row_counter += 1
end
end
# zero-padding
for i in n_atoms+1:60
m[i, 1:6] = [0, 0, 0, 0, 0, 0]
end
x_data[1:60, 1:6, iter] = m
iter += 1
end
remainingsamples = length(labels)
println("downsampled to $remainingsamples samples, $number_converged converged $number_failed failed")
println("...loading done, partitioning data.")
(x_train, y_train), (x_val, y_val) = splitobs((x_data, labels); at=0.8, shuffle=true)
return (DataLoader(collect.((x_train, y_train)); batchsize=2^6, shuffle=true),
DataLoader(collect.((x_val, y_val)); batchsize=2^6, shuffle=false))
endget_dataloaders (generic function with 1 method)
For ease of debugging and as a reference, the original tutorial dataloading function is included below.
function get_tutorial_dataloaders()
dataset_size = 1000
sequence_length = 50
data = [MLUtils.Datasets.make_spiral(sequence_length) for _ in 1:dataset_size]
# Get the labels
labels = vcat(repeat([0.0f0], dataset_size ÷ 2), repeat([1.0f0], dataset_size ÷ 2))
clockwise_spirals = [reshape(d[1][:, 1:sequence_length], :, sequence_length, 1)
for d in data[1:(dataset_size÷2)]]
anticlockwise_spirals = [reshape(d[1][:, (sequence_length+1):end], :, sequence_length,
1) for d in data[((dataset_size÷2)+1):end]]
x_data = Float32.(cat(clockwise_spirals..., anticlockwise_spirals...; dims=3))
# Split the dataset
(x_train, y_train), (x_val, y_val) = splitobs((x_data, labels); at=0.8, shuffle=true)
return (DataLoader(collect.((x_train, y_train)); batchsize=128, shuffle=true),
DataLoader(collect.((x_val, y_val)); batchsize=128, shuffle=false))
endget_tutorial_dataloaders (generic function with 1 method)
Configure the Neural Network
Following from the tutorial in the Lux documentation we write a series of functions that will create our NN.
using Lux, Random, Optimisers, Zygote, NNlib, StatisticsWe seed the random number generator for consistent results.
rng = Random.default_rng()
Random.seed!(rng, 42)TaskLocalRNG()
Define a new struct that extends the Lux Container type and holds the LSTM model and the classifier.
struct StateClassifier{L,C} <:
Lux.AbstractExplicitContainerLayer{(:lstm_cell, :classifier)}
lstm_cell::L
classifier::C
endfunction StateClassifier(in_dims, hidden_dims, out_dims)
return StateClassifier(LSTMCell(in_dims => hidden_dims),
Dense(hidden_dims => out_dims, sigmoid))
endStateClassifier
function (s::StateClassifier)(x::AbstractArray{T,3}, ps::NamedTuple,
st::NamedTuple) where {T}
x_init, x_rest = Iterators.peel(eachslice(x; dims=2))
(y, carry), st_lstm = s.lstm_cell(x_init, ps.lstm_cell, st.lstm_cell)
for x in x_rest
(y, carry), st_lstm = s.lstm_cell((x, carry), ps.lstm_cell, st_lstm)
end
y, st_classifier = s.classifier(y, ps.classifier, st.classifier)
st = merge(st, (classifier=st_classifier, lstm_cell=st_lstm))
return vec(y), st
endDefine the loss function, using binarycrossentropy for simplicity.
function xlogy(x, y)
result = x * log(y)
return ifelse(iszero(x), zero(result), result)
end
function binarycrossentropy(y_pred, y_true)
y_pred = y_pred .+ eps(eltype(y_pred))
return mean(@. -xlogy(y_true, y_pred) - xlogy(1 - y_true, 1 - y_pred))
end
function compute_loss(x, y, model, ps, st)
y_pred, st = model(x, ps, st)
return binarycrossentropy(y_pred, y), y_pred, st
end
matches(y_pred, y_true) = sum((y_pred .> 0.5) .== y_true)
accuracy(y_pred, y_true) = matches(y_pred, y_true) / length(y_pred)accuracy (generic function with 1 method)
function create_optimiser(ps)
opt = Optimisers.ADAM(0.0001f0)
return Optimisers.setup(opt, ps)
endcreate_optimiser (generic function with 1 method)
Train the NN
Now with all of the boilerplate work out of the way, the NN can actually be trained!
From extensive debugging and iteration, it was determined that a very low learning rate was critical for the successful operation of the NN. The number of epochs is set to 55, which is the point at which the change in the loss substantially decreases and the model begins to overfit. The ideal batch size was determined to be 2^6, and substantially higher or lower sizes would cause parabolic loss curves or numerical instability (which can be partially blamed on the loss function).
Begin by loading the data from the file and partition it:
(train_loader, val_loader) = get_dataloaders(smaller_label)Progress:
- row 2050 of 3274
- row 4100 of 3274
- row 6150 of 3274
downsampled to 6548 samples, 3274 converged 3274 failed
...loading done, partitioning data.
(DataLoader(::Tuple{Array{Float32, 3}, Vector{Float32}}, shuffle=true, batchsize=64), DataLoader(::Tuple{Array{Float32, 3}, Vector{Float32}}, batchsize=64))
Initialize the model and the optimizer (using ADAM):
model = StateClassifier(60, 6, 1)
rng = Random.default_rng()
Random.seed!(rng, 0)
ps, st = Lux.setup(rng, model)
opt_state = create_optimiser(ps)(lstm_cell = (weight_i = Leaf(Adam{Float32}(0.0001, (0.9, 0.999), 1.19209f-7), (Float32[0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0; … ; 0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0], Float32[0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0; … ; 0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0], (0.9, 0.999))), weight_h = Leaf(Adam{Float32}(0.0001, (0.9, 0.999), 1.19209f-7), (Float32[0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0; … ; 0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0], Float32[0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0; … ; 0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0], (0.9, 0.999))), bias = Leaf(Adam{Float32}(0.0001, (0.9, 0.999), 1.19209f-7), (Float32[0.0; 0.0; … ; 0.0; 0.0;;], Float32[0.0; 0.0; … ; 0.0; 0.0;;], (0.9, 0.999)))), classifier = (weight = Leaf(Adam{Float32}(0.0001, (0.9, 0.999), 1.19209f-7), (Float32[0.0 0.0 … 0.0 0.0], Float32[0.0 0.0 … 0.0 0.0], (0.9, 0.999))), bias = Leaf(Adam{Float32}(0.0001, (0.9, 0.999), 1.19209f-7), (Float32[0.0;;], Float32[0.0;;], (0.9, 0.999)))))
And finally, train the model printing the occassional update.
loss_vector = Float64[]
accuracy_vector = Float64[]
for epoch in 1:55
# Train the model
epoch_loss = Float64[]
for (x, y) in train_loader
(loss, y_pred, st), back = pullback(p -> compute_loss(x, y, model, p, st), ps)
gs = back((one(loss), nothing, nothing))[1]
opt_state, ps = Optimisers.update(opt_state, ps, gs)
push!(epoch_loss, loss)
end
avg_loss = mean(epoch_loss)
push!(loss_vector, avg_loss)
# Validate the model
epoch_accuracy = Float64[]
st_ = Lux.testmode(st)
for (x, y) in val_loader
(loss, y_pred, st_) = compute_loss(x, y, model, ps, st_)
acc = accuracy(y_pred, y)
push!(epoch_accuracy, acc)
end
avg_accuracy = mean(epoch_accuracy)
if epoch == 1 || mod(epoch, 5) == 0
println("Epoch # $epoch:\n - loss of $avg_loss")
println(" - accuracy of $avg_accuracy")
flush(stdout)
end
push!(accuracy_vector, avg_accuracy)
endEpoch # 1:
- loss of 0.7282473619391279
- accuracy of 0.5146825396825397
Epoch # 5:
- loss of 0.7089660269458119
- accuracy of 0.5236111111111111
Epoch # 10:
- loss of 0.6943259777092352
- accuracy of 0.5287202380952382
Epoch # 15:
- loss of 0.686766859961719
- accuracy of 0.5387896825396825
Epoch # 20:
- loss of 0.6815697768839394
- accuracy of 0.5472718253968254
Epoch # 25:
- loss of 0.6774497729976002
- accuracy of 0.5570436507936508
Epoch # 30:
- loss of 0.6739311807039308
- accuracy of 0.5617063492063492
Epoch # 35:
- loss of 0.6705420882236667
- accuracy of 0.5556547619047618
Epoch # 40:
- loss of 0.6674438722249938
- accuracy of 0.5660714285714286
Epoch # 45:
- loss of 0.6641433820491884
- accuracy of 0.5712797619047618
Epoch # 50:
- loss of 0.661274130751447
- accuracy of 0.5734126984126984
Epoch # 55:
- loss of 0.6587260307335272
- accuracy of 0.5696924603174602
using Plots
plot(loss_vector, label="loss", legend=:bottom, color=:red, rightmargin=1.5Plots.cm, bottommargin=0.5Plots.cm, box=:on, fmt=:png)
plot!(twinx(), accuracy_vector, label="accuracy", legend=:top, xlabel="epoch", rightmargin=1.5Plots.cm, bottommargin=0.5Plots.cm, box=:on)
While this final accuracy is not outstanding, it is worthwhile for this use case since every successful prediction could save potentially weeks of time. If taken more as a proof of concept, providing a model with only the molecular coordinates and then augmenting the data with electronic descriptors retrived from the same calculations was able to dramatically improve the results. In the future, more descriptors could be added to further increase the performance, or alternative network architectures could be explored.
using Dates
timestamp = now()
savefig("results/result-$timestamp.png")