Преобразуйте параметры неведомой сети в ONNXParameters
к learnable
params = unfreezeParameters(
размораживает параметры сети, заданные params
,names
)names
в ONNXParameters
params объекта
. Функция перемещает указанные параметры из params.Nonlearnables
в входной параметр params
на params.Learnables
в выходном аргументе params
.
Импортируйте squeezenet
свертка нейронной сети как функции и подстройте предварительно обученную сеть с передачей обучения, чтобы выполнить классификацию на новом наборе изображений.
Этот пример использует несколько вспомогательных функций. Чтобы просмотреть код для этих функций, смотрите Вспомогательные функции.
Разархивируйте и загружайте новые изображения как image datastore. imageDatastore
автоматически помечает изображения на основе имен папок и сохраняет данные как ImageDatastore
объект. image datastore позволяет вам хранить большие данные изображения, включая данные, которые не помещаются в памяти, и эффективно считывать пакеты изображений во время обучения сверточной нейронной сети. Укажите размер мини-пакета.
unzip('MerchData.zip'); miniBatchSize = 8; imds = imageDatastore('MerchData', ... 'IncludeSubfolders',true, ... 'LabelSource','foldernames',... 'ReadSize', miniBatchSize);
Этот набор данных небольшой, содержащий 75 обучающих изображений. Отобразите некоторые образцовые изображения.
numImages = numel(imds.Labels); idx = randperm(numImages,16); figure for i = 1:16 subplot(4,4,i) I = readimage(imds,idx(i)); imshow(I) end
Извлеките набор обучающих данных и однократно закодируйте категориальные классификационные метки.
XTrain = readall(imds); XTrain = single(cat(4,XTrain{:})); YTrain_categ = categorical(imds.Labels); YTrain = onehotencode(YTrain_categ,2)';
Определите количество классов в данных.
classes = categories(YTrain_categ); numClasses = numel(classes)
numClasses = 5
squeezenet
- сверточная нейронная сеть, которая обучена более чем на миллионе изображений из базы данных ImageNet. В результате сеть узнала представления богатых функций для широкой области значений изображений. Сеть может классифицировать изображения по 1000 категориям объектов, таким как клавиатура, мышь, карандаш и многие животные.
Импортируйте предварительно обученную squeezenet
сеть как функция.
squeezenetONNX() params = importONNXFunction('squeezenet.onnx','squeezenetFcn')
A function containing the imported ONNX network has been saved to the file squeezenetFcn.m. To learn how to use this function, type: help squeezenetFcn.
params = ONNXParameters with properties: Learnables: [1×1 struct] Nonlearnables: [1×1 struct] State: [1×1 struct] NumDimensions: [1×1 struct] NetworkFunctionName: 'squeezenetFcn'
params
является ONNXParameters
объект, который содержит сетевые параметры. squeezenetFcn
является модельной функцией, которая содержит сетевую архитектуру. importONNXFunction
сохраняет squeezenetFcn
в текущей папке.
Вычислите классификационную точность предварительно обученной сети на новом наборе обучающих данных.
accuracyBeforeTraining = getNetworkAccuracy(XTrain,YTrain,params);
fprintf('%.2f accuracy before transfer learning\n',accuracyBeforeTraining);
0.01 accuracy before transfer learning
Точность очень низкая.
Отобразите настраиваемые параметры сети путем ввода params.Learnables
. Эти параметры, такие как веса (W
) и смещение (B
) свертки и полносвязные слои, обновляются сетью во время обучения. Непоследовательные параметры остаются постоянными во время обучения.
Последние два настраиваемых параметры предварительно обученной сети сконфигурированы для 1000 классов.
conv10_W: [1×1×512×1000 dlarray]
conv10_B: [1000×1 dlarray]
Параметры conv10_W
и conv10_B
должна быть уточнена для новой задачи классификации. Передайте параметры для классификации пяти классов путем инициализации параметров.
params.Learnables.conv10_W = rand(1,1,512,5); params.Learnables.conv10_B = rand(5,1);
Заморозите все параметры сети, чтобы преобразовать их в неучаемые параметры. Поскольку вам не нужно вычислять градиенты замороженных слоев, замораживание весов многих начальных слоев может значительно ускорить сетевое обучение.
params = freezeParameters(params,'all');
Размораживайте последние два параметра сети, чтобы преобразовать их в настраиваемые параметры.
params = unfreezeParameters(params,'conv10_W'); params = unfreezeParameters(params,'conv10_B');
Сейчас сеть готова к обучению. Инициализируйте график процесса обучения.
plots = "training-progress"; if plots == "training-progress" figure lineLossTrain = animatedline; xlabel("Iteration") ylabel("Loss") end
Задайте опции обучения.
velocity = []; numEpochs = 5; miniBatchSize = 16; numObservations = size(YTrain,2); numIterationsPerEpoch = floor(numObservations./miniBatchSize); initialLearnRate = 0.01; momentum = 0.9; decay = 0.01;
Обучите сеть.
iteration = 0; start = tic; executionEnvironment = "cpu"; % Change to "gpu" to train on a GPU. % Loop over epochs. for epoch = 1:numEpochs % Shuffle data. idx = randperm(numObservations); XTrain = XTrain(:,:,:,idx); YTrain = YTrain(:,idx); % Loop over mini-batches. for i = 1:numIterationsPerEpoch iteration = iteration + 1; % Read mini-batch of data. idx = (i-1)*miniBatchSize+1:i*miniBatchSize; X = XTrain(:,:,:,idx); Y = YTrain(:,idx); % If training on a GPU, then convert data to gpuArray. if (executionEnvironment == "auto" && canUseGPU) || executionEnvironment == "gpu" X = gpuArray(X); end % Evaluate the model gradients and loss using dlfeval and the % modelGradients function. [gradients,loss,state] = dlfeval(@modelGradients,X,Y,params); params.State = state; % Determine the learning rate for the time-based decay learning rate schedule. learnRate = initialLearnRate/(1 + decay*iteration); % Update the network parameters using the SGDM optimizer. [params.Learnables,velocity] = sgdmupdate(params.Learnables,gradients,velocity); % Display the training progress. if plots == "training-progress" D = duration(0,0,toc(start),'Format','hh:mm:ss'); addpoints(lineLossTrain,iteration,double(gather(extractdata(loss)))) title("Epoch: " + epoch + ", Elapsed: " + string(D)) drawnow end end end
Вычислите классификационную точность сети после подстройки.
accuracyAfterTraining = getNetworkAccuracy(XTrain,YTrain,params);
fprintf('%.2f accuracy after transfer learning\n',accuracyAfterTraining);
1.00 accuracy after transfer learning
Вспомогательные функции
В этом разделе представлен код вспомогательных функций, используемых в этом примере.
The getNetworkAccuracy
функция оценивает эффективность сети путем вычисления точности классификации.
function accuracy = getNetworkAccuracy(X,Y,onnxParams) N = size(X,4); Ypred = squeezenetFcn(X,onnxParams,'Training',false); [~,YIdx] = max(Y,[],1); [~,YpredIdx] = max(Ypred,[],1); numIncorrect = sum(abs(YIdx-YpredIdx) > 0); accuracy = 1 - numIncorrect/N; end
The modelGradients
функция вычисляет потери и градиенты.
function [grad, loss, state] = modelGradients(X,Y,onnxParams) [y,state] = squeezenetFcn(X,onnxParams,'Training',true); loss = crossentropy(y,Y,'DataFormat','CB'); grad = dlgradient(loss,onnxParams.Learnables); end
The squeezenetONNX
функция генерирует модель ONNX squeezenet
сети.
function squeezenetONNX() exportONNXNetwork(squeezenet,'squeezenet.onnx'); end
params
- Параметры сетиONNXParameters
объектПараметры сети, заданные как ONNXParameters
объект. params
содержит сетевые параметры импортированной модели ONNX™.
names
- Имена параметров для разморозки'all'
| строковые массивыИмена параметров для разморозки, заданные как 'all'
или строковые массивы. Размораживайте все неисчерпаемые параметры путем установки names
на 'all'
. Размораживание k
неисчерпаемые параметры путем определения имен параметров в 1-by- k
строковые массивы names
.
Пример: ["gpu_0_sl_pred_b_0", "gpu_0_sl_pred_w_0"]
Типы данных: char
| string
params
- Параметры сетиONNXParameters
объектПараметры сети, возвращенные как ONNXParameters
объект. params
содержит параметры сети, обновленные unfreezeParameters
.
У вас есть измененная версия этого примера. Вы хотите открыть этот пример с вашими правками?
1. Если смысл перевода понятен, то лучше оставьте как есть и не придирайтесь к словам, синонимам и тому подобному. О вкусах не спорим.
2. Не дополняйте перевод комментариями “от себя”. В исправлении не должно появляться дополнительных смыслов и комментариев, отсутствующих в оригинале. Такие правки не получится интегрировать в алгоритме автоматического перевода.
3. Сохраняйте структуру оригинального текста - например, не разбивайте одно предложение на два.
4. Не имеет смысла однотипное исправление перевода какого-то термина во всех предложениях. Исправляйте только в одном месте. Когда Вашу правку одобрят, это исправление будет алгоритмически распространено и на другие части документации.
5. По иным вопросам, например если надо исправить заблокированное для перевода слово, обратитесь к редакторам через форму технической поддержки.