Transfer Learning with ResNet-50 in MATLAB
Here’s a structured guide to transfer learning with ResNet-50 in MATLAB, including key steps, code examples, and best practices:
Transfer Learning with ResNet-50 in MATLAB
1. Overview
- ResNet-50: A 50-layer deep convolutional neural network (CNN) pre-trained on the ImageNet dataset.
- Transfer Learning: Reuse the feature extraction layers of ResNet-50 and retrain the final layers on your custom dataset.
2. Key Steps
Step 1: Load Pre-trained ResNet-50
net = resnet50; % Load ResNet-50 inputSize = net.Layers(1).InputSize; % Typically [224 224 3]
- Requires the Deep Learning Toolbox and Deep Learning Toolbox Model for ResNet-50.
Step 2: Prepare Your Dataset
Use imageDatastore
to organize labeled images:
imds = imageDatastore('path/to/data', ... 'IncludeSubfolders', true, ... 'LabelSource', 'foldernames'); % Split into training and validation sets (e.g., 70-30 split) [imdsTrain, imdsVal] = splitEachLabel(imds, 0.7, 'randomized');
Step 3: Modify Network Architecture
- Remove Final Layers (classification layer and fully connected layer):
lgraph = layerGraph(net); lgraph = removeLayers(lgraph, {'fc1000', 'prob', 'ClassificationLayer_fc1000'});
- Add New Layers for your task:
numClasses = numel(categories(imdsTrain.Labels)); newLayers = [ fullyConnectedLayer(numClasses, 'Name', 'fcNew', 'WeightLearnRateFactor', 10, 'BiasLearnRateFactor', 10); softmaxLayer('Name', 'softmax'); classificationLayer('Name', 'classOutput'); ]; lgraph = addLayers(lgraph, newLayers); lgraph = connectLayers(lgraph, 'avg_pool', 'fcNew');
Step 4: Data Augmentation
Prevent overfitting with augmented images:
augmenter = imageDataAugmenter( ... 'RandXReflection', true, ... 'RandRotation', [-20 20], ... 'RandXScale', [0.8 1.2]); augimdsTrain = augmentedImageDatastore(inputSize, imdsTrain, ... 'DataAugmentation', augmenter);
Step 5: Training Options
Configure hyperparameters for training:
options = trainingOptions('sgdm', ... 'MiniBatchSize', 32, ... 'MaxEpochs', 10, ... 'InitialLearnRate', 1e-4, ... 'Shuffle', 'every-epoch', ... 'ValidationData', augimdsVal, ... 'ValidationFrequency', 50, ... 'Verbose', true, ... 'Plots', 'training-progress');
Step 6: Train the Network
netTransfer = trainNetwork(augimdsTrain, lgraph, options);
Step 7: Evaluate Performance
[YPred, probs] = classify(netTransfer, augimdsVal); accuracy = mean(YPred == imdsVal.Labels); % Confusion Matrix confusionchart(imdsVal.Labels, YPred);
3. Fine-Tuning (Optional)
To update weights in earlier layers (useful for small datasets):
% Unfreeze earlier layers (e.g., from 'res5a_branch2a' onward) for i = 1:numel(lgraph.Layers) if isprop(lgraph.Layers(i), 'WeightLearnRateFactor') lgraph.Layers(i).WeightLearnRateFactor = 1; lgraph.Layers(i).BiasLearnRateFactor = 1; end end % Retrain with a lower learning rate (e.g., 1e-5) options.InitialLearnRate = 1e-5;
4. Deployment
- Export the model to ONNX:
exportONNXNetwork(netTransfer, 'resnet50_custom.onnx');
- Generate C/C++ code with MATLAB Coder.
5. Tips
- Small Datasets: Use heavy data augmentation and fine-tune only the final layers.
- Large Datasets: Fine-tune deeper layers for better accuracy.
- GPU Acceleration: Use
trainingOptions('ExecutionEnvironment', 'gpu')
for faster training.
6. Example Use Cases
- Medical image classification (e.g., X-rays).
- Object detection in autonomous vehicles.
- Custom image recognition for industrial quality control.
Summary
Transfer learning with ResNet-50 in MATLAB involves:
- Loading the pre-trained network.
- Replacing the final layers with custom layers.
- Training on augmented data.
- Fine-tuning (optional) for improved performance.
For large datasets or domain-specific tasks, fine-tuning deeper layers often yields better results. Use handle classes (e.g., for managing datasets) to avoid memory duplication when working with large image data.