Transfer Learning with ResNet-50 in MATLAB

Here’s a structured guide to transfer learning with ResNet-50 in MATLAB, including key steps, code examples, and best practices:


Transfer Learning with ResNet-50 in MATLAB

1. Overview

  • ResNet-50: A 50-layer deep convolutional neural network (CNN) pre-trained on the ImageNet dataset.
  • Transfer Learning: Reuse the feature extraction layers of ResNet-50 and retrain the final layers on your custom dataset.

2. Key Steps

Step 1: Load Pre-trained ResNet-50

matlab
Copy
net = resnet50;                   % Load ResNet-50
inputSize = net.Layers(1).InputSize;  % Typically [224 224 3]
  • Requires the Deep Learning Toolbox and Deep Learning Toolbox Model for ResNet-50.

Step 2: Prepare Your Dataset

Use imageDatastore to organize labeled images:

matlab
Copy
imds = imageDatastore('path/to/data', ...
    'IncludeSubfolders', true, ...
    'LabelSource', 'foldernames');

% Split into training and validation sets (e.g., 70-30 split)
[imdsTrain, imdsVal] = splitEachLabel(imds, 0.7, 'randomized');

Step 3: Modify Network Architecture

  1. Remove Final Layers (classification layer and fully connected layer):
    matlab
    Copy
    lgraph = layerGraph(net);
    lgraph = removeLayers(lgraph, {'fc1000', 'prob', 'ClassificationLayer_fc1000'});
  2. Add New Layers for your task:
    matlab
    Copy
    numClasses = numel(categories(imdsTrain.Labels));
    
    newLayers = [
        fullyConnectedLayer(numClasses, 'Name', 'fcNew', 'WeightLearnRateFactor', 10, 'BiasLearnRateFactor', 10);
        softmaxLayer('Name', 'softmax');
        classificationLayer('Name', 'classOutput');
    ];
    
    lgraph = addLayers(lgraph, newLayers);
    lgraph = connectLayers(lgraph, 'avg_pool', 'fcNew');

Step 4: Data Augmentation

Prevent overfitting with augmented images:

matlab
Copy
augmenter = imageDataAugmenter( ...
    'RandXReflection', true, ...
    'RandRotation', [-20 20], ...
    'RandXScale', [0.8 1.2]);

augimdsTrain = augmentedImageDatastore(inputSize, imdsTrain, ...
    'DataAugmentation', augmenter);

Step 5: Training Options

Configure hyperparameters for training:

matlab
Copy
options = trainingOptions('sgdm', ...
    'MiniBatchSize', 32, ...
    'MaxEpochs', 10, ...
    'InitialLearnRate', 1e-4, ...
    'Shuffle', 'every-epoch', ...
    'ValidationData', augimdsVal, ...
    'ValidationFrequency', 50, ...
    'Verbose', true, ...
    'Plots', 'training-progress');

Step 6: Train the Network

matlab
Copy
netTransfer = trainNetwork(augimdsTrain, lgraph, options);

Step 7: Evaluate Performance

matlab
Copy
[YPred, probs] = classify(netTransfer, augimdsVal);
accuracy = mean(YPred == imdsVal.Labels);

% Confusion Matrix
confusionchart(imdsVal.Labels, YPred);

3. Fine-Tuning (Optional)

To update weights in earlier layers (useful for small datasets):

matlab
Copy
% Unfreeze earlier layers (e.g., from 'res5a_branch2a' onward)
for i = 1:numel(lgraph.Layers)
    if isprop(lgraph.Layers(i), 'WeightLearnRateFactor')
        lgraph.Layers(i).WeightLearnRateFactor = 1;
        lgraph.Layers(i).BiasLearnRateFactor = 1;
    end
end

% Retrain with a lower learning rate (e.g., 1e-5)
options.InitialLearnRate = 1e-5;

4. Deployment

  • Export the model to ONNX:
    matlab
    Copy
    exportONNXNetwork(netTransfer, 'resnet50_custom.onnx');
  • Generate C/C++ code with MATLAB Coder.

5. Tips

  • Small Datasets: Use heavy data augmentation and fine-tune only the final layers.
  • Large Datasets: Fine-tune deeper layers for better accuracy.
  • GPU Acceleration: Use trainingOptions('ExecutionEnvironment', 'gpu') for faster training.

6. Example Use Cases

  • Medical image classification (e.g., X-rays).
  • Object detection in autonomous vehicles.
  • Custom image recognition for industrial quality control.

Summary

Transfer learning with ResNet-50 in MATLAB involves:

  1. Loading the pre-trained network.
  2. Replacing the final layers with custom layers.
  3. Training on augmented data.
  4. Fine-tuning (optional) for improved performance.

For large datasets or domain-specific tasks, fine-tuning deeper layers often yields better results. Use handle classes (e.g., for managing datasets) to avoid memory duplication when working with large image data.