exponenta event banner

freezeParameters

Преобразовать доступные для изучения сетевые параметры в ONNXParameters к nonlearnable

    Описание

    пример

    params = freezeParameters(params,names) замораживает параметры сети, указанные names в ONNXParameters объект params. Функция перемещает указанные параметры из params.Learnables во входном аргументе params кому params.Nonlearnables в выходном аргументе params.

    Примеры

    свернуть все

    Импорт squeezenet нейронная сеть свертки как функция и точная настройка предварительно обученной сети с обучением передаче для выполнения классификации по новой коллекции изображений.

    В этом примере используется несколько вспомогательных функций. Для просмотра кода этих функций см. раздел Вспомогательные функции.

    Распакуйте и загрузите новые образы как хранилище данных образов. imageDatastore автоматически помечает изображения на основе имен папок и сохраняет данные в виде ImageDatastore объект. Хранилище данных изображения позволяет хранить большие данные изображения, включая данные, которые не помещаются в память, и эффективно считывать партии изображений во время обучения сверточной нейронной сети. Укажите размер мини-партии.

    unzip('MerchData.zip');
    miniBatchSize = 8;
    imds = imageDatastore('MerchData', ...
        'IncludeSubfolders',true, ...
        'LabelSource','foldernames',...
        'ReadSize', miniBatchSize);

    Этот набор данных невелик и содержит 75 тренировочных изображений. Отображение некоторых образцов изображений.

    numImages = numel(imds.Labels);
    idx = randperm(numImages,16);
    figure
    for i = 1:16
        subplot(4,4,i)
        I = readimage(imds,idx(i));
        imshow(I)
    end

    Извлеките обучающий набор и одноконтактно закодируйте классификационные метки категорий.

    XTrain = readall(imds);
    XTrain = single(cat(4,XTrain{:}));
    YTrain_categ = categorical(imds.Labels);
    YTrain = onehotencode(YTrain_categ,2)';

    Определите количество классов в данных.

    classes = categories(YTrain_categ);
    numClasses = numel(classes)
    numClasses = 5
    

    squeezenet является сверточной нейронной сетью, которая обучена на более чем миллионе изображений из базы данных ImageNet. В результате сеть получила богатые представления элементов для широкого спектра изображений. Сеть может классифицировать изображения на 1000 категорий объектов, таких как клавиатура, мышь, карандаш и многие животные.

    Импорт предварительно подготовленных squeezenet сеть как функция.

    squeezenetONNX()
    params = importONNXFunction('squeezenet.onnx','squeezenetFcn')
    A function containing the imported ONNX network has been saved to the file squeezenetFcn.m.
    To learn how to use this function, type: help squeezenetFcn.
    
    params = 
      ONNXParameters with properties:
    
                 Learnables: [1×1 struct]
              Nonlearnables: [1×1 struct]
                      State: [1×1 struct]
              NumDimensions: [1×1 struct]
        NetworkFunctionName: 'squeezenetFcn'
    
    

    params является ONNXParameters объект, содержащий параметры сети. squeezenetFcn - функция модели, содержащая сетевую архитектуру. importONNXFunction экономит squeezenetFcn в текущей папке.

    Рассчитайте точность классификации предварительно обученной сети на новом обучающем наборе.

    accuracyBeforeTraining = getNetworkAccuracy(XTrain,YTrain,params);
    fprintf('%.2f accuracy before transfer learning\n',accuracyBeforeTraining);
    0.01 accuracy before transfer learning
    

    Точность очень низкая.

    Отображение обучаемых параметров сети путем ввода params.Learnables. Эти параметры, такие как веса (W) и предвзятость (B) свёртки и полностью соединенных слоев, обновляются сетью во время обучения. Нечеткие параметры остаются постоянными во время обучения.

    Последние два обучаемых параметра предварительно обученной сети сконфигурированы для 1000 классов.

    conv10_W: [1×1×512×1000 dlarray]

    conv10_B: [1000×1 dlarray]

    Параметры conv10_W и conv10_B необходимо выполнить точную настройку для новой проблемы классификации. Передайте параметры для классификации пяти классов путем инициализации параметров.

    params.Learnables.conv10_W = rand(1,1,512,5);
    params.Learnables.conv10_B = rand(5,1);

    Заморозите все параметры сети, чтобы преобразовать их в неочищаемые параметры. Поскольку не нужно вычислять градиенты замороженных слоев, замораживание весов многих начальных слоев может значительно ускорить обучение сети.

    params = freezeParameters(params,'all');

    Разморозите последние два параметра сети, чтобы преобразовать их в обучаемые параметры.

    params = unfreezeParameters(params,'conv10_W');
    params = unfreezeParameters(params,'conv10_B');

    Сейчас сеть готова к тренировкам. Инициализируйте график хода обучения.

    plots = "training-progress";
    if plots == "training-progress"
        figure
        lineLossTrain = animatedline;
        xlabel("Iteration")
        ylabel("Loss")
    end

    Укажите параметры обучения.

    velocity = [];
    numEpochs = 5;
    miniBatchSize = 16;
    numObservations = size(YTrain,2);
    numIterationsPerEpoch = floor(numObservations./miniBatchSize);
    initialLearnRate = 0.01;
    momentum = 0.9;
    decay = 0.01;

    Обучение сети.

    iteration = 0;
    start = tic;
    executionEnvironment = "cpu"; % Change to "gpu" to train on a GPU.
    
    % Loop over epochs.
    for epoch = 1:numEpochs
        
        % Shuffle data.
        idx = randperm(numObservations);
        XTrain = XTrain(:,:,:,idx);
        YTrain = YTrain(:,idx);
        
        % Loop over mini-batches.
        for i = 1:numIterationsPerEpoch
            iteration = iteration + 1;
            
            % Read mini-batch of data.
            idx = (i-1)*miniBatchSize+1:i*miniBatchSize;
            X = XTrain(:,:,:,idx);        
            Y = YTrain(:,idx);
            
            % If training on a GPU, then convert data to gpuArray.
            if (executionEnvironment == "auto" && canUseGPU) || executionEnvironment == "gpu"
                X = gpuArray(X);         
            end
            
            % Evaluate the model gradients and loss using dlfeval and the
            % modelGradients function.
            [gradients,loss,state] = dlfeval(@modelGradients,X,Y,params);
            params.State = state;
            
            % Determine the learning rate for the time-based decay learning rate schedule.
            learnRate = initialLearnRate/(1 + decay*iteration);
            
            % Update the network parameters using the SGDM optimizer.
            [params.Learnables,velocity] = sgdmupdate(params.Learnables,gradients,velocity);
            
            % Display the training progress.
            if plots == "training-progress"
                D = duration(0,0,toc(start),'Format','hh:mm:ss');
                addpoints(lineLossTrain,iteration,double(gather(extractdata(loss))))
                title("Epoch: " + epoch + ", Elapsed: " + string(D))
                drawnow
            end
        end
    end

    Вычислите точность классификации сети после точной настройки.

    accuracyAfterTraining = getNetworkAccuracy(XTrain,YTrain,params);
    fprintf('%.2f accuracy after transfer learning\n',accuracyAfterTraining);
    1.00 accuracy after transfer learning
    

    Вспомогательные функции

    В этом разделе представлен код вспомогательных функций, используемых в данном примере.

    getNetworkAccuracy функция оценивает производительность сети, вычисляя точность классификации.

    function accuracy = getNetworkAccuracy(X,Y,onnxParams)
    
    N = size(X,4);
    Ypred = squeezenetFcn(X,onnxParams,'Training',false);
    
    [~,YIdx] = max(Y,[],1);
    [~,YpredIdx] = max(Ypred,[],1);
    numIncorrect = sum(abs(YIdx-YpredIdx) > 0);
    accuracy = 1 - numIncorrect/N;
    
    end

    modelGradients функция вычисляет потери и градиенты.

    function [grad, loss, state] = modelGradients(X,Y,onnxParams)
    
    [y,state] = squeezenetFcn(X,onnxParams,'Training',true);
    loss = crossentropy(y,Y,'DataFormat','CB');
    grad = dlgradient(loss,onnxParams.Learnables);
    
    end

    squeezenetONNX генерирует модель ONNX squeezenet сеть.

    function squeezenetONNX()
        
    exportONNXNetwork(squeezenet,'squeezenet.onnx');
    
    end
    

    Входные аргументы

    свернуть все

    Параметры сети, указанные как ONNXParameters объект. params содержит сетевые параметры импортированной модели ONNX™.

    Имена параметров для замораживания, указанные как 'all' или строковый массив. Замораживание всех обучаемых параметров с помощью параметра names кому 'all'. Замораживание k обучаемые параметры путем определения имен параметров в 1-by-k строковый массив names.

    Пример: 'all'

    Пример: ["gpu_0_sl_pred_b_0", "gpu_0_sl_pred_w_0"]

    Типы данных: char | string

    Выходные аргументы

    свернуть все

    Параметры сети, возвращенные как ONNXParameters объект. params содержит параметры сети, обновленные freezeParameters.

    Представлен в R2020b