exponenta event banner

Сеть поездов по изображениям и данным функций

В этом примере показано, как обучить сеть, которая классифицирует рукописные цифры, используя как изображение, так и входные данные функции.

Загрузка данных обучения

Загрузить изображения цифр XTrain, этикетки YTrain, и углы поворота по часовой стрелке anglesTrain. Создание arrayDatastore объект для изображений, меток и углов, а затем используйте combine для создания единого хранилища данных, содержащего все данные обучения. Извлеките имена классов и высоту, ширину, количество каналов и количество недискретных ответов.

[XTrain,YTrain,anglesTrain] = digitTrain4DArrayData;

dsXTrain = arrayDatastore(XTrain,'IterationDimension',4);
dsAnglesTrain = arrayDatastore(anglesTrain);
dsYTrain = arrayDatastore(YTrain);

dsTrain = combine(dsXTrain,dsAnglesTrain,dsYTrain);

classes = categories(YTrain);
[h,w,c,numObservations] = size(XTrain);

Отображение 20 случайных тренировочных изображений.

numTrainImages = numel(YTrain);
figure
idx = randperm(numTrainImages,20);
for i = 1:numel(idx)
    subplot(4,5,i)    
    imshow(XTrain(:,:,:,idx(i)))
    title("Angle: " + anglesTrain(idx(i)))
end

Определение сети

Определите размер входного изображения, количество элементов каждого наблюдения, количество классов, а также размер и количество фильтров слоя свертки.

imageInputSize = [h w c];
numFeatures = 1;
numClasses = numel(classes);
filterSize = 5;
numFilters = 16;

Чтобы создать сеть с двумя входными слоями, необходимо определить сеть из двух частей и соединить их, например, с помощью слоя конкатенации.

Определите первую часть сети. Определите слои классификации изображений и включите слой конкатенации перед последним полностью соединенным слоем.

layers = [
    imageInputLayer(imageInputSize,'Normalization','none','Name','images')
    convolution2dLayer(filterSize,numFilters,'Name','conv')
    reluLayer('Name','relu')
    fullyConnectedLayer(50,'Name','fc1')
    concatenationLayer(1,2,'Name','concat')
    fullyConnectedLayer(numClasses,'Name','fc2')
    softmaxLayer('Name','softmax')];

Преобразование слоев в график слоев.

lgraph = layerGraph(layers);

Для второй части сети добавьте уровень ввода элемента и подключите его ко второму входу уровня конкатенации.

featInput = featureInputLayer(numFeatures,'Name','features');
lgraph = addLayers(lgraph, featInput);
lgraph = connectLayers(lgraph, 'features', 'concat/in2');

Визуализация сети.

figure
plot(lgraph)

Создать dlnetwork объект.

dlnet = dlnetwork(lgraph);

При использовании функций predict и forward на dlnetwork , входные аргументы должны соответствовать порядку, заданному InputNames имущества dlnetwork объект. Проверьте имя и порядок входных слоев.

dlnet.InputNames
ans = 1×2 cell
    {'images'}    {'features'}

Определение функции градиентов модели

modelGradients функция, перечисленная в разделе «Функция градиентов модели» примера, принимает в качестве входного значения dlnetwork объект dlnet, мини-пакет входных данных изображения dlX1, мини-пакет входных данных о характеристиках dlX2и соответствующие метки dlYи возвращает градиенты потерь относительно обучаемых параметров в dlnet, состояние сети и потери.

Укажите параметры обучения

Поезд с размером мини-партии 128 на 15 эпох.

numEpochs = 15;
miniBatchSize = 128;

Укажите параметры оптимизации SGDM. Укажите начальную скорость обучения 0.01 с распадом 0.01 и импульс 0.9.

learnRate = 0.01;
decay = 0.01;
momentum = 0.9;

Для контроля хода обучения можно построить график потерь обучения после каждой итерации. Создание переменной plots который содержит "training-progress". Если график хода обучения не требуется, установите для этого значения значение "none".

plots = "training-progress";

Модель поезда

Обучение модели с помощью пользовательского цикла обучения. Инициализируйте параметр скорости для решателя SGDM.

velocity = [];

Использовать minibatchqueue обрабатывать и управлять мини-партиями изображений во время обучения. Для каждой мини-партии:

  • Использование пользовательской функции предварительной обработки мини-партии preprocessData (определяется в конце этого примера) для одноконтактного кодирования меток класса.

  • По умолчанию minibatchqueue объект преобразует данные в dlarray объекты с базовым типом single. Форматирование изображений с метками размеров 'SSCB' (пространственные, пространственные, канальные, пакетные) и углы с метками размеров 'CB' (канал, партия). Не добавляйте формат к меткам класса.

  • Обучение на GPU, если он доступен. По умолчанию minibatchqueue объект преобразует каждый вывод в gpuArray если графический процессор доступен. Для использования графического процессора требуется Toolbox™ параллельных вычислений и поддерживаемое устройство графического процессора. Сведения о поддерживаемых устройствах см. в разделе Поддержка графического процессора по выпуску (Parallel Computing Toolbox).

mbq = minibatchqueue(dsTrain,...
    'MiniBatchSize',miniBatchSize,...
    'MiniBatchFcn', @preprocessMiniBatch,...
    'MiniBatchFormat',{'SSCB','CB',''});

Для каждой эпохи тасуйте данные и закольцовывайте мини-пакеты данных. В конце каждой эпохи отобразите ход обучения. Для каждой мини-партии:

  • Оценка градиентов, состояния и потерь модели с помощью dlfeval и modelGradients и обновить состояние сети.

  • Обновление параметров сети с помощью sgdmupdate функция.

Инициализируйте график хода обучения.

if plots == "training-progress"
    figure
    lineLossTrain = animatedline('Color',[0.85 0.325 0.098]);
    ylim([0 inf])
    xlabel("Iteration")
    ylabel("Loss")
    grid on
end

Тренируйте модель.

iteration = 0;
start = tic;

% Loop over epochs.
for epoch = 1:numEpochs
    
    % Shuffle data.
    shuffle(mbq)
    
    % Loop over mini-batches.
    while hasdata(mbq)

        iteration = iteration + 1;
        
        % Read mini-batch of data.
        [dlX1,dlX2,dlY] = next(mbq);
        
        % Evaluate the model gradients, state, and loss using dlfeval and the
        % modelGradients function and update the network state.
        [gradients,state,loss] = dlfeval(@modelGradients,dlnet,dlX1,dlX2,dlY);
        dlnet.State = state;
        
        % Update the network parameters using the SGDM optimizer.
        [dlnet, velocity] = sgdmupdate(dlnet, gradients, velocity, learnRate, momentum);
        
        if plots == "training-progress"
            % Display the training progress.
            D = duration(0,0,toc(start),'Format','hh:mm:ss');
            %completionPercentage = round(iteration/numIterations*100,0);
            title("Epoch: " + epoch + ", Elapsed: " + string(D));
            addpoints(lineLossTrain,iteration,double(gather(extractdata(loss))))
            drawnow
        end
    end
end

Тестовая модель

Проверьте точность классификации модели путем сравнения прогнозов на тестовом наборе с истинными метками. Проверка точности классификации модели путем сравнения прогнозов на тестовом наборе с истинными метками и углами. Управление набором тестовых данных с помощью minibatchqueue объект с той же настройкой, что и данные обучения.

[XTest,YTest,anglesTest] = digitTest4DArrayData;

dsXTest = arrayDatastore(XTest,'IterationDimension',4);
dsAnglesTest = arrayDatastore(anglesTest);
dsYTest = arrayDatastore(YTest);

dsTest = combine(dsXTest,dsAnglesTest,dsYTest);

mbqTest = minibatchqueue(dsTest,...
    'MiniBatchSize',miniBatchSize,...
    'MiniBatchFcn', @preprocessMiniBatch,...
    'MiniBatchFormat',{'SSCB','CB',''});

Закольцовывать мини-пакеты и классифицировать изображения с помощью modelPredictions функция, перечисленная в конце примера.

[predictions,predCorr] = modelPredictions(dlnet,mbqTest,classes); 

Оцените точность классификации.

accuracy = mean(predCorr)
accuracy = 0.9818

Просмотрите некоторые изображения с их прогнозами.

idx = randperm(size(XTest,4),9);
figure
for i = 1:9
    subplot(3,3,i)
    I = XTest(:,:,:,idx(i));
    imshow(I)

    label = string(predictions(idx(i)));
    title("Predicted Label: " + label)
end

Функция градиентов модели

modelGradients функция принимает в качестве входного значения a dlnetwork объект dlnet, мини-пакет входных данных изображения dlX1, мини-пакет входных данных о характеристиках dlX2, и соответствующие метки Yи возвращает градиенты потерь относительно обучаемых параметров в dlnet, состояние сети и потери. Для автоматического вычисления градиентов используйте dlgradient функция.

При использовании forward функция на dlnetwork , входные аргументы должны соответствовать порядку, заданному InputNames имущества dlnetwork объект.

function [gradients,state,loss] = modelGradients(dlnet,dlX1,dlX2,Y)

[dlYPred,state] = forward(dlnet,dlX1,dlX2);

loss = crossentropy(dlYPred,Y);
gradients = dlgradient(loss,dlnet.Learnables);

end

Функция прогнозирования модели

modelPredictions функция принимает в качестве входного значения a dlnetwork объект dlnet, a minibatchqueue входных данных mbqи сетевые классы, и вычисляет предсказания модели путем итерации по всем данным в minibatchqueue объект. Функция использует onehotdecode функция, чтобы найти прогнозируемый класс с наивысшим баллом, а затем сравнивает прогноз с истинной меткой. Функция возвращает предсказания и вектор единиц и нулей, который представляет правильные и неправильные предсказания.

function [classesPredictions,classCorr] = modelPredictions(dlnet,mbq,classes)

    classesPredictions = [];    
    classCorr = [];  
    
    while hasdata(mbq)
        [dlX1,dlX2,dlY] = next(mbq);
        
        % Make predictions using the model function.
        dlYPred = predict(dlnet,dlX1,dlX2);
        
        % Determine predicted classes.
        YPredBatch = onehotdecode(dlYPred,classes,1);
        classesPredictions = [classesPredictions YPredBatch];
                
        % Compare predicted and true classes.
        Y = onehotdecode(dlY,classes,1);
        classCorr = [classCorr YPredBatch == Y];
                
    end

end

Функция предварительной обработки мини-партий

preprocessMiniBatch функция выполняет предварительную обработку данных с помощью следующих шагов:

  1. Извлеките данные изображения из входящего массива ячеек и объедините их в числовой массив. Конкатенация данных изображения над четвертым размером добавляет к каждому изображению третий размер, который будет использоваться в качестве размера одиночного канала.

  2. Извлеките данные метки и угла из входящих массивов ячеек и объедините их вдоль второго размера в категориальный массив и числовой массив соответственно.

  3. Одноконтурное кодирование категориальных меток в числовые массивы. Кодирование в первом измерении создает закодированный массив, который соответствует форме сетевого вывода.

function [X,angle,Y] = preprocessMiniBatch(XCell,angleCell,YCell)
    
    % Extract image data from cell and concatenate.
    X = cat(4,XCell{:});
    % Extract angle data from cell and concatenate.
    angle = cat(2,angleCell{:});
    % Extract label data from cell and concatenate.
    Y = cat(2,YCell{:});    
        
    % One-hot encode labels.
    Y = onehotencode(Y,1);
    
end

См. также

| | | | | | | |

Связанные примеры

Подробнее