Преобразуйте learnable сетевые параметры в ONNXParameters
к nonlearnable
params = freezeParameters(
замораживает сетевые параметры, заданные params
,names
)names
в ONNXParameters
объект params
. Функция перемещает заданные параметры от params.Learnables
во входном параметре params
к params.Nonlearnables
в выходном аргументе params
.
Импортируйте alexnet
нейронная сеть свертки как функция и подстройка предварительно обученная сеть с передачей обучения, чтобы выполнить классификацию на новом наборе изображений.
Этот пример использует несколько функций помощника. Чтобы просмотреть код для этих функций, смотрите Функции Помощника.
Разархивируйте и загрузите новые изображения как datastore изображений. imageDatastore
автоматически помечает изображения на основе имен папок и хранит данные как ImageDatastore
объект. Datastore изображений позволяет вам сохранить большие данные изображения, включая данные, которые не умещаются в памяти, и эффективно считать пакеты изображений во время обучения сверточной нейронной сети. Задайте мини-пакетный размер.
unzip('MerchData.zip'); miniBatchSize = 8; imds = imageDatastore('MerchData', ... 'IncludeSubfolders',true, ... 'LabelSource','foldernames',... 'ReadSize', miniBatchSize);
Этот набор данных является небольшим, содержа 75 учебных изображений. Отобразите некоторые демонстрационные изображения.
numImages = numel(imds.Labels); idx = randperm(numImages,16); figure for i = 1:16 subplot(4,4,i) I = readimage(imds,idx(i)); imshow(I) end
Извлеките набор обучающих данных, и одногорячий кодируют категориальные метки классификации.
XTrain = readall(imds); XTrain = single(cat(4,XTrain{:})); YTrain_categ = categorical(imds.Labels); YTrain = onehotencode(YTrain_categ,2)';
Определите количество классов в данных.
classes = categories(YTrain_categ); numClasses = numel(classes)
numClasses = 5
AlexNet является сверточной нейронной сетью, которая обучена больше чем на миллионе изображений от базы данных ImageNet. В результате сеть изучила богатые представления функции для широкого спектра изображений. Сеть может классифицировать изображения в 1 000 категорий объектов, таких как клавиатура, мышь, карандаш и многие животные.
Импортируйте предварительно обученный alexnet
сеть как функция.
alexnetONNX() params = importONNXFunction('alexnet.onnx','alexnetFcn')
A function containing the imported ONNX network has been saved to the file alexnetFcn.m. To learn how to use this function, type: help alexnetFcn.
params = ONNXParameters with properties: Learnables: [1×1 struct] Nonlearnables: [1×1 struct] State: [1×1 struct] NumDimensions: [1×1 struct] NetworkFunctionName: 'alexnetFcn'
params
ONNXParameters
объект, который содержит сетевые параметры. alexnetFcn
функция модели, которая содержит сетевую архитектуру. importONNXFunction
сохраняет alexnetFcn
в текущей папке.
Вычислите точность классификации предварительно обученной сети на новом наборе обучающих данных.
accuracyBeforeTraining = getNetworkAccuracy(XTrain,YTrain,params);
fprintf('%.2f accuracy before transfer learning\n',accuracyBeforeTraining);
0.01 accuracy before transfer learning
Точность является очень низкой.
Отобразите настраиваемые параметры сети. Эти параметры, например, веса (W
) и смещение (B
) из свертки и полносвязных слоев, обновляются сетью во время обучения. Параметры Nonlearnable остаются постоянными во время обучения.
params.Learnables
ans = struct with fields:
data_Mean: [227×227×3 dlarray]
conv1_W: [11×11×3×96 dlarray]
conv1_B: [96×1 dlarray]
conv2_W: [5×5×48×256 dlarray]
conv2_B: [256×1 dlarray]
conv3_W: [3×3×256×384 dlarray]
conv3_B: [384×1 dlarray]
conv4_W: [3×3×192×384 dlarray]
conv4_B: [384×1 dlarray]
conv5_W: [3×3×192×256 dlarray]
conv5_B: [256×1 dlarray]
fc6_W: [6×6×256×4096 dlarray]
fc6_B: [4096×1 dlarray]
fc7_W: [1×1×4096×4096 dlarray]
fc7_B: [4096×1 dlarray]
fc8_W: [1×1×4096×1000 dlarray]
fc8_B: [1000×1 dlarray]
Последние два настраиваемых параметра предварительно обученной сети сконфигурированы для 1 000 классов. Параметры fc8_W
и fc8_B
должен быть подстроен для новой проблемы классификации. Передайте параметры, чтобы классифицировать 5 классов путем инициализации их.
params.Learnables.fc8_B = rand(5,1); params.Learnables.fc8_W = rand(1,1,4096,5);
Заморозьте все параметры сети, чтобы преобразовать их в nonlearnable параметры. Поскольку вы не должны вычислять градиенты блокированных слоев, замораживание весов многих начальных слоев может значительно ускорить сетевое обучение.
params = freezeParameters(params,'all');
Разморозьте последние два параметра сети, чтобы преобразовать их в настраиваемые параметры.
params = unfreezeParameters(params,'fc8_W'); params = unfreezeParameters(params,'fc8_B');
Теперь сеть готова к обучению. Инициализируйте график процесса обучения.
plots = "training-progress"; if plots == "training-progress" figure lineLossTrain = animatedline; xlabel("Iteration") ylabel("Loss") end
Задайте опции обучения.
velocity = []; numEpochs = 5; miniBatchSize = 16; numObservations = size(YTrain,2); numIterationsPerEpoch = floor(numObservations./miniBatchSize); initialLearnRate = 0.01; momentum = 0.9; decay = 0.01;
Обучите сеть.
iteration = 0; start = tic; executionEnvironment = "cpu"; % Change to "gpu" to train on a GPU. % Loop over epochs. for epoch = 1:numEpochs % Shuffle data. idx = randperm(numObservations); XTrain = XTrain(:,:,:,idx); YTrain = YTrain(:,idx); % Loop over mini-batches. for i = 1:numIterationsPerEpoch iteration = iteration + 1; % Read mini-batch of data. idx = (i-1)*miniBatchSize+1:i*miniBatchSize; X = XTrain(:,:,:,idx); Y = YTrain(:,idx); % If training on a GPU, then convert data to gpuArray. if (executionEnvironment == "auto" && canUseGPU) || executionEnvironment == "gpu" X = gpuArray(X); end % Evaluate the model gradients and loss using dlfeval and the % modelGradients function. [gradients,loss,state] = dlfeval(@modelGradients,X,Y,params); params.State = state; % Determine learning rate for time-based decay learning rate schedule. learnRate = initialLearnRate/(1 + decay*iteration); % Update the network parameters using the SGDM optimizer. [params.Learnables,velocity] = sgdmupdate(params.Learnables,gradients,velocity); % Display the training progress. if plots == "training-progress" D = duration(0,0,toc(start),'Format','hh:mm:ss'); addpoints(lineLossTrain,iteration,double(gather(extractdata(loss)))) title("Epoch: " + epoch + ", Elapsed: " + string(D)) drawnow end end end
Вычислите точность классификации сети после подстройки.
accuracyAfterTraining = getNetworkAccuracy(XTrain,YTrain,params);
fprintf('%.2f accuracy after transfer learning\n',accuracyAfterTraining);
0.99 accuracy after transfer learning
Функции помощника
Этот раздел предоставляет код функций помощника, используемых в этом примере.
getNetworkAccuracy
функция оценивает производительность сети путем вычисления точности классификации.
function accuracy = getNetworkAccuracy(X,Y,onnxParams) N = size(X,4); Ypred = alexnetFcn(X,onnxParams,'Training',false); [~,YIdx] = max(Y,[],1); [~,YpredIdx] = max(Ypred,[],1); numIncorrect = sum(abs(YIdx-YpredIdx) > 0); accuracy = 1 - numIncorrect/N; end
modelGradients
функция вычисляет потерю и градиенты.
function [grad, loss, state] = modelGradients(X,Y,onnxParams) [y,state] = alexnetFcn(X,onnxParams,'Training',true); loss = crossentropy(y,Y,'DataFormat','CB'); grad = dlgradient(loss,onnxParams.Learnables); end
alexnetONNX
функция генерирует модель ONNX alexnet
сеть. Вам нужна Модель Deep Learning Toolbox для Сетевой поддержки AlexNet, чтобы получить доступ к этой модели.
function alexnetONNX() exportONNXNetwork(alexnet,'alexnet.onnx'); end
params
— Сетевые параметрыONNXParameters
объектСетевые параметры в виде ONNXParameters
объект. params
содержит сетевые параметры импортированной модели ONNX™.
names
— Имена параметров, чтобы заморозиться'all'
| массив строкИмена параметров, чтобы заморозиться в виде 'all'
или массив строк. Заморозьте все настраиваемые параметры установкой names
к 'all'
. Заморозьте k
настраиваемые параметры путем определения названий параметра в 1 k
массив строк names
.
Пример: 'all'
Пример: ["gpu_0_sl_pred_b_0", "gpu_0_sl_pred_w_0"]
Типы данных: char |
string
params
— Сетевые параметрыONNXParameters
объектСетевые параметры, возвращенные как ONNXParameters
объект. params
содержит сетевые параметры, обновленные freezeParameters
.
У вас есть модифицированная версия этого примера. Вы хотите открыть этот пример со своими редактированиями?
1. Если смысл перевода понятен, то лучше оставьте как есть и не придирайтесь к словам, синонимам и тому подобному. О вкусах не спорим.
2. Не дополняйте перевод комментариями “от себя”. В исправлении не должно появляться дополнительных смыслов и комментариев, отсутствующих в оригинале. Такие правки не получится интегрировать в алгоритме автоматического перевода.
3. Сохраняйте структуру оригинального текста - например, не разбивайте одно предложение на два.
4. Не имеет смысла однотипное исправление перевода какого-то термина во всех предложениях. Исправляйте только в одном месте. Когда Вашу правку одобрят, это исправление будет алгоритмически распространено и на другие части документации.
5. По иным вопросам, например если надо исправить заблокированное для перевода слово, обратитесь к редакторам через форму технической поддержки.