Обучите сеть классификации изображений устойчивой к состязательным примерам

Этот пример показывает, как обучить нейронную сеть, которая является устойчивой к состязательным примерам, используя состязательное обучение метода FGSM.

Нейронные сети могут быть восприимчивы к явлению, известному как состязательные примеры [1], где очень маленькие изменения входа могут привести к его неправильной классификации. Эти изменения часто незаметны для людей.

Методы создания состязательных примеров включают FGSM [2] и основной итерационный метод (BIM) [3], также известный как проективный градиентный спуск [4]. Эти методы могут значительно снизить точность сети.

Можно использовать состязательное обучение [5] для обучения сетей, устойчивых к состязательным примерам. В этом примере показано, как:

  1. Обучите сеть классификации изображений.

  2. Исследуйте робастность сети, генерируя состязательные примеры.

  3. Обучите сеть классификации изображений, которая является устойчивой к состязательным примерам.

Загрузка обучающих данных

The digitTrain4DArrayData функция загружает изображения рукописных цифр и их цифровых меток. Создайте arrayDatastore объект для изображений и меток, а затем используйте combine функция для создания одного datastore, содержащего все обучающие данные.

rng default
[XTrain,TTrain] = digitTrain4DArrayData;

dsXTrain = arrayDatastore(XTrain,'IterationDimension',4);
dsTTrain = arrayDatastore(TTrain);

dsTrain = combine(dsXTrain,dsTTrain);

Извлеките имена классов.

classes = categories(TTrain);

Конструкция сетевой архитектуры

Определите сеть классификации изображений.

layers = [
    imageInputLayer([28 28 1],'Normalization','none','Name','input')
    convolution2dLayer(3,32,'Padding',1,'Name','conv1')
    reluLayer('Name','relu1')
    convolution2dLayer(3,64,'Padding',1,'Name','conv3')
    reluLayer('Name','relu3')
    maxPooling2dLayer(2,'Stride',2,'Name','pool')
    fullyConnectedLayer(10,'Name','fc2')
    softmaxLayer('Name','softmax')];
lgraph = layerGraph(layers);

Создайте dlnetwork объект из графика слоев.

dlnet = dlnetwork(lgraph);

Задайте функцию градиентов модели

Создайте функцию modelGradients, перечисленный в конце примера, который принимает как вход a dlnetwork объект и мини-пакет входных данных с соответствующими метками и возвращает градиенты потерь относительно настраиваемых параметров в сети и соответствующих потерь.

Обучите сеть

Обучите сеть с помощью пользовательского цикла обучения.

Задайте опции обучения. Обучайте на 30 эпох с размером мини-пакета 100 и скоростью обучения 0,01.

numEpochs = 30;
miniBatchSize = 100;
learnRate = 0.01;
executionEnvironment = "auto";

Создайте minibatchqueue объект, который обрабатывает и управляет мини-пакетами изображений во время обучения. Для каждого мини-пакета:

  • Используйте пользовательскую функцию мини-пакетной предварительной обработки preprocessMiniBatch (определено в конце этого примера), чтобы преобразовать метки в переменные с кодировкой с одним горячим контактом.

  • Форматируйте данные изображения с помощью меток размерностей 'SSCB' (пространственный, пространственный, канальный, пакетный).

  • Обучите на графическом процессоре, если он доступен. По умолчанию в minibatchqueue объект преобразует каждый выход в gpuArray при наличии графический процессор. Для использования графический процессор требуется Parallel Computing Toolbox™ и поддерживаемый графический процессор. Для получения информации о поддерживаемых устройствах смотрите Поддержку GPU by Release (Parallel Computing Toolbox).

mbq = minibatchqueue(dsTrain, ...
    'MiniBatchSize',miniBatchSize,...
    'MiniBatchFcn',@preprocessMiniBatch,...
    'MiniBatchFormat',{'SSCB',''});

Инициализируйте график процесса обучения.

figure
lineLossTrain = animatedline('Color',[0.85 0.325 0.098]);
ylim([0 inf])
xlabel("Iteration")
ylabel("Loss")
grid on

Инициализируйте параметр скорости для решателя SGDM.

velocity = [];

Обучите сеть с помощью пользовательского цикла обучения. Для каждой эпохи перетасуйте данные и закольцовывайте по мини-пакетам данных. Для каждого мини-пакета:

  • Оцените градиенты модели, состояние и потери с помощью dlfeval и modelGradients функционирует и обновляет состояние сети.

  • Обновляйте параметры сети с помощью sgdmupdate функция.

  • Отображение процесса обучения.

iteration = 0;
start = tic;

% Loop over epochs.
for epoch = 1:numEpochs
    % Shuffle data.
    shuffle(mbq)

    % Loop over mini-batches.
    while hasdata(mbq)
        iteration = iteration +1;

        % Read mini-batch of data.
        [dlX,dlT] = next(mbq);

        % If training on a GPU, then convert data to gpuArray.
        if (executionEnvironment == "auto" && canUseGPU) || executionEnvironment == "gpu"
            dlX = gpuArray(dlX);
            dlT = gpuArray(dlT);
        end

        % Evaluate the model gradients, state, and loss.
        [gradients,state,loss] = dlfeval(@modelGradients,dlnet,dlX,dlT);
        dlnet.State = state;

        % Update the network parameters using the SGDM optimizer.
        [dlnet,velocity] = sgdmupdate(dlnet,gradients,velocity,learnRate);

        % Display the training progress.
        D = duration(0,0,toc(start),'Format','hh:mm:ss');
        addpoints(lineLossTrain,iteration,loss)
        title("Epoch: " + epoch + ", Elapsed: " + string(D))
        drawnow
    end
end

Тестирование сети

Проверьте точность классификации сети путем оценки предсказаний сети на тестовые данные наборе.

Создайте minibatchqueue объект, содержащий тестовые данные.

[XTest,TTest] = digitTest4DArrayData;

dsXTest = arrayDatastore(XTest,'IterationDimension',4);
dsTTest = arrayDatastore(TTest);

dsTest = combine(dsXTest,dsTTest);

mbqTest = minibatchqueue(dsTest, ...
    'MiniBatchSize',miniBatchSize, ...
    'MiniBatchFcn',@preprocessMiniBatch, ...
    'MiniBatchFormat','SSCB');

Спрогнозируйте классы тестовых данных, используя обученную сеть и modelPredictions функция, заданная в конце этого примера.

YPred = modelPredictions(dlnet,mbqTest,classes);
acc = mean(YPred == TTest)
acc = 0.9866

Точность сети очень высока.

Тестирование сети с состязательными входами

Примените состязательные возмущения к входу изображениям и посмотрите, как это влияет на точность сети.

Можно сгенерировать состязательные примеры, используя такие методы, как FGSM и BIM. FGSM является простым методом, который делает один шаг в направлении градиента XL(X,T) функции потерь L, относительно изображения X вы хотите найти состязательный пример для и метки класса T. Состязательный пример вычисляется как

Xadv=X+ϵ.sign(XL(X,T)).

Параметр ϵ управляет тем, как отличаются состязательные примеры от оригинальных изображений. В этом примере значения пикселей находятся в диапазоне от 0 до 1, поэтому ϵ значение 0,1 изменяет каждое отдельное значение пикселя на до 10% от области значений. Значение ϵ зависит от шкалы изображения. Для примера, если ваше изображение находится между 0 и 255, вам нужно умножить это значение на 255.

BIM является простым улучшением FGSM, которое применяет FGSM в нескольких итерациях и применяет порог. После каждой итерации BIM зажимает возмущение, чтобы убедиться, что величина не превышает ϵ. Этот способ может привести к состязательным примерам с меньшими искажениями, чем FGSM. Для получения дополнительной информации о генерации состязательных примеров, смотрите Сгенерировать неотключенные и целенаправленные состязательные примеры для классификации изображений.

Создайте состязательные примеры с помощью BIM. Задайте epsilon до 0,1.

epsilon = 0.1;

Для BIM, размер возмущения управляется параметром α представление размера шага в каждой итерации. Это так, когда BIM обычно принимает много, меньших, шагов FGSM в направлении градиента.

Определите размер шага alpha и количество итераций.

alpha = 0.01;
numAdvIter = 20;

Используйте adversarialExamples функция (определенная в конце этого примера) для вычисления состязательных примеров с использованием BIM на наборе тестовых данных. Эта функция также возвращает новые предсказания для состязательных изображений.

reset(mbqTest)
[XAdv,YPredAdv] = adversarialExamples(dlnet,mbqTest,epsilon,alpha,numAdvIter,classes);

Вычислите точность сети на данных состязательного примера.

accAdversarial = mean(YPredAdv == TTest)
accAdversarial = 0.0114

Постройте график результатов.

visualizePredictions(XAdv,YPredAdv,TTest);

Вы можете увидеть, что точность сильно ухудшается BIM, хотя возмущение изображения вряд ли видно.

Обучите робастную сеть

Можно обучить сеть быть устойчивой против состязательных примеров. Одним из популярных методов является состязательная подготовка. Состязательное обучение включает применение состязательных возмущений к обучающим данным в процессе обучения [4] [5].

Состязательное обучение FGSM является быстрым и эффективным методом для обучения сети, чтобы быть устойчивой к состязательным примерам. FGSM аналогичен BIM, но он делает один больший шаг в направлении градиента, чтобы сгенерировать состязательное изображение.

Состязательное обучение включает применение метода FGSM к каждому мини-пакету обучающих данных. Однако для того, чтобы обучение было эффективным, эти критерии должны применяться:

  • Метод обучения FGSM должен использовать случайным образом инициализированное возмущение вместо возмущения, которое инициализировано как нуль.

  • Чтобы сеть была устойчива к возмущениям размера ϵ, выполните обучение FGSM со значением, немного большим, чем ϵ. В данном примере во время состязательного обучения вы возмущаете изображения, используя размер шага α=1.25ϵ.

Обучите новую сеть с состязательным обучением FGSM. Начните с использования той же необученной сетевой архитектуры, что и в исходной сети.

dlnetRobust = dlnetwork(lgraph);     

Определите параметры состязательного обучения. Установите количество итераций равным 1, так как FGSM эквивалентен BIM с одной итерацией. Случайным образом инициализируйте возмущение и возмущайте изображения, используя alpha.

numIter = 1;
initialization = "random";
alpha = 1.25*epsilon;

Инициализируйте график процесса обучения.

figure
lineLossRobustTrain = animatedline('Color',[0.85 0.325 0.098]);
ylim([0 inf])
xlabel("Iteration")
ylabel("Loss")
grid on

Обучите устойчивую сеть с помощью пользовательского цикла обучения и тех же опций обучения, которые были определены ранее. Этот цикл такой же, как и в предыдущем пользовательском обучении, но с добавлением состязательных возмущений.

velocity = [];
iteration = 0;
start = tic;

% Loop over epochs.
for epoch = 1:numEpochs
    % Shuffle data.
    shuffle(mbq)

    % Loop over mini-batches.
    while hasdata(mbq)
        iteration = iteration + 1;

        % Read mini-batch of data.
        [dlX,dlT] = next(mbq);

        %  If training on a GPU, then convert data to gpuArray.
        if (executionEnvironment == "auto" && canUseGPU) || executionEnvironment == "gpu"
            dlX = gpuArray(dlX);
            dlT = gpuArray(dlT);
        end

        % Apply adversarial perturbations to the data.
        dlX = basicIterativeMethod(dlnetRobust,dlX,dlT,alpha,epsilon, ...
            numIter,initialization);

        % Evaluate the model gradients, state, and loss.
        [gradients,state,loss] = dlfeval(@modelGradients,dlnetRobust,dlX,dlT);
        dlnet.State = state;

        % Update the network parameters using the SGDM optimizer.
        [dlnetRobust,velocity] = sgdmupdate(dlnetRobust,gradients,velocity,learnRate);

        % Display the training progress.
        D = duration(0,0,toc(start),'Format','hh:mm:ss');
        addpoints(lineLossRobustTrain,iteration,loss)
        title("Epoch: " + epoch + ", Elapsed: " + string(D))
        drawnow
    end
end

Тестируйте робастную сеть

Вычислите точность устойчивой сети по тестовым данным цифр. Точность устойчивой сети может быть немного ниже, чем неробустовая сеть на стандартных данных.

reset(mbqTest)
YPred = modelPredictions(dlnetRobust,mbqTest,classes);
accRobust = mean(YPred == TTest)
accRobust = 0.9972

Вычислите точность состязания.

reset(mbqTest)
[XAdv,YPredAdv] = adversarialExamples(dlnetRobust,mbqTest,epsilon,alpha,numAdvIter,classes);
accRobustAdv = mean(YPredAdv == TTest)
accRobustAdv = 0.7558

Состязательная точность устойчивой сети намного лучше, чем у исходной сети.

Вспомогательные функции

Функция градиентов модели

The modelGradients функция принимает как вход dlnetwork dlnet объекта и мини-пакет входных данных dlX с соответствующими метками T и возвращает градиенты потерь относительно настраиваемых параметров в dlnet, состояние сети и потери. Чтобы вычислить градиенты автоматически, используйте dlgradient функция.

function [gradients,state,loss] = modelGradients(dlnet,dlX,T)

[dlYPred,state] = forward(dlnet,dlX);

loss = crossentropy(dlYPred,T);
gradients = dlgradient(loss,dlnet.Learnables);

loss = double(gather(extractdata(loss)));

end

Входная функция градиентов

The modelGradientsInput функция принимает как вход dlnetwork dlnet объекта и мини-пакет входных данных dlX с соответствующими метками T и возвращает градиенты потерь относительно входных данных dlX.

function gradient = modelGradientsInput(dlnet,dlX,T)

T = squeeze(T);
T = dlarray(T,'CB');

[dlYPred] = forward(dlnet,dlX);

loss = crossentropy(dlYPred,T);
gradient = dlgradient(loss,dlX);

end

Функция мини-пакетной предварительной обработки

The preprocessMiniBatch функция предварительно обрабатывает мини-пакет предикторов и меток с помощью следующих шагов:

  1. Извлеките данные изображения из входящего массива ячеек и соедините в четырехмерный массив.

  2. Извлеките данные метки из входящего массива ячеек и сгруппируйте в категориальный массив по второму измерению.

  3. Однократное кодирование категориальных меток в числовые массивы. Кодирование в первую размерность создает закодированный массив, который совпадает с формой выходного сигнала сети.

function [X,T] = preprocessMiniBatch(XCell,TCell)

% Concatenate.
X = cat(4,XCell{1:end});

X = single(X);

% Extract label data from the cell and concatenate.
T = cat(2,TCell{1:end});

% One-hot encode labels.
T = onehotencode(T,1);

end

Функция предсказаний модели

The modelPredictions функция принимает как вход dlnetwork dlnet объекта, а minibatchqueue входных данных mbq, и сетевых классов, и вычисляет предсказания модели путем итерации по всем данным в minibatchqueue объект. Функция использует onehotdecode функция для поиска предсказанного класса с самым высоким счетом.

function predictions = modelPredictions(dlnet,mbq,classes)

predictions = [];

while hasdata(mbq)
    
    dlXTest = next(mbq);
    dlYPred = predict(dlnet,dlXTest);
    
    YPred = onehotdecode(dlYPred,classes,1)';
    
    predictions = [predictions; YPred];
end

end

Функция состязательных примеров

Сгенерируйте состязательные примеры для minibatchqueue объект с использованием базового итерационного метода (BIM) и предсказание класса состязательных примеров с помощью обученной сети dlnet.

function [XAdv,predictions] = adversarialExamples(dlnet,mbq,epsilon,alpha,numIter,classes)

XAdv = {};
predictions = [];
iteration = 0;

% Generate adversarial images for each mini-batch.
while hasdata(mbq)

    iteration = iteration +1;
    [dlX,dlT] = next(mbq);

    initialization = "zero";
    
    % Generate adversarial images.
    XAdvMBQ = basicIterativeMethod(dlnet,dlX,dlT,alpha,epsilon, ...
        numIter,initialization);

    % Predict the class of the adversarial images.
    dlYPred = predict(dlnet,XAdvMBQ);
    YPred = onehotdecode(dlYPred,classes,1)';

    XAdv{iteration} = XAdvMBQ;
    predictions = [predictions; YPred];
end

% Concatenate.
XAdv = cat(4,XAdv{:});

end

Функция базового итерационного метода

Сгенерируйте состязательные примеры, используя основной итерационный метод (BIM). Этот метод запускается для нескольких итераций с порогом в конце каждой итерации, чтобы убедиться, что значения не превышают epsilon. Когда numIter установлено равным 1, это эквивалентно использованию метода знака быстрого градиента (FGSM).

function XAdv = basicIterativeMethod(dlnet,dlX,dlT,alpha,epsilon,numIter,initialization)

% Initialize the perturbation.
if initialization == "zero"
    delta = zeros(size(dlX),'like',dlX);
else
    delta = epsilon*(2*rand(size(dlX),'like',dlX) - 1);
end

for i = 1:numIter
  
    % Apply adversarial perturbations to the data.
    gradient = dlfeval(@modelGradientsInput,dlnet,dlX+delta,dlT);
    delta = delta + alpha*sign(gradient);
    delta(delta > epsilon) = epsilon;
    delta(delta < -epsilon) = -epsilon;
end

XAdv = dlX + delta;

end

Функция визуализации результатов предсказания

Визуализируйте изображения вместе с их предсказанными классами. Для правильных предсказаний используется зеленый текст. В неправильных предсказаниях используется красный текст.

function visualizePredictions(XTest,YPred,TTest)

figure
height = 4;
width = 4;
numImages = height*width;

% Select random images from the data.
indices = randperm(size(XTest,4),numImages);

XTest = extractdata(XTest);
XTest = XTest(:,:,:,indices);
YPred = YPred(indices);
TTest = TTest(indices);

% Plot images with the predicted label.
for i = 1:(numImages)
    subplot(height,width,i)
    imshow(XTest(:,:,:,i))

    % If the prediction is correct, use green. If the prediction is false,
    % use red.
    if YPred(i) == TTest(i)
        color = "\color{green}";
    else
        color = "\color{red}";
    end
    title("Prediction: " + color + string(YPred(i)))
end

end

Ссылки

[1] Сегеди, Кристиан, Войцех Заремба, Илья Суцкевер, Жоан Бруна, Думитру Эрхан, Иан Гудфеллоу и Роб Фергус. «Интригующие свойства нейронных сетей». Препринт, представленный 19 февраля 2014 года. https://arxiv.org/abs/1312.6199.

[2] Гудфеллоу, Ян Дж., Джонатон Шленс и Кристиан Сегеди. «Объяснение и использование состязательных примеров». Препринт, представленный 20 марта 2015 года. https://arxiv.org/abs/1412.6572.

[3] Куракин, Алексей, Ян Гудфеллоу и Сами Бенгио. «Состязательные примеры в физическом мире». Препринт, представленный 10 февраля 2017 года. https://arxiv.org/abs/1607.02533.

[4] Мадри, Александер, Александар Макелов, Людвиг Шмидт, Димитрис Ципрас, и Адриан Владу. «К моделям глубокого обучения, устойчивым к состязательным атакам». Препринт, представленный 4 сентября 2019 года. https://arxiv.org/abs/1706.06083.

[5] Вонг, Эрик, Лесли Райс и Дж. Зико Колтер. Fast Is Better than Free: Revisiting Conversarial Training (неопр.) (недоступная ссылка). Препринт, представленный 12 января 2020 года. https://arxiv.org/abs/2001.03994.

См. также

| | | |

Похожие темы