exponenta event banner

Сеть классификации изображений Train Устойчивые к состязательности Примеры

Этот пример показывает, как обучить нейронную сеть, которая устойчива к состязательным примерам с использованием метода быстрого градиентного знака (FGSM) состязательного обучения.

Нейронные сети могут быть восприимчивы к явлению, известному как состязательные примеры [1], где очень небольшие изменения входных данных могут привести к его неправильной классификации. Эти изменения часто незаметны для людей.

Методы создания состязательных примеров включают FGSM [2] и основной итеративный метод (BIM) [3], также известный как проектируемый градиентный спуск [4]. Эти методы могут значительно ухудшить точность сети.

Вы можете использовать обучение состязательности [5] для обучения сетей, которые являются надежными для состязательности примеров. В этом примере показано, как:

  1. Обучение сети классификации изображений.

  2. Исследование надежности сети путем создания примеров состязательности.

  3. Обучение сети классификации изображений, которая является надежной для состязательных примеров.

Загрузка данных обучения

digitTrain4DArrayData функция загружает изображения рукописных цифр и их цифровые метки. Создание arrayDatastore объект для изображений и меток, а затем используйте combine создание единого хранилища данных, содержащего все обучающие данные.

rng default
[XTrain,TTrain] = digitTrain4DArrayData;

dsXTrain = arrayDatastore(XTrain,'IterationDimension',4);
dsTTrain = arrayDatastore(TTrain);

dsTrain = combine(dsXTrain,dsTTrain);

Извлеките имена классов.

classes = categories(TTrain);

Создание сетевой архитектуры

Определите сеть классификации изображений.

layers = [
    imageInputLayer([28 28 1],'Normalization','none','Name','input')
    convolution2dLayer(3,32,'Padding',1,'Name','conv1')
    reluLayer('Name','relu1')
    convolution2dLayer(3,64,'Padding',1,'Name','conv3')
    reluLayer('Name','relu3')
    maxPooling2dLayer(2,'Stride',2,'Name','pool')
    fullyConnectedLayer(10,'Name','fc2')
    softmaxLayer('Name','softmax')];
lgraph = layerGraph(layers);

Создать dlnetwork объект из графа слоев.

dlnet = dlnetwork(lgraph);

Определение функции градиентов модели

Создание функции modelGradients, перечисленных в конце примера, который принимает в качестве входных данных dlnetwork объект и мини-пакет входных данных с соответствующими метками и возвращает градиенты потерь относительно обучаемых параметров в сети и соответствующих потерь.

Железнодорожная сеть

Обучение сети с использованием пользовательского цикла обучения.

Укажите параметры обучения. Обучение в течение 30 эпох с размером мини-партии 100 и скоростью обучения 0,01.

numEpochs = 30;
miniBatchSize = 100;
learnRate = 0.01;
executionEnvironment = "auto";

Создать minibatchqueue объект, обрабатывающий и управляющий мини-партиями изображений во время обучения. Для каждой мини-партии:

  • Использование пользовательской функции предварительной обработки мини-партии preprocessMiniBatch (определяется в конце этого примера) для преобразования меток в однокодированные переменные.

  • Форматирование данных изображения с метками размеров 'SSCB' (пространственный, пространственный, канальный, пакетный).

  • Обучение на GPU, если он доступен. По умолчанию minibatchqueue объект преобразует каждый вывод в gpuArray если графический процессор доступен. Для использования графического процессора требуется Toolbox™ параллельных вычислений и поддерживаемое устройство графического процессора. Сведения о поддерживаемых устройствах см. в разделе Поддержка графического процессора по выпуску (Parallel Computing Toolbox).

mbq = minibatchqueue(dsTrain, ...
    'MiniBatchSize',miniBatchSize,...
    'MiniBatchFcn',@preprocessMiniBatch,...
    'MiniBatchFormat',{'SSCB',''});

Инициализируйте график хода обучения.

figure
lineLossTrain = animatedline('Color',[0.85 0.325 0.098]);
ylim([0 inf])
xlabel("Iteration")
ylabel("Loss")
grid on

Инициализируйте параметр скорости для решателя SGDM.

velocity = [];

Обучение сети с использованием пользовательского цикла обучения. Для каждой эпохи тасуйте данные и закольцовывайте мини-пакеты данных. Для каждой мини-партии:

  • Оцените градиенты модели, состояние и потери с помощью dlfeval и modelGradients и обновить состояние сети.

  • Обновление параметров сети с помощью sgdmupdate функция.

  • Просмотрите ход обучения.

iteration = 0;
start = tic;

% Loop over epochs.
for epoch = 1:numEpochs
    % Shuffle data.
    shuffle(mbq)

    % Loop over mini-batches.
    while hasdata(mbq)
        iteration = iteration +1;

        % Read mini-batch of data.
        [dlX,dlT] = next(mbq);

        % If training on a GPU, then convert data to gpuArray.
        if (executionEnvironment == "auto" && canUseGPU) || executionEnvironment == "gpu"
            dlX = gpuArray(dlX);
            dlT = gpuArray(dlT);
        end

        % Evaluate the model gradients, state, and loss.
        [gradients,state,loss] = dlfeval(@modelGradients,dlnet,dlX,dlT);
        dlnet.State = state;

        % Update the network parameters using the SGDM optimizer.
        [dlnet,velocity] = sgdmupdate(dlnet,gradients,velocity,learnRate);

        % Display the training progress.
        D = duration(0,0,toc(start),'Format','hh:mm:ss');
        addpoints(lineLossTrain,iteration,loss)
        title("Epoch: " + epoch + ", Elapsed: " + string(D))
        drawnow
    end
end

Тестовая сеть

Проверка точности классификации сети путем оценки сетевых прогнозов на наборе тестовых данных.

Создать minibatchqueue объект, содержащий тестовые данные.

[XTest,TTest] = digitTest4DArrayData;

dsXTest = arrayDatastore(XTest,'IterationDimension',4);
dsTTest = arrayDatastore(TTest);

dsTest = combine(dsXTest,dsTTest);

mbqTest = minibatchqueue(dsTest, ...
    'MiniBatchSize',miniBatchSize, ...
    'MiniBatchFcn',@preprocessMiniBatch, ...
    'MiniBatchFormat','SSCB');

Прогнозирование классов тестовых данных с использованием обученной сети и modelPredictions функция, определенная в конце этого примера.

YPred = modelPredictions(dlnet,mbqTest,classes);
acc = mean(YPred == TTest)
acc = 0.9866

Точность сети очень высока.

Тестовая сеть с состязательными входами

Примените состязательные возмущения к входным изображениям и посмотрите, как это влияет на точность сети.

Можно создать примеры состязательности с использованием таких методов, как FGSM и BIM. FGSM - это простая техника, которая делает один шаг в направлении градиентного ∇XL (X, T) функции потерь L, относительно изображения X, для которого вы хотите найти пример состязательности, и метки класса T. Пример состязательности вычисляется как

Xadv=X+ϵ.sign (∇XL (X, T)).

Параметр ϵ определяет, как отличаются примеры состязательности от исходных изображений. В этом примере значения пикселей находятся в диапазоне от 0 до 1, поэтому ϵ значение 0,1 изменяет каждое значение каждого отдельного пикселя до 10% диапазона. Значение ϵ зависит от масштаба изображения. Например, если изображение находится в диапазоне от 0 до 255, необходимо умножить это значение на 255.

BIM представляет собой простое усовершенствование FGSM, которое применяет FGSM к нескольким итерациям и применяет пороговое значение. После каждой итерации BIM отсекает возмущение, чтобы гарантировать, что величина не превышает ϵ. Этот метод может давать состязательные примеры с меньшим искажением, чем FGSM. Дополнительные сведения о создании примеров состязательности см. в разделе Создание нецелевых и целевых примеров состязательности для классификации изображений.

Создайте примеры состязательности с помощью BIM. Набор epsilon до 0,1.

epsilon = 0.1;

Для BIM размер возмущения управляется параметром α, представляющим размер шага в каждой итерации. Это так, как BIM обычно делает много, меньших шагов FGSM в направлении градиента.

Определение размера шага alpha и количество итераций.

alpha = 0.01;
numAdvIter = 20;

Используйте adversarialExamples функция (определенная в конце этого примера) для вычисления состязательных примеров с использованием BIM в наборе тестовых данных. Эта функция также возвращает новые прогнозы для состязательных образов.

reset(mbqTest)
[XAdv,YPredAdv] = adversarialExamples(dlnet,mbqTest,epsilon,alpha,numAdvIter,classes);

Вычислите точность сети на основе данных состязательного примера.

accAdversarial = mean(YPredAdv == TTest)
accAdversarial = 0.0114

Постройте график результатов.

visualizePredictions(XAdv,YPredAdv,TTest);

Вы можете видеть, что точность сильно ухудшается BIM, даже несмотря на то, что возмущение изображения почти не видно.

Надежная сеть поездов

Можно обучить сеть быть надежной против состязательных примеров. Одним из популярных методов является состязательная подготовка. Состязательное обучение включает применение состязательных возмущений к данным обучения в процессе обучения [4] [5].

Состязательное обучение FGSM является быстрым и эффективным методом обучения сети быть надежной к состязательным примерам. FGSM подобен BIM, но он делает один больший шаг в направлении градиента, чтобы сформировать состязательное изображение.

Состязательное обучение включает применение метода FGSM к каждой мини-партии данных обучения. Однако для того, чтобы обучение было эффективным, должны применяться следующие критерии:

  • Метод обучения FGSM должен использовать произвольно инициализированное возмущение вместо возмущения, которое инициализировано равным нулю.

  • Чтобы сеть была устойчива к возмущениям размера ϵ, выполните обучение FGSM со значением, немного большим ϵ. В этом примере во время состязательного обучения изображения возмущаются с помощью α=1.25ϵ размера шага.

Обучение новой сети с состязательным обучением FGSM. Начните с использования той же необученной сетевой архитектуры, что и в исходной сети.

dlnetRobust = dlnetwork(lgraph);     

Определите параметры состязательного обучения. Установите число итераций равным 1, поскольку FGSM эквивалентен BIM с одной итерацией. Случайная инициализация возмущения и возмущение изображений с помощью alpha.

numIter = 1;
initialization = "random";
alpha = 1.25*epsilon;

Инициализируйте график хода обучения.

figure
lineLossRobustTrain = animatedline('Color',[0.85 0.325 0.098]);
ylim([0 inf])
xlabel("Iteration")
ylabel("Loss")
grid on

Обучение надежной сети с использованием пользовательского цикла обучения и тех же параметров обучения, которые были определены ранее. Этот цикл тот же, что и в предыдущем пользовательском обучении, но с добавленным состязательным возмущением.

velocity = [];
iteration = 0;
start = tic;

% Loop over epochs.
for epoch = 1:numEpochs
    % Shuffle data.
    shuffle(mbq)

    % Loop over mini-batches.
    while hasdata(mbq)
        iteration = iteration + 1;

        % Read mini-batch of data.
        [dlX,dlT] = next(mbq);

        %  If training on a GPU, then convert data to gpuArray.
        if (executionEnvironment == "auto" && canUseGPU) || executionEnvironment == "gpu"
            dlX = gpuArray(dlX);
            dlT = gpuArray(dlT);
        end

        % Apply adversarial perturbations to the data.
        dlX = basicIterativeMethod(dlnetRobust,dlX,dlT,alpha,epsilon, ...
            numIter,initialization);

        % Evaluate the model gradients, state, and loss.
        [gradients,state,loss] = dlfeval(@modelGradients,dlnetRobust,dlX,dlT);
        dlnet.State = state;

        % Update the network parameters using the SGDM optimizer.
        [dlnetRobust,velocity] = sgdmupdate(dlnetRobust,gradients,velocity,learnRate);

        % Display the training progress.
        D = duration(0,0,toc(start),'Format','hh:mm:ss');
        addpoints(lineLossRobustTrain,iteration,loss)
        title("Epoch: " + epoch + ", Elapsed: " + string(D))
        drawnow
    end
end

Тестирование надежной сети

Вычислите точность надежной сети по данным тестирования цифр. Точность надежной сети может быть несколько ниже, чем ненадежная сеть на стандартных данных.

reset(mbqTest)
YPred = modelPredictions(dlnetRobust,mbqTest,classes);
accRobust = mean(YPred == TTest)
accRobust = 0.9972

Вычислите состязательную точность.

reset(mbqTest)
[XAdv,YPredAdv] = adversarialExamples(dlnetRobust,mbqTest,epsilon,alpha,numAdvIter,classes);
accRobustAdv = mean(YPredAdv == TTest)
accRobustAdv = 0.7558

Состязательная точность надежной сети намного лучше, чем у исходной сети.

Вспомогательные функции

Функция градиентов модели

modelGradients функция принимает в качестве входного значения a dlnetwork объект dlnet и мини-пакет входных данных dlX с соответствующими метками Т и возвращает градиенты потерь относительно обучаемых параметров в dlnet, состояние сети и потери. Для автоматического вычисления градиентов используйте dlgradient функция.

function [gradients,state,loss] = modelGradients(dlnet,dlX,T)

[dlYPred,state] = forward(dlnet,dlX);

loss = crossentropy(dlYPred,T);
gradients = dlgradient(loss,dlnet.Learnables);

loss = double(gather(extractdata(loss)));

end

Функция входных градиентов

modelGradientsInput функция принимает в качестве входного значения a dlnetwork объект dlnet и мини-пакет входных данных dlX с соответствующими метками Т и возвращает градиенты потерь по отношению к входным данным dlX.

function gradient = modelGradientsInput(dlnet,dlX,T)

T = squeeze(T);
T = dlarray(T,'CB');

[dlYPred] = forward(dlnet,dlX);

loss = crossentropy(dlYPred,T);
gradient = dlgradient(loss,dlX);

end

Функция предварительной обработки мини-партий

preprocessMiniBatch функция предварительно обрабатывает мини-пакет предикторов и меток, используя следующие шаги:

  1. Извлеките данные изображения из входящего массива ячеек и объедините их в четырехмерный массив.

  2. Извлеките данные метки из входящего массива ячеек и объедините их в категориальный массив вдоль второго измерения.

  3. Одноконтурное кодирование категориальных меток в числовые массивы. Кодирование в первом измерении создает закодированный массив, который соответствует форме сетевого вывода.

function [X,T] = preprocessMiniBatch(XCell,TCell)

% Concatenate.
X = cat(4,XCell{1:end});

X = single(X);

% Extract label data from the cell and concatenate.
T = cat(2,TCell{1:end});

% One-hot encode labels.
T = onehotencode(T,1);

end

Функция прогнозирования модели

modelPredictions функция принимает в качестве входного значения a dlnetwork объект dlnet, a minibatchqueue входных данных mbqи сетевые классы, и вычисляет предсказания модели путем итерации по всем данным в minibatchqueue объект. Функция использует onehotdecode функция для поиска прогнозируемого класса с наивысшим баллом.

function predictions = modelPredictions(dlnet,mbq,classes)

predictions = [];

while hasdata(mbq)
    
    dlXTest = next(mbq);
    dlYPred = predict(dlnet,dlXTest);
    
    YPred = onehotdecode(dlYPred,classes,1)';
    
    predictions = [predictions; YPred];
end

end

Функция состязательных примеров

Создание примеров состязательности для minibatchqueue объект с использованием основного итеративного метода (BIM) и прогнозирование класса состязательных примеров с использованием обученной сети dlnet.

function [XAdv,predictions] = adversarialExamples(dlnet,mbq,epsilon,alpha,numIter,classes)

XAdv = {};
predictions = [];
iteration = 0;

% Generate adversarial images for each mini-batch.
while hasdata(mbq)

    iteration = iteration +1;
    [dlX,dlT] = next(mbq);

    initialization = "zero";
    
    % Generate adversarial images.
    XAdvMBQ = basicIterativeMethod(dlnet,dlX,dlT,alpha,epsilon, ...
        numIter,initialization);

    % Predict the class of the adversarial images.
    dlYPred = predict(dlnet,XAdvMBQ);
    YPred = onehotdecode(dlYPred,classes,1)';

    XAdv{iteration} = XAdvMBQ;
    predictions = [predictions; YPred];
end

% Concatenate.
XAdv = cat(4,XAdv{:});

end

Функция основного итеративного метода

Создание состязательных примеров с использованием основного итеративного метода (BIM). Этот метод выполняется для нескольких итераций с порогом в конце каждой итерации, чтобы гарантировать, что значения не превышают epsilon. Когда numIter имеет значение 1, что эквивалентно использованию метода быстрых градиентных знаков (FGSM).

function XAdv = basicIterativeMethod(dlnet,dlX,dlT,alpha,epsilon,numIter,initialization)

% Initialize the perturbation.
if initialization == "zero"
    delta = zeros(size(dlX),'like',dlX);
else
    delta = epsilon*(2*rand(size(dlX),'like',dlX) - 1);
end

for i = 1:numIter
  
    % Apply adversarial perturbations to the data.
    gradient = dlfeval(@modelGradientsInput,dlnet,dlX+delta,dlT);
    delta = delta + alpha*sign(gradient);
    delta(delta > epsilon) = epsilon;
    delta(delta < -epsilon) = -epsilon;
end

XAdv = dlX + delta;

end

Визуализация функции результатов прогнозирования

Визуализация изображений вместе с их прогнозируемыми классами. В правильных прогнозах используется зеленый текст. В неправильных прогнозах используется красный текст.

function visualizePredictions(XTest,YPred,TTest)

figure
height = 4;
width = 4;
numImages = height*width;

% Select random images from the data.
indices = randperm(size(XTest,4),numImages);

XTest = extractdata(XTest);
XTest = XTest(:,:,:,indices);
YPred = YPred(indices);
TTest = TTest(indices);

% Plot images with the predicted label.
for i = 1:(numImages)
    subplot(height,width,i)
    imshow(XTest(:,:,:,i))

    % If the prediction is correct, use green. If the prediction is false,
    % use red.
    if YPred(i) == TTest(i)
        color = "\color{green}";
    else
        color = "\color{red}";
    end
    title("Prediction: " + color + string(YPred(i)))
end

end

Ссылки

[1] Сегеди, Кристиан, Войцех Заремба, Илья Суцкевер, Йоан Бруна, Думитру Эрхан, Иэн Гудфеллоу и Роб Фергус. «Интригующие свойства нейронных сетей». Препринт, представлен 19 февраля 2014 года. https://arxiv.org/abs/1312.6199.

[2] Гудфеллоу, Йен Дж., Джонатон Шленс и Кристиан Сегеди. «Объяснение и использование состязательных примеров». Препринт, представлен 20 марта 2015 года. https://arxiv.org/abs/1412.6572.

[3] Куракин, Алексей, Ян Гудфеллоу и Сами Бенгио. «Состязательные примеры в физическом мире». Препринт, представлен 10 февраля 2017 года. https://arxiv.org/abs/1607.02533.

[4] Мэдри, Александр, Александр Мэкелов, Людвиг Шмидт, Димитрис Ципрас и Эдриан Влэду. «К моделям глубокого обучения, устойчивым к противоборствующим атакам». Препринт, представлен 4 сентября 2019 года. https://arxiv.org/abs/1706.06083.

[5] Вонг, Эрик, Лесли Райс и Дж. Зико Колтер. «Быстрее лучше, чем бесплатно: повторное посещение состязательного обучения». Препринт, представлен 12 января 2020 года. https://arxiv.org/abs/2001.03994.

См. также

| | | |

Связанные темы