RNAForecaster.createEnsembleForecaster
— MethodcreateEnsembleForecaster(expressionDataT0::Matrix{Float32}, expressionDataT1::Matrix{Float32}; nNetworks::Int = 5, trainingProp::Float64 = 0.8, hiddenLayerNodes::Int = 2*size(expressionDataT0)[1], shuffleData::Bool = true, seed::Int = 123, learningRate::Float64 = 0.005, nEpochs::Int = 10, batchsize::Int = 100, checkStability::Bool = false, iterToCheck::Int = 50, stabilityThreshold::Float32 = 2*maximum(expressionDataT0), stabilityChecksBeforeFail::Int = 5, useGPU::Bool = false)
Function to train multiple neural ODEs to predict expression, allowing an ensembling of their predicitons, which tends to yield more accurate results on future predictions. This is because stochastic gradient descent yields slightly different solutions when given different random seeds. In the training data these solutions yield almost identical results, but when generalizing to future predictions, the results can diverge substantially. To account for this, we can average across multiple forecasters.
It is recommended to run this function on a GPU (useGPU = true) or if a GPU is not available run in parallel. To train the neural networks on separate processes call using Distributed addprocs(desiredNumberOfParallelProcesses) @everywhere using RNAForecaster
Required Arguments
- expressionDataT0 - Float32 Matrix of log-normalized expression counts in the format of genes x cells
- expressionDataT1 - Float32 Matrix of log-normalized expression counts in the format
of genes x cells from a time after expressionDataT0
Keyword Arguments
- nNetworks - number of networks to train
- trainingProp - proportion of the data to use for training the model, the rest will be
used for a validation set. If you don't want a validation set, this value can be set to 1.0
- hiddenLayerNodes - number of nodes in the hidden layer of the neural network
- shuffleData - should the cells be randomly shuffled before training
- seed - random seed
- learningRate - learning rate for the neural network during training
- nEpochs - how many times should the neural network be trained on the data.
Generally yields small gains in performance, can be lowered to speed up the training process
- batchsize - batch size for training
- checkStability - should the stability of the networks future time predictions be checked,
retraining the network if unstable?
- iterToCheck - when checking stability, how many future time steps should be predicted?
- stabilityThreshold - when checking stability, what is the maximum gene variance allowable across predictions?
- stabilityChecksBeforeFail - when checking stability, how many times should the network
be allowed to retrain before an error is thrown? Used to prevent an infinite loop.
- useGPU - use a GPU to train the neural network? highly recommended for large data sets, if available
RNAForecaster.estimateT0LabelingData
— Method" estimateT0LabelingData(labeledData::Matrix{Float32}, totalData::Matrix{Float32}, unlabeledData::Matrix{Float32}, labelingTime::AbstractVector)
Function to predict total expression level before labeling based on degradation rate estimates. Outputs the estimated time 1 counts matrix.
Required Arguments
- labeledData - Float32 counts matrix of the labeled counts
- totalData - Float32 counts matrix of combined counts
- unlabeledData - Float32 counts matrix of unlabeled counts
- labelingTime - Vector with the amount of time each cell was labeled for
RNAForecaster.filterByGeneVar
— MethodfilterByGeneVar(t0Counts::Matrix{Float32}, t1Counts::Matrix{Float32}, topGenes::Int)
Filter by gene variance, measured across both count matrices.
Required Arguments
- t0Counts - Counts matrix for time 1. Should be genes x cells and Float32.
- t1Counts - Counts matrix for time 2. Should be genes x cells and Float32.
- topGenes - the number of top most variable genes to use
Examples
mat0 = Float32.(randn(50,50)) mat1 = Float32.(randn(50,50)) hvgFilteredCounts = filterByGeneVar(mat0, mat1, 20)
RNAForecaster.filterByZeroProp
— MethodfilterByZeroProp(t0Counts::Matrix{Float32}, t1Counts::Matrix{Float32}, zeroProp::Float32)
Filter by zero proportion for both genes and cells. Very high sparsity prevents the neural network from achieving a stable solution.
Required Arguments
- t0Counts - Counts matrix for time 1. Should be genes x cells and Float32.
- t1Counts - Counts matrix for time 2. Should be genes x cells and Float32.
Keyword Arguments
- zeroProp - proportion of zeroes allowed for a gene or a cell. 0.98f0 by default
RNAForecaster.findSigRegulation
— MethodfindSigRegulation(perturbData, geneNames::Vector{String}; pvalCut::Float64 = 0.05)
Function to get a sorted data frame of whether there exists a statistically significant difference between predicted gene expression before and after perturbation.
Required Arguments
- perturbData - results from perturbEffectPredictions function
- geneNames - vector of gene names in the order of the input expression data
- pvalCut - p value threshold. Default is 0.05
RNAForecaster.genePerturbExpressionChanges
— MethodgenePerturbExpressionChanges(perturbData, geneNames::Vector{String}, perturbGene::String; genesperturbd::Vector{String} = geneNames)
Function to get a sorted data frame of the predicted effect of a gene perturb on all other genes.
Required Arguments
- perturbData - results from perturbEffectPredictions function
- geneNames - vector of gene names in the order of the input expression data
- perturbGene - a gene name to query the predicted perturb effect on expression
Optional Arguments
- genesPerturbed - If less than all the gene perturbs were performed, the ordered names of the perturb genes must be supplied
RNAForecaster.geneResponseToPerturb
— MethodgeneResponseToPerturb(perturbData, geneNames::Vector{String}, geneOfInterest::String; genesPerturbed::Vector{String} = geneNames)
Function to get a sorted data frame of the predicted effect of all other gene perturbations on a particular gene of interest.
Required Arguments
- perturbData - results from perturbEffectPredictions function
- geneNames - vector of gene names in the order of the input expression data
- geneOfInterest - a gene name to query
RNAForecaster.loadForecaster
— MethodloadForecaster(fileName::String, inputNodes::Int, hiddenLayerNodes::Int)
Recreates a previously saved neural network.
Note: if for some reason you are loading the network and have not first loaded the DiffEqFlux and DifferentialEquations packages (normally should be loaded when loading RNAForecaster.jl) then the network will not work, even if you load the required packages afterwards.
Required Arguments
- fileName - file name where the parameters are saved
- inputNodes - number of input nodes in the network. Should be the same as the number
of genes in the data the network was trained on
- hiddenLayerNodes - number of hidden layer nodes in the network
RNAForecaster.mostTimeVariableGenes
— MethodmostTimeVariableGenes(cellFutures::AbstractArray{Float32}, geneNames::Vector{String}; statType = "mean")
For each cell, takes the predicted expression levels of each gene over time and finds the variance with respect to predicted time points. Then get the mean/median for each gene's variance across cells for each gene.
Outputs a sorted DataFrame containing gene names and the variances over predicted time.
Required Arguments
- cellFutures - a 3D tensor of gene expression over time; the output from predictCellFutures
- geneNames - a vector of gene names corresponding to the order of the genes in cellFutures
Optional Arguments
- statType - How to summarize the gene variances. Valid options are "mean" or "median"
RNAForecaster.perturbEffectPredictions
— MethodperturbEffectPredictions(trainedNetwork, splicedData::Matrix{Float32}, nCells::Int; perturbGenes::Vector{String} = Vector{String}(undef, 0), geneNames::Vector{String} = Vector{String}(undef,0), seed::Int=123)
Based on spliced/unspliced counts, predict the immediate transcriptomic effect of any or all single gene perturbations. Outputs a tuple containing the cells used for prediction, the expression predictions, the gene wise differences, and the cell-wise euclidean distances for each perturbation.
This function is capable of running on multiple parallel processes using Distributed.jl. Call addprocs(n) before running the function to add parallel workers, where n is the number of additional processes desired.
Required Arguments
- trainedNetwork - trained neuralODE from trainRNAForecaster
- splicedData - log normalized spliced counts matrix. Must be in Float32 format
- nCells - how many cells from the data should be used for prediction of perturb effect.
Higher values will increase computational time required.
Optional Arguments
- perturbGenes - list of genes to simulate a perturbation of. By default all genes are used
- geneNames - if providing a subset of the genes to perturb, a vector of gene names to
match against, in the order of splicedData
- perturbLevels - list of perturbation levels to use for each perturbed gene. By default
all genes are set to zero, simulating a KO.
- seed - Random seed for reproducibility on the cells chosen for prediction
RNAForecaster.predictCellFutures
— MethodpredictCellFutures(trainedNetwork, expressionData::Matrix{Float32}, tSteps::Int; perturbGenes::Vector{String} = Vector{String}(undef,0), geneNames::Vector{String} = Vector{String}(undef,0), perturbationLevels::Vector{Float32} = Vector{Float32}(undef,0), enforceMaxPred::Bool = true, maxPrediction::Float32 = 2*maximum(expressionData))
Function to make future expression predictions using a trained neural ODE outputs a 3d tensor containing a predicted expression counts matrix for the cell at each time step
Required Arguments
- trainedNetwork - the trained neural ODE, from the trainRNAForecaster function
- expressionData - the initial expression states that should be used to make predictions from
- tSteps - how many future time steps should be predicted. (Error will propagate
with each prediction so predictions will eventually become highly innaccurate at high numbers of time steps)
Keyword Arguments
- perturbGenes - a vector of gene names that will have their values set to a constant 'perturbed' level.
- geneNames - a vector of gene names in the order of the rows of the expressionData.
Used only when simulating perturbations.
- perturbationLevels - a vector of Float32, corresponding to the level each perturbed
gene's expression should be set at.
- enforceMaxPred - should a maximum allowed prediction be enforced? This is used
to represent prior knowledge about what sort of expression values are remotely reasonable predictions.
- maxPrediction - if enforcing a maximum prediction, what should the value be?
2 times the maximum of the input expression data by default (in log space).
RNAForecaster.saveForecaster
— MethodsaveForecaster(trainedModel, fileName::String)
Saves the parameters from the neural network after training.
Required Arguments
- trainedModel - the trained neural network from trainRNAForecaster - just the network
don't include the loss results
- fileName - fileName to save the parameters to. Must end in .jld2
RNAForecaster.totalPerturbImpact
— MethodtotalPerturbImpact(perturbData, geneNames::Vector{String})
Function to yield a sorted data frame of the size of the predicted effect of a perturbation on the cellular transcriptome. Intended to serve as a measure of more or less impactful gene perturbations.
Required Arguments
- perturbData - results from perturbEffectPredictions function
- geneNames - vector of gene names in the order of the input expression data.
Should only include perturbed genes
RNAForecaster.trainRNAForecaster
— MethodtrainRNAForecaster(expressionDataT0::Matrix{Float32}, expressionDataT1::Matrix{Float32}; trainingProp::Float64 = 0.8, hiddenLayerNodes::Int = 2*size(expressionDataT0)[1], shuffleData::Bool = true, seed::Int = 123, learningRate::Float64 = 0.005, nEpochs::Int = 10, batchsize::Int = 100, checkStability::Bool = false, iterToCheck::Int = 50, stabilityThreshold::Float32 = 2*maximum(expressionDataT0), stabilityChecksBeforeFail::Int = 5, useGPU::Bool = false)
Function to train RNAForecaster based on expression data. Main input is two matrices representing expression data from two different time points in the same cell. This can be either based on splicing or metabolic labeling currently. Each should be log normalized and have genes as rows and cells as columns.
Required Arguments
- expressionDataT0 - Float32 Matrix of log-normalized expression counts in the format of genes x cells
- expressionDataT1 - Float32 Matrix of log-normalized expression counts in the format
of genes x cells from a time after expressionDataT0
Keyword Arguments
- trainingProp - proportion of the data to use for training the model, the rest will be
used for a validation set. If you don't want a validation set, this value can be set to 1.0
- hiddenLayerNodes - number of nodes in the hidden layer of the neural network
- shuffleData - should the cells be randomly shuffled before training
- seed - random seed
- learningRate - learning rate for the neural network during training
- nEpochs - how many times should the neural network be trained on the data.
Generally yields small gains in performance, can be lowered to speed up the training process
- batchsize - batch size for training
- checkStability - should the stability of the networks future time predictions be checked,
retraining the network if unstable?
- iterToCheck - when checking stability, how many future time steps should be predicted?
- stabilityThreshold - when checking stability, what is the maximum gene variance allowable across predictions?
- stabilityChecksBeforeFail - when checking stability, how many times should the network
be allowed to retrain before an error is thrown? Used to prevent an infinite loop.
- useGPU - use a GPU to train the neural network? highly recommended for large data sets, if available