Main Content

trainOCR

Train OCR model to recognize text in image

Since R2023a

    Description

    Training

    example

    modelFileName = trainOCR(trainingData,modelName,baseModel,ocrOptions) trains a new OCR model by fine-tuning a pretrained base model using the hyperparameters specified in ocrOptions.

    [modelFileName,info] = trainOCR(___) returns a structure that contains information on training progress, such as the training root mean squared error (RMSE) and learning rate for each iteration, using the input arguments from the previous syntax. For a list of the returned error rates, see the modelFileName output argument.

    Resume training

    [modelFileName,info] = trainOCR(trainingData,modelName,checkpoint,ocrOptions) resumes training from an OCR training checkpoint. Use this syntax to improve the accuracy of your OCR model by using additional training data or to perform more training iterations.

    Examples

    collapse all

    This example shows how to train an OCR model that can recognize fourteen-segment characters. The training data contains word samples of fourteen-segment characters.

    Unzip and extract training images.

    datasetZip = 'dseg14.zip';
    evalfiles = unzip(datasetZip);

    The training images was annotated with bounding boxes containing words and text labels were added to these bounding boxes as an attribute using the Image Labeler. The labels were exported from the app as groundTruth object and saved in dseg14Gtruth.mat file.

    ld = load("dseg14Gtruth.mat");
    gTruth = ld.gTruth;

    Create datastores that contain images, bounding boxes and text labels from the groundTruth object using the ocrTrainingData function with the label and attribute names used during labeling.

    labelName = "Text";
    attributeName = "Word";
    [imds,boxds,txtds] = ocrTrainingData(gTruth,labelName,attributeName);

    Combine the datastores.

    cds = combine(imds,boxds,txtds);

    Split the data for training and validation with a training-to-validation ratio of 0.9

    % Set the random number seed for reproducibility.
    rng(0); 
    
    % Compute number of training and validation samples.
    trainingToValidationRatio = 0.9;
    numSamples = height(ld.gTruth.LabelData);
    numTrainSamples = ceil(trainingToValidationRatio*numSamples);
    
    % Divide the dataset into training and validation.
    indices = randperm(numSamples);
    trainIndices = indices(1:numTrainSamples);
    validationIndices = indices(numTrainSamples+1:end);
    
    cdsTrain = subset(cds, trainIndices);
    cdsValidation = subset(cds, validationIndices);

    Specify training options. Set the gradient decay factor for ADAM optimization to 0.9, and use an initial learning rate of 40e-4. Set the verbose frequency to 160 iterations and the maximum number of epochs for training to 5. Specify the checkpoint path to enable saving checkpoints and specify the validation data to enable validation.

    outputDir = "OCRModel";
    if ~exist(outputDir, "dir")
        mkdir(outputDir);
    end
    
    checkpointsDir = "Checkpoints";
    if ~exist(checkpointsDir, "dir")
        mkdir(checkpointsDir);
    end
    
    ocrOptions = ocrTrainingOptions(GradientDecayFactor=0.9, ...
        InitialLearnRate=40e-4, MaxEpochs=5, VerboseFrequency=160, ...
        CheckpointPath=checkpointsDir, ValidationData=cdsValidation, ...
        OutputLocation=outputDir);

    Train a new OCR model by fine-tuning the pretrained "english" model. The training will take about 3-4 minutes.

    outputModelName = "fourteenSegment";
    baseModel = "english";
    outputModel = trainOCR(cdsTrain, outputModelName, baseModel, ocrOptions);
    *************************************************************************
    Starting OCR training
    
    Model Name: fourteenSegment
    Base Model: english
    
    Preparing training data... 100.00 % completed.
    Preparing validation data... 100.00 % completed.
    
    Character Set: +,-./0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZ
    
    |======================================================================================================================================|
    | Epoch | Iteration | Time Elapsed |           Training Statistics           |          Validation Statistics          | Base Learning |
    |       |           |  (hh:mm:ss)  |   RMSE   | Character Error | Word Error |   RMSE   | Character Error | Word Error |     Rate      |
    |======================================================================================================================================|
    |   1   |     1     |   00:00:23   |   9.51   |     100.00      |   100.00   |   0.00   |      0.00       |    0.00    |    0.0040     |
    |   1   |    160    |   00:00:34   |   2.43   |      18.46      |   38.12    |   1.69   |      14.19      |   27.78    |    0.0040     |
    |   2   |    320    |   00:00:44   |   1.47   |      10.22      |   21.88    |   0.75   |      6.35       |   11.11    |    0.0040     |
    |   3   |    480    |   00:00:54   |   1.06   |      6.94       |   15.00    |   0.49   |      5.56       |    5.56    |    0.0040     |
    |   4   |    640    |   00:01:04   |   0.86   |      5.30       |   11.72    |   0.68   |      5.56       |    5.56    |    0.0040     |
    |   5   |    800    |   00:01:15   |   0.73   |      4.25       |    9.50    |   0.49   |      6.35       |   11.11    |    0.0040     |
    |   5   |    845    |   00:01:17   |   0.70   |      4.06       |    9.11    |   0.49   |      6.35       |   11.11    |    0.0040     |
    |======================================================================================================================================|
    
    OCR training complete.
    Exit condition: Reached maximum epochs.
    
    Model file name: OCRModel/fourteenSegment.traineddata
    *************************************************************************
    

    Use the trained model to perform OCR on a test image and visualize the results.

    I = imread("DSEG14.png");
    ocrResults = ocr(I,Model=outputModel);
    Iocr = insertObjectAnnotation(I,"rectangle",...
                ocrResults.WordBoundingBoxes,ocrResults.Words,...
                LineWidth=2,FontSize=17);
    imshow(Iocr)

    Input Arguments

    collapse all

    Ground truth data, specified as a datastore that returns a cell array or a table when input to the read function. The table must contain at least these three columns:

    • 1st column — A cell vector of logical, grayscale, or RGB images.

    • 2nd column — A cell vector in which each cell corresponds to an image and contains an M-by-4 matrix. M is the number of bounding boxes in the image, and each row of the matrix specifies a bounding box in the form [x,y,width,height]. cell vector that contains M-by-4 matrices with M bounding boxes of the form [x,y,width,height].

    • 3rd column — A cell vector in which each cell corresponds to an image and contains N strings. N is the number of lines of text in the image, and each line must contain only text without newline characters.

    New model name, specified as a string scalar or character vector. If the folder already contains a file with the name specified by the modelName argument, then the trainOCR function overwrites it during training.

    Pretrained base model, specified as a string scalar or character vector. You can specify any of these options:

    • Language models shipped in the Computer Vision Toolbox™, such as "english", "seven-segment", or "japanese".

    • One of the supported languages described in the Model argument of the ocr function. You cannot use quantized models, such as "english-fast", "seven-segment-fast", or "japanese-fast", as base models.

    • Full path to a custom trained model with a .traineddata extension.

    Hyperparameters for training, specified as an ocrTrainingOptions object.

    OCR training checkpoint, specified as a string scalar or character vector. You must specify a path to a file with a .checkpoint.traineddata extension, such as the path specified by the CheckpointPath property of an ocrTrainingOptions object. When you specify a value for the CheckpointPath argument of an ocrTrainingOptions object, the trainOCR function saves checkpoints at regular intervals during training. You can resume training from any one of these saved checkpoints.

    Output Arguments

    collapse all

    Model filename, returned as a string scalar.

    Information on training progress, returned as a structure containing these fields:

    • BaseLearnRate — Learning rate at each iteration.

    • TrainingRMSE — Training RMSE at each iteration.

    • TrainingCharError — Training character error rate at each iteration.

    • TrainingWordError — Training word error rate at each iteration.

    • ValidationRMSE — Validation RMSE at each iteration.

    • ValidationCharError — Validation character error rate at each iteration.

    • ValidationWordError — Validation word error rate at each iteration.

    • FinalValidationRMSE — Final validation RMSE at the end of the training.

    • OutputModelIteration — Iteration number of the returned model.

    If you do not specify validation data , the structure contains empty ValidationRMSE, ValidationCharError, ValidationWordError, and FinalValidationRMSE fields.

    Limitations

    • Training OCR models with right-to-left scripts such as Arabic and Hebrew are not supported.

    Algorithms

    • The trainOCR function creates a temporary folder, "<modelName>Training/, where <modelName> is the value of the modelName argument, in the location specified by the OutputLocation property of the ocrTrainingOptions object. The folder contains training artifacts. If the folder does not already exist before you run the trainOCR function, the function deletes it at the end of training. If the folder already exists prior to training, the function does not delete the folder.

    • Images read from trainingData must contain text of at least one-word length and up to a maximum of one-line length. The trainOCR function does not support images that contain multiple lines of text.

    • The trainOCR function does not support on-the-fly data augmentation using a datastore transform. All the image data is read once from the training datastores at the start of training.

    Version History

    Introduced in R2023a