Main Content

Train Neural ODE Network

This example shows how to train an augmented neural ordinary differential equation (ODE) network.

A neural ODE [1] is a deep learning operation that returns the solution of an ODE. In particular, given an input, a neural ODE operation outputs the numerical solution of the ODE y=f(t,y,θ) for the time horizon (t0,t1) and the initial condition y(t0)=y0, where t and y denote the ODE function inputs and θ is a set of learnable parameters. Typically, the initial condition y0 is either the network input or, as in the case of this example, the output of another deep learning operation.

An augmented neural ODE [2] operation improves upon a standard neural ODE by augmenting the input data with extra channels and then discarding the augmentation after the neural ODE operation. Empirically, augmented neural ODEs are more stable, generalize better, and have a lower computational cost than neural ODEs.

This example trains a simple convolutional neural network with an augmented neural ODE operation.

The ODE function can be a collection of deep learning operations. In this example, the model uses a convolution-tanh block as the ODE function:

The example shows how to train a neural network to classify images of digits using an augmented neural ODE operation.

Load Training Data

Load the training images and labels using the digitTrain4DArrayData function.

[XTrain,TTrain] = digitTrain4DArrayData;

View the number of classes of the training data.

classNames = categories(TTrain);
numClasses = numel(classNames)
numClasses = 10

View some images from the training data.

numObservations = size(XTrain,4);
idx = randperm(numObservations,64);
I = imtile(XTrain(:,:,:,idx));

Define Deep Learning Model

Define the following network, which classifies images.

  • A convolution-ReLU block with 8 3-by-3 filters with a stride of 2

  • An augmentation step that concatenates an array of zeros to the input such that the number of channels is doubled

  • A neural ODE operation with ODE function containing a convolution-tanh block with 16 3-by-3 filters

  • For classification output, a fully connect operation of size 10 (the number of classes) and a softmax operation

A neural ODE operation outputs the solution of a specified ODE function. For this example, specify a convolution-tanh block as the ODE function.

That is, specify the ODE function given by y=f(t,y,θ), where f denotes the convolution-tanh operation, y is the input data, and θ contains the learnable parameters for the convolution operation. In this case, the variable t is unused.

Define and Initialize Model Parameters

Define the learnable parameters for each of the operations and include them in a structure. Use the format parameters.OperationName.ParameterName, where parameters is the structure, OperationName is the name of the operation (for example, "conv1"), and ParameterName is the name of the parameter (for example, "Weights"). Initialize the learnable layer weights and biases using the initializeGlorot and initializeZeros example functions, respectively. The initialization example functions are attached to this example as supporting files. To access these functions, open this example as a live script. For more information about initializing learnable parameters for model functions, see Initialize Learnable Parameters for Model Function.

Initialize the parameters structure.

parameters = struct;

Initialize the parameters for the first convolutional layer. Specify 8 3-by-3 filters. If you change these dimensions, then you must manually calculate the input size of the fully connect operation for its Glorot weights initialization.

filterSize = [3 3];
numFilters = 8;

numChannels = size(XTrain,3);
sz = [filterSize numChannels numFilters];
numOut = prod(filterSize) * numFilters;
numIn = prod(filterSize) * numFilters;

parameters.conv1.Weights = initializeGlorot(sz,numOut,numIn);
parameters.conv1.Bias = initializeZeros([numFilters 1]);

Initialize the parameters for the convolution operation used in the neural ODE function. Because the augmentation step augments the input data with an array of zeros, the number of input channels is given by numFilters + numExtraChannels, where numExtraChannels is the number of channels in the augmentation. Similarly, because the model discards channels of the output of the neural ODE operation corresponding to the augmentation, the convolution operation in the neural ODE must have (numChannels + numExtraChannels) filters, where numChannels is the desired number of output channels.

Specify the same number of filters as the first convolution layer and a matching augmentation size.

numChannels = numFilters;
numExtraChannels = numFilters;

numFiltersAugmented = numChannels + numExtraChannels;
sz = [filterSize numFiltersAugmented numFiltersAugmented];

numOut = prod(filterSize) * numFiltersAugmented;
numIn = prod(filterSize) * numFiltersAugmented;

parameters.neuralode.Weights = initializeGlorot(sz,numOut,numIn);
parameters.neuralode.Bias = initializeZeros([numFiltersAugmented 1]);

Initialize the parameters for the fully connect operation. To initialize the weights of the fully connect operation using the Glorot initializer, first calculate the number of input elements to the operation.

For each operation in the model that changes the size of the data flowing through, consider the output sizes when you pass 28-by-28 images through the model:

  • The first convolution has 8 filters with "same" padding and a stride of 2. This operation outputs 14-by-14 images with 8 channels.

  • The model then augments the data with an 8-channel array of zeros. This operation outputs 14-by-14 images with 16 channels.

  • The neural ODE operation has a convolution operation with 16 filters and "same" padding. This operation outputs 14-by-14 images with 16 channels.

  • The model then discards the channels corresponding to the augmentation. This operation outputs 14-by-14 images with 8 channels.

This means that the number of input elements to the fully connect operation is 14*14*8=1568.

sz =  [14 14];
inputSize = prod(sz)*numChannels;
outputSize = numClasses;

sz = [outputSize inputSize];
numOut = outputSize;
numIn = inputSize;

parameters.fc1.Weights = initializeGlorot(sz,numOut,numIn);
parameters.fc1.Bias = initializeZeros([outputSize 1]);

View the structure of parameters.

parameters = struct with fields:
        conv1: [1×1 struct]
    neuralode: [1×1 struct]
          fc1: [1×1 struct]

View the parameters for the neural ODE operation.

ans = struct with fields:
    Weights: [3×3×16×16 dlarray]
       Bias: [16×1 dlarray]

Define Model Hyper Parameters

Define the hyperparameters for the operations and include them in a structure. Use the format hyperparameters.OperationName.ParameterName where hyperparameters is the structure, OperationName is the name of the operation (for example "neuralode") and ParameterName is the name of the hyperparameter (for example, "tspan").

Initialize the hyperparameters structure.

hyperparameters = struct;

For the neural ODE, specify an interval of integration of [0 0.1].

hyperparameters.neuralode.tspan = [0 0.1];

Define Neural ODE Function

Create the function odeModel, listed in the ODE Function section of the example, which takes as input the time input (unused), the initial conditions, and the ODE function parameters. The function applies a convolution operation followed by a tanh operation to the input data using the weights and biases given by the parameters.

Define Model Function

Create the function model, listed in the Model Function section of the example, which computes the outputs of the deep learning model. The function model takes as input the model parameters and the input data. The function outputs the predictions for the labels.

Define Model Gradients Function

Create the function modelGradients, listed in the Model Gradients section of the example, which takes as input the model parameters and a mini-batch of input data with corresponding targets containing the labels, and returns the gradients of the loss with respect to the learnable parameters and the corresponding loss.

Specify Training Options

Specify the training options. Train with a mini-batch size of 64 for 30 epochs.

miniBatchSize = 64;
numEpochs = 30;

Train Model

Train the model using a custom training loop.

Create a minibatchqueue object that processes and manages mini-batches of images during training. To create a minibatchqueue object, first create a datastore that returns the images and labels by creating array datastores and then combining them.

dsXTrain = arrayDatastore(XTrain,IterationDimension=4);
dsTTrain = arrayDatastore(TTrain);
dsTrain = combine(dsXTrain,dsTTrain);

Create the mini-batch queue. For each mini-batch:

  • Use the custom mini-batch preprocessing function preprocessMiniBatch, defined in the Mini-Batch Preprocessing Function section of the example, to convert the labels to one-hot encoded variables.

  • Format the image data with the dimension labels "SSCB" (spatial, spatial, channel, batch). By default, the minibatchqueue object converts the data to dlarray objects with underlying type single.

  • Discard partial mini-batches.

  • Train on a GPU if one is available. By default, the minibatchqueue object converts each output to a gpuArray if a GPU is available. Using a GPU requires Parallel Computing Toolbox™ and a supported GPU device. For information on supported devices, see GPU Support by Release (Parallel Computing Toolbox).

mbq = minibatchqueue(dsTrain, ...
    MiniBatchSize=miniBatchSize, ...
    MiniBatchFcn=@preprocessMiniBatch, ...
    MiniBatchFormat=["SSCB" "CB"]);

Initialize the moving average of the parameter gradients and the element-wise squares of the gradients used by the Adam optimizer.

trailingAvg = [];
trailingAvgSq = [];

Initialize the training plot.

C = colororder;
lineLossTrain = animatedline(Color=C(2,:));
ylim([0 inf])
grid on

Train the model using a custom training loop. For each epoch, shuffle the data. For each mini-batch:

  • Evaluate the model gradients using the dlfeval and modelGradients functions.

  • Update the network parameters using the adamupdate function.

  • Update the training progress plot.

iteration = 0;
start = tic;

% Loop over epochs.
for epoch = 1:numEpochs

    % Shuffle data.

    % Loop over mini-batches.
    while hasdata(mbq)

        iteration = iteration + 1;

        [dlX,dlT] = next(mbq);

        % Evaluate the model gradients, state, and loss using dlfeval and the
        % modelGradients function.
        [gradients,loss] = dlfeval(@modelGradients, parameters, dlX, dlT, hyperparameters);

        % Update the network parameters using the Adam optimizer.
        [parameters,trailingAvg,trailingAvgSq] = adamupdate(parameters,gradients, ...

        % Display the training progress.
        D = duration(0,0,toc(start),Format="hh:mm:ss");
        loss = double(gather(extractdata(loss)));
        title("Epoch: " + epoch + ", Elapsed: " + string(D))

Test Model

Test the classification accuracy of the model by comparing the predictions on a held-out test set with the true labels.

Load the test data.

[XTest,TTest] = digitTest4DArrayData;

After training, making predictions on new data does not require the labels. Create a minibatchqueue object containing only the predictors of the test data:

  • Set the number of outputs of the mini-batch queue to 1.

  • Specify the same mini-batch size used for training.

  • Preprocess the predictors using the preprocessPredictors function, listed in the Mini-Batch Predictors Preprocessing Function section of the example.

  • For the single output of the datastore, specify the mini-batch format "SSCB" (spatial, spatial, channel, batch).

dsTest = arrayDatastore(XTest,IterationDimension=4);

mbqTest = minibatchqueue(dsTest,1, ...
    MiniBatchSize=miniBatchSize, ...
    MiniBatchFormat="SSCB", ...

Loop over the mini-batches and classify the sequences using modelPredictions function, listed in the Model Predictions Function section of the example.

YPred = modelPredictions(parameters,hyperparameters,mbqTest,classNames);

Visualize the predictions in a confusion matrix.


Model Function

The function model takes as input the model parameters, the input data dlX, the model hyperparameters, and outputs the predictions for the labels.

This diagram outlines the model structure.

For the neural ODE operation, use the dlode45 function and specify the odeModel function, listed in the ODE Function section of the example. Increase the absolute and relative tolerance using the AbsoluteTolerance and RelativeTolerance name-value arguments, respectively. To calculate the gradients by solving the associated adjoint ODE system, set the GradientMode option to "adjoint".

function dlY = model(parameters,dlX,hyperparameters)

% Convolution, ReLU.
weights = parameters.conv1.Weights;
bias = parameters.conv1.Bias;
dlY = dlconv(dlX,weights,bias,Padding="same",Stride=2);

dlY = relu(dlY);

% Augment.
weights = parameters.neuralode.Weights;

numChannels = size(dlY,3);
szAugmented = size(dlY);
szAugmented(3) = size(weights,3) - numChannels;

dlY0 = cat(3, dlY, zeros(szAugmented,"like",dlY));

% Neural ODE.
tspan = hyperparameters.neuralode.tspan;
dlY = dlode45(@odeModel,tspan,dlY0,parameters.neuralode, ...
    GradientMode="adjoint", ...
    AbsoluteTolerance=1e-3, ...

% Discard augmentation.
dlY(:,:,numChannels+1:end,:) = [];

% Fully connect, softmax.
weights = parameters.fc1.Weights;
bias = parameters.fc1.Bias;
dlY = fullyconnect(dlY,weights,bias);

dlY = softmax(dlY);


ODE Function

The neural ODE operation consists of a convolution operation followed by a tanh operation.

The ODE function odeModel takes as input the function inputs t (unused) and y and the ODE function parameters p containing the convolution weights and biases, and returns the output of the convolution-tanh block operation.

function z = odeModel(t,y,p)

weights = p.Weights;
bias = p.Bias;

z = dlconv(y,weights,bias,Padding="same");
z = tanh(z);


Model Gradients Function

The modelGradients function takes as input the model parameters, a mini-batch of input data dlX with corresponding targets dlT, and model hyperparameters, and returns the gradients of the loss with respect to the learnable parameters and the corresponding loss. To compute the gradients using automatic differentiation, use the dlgradient function.

function [gradients,loss] = modelGradients(parameters,dlX,dlT,hyperparameters)

dlY = model(parameters,dlX,hyperparameters);

loss = crossentropy(dlY,dlT);

gradients = dlgradient(loss,parameters);


Model Predictions Function

The modelPredictions function takes as input the model parameters, model hyperparameters, a minibatchqueue of input data mbq, and the network classes, and computes the model predictions by iterating over all data in the minibatchqueue object. The function uses the onehotdecode function to find the predicted classes with the highest score.

function predictions = modelPredictions(parameters,hyperparameters,mbq,classNames)

predictions = [];

while hasdata(mbq)
    dlX = next(mbq);
    dlYPred = model(parameters,dlX,hyperparameters);
    YPred = onehotdecode(dlYPred,classNames,1)';
    predictions = [predictions; YPred];


Mini-Batch Preprocessing Function

The preprocessMiniBatch function preprocesses a mini-batch of predictors and labels using the following steps:

  1. Preprocess the images using the preprocessPredictors function.

  2. Extract the label data from the incoming cell array and concatenate into a categorical array along the second dimension.

  3. One-hot encode the categorical labels into numeric arrays. Encoding into the first dimension produces an encoded array that matches the shape of the network output.

function [X,Y] = preprocessMiniBatch(XCell,YCell)

% Preprocess predictors.
X = preprocessPredictors(XCell);

% Extract label data from cell and concatenate.
Y = cat(2,YCell{:});

% One-hot encode labels.
Y = onehotencode(Y,1);


Predictors Preprocessing Function

The preprocessPredictors function preprocesses a mini-batch of predictors by extracting the image data from the input cell array and concatenating the data into a numeric array. For grayscale input, concatenating over the fourth dimension adds a third dimension to each image to use as a singleton channel dimension.

function X = preprocessPredictors(XCell)

X = cat(4,XCell{:});



  1. Chen, Ricky T. Q., Yulia Rubanova, Jesse Bettencourt, and David Duvenaud. “Neural Ordinary Differential Equations.” Preprint, submitted June 19, 2018.

  2. Dupont, Emilien, Arnaud Doucet, and Yee Whye Teh. “Augmented Neural ODEs.” Preprint, submitted October 26, 2019.

See Also

| | | |

Related Topics