Training and Forecasting
Training RNAForecaster requires two expression count matrices. These count matrices should be formatted with genes as rows and cells as columns. Each matrix should represent two different time points from the same cells. This can be accomplished from transcriptomic profiling using spliced and unspliced counts or by using labeled and unlabeled counts from metabolic labeling scRNA-seq protocols such as scEU-seq (see Battich et al. 2020).
Here, we will generate two random matrices for illustrative purposes.
testT0 = log1p.(Float32.(abs.(randn(10,1000))))
testT1 = log1p.(0.5f0 .* testT0)
Note that the input matrices are expected to be of type Float32 and log transformed, as shown above.
Training
To train the neural ODE network we call the trainRNAForecaster
function.
testForecaster = trainRNAForecaster(testT0, testT1);
In the simplest case, we only need to input the matrices, but there are several options provided to modify the training of the neural network, as shown below.
RNAForecaster.trainRNAForecaster
— FunctiontrainRNAForecaster(expressionDataT0::Matrix{Float32}, expressionDataT1::Matrix{Float32}; trainingProp::Float64 = 0.8, hiddenLayerNodes::Int = 2*size(expressionDataT0)[1], shuffleData::Bool = true, seed::Int = 123, learningRate::Float64 = 0.005, nEpochs::Int = 10, batchsize::Int = 100, checkStability::Bool = false, iterToCheck::Int = 50, stabilityThreshold::Float32 = 2*maximum(expressionDataT0), stabilityChecksBeforeFail::Int = 5, useGPU::Bool = false)
Function to train RNAForecaster based on expression data. Main input is two matrices representing expression data from two different time points in the same cell. This can be either based on splicing or metabolic labeling currently. Each should be log normalized and have genes as rows and cells as columns.
Required Arguments
- expressionDataT0 - Float32 Matrix of log-normalized expression counts in the format of genes x cells
- expressionDataT1 - Float32 Matrix of log-normalized expression counts in the format
of genes x cells from a time after expressionDataT0
Keyword Arguments
- trainingProp - proportion of the data to use for training the model, the rest will be
used for a validation set. If you don't want a validation set, this value can be set to 1.0
- hiddenLayerNodes - number of nodes in the hidden layer of the neural network
- shuffleData - should the cells be randomly shuffled before training
- seed - random seed
- learningRate - learning rate for the neural network during training
- nEpochs - how many times should the neural network be trained on the data.
Generally yields small gains in performance, can be lowered to speed up the training process
- batchsize - batch size for training
- checkStability - should the stability of the networks future time predictions be checked,
retraining the network if unstable?
- iterToCheck - when checking stability, how many future time steps should be predicted?
- stabilityThreshold - when checking stability, what is the maximum gene variance allowable across predictions?
- stabilityChecksBeforeFail - when checking stability, how many times should the network
be allowed to retrain before an error is thrown? Used to prevent an infinite loop.
- useGPU - use a GPU to train the neural network? highly recommended for large data sets, if available
For example, by default RNAForecaster partitions the input data into a training and a validation set. If we want the neural network to be trained on the entire data set, we can set trainingProp = 1.0
.
When using larger data sets, such as a matrix from a normal scRNAseq experiment which may contain thousands of variable genes and tens of thousands of cells, it becomes inefficient to train the network on a CPU. If a GPU is available, setting useGPU = true
can massively speed up the training process.
#Forecasting
Once we have trained the neural network, we can use it to forecast future expression states. For example, to predict the next fifty time points from our test data, we could run:
testOut1 = predictCellFutures(testForecaster[1], testT0, 50);
The predictions can also be conditioned on arbitrary perturbations in gene expression.
geneNames = ["A", "B", "C", "D", "E", "F", "G", "H", "I", "J"]
testOut2 = predictCellFutures(testForecaster[1], testT0, 50, perturbGenes = ["A", "B", "F"],
geneNames = geneNames, perturbationLevels = [1.0f0, 2.0f0, 0.0f0]);
All options for predictCellFutures
are shown here:
RNAForecaster.predictCellFutures
— FunctionpredictCellFutures(trainedNetwork, expressionData::Matrix{Float32}, tSteps::Int; perturbGenes::Vector{String} = Vector{String}(undef,0), geneNames::Vector{String} = Vector{String}(undef,0), perturbationLevels::Vector{Float32} = Vector{Float32}(undef,0), enforceMaxPred::Bool = true, maxPrediction::Float32 = 2*maximum(expressionData))
Function to make future expression predictions using a trained neural ODE outputs a 3d tensor containing a predicted expression counts matrix for the cell at each time step
Required Arguments
- trainedNetwork - the trained neural ODE, from the trainRNAForecaster function
- expressionData - the initial expression states that should be used to make predictions from
- tSteps - how many future time steps should be predicted. (Error will propagate
with each prediction so predictions will eventually become highly innaccurate at high numbers of time steps)
Keyword Arguments
- perturbGenes - a vector of gene names that will have their values set to a constant 'perturbed' level.
- geneNames - a vector of gene names in the order of the rows of the expressionData.
Used only when simulating perturbations.
- perturbationLevels - a vector of Float32, corresponding to the level each perturbed
gene's expression should be set at.
- enforceMaxPred - should a maximum allowed prediction be enforced? This is used
to represent prior knowledge about what sort of expression values are remotely reasonable predictions.
- maxPrediction - if enforcing a maximum prediction, what should the value be?
2 times the maximum of the input expression data by default (in log space).
Once we have forecast expression levels for each gene, we may want to know which genes expression levels change the most over time, as these are likely to be important in ongoing biological process we are attempting to model. To assay this we simply run mostTimeVariableGenes
which outputs a table of genes ordered by the most variable over predicted time points.
geneOutputTable = mostTimeVariableGenes(testOut1, geneNames)
RNAForecaster.mostTimeVariableGenes
— FunctionmostTimeVariableGenes(cellFutures::AbstractArray{Float32}, geneNames::Vector{String}; statType = "mean")
For each cell, takes the predicted expression levels of each gene over time and finds the variance with respect to predicted time points. Then get the mean/median for each gene's variance across cells for each gene.
Outputs a sorted DataFrame containing gene names and the variances over predicted time.
Required Arguments
- cellFutures - a 3D tensor of gene expression over time; the output from predictCellFutures
- geneNames - a vector of gene names corresponding to the order of the genes in cellFutures
Optional Arguments
- statType - How to summarize the gene variances. Valid options are "mean" or "median"