Main Content

Custom Training Loops

Customize deep learning training loops and loss functions

If the trainingOptions function does not provide the training options that you need for your task, or custom output layers do not support the loss functions that you need, then you can define a custom training loop. For models that layer graphs do not support, you can define a custom model as a function. To learn more, see Define Custom Training Loops, Loss Functions, and Networks.

Funktionen

alle erweitern

dlnetworkDeep learning network for custom training loops (Seit R2019b)
resetStateReset state parameters of neural network
plotPlot neural network architecture
addInputLayerAdd input layer to network (Seit R2022b)
addLayersAdd layers to layer graph or network
removeLayersRemove layers from layer graph or network
connectLayersConnect layers in layer graph or network
disconnectLayersDisconnect layers in layer graph or network
replaceLayerReplace layer in layer graph or network
summaryPrint network summary (Seit R2022b)
initializeInitialize learnable and state parameters of a dlnetwork (Seit R2021a)
networkDataLayoutDeep learning network data layout for learnable parameter initialization (Seit R2022b)
layerGraphGraph of network layers for deep learning
setL2FactorSet L2 regularization factor of layer learnable parameter
getL2FactorGet L2 regularization factor of layer learnable parameter
setLearnRateFactorSet learn rate factor of layer learnable parameter
getLearnRateFactorGet learn rate factor of layer learnable parameter
forwardCompute deep learning network output for training (Seit R2019b)
predictCompute deep learning network output for inference (Seit R2019b)
adamupdateUpdate parameters using adaptive moment estimation (Adam) (Seit R2019b)
rmspropupdate Update parameters using root mean squared propagation (RMSProp) (Seit R2019b)
sgdmupdate Update parameters using stochastic gradient descent with momentum (SGDM) (Seit R2019b)
lbfgsupdateUpdate parameters using limited-memory BFGS (L-BFGS) (Seit R2023a)
lbfgsStateState of limited-memory BFGS (L-BFGS) solver (Seit R2023a)
dlupdate Update parameters using custom function (Seit R2019b)
trainingProgressMonitorMonitor and plot training progress for deep learning custom training loops (Seit R2022b)
updateInfoUpdate information values for custom training loops (Seit R2022b)
recordMetricsRecord metric values for custom training loops (Seit R2022b)
groupSubPlotGroup metrics in training plot (Seit R2022b)
padsequencesPad or truncate sequence data to same length (Seit R2021a)
minibatchqueueCreate mini-batches for deep learning (Seit R2020b)
onehotencodeEncode data labels into one-hot vectors (Seit R2020b)
onehotdecodeDecode probability vectors into class labels (Seit R2020b)
nextObtain next mini-batch of data from minibatchqueue (Seit R2020b)
resetReset minibatchqueue to start of data (Seit R2020b)
shuffleShuffle data in minibatchqueue (Seit R2020b)
hasdataDetermine if minibatchqueue can return mini-batch (Seit R2020b)
partitionPartition minibatchqueue (Seit R2020b)
dlarrayDeep learning array for customization (Seit R2019b)
dlgradientCompute gradients for custom training loops using automatic differentiation (Seit R2019b)
dlfevalEvaluate deep learning model for custom training loops (Seit R2019b)
dimsDimension labels of dlarray (Seit R2019b)
finddimFind dimensions with specified label (Seit R2019b)
stripdimsRemove dlarray data format (Seit R2019b)
extractdataExtract data from dlarray (Seit R2019b)
isdlarrayCheck if object is dlarray (Seit R2020b)
crossentropyCross-entropy loss for classification tasks (Seit R2019b)
l1lossL1 loss for regression tasks (Seit R2021b)
l2lossL2 loss for regression tasks (Seit R2021b)
huberHuber loss for regression tasks (Seit R2021a)
mseHalf mean squared error (Seit R2019b)
ctcConnectionist temporal classification (CTC) loss for unaligned sequence classification (Seit R2021a)
dlaccelerateAccelerate deep learning function for custom training loops (Seit R2021a)
AcceleratedFunctionAccelerated deep learning function (Seit R2021a)
clearCacheClear accelerated deep learning function trace cache (Seit R2021a)

Themen

Custom Training Loops

Automatic Differentiation

Deep Learning Function Acceleration

Verwandte Informationen