Преобразуйте learnable сетевые параметры в ONNXParameters к nonlearnable
params = freezeParameters( замораживает сетевые параметры, заданные params,names)names в ONNXParameters объект params. Функция перемещает заданные параметры от params.Learnables во входном параметре params к params.Nonlearnables в выходном аргументе params.
Импортируйте alexnet нейронная сеть свертки как функция и подстройка предварительно обученная сеть с передачей обучения, чтобы выполнить классификацию на новом наборе изображений.
Этот пример использует несколько функций помощника. Чтобы просмотреть код для этих функций, смотрите Функции Помощника.
Разархивируйте и загрузите новые изображения как datastore изображений. imageDatastore автоматически помечает изображения на основе имен папок и хранит данные как ImageDatastore объект. Datastore изображений позволяет вам сохранить большие данные изображения, включая данные, которые не умещаются в памяти, и эффективно считать пакеты изображений во время обучения сверточной нейронной сети. Задайте мини-пакетный размер.
unzip('MerchData.zip'); miniBatchSize = 8; imds = imageDatastore('MerchData', ... 'IncludeSubfolders',true, ... 'LabelSource','foldernames',... 'ReadSize', miniBatchSize);
Этот набор данных является небольшим, содержа 75 учебных изображений. Отобразите некоторые демонстрационные изображения.
numImages = numel(imds.Labels); idx = randperm(numImages,16); figure for i = 1:16 subplot(4,4,i) I = readimage(imds,idx(i)); imshow(I) end

Извлеките набор обучающих данных, и одногорячий кодируют категориальные метки классификации.
XTrain = readall(imds);
XTrain = single(cat(4,XTrain{:}));
YTrain_categ = categorical(imds.Labels);
YTrain = onehotencode(YTrain_categ,2)';Определите количество классов в данных.
classes = categories(YTrain_categ); numClasses = numel(classes)
numClasses = 5
AlexNet является сверточной нейронной сетью, которая обучена больше чем на миллионе изображений от базы данных ImageNet. В результате сеть изучила богатые представления функции для широкого спектра изображений. Сеть может классифицировать изображения в 1 000 категорий объектов, таких как клавиатура, мышь, карандаш и многие животные.
Импортируйте предварительно обученный alexnet сеть как функция.
alexnetONNX() params = importONNXFunction('alexnet.onnx','alexnetFcn')
A function containing the imported ONNX network has been saved to the file alexnetFcn.m. To learn how to use this function, type: help alexnetFcn.
params =
ONNXParameters with properties:
Learnables: [1×1 struct]
Nonlearnables: [1×1 struct]
State: [1×1 struct]
NumDimensions: [1×1 struct]
NetworkFunctionName: 'alexnetFcn'
params ONNXParameters объект, который содержит сетевые параметры. alexnetFcn функция модели, которая содержит сетевую архитектуру. importONNXFunction сохраняет alexnetFcn в текущей папке.
Вычислите точность классификации предварительно обученной сети на новом наборе обучающих данных.
accuracyBeforeTraining = getNetworkAccuracy(XTrain,YTrain,params);
fprintf('%.2f accuracy before transfer learning\n',accuracyBeforeTraining);0.01 accuracy before transfer learning
Точность является очень низкой.
Отобразите настраиваемые параметры сети. Эти параметры, например, веса (W) и смещение (B) из свертки и полносвязных слоев, обновляются сетью во время обучения. Параметры Nonlearnable остаются постоянными во время обучения.
params.Learnables
ans = struct with fields:
data_Mean: [227×227×3 dlarray]
conv1_W: [11×11×3×96 dlarray]
conv1_B: [96×1 dlarray]
conv2_W: [5×5×48×256 dlarray]
conv2_B: [256×1 dlarray]
conv3_W: [3×3×256×384 dlarray]
conv3_B: [384×1 dlarray]
conv4_W: [3×3×192×384 dlarray]
conv4_B: [384×1 dlarray]
conv5_W: [3×3×192×256 dlarray]
conv5_B: [256×1 dlarray]
fc6_W: [6×6×256×4096 dlarray]
fc6_B: [4096×1 dlarray]
fc7_W: [1×1×4096×4096 dlarray]
fc7_B: [4096×1 dlarray]
fc8_W: [1×1×4096×1000 dlarray]
fc8_B: [1000×1 dlarray]
Последние два настраиваемых параметра предварительно обученной сети сконфигурированы для 1 000 классов. Параметры fc8_W и fc8_B должен быть подстроен для новой проблемы классификации. Передайте параметры, чтобы классифицировать 5 классов путем инициализации их.
params.Learnables.fc8_B = rand(5,1); params.Learnables.fc8_W = rand(1,1,4096,5);
Заморозьте все параметры сети, чтобы преобразовать их в nonlearnable параметры. Поскольку вы не должны вычислять градиенты блокированных слоев, замораживание весов многих начальных слоев может значительно ускорить сетевое обучение.
params = freezeParameters(params,'all');Разморозьте последние два параметра сети, чтобы преобразовать их в настраиваемые параметры.
params = unfreezeParameters(params,'fc8_W'); params = unfreezeParameters(params,'fc8_B');
Теперь сеть готова к обучению. Инициализируйте график процесса обучения.
plots = "training-progress"; if plots == "training-progress" figure lineLossTrain = animatedline; xlabel("Iteration") ylabel("Loss") end
Задайте опции обучения.
velocity = []; numEpochs = 5; miniBatchSize = 16; numObservations = size(YTrain,2); numIterationsPerEpoch = floor(numObservations./miniBatchSize); initialLearnRate = 0.01; momentum = 0.9; decay = 0.01;
Обучите сеть.
iteration = 0; start = tic; executionEnvironment = "cpu"; % Change to "gpu" to train on a GPU. % Loop over epochs. for epoch = 1:numEpochs % Shuffle data. idx = randperm(numObservations); XTrain = XTrain(:,:,:,idx); YTrain = YTrain(:,idx); % Loop over mini-batches. for i = 1:numIterationsPerEpoch iteration = iteration + 1; % Read mini-batch of data. idx = (i-1)*miniBatchSize+1:i*miniBatchSize; X = XTrain(:,:,:,idx); Y = YTrain(:,idx); % If training on a GPU, then convert data to gpuArray. if (executionEnvironment == "auto" && canUseGPU) || executionEnvironment == "gpu" X = gpuArray(X); end % Evaluate the model gradients and loss using dlfeval and the % modelGradients function. [gradients,loss,state] = dlfeval(@modelGradients,X,Y,params); params.State = state; % Determine learning rate for time-based decay learning rate schedule. learnRate = initialLearnRate/(1 + decay*iteration); % Update the network parameters using the SGDM optimizer. [params.Learnables,velocity] = sgdmupdate(params.Learnables,gradients,velocity); % Display the training progress. if plots == "training-progress" D = duration(0,0,toc(start),'Format','hh:mm:ss'); addpoints(lineLossTrain,iteration,double(gather(extractdata(loss)))) title("Epoch: " + epoch + ", Elapsed: " + string(D)) drawnow end end end

Вычислите точность классификации сети после подстройки.
accuracyAfterTraining = getNetworkAccuracy(XTrain,YTrain,params);
fprintf('%.2f accuracy after transfer learning\n',accuracyAfterTraining);0.99 accuracy after transfer learning
Функции помощника
Этот раздел предоставляет код функций помощника, используемых в этом примере.
getNetworkAccuracy функция оценивает производительность сети путем вычисления точности классификации.
function accuracy = getNetworkAccuracy(X,Y,onnxParams) N = size(X,4); Ypred = alexnetFcn(X,onnxParams,'Training',false); [~,YIdx] = max(Y,[],1); [~,YpredIdx] = max(Ypred,[],1); numIncorrect = sum(abs(YIdx-YpredIdx) > 0); accuracy = 1 - numIncorrect/N; end
modelGradients функция вычисляет потерю и градиенты.
function [grad, loss, state] = modelGradients(X,Y,onnxParams) [y,state] = alexnetFcn(X,onnxParams,'Training',true); loss = crossentropy(y,Y,'DataFormat','CB'); grad = dlgradient(loss,onnxParams.Learnables); end
alexnetONNX функция генерирует модель ONNX alexnet сеть. Вам нужна Модель Deep Learning Toolbox для Сетевой поддержки AlexNet, чтобы получить доступ к этой модели.
function alexnetONNX() exportONNXNetwork(alexnet,'alexnet.onnx'); end
params — Сетевые параметрыONNXParameters объектСетевые параметры в виде ONNXParameters объект. params содержит сетевые параметры импортированной модели ONNX™.
names — Имена параметров, чтобы заморозиться'all' | массив строкИмена параметров, чтобы заморозиться в виде 'all' или массив строк. Заморозьте все настраиваемые параметры установкой names к 'all'. Заморозьте k настраиваемые параметры путем определения названий параметра в 1 k массив строк names.
Пример: 'all'
Пример: ["gpu_0_sl_pred_b_0", "gpu_0_sl_pred_w_0"]
Типы данных: char | string
params — Сетевые параметрыONNXParameters объектСетевые параметры, возвращенные как ONNXParameters объект. params содержит сетевые параметры, обновленные freezeParameters.
У вас есть модифицированная версия этого примера. Вы хотите открыть этот пример со своими редактированиями?
1. Если смысл перевода понятен, то лучше оставьте как есть и не придирайтесь к словам, синонимам и тому подобному. О вкусах не спорим.
2. Не дополняйте перевод комментариями “от себя”. В исправлении не должно появляться дополнительных смыслов и комментариев, отсутствующих в оригинале. Такие правки не получится интегрировать в алгоритме автоматического перевода.
3. Сохраняйте структуру оригинального текста - например, не разбивайте одно предложение на два.
4. Не имеет смысла однотипное исправление перевода какого-то термина во всех предложениях. Исправляйте только в одном месте. Когда Вашу правку одобрят, это исправление будет алгоритмически распространено и на другие части документации.
5. По иным вопросам, например если надо исправить заблокированное для перевода слово, обратитесь к редакторам через форму технической поддержки.