В этом примере показано, как сократить параметры обученной нейронной сети с помощью двух метрик счета параметра: счет величины [1] и Синаптический счет Потока [2].
Во многих приложениях, где передача обучения используется, чтобы переобучить сеть классификации изображений для новой задачи или когда новая сеть обучена с нуля, оптимальная сетевая архитектура не известна и сетевая сила, которая будет сверхпараметрирована. Сверхпараметрированная сеть имеет избыточные связи.. Структурированное сокращение, также известное sparsification, является методом сжатия, который стремится идентифицировать избыточные, ненужные связи, которые можно удалить, не влияя на сетевую точность. Когда вы используете сокращение в сочетании с сетевым квантованием, можно уменьшить скорость вывода и объем потребляемой памяти сети, облегчающей развертываться.
В этом примере показано, как к:
Выполните постобучение, итеративное, неструктурированное сокращение без потребности в обучающих данных
Оцените эффективность двух различных алгоритмов сокращения
Исследуйте мудрую слоем разреженность, вызванную после сокращения
Оцените удар сокращения на точности классификации
Оцените удар квантования на точности классификации сокращенной сети
Этот пример использует простую сверточную нейронную сеть, чтобы классифицировать рукописные цифры от 0 до 9. Для получения дополнительной информации о подготовке данных, используемых для обучения и валидации, смотрите, Создают Простую сеть глубокого обучения для Классификации.
Загрузите данные об обучении и валидации. Обучите сверточную нейронную сеть задаче классификации.
[imdsTrain, imdsValidation] = loadDigitDataset; net = trainDigitDataNetwork(imdsTrain, imdsValidation); trueLabels = imdsValidation.Labels; classes = categories(trueLabels);
Создайте minibatchqueue
объект, содержащий данные о валидации. SetexecutionEnvironment к автоматическому, чтобы оценить сеть на графическом процессоре, если вы доступны. По умолчанию, minibatchqueue
объект преобразует каждый выход в gpuArray
если графический процессор доступен. Используя графический процессор требует Parallel Computing Toolbox™ и поддерживаемого устройства графического процессора. Для получения информации о поддерживаемых устройствах смотрите Поддержку графического процессора Релизом (Parallel Computing Toolbox).
executionEnvironment = "auto"; miniBatchSize = 128; imdsValidation.ReadSize = miniBatchSize; mbqValidation = minibatchqueue (imdsValidation, 1,... 'MiniBatchSize', miniBatchSize,... 'MiniBatchFormat','SSCB',... 'MiniBatchFcn',@preprocessMiniBatch,... 'OutputEnvironment'Среда выполнения;
Цель сокращения нейронной сети состоит в том, чтобы идентифицировать и удалить неважные связи, чтобы уменьшать размер сети, не влияя на сетевую точность. В следующем рисунке, слева, сеть имеет связи, которые сопоставляют каждый нейрон с нейроном следующего слоя. После сокращения сеть имеет меньше связей, чем исходная сеть.
Алгоритм сокращения присваивает счет каждому параметру в сети. Счет оценивает важность каждой связи в сети. Можно использовать один из двух подходов сокращения, чтобы достигнуть целевой разреженности:
Сокращение одного выстрела - Удаляет заданный процент связей на основе их счета за один шаг. Этот метод подвержен коллапсу слоя, когда вы задаете высокое значение разреженности.
Итеративное сокращение - Достигает целевой разреженности в серии итеративных шагов. Можно использовать этот метод, когда оцененные баллы чувствительны к структуре сети. Баллы переоценены в каждой итерации, так использование серии шагов позволяет сети перемещаться к разреженности инкрементно.
Этот пример использует итеративный метод сокращения, чтобы достигнуть целевой разреженности.
dlnetwork Object
В этом примере вы используете Синаптический алгоритм Потока, который требует, чтобы вы создали пользовательскую функцию стоимости и оценили градиенты относительно функции стоимости, чтобы вычислить счет параметра. Чтобы создать пользовательскую функцию стоимости, сначала преобразуйте предварительно обученную сеть в dlnetwork
.
Преобразуйте сеть в график слоев и удалите слои, используемые для классификации с помощью removeLayers
.
lgraph = layerGraph(net.Layers); lgraph = removeLayers(lgraph,["softmax","classoutput"]); dlnet = dlnetwork(lgraph);
Используйте analyzeNetwork
анализировать сетевую архитектуру и настраиваемые параметры.
analyzeNetwork(dlnet)
Оцените точность сети перед сокращением.
accuracyOriginalNet = evaluateAccuracy(dlnet,mbqValidation,classes,trueLabels)
accuracyOriginalNet = 0.9900
Слои с настраиваемыми параметрами являются 3 слоями свертки и одним полносвязным слоем. Сеть первоначально состоит из общих 21 578 настраиваемых параметров.
numTotalParams = sum(cellfun(@numel,dlnet.Learnables.Value))
numTotalParams = 21578
numNonZeroPerParam = cellfun(@(w)nnz(extractdata(w)),dlnet.Learnables.Value)
numNonZeroPerParam = 8×1
72
8
1152
16
4608
32
15680
10
Разреженность задана как процент параметров в сети со значением нуля. Проверяйте разреженность сети.
initialSparsity = 1-sum(numNonZeroPerParam)/numTotalParams
initialSparsity = 0
Перед сокращением сеть имеет разреженность нуля.
Чтобы задать итеративную схему сокращения, задайте целевую разреженность и количество итераций. В данном примере использование линейно расположило итерации с интервалами, чтобы достигнуть целевой разреженности.
numIterations = 10; targetSparsity = 0.90; iterationScheme = linspace(0,targetSparsity,numIterations);
Для каждой итерации пользовательский цикл сокращения в этом примере выполняет следующие шаги:
Вычислите счет к каждой связи.
Оцените музыку ко всем связям в сети на основе выбранного алгоритма сокращения.
Определите порог для удаления связей с самыми низкими баллами.
Создайте маску сокращения с помощью порога.
Примените маску сокращения к настраиваемым параметрам сети.
Вместо того, чтобы установить записи в массивах весов непосредственно обнулять, алгоритм сокращения создает бинарную маску для каждого настраиваемого параметра, который задает, сокращена ли связь. Маска позволяет вам исследовать поведение сокращенной сети и пробовать различные схемы сокращения, не изменяя базовую структуру сети.
Например, рассмотрите следующие веса.
testWeight = [10.4 5.6 0.8 9];
Создайте бинарную маску для каждого параметра в testWeight.
testMask = [1 0 1 0];
Примените маску к testWeight
получить сокращенные веса.
testWeightsPruned = testWeight.*testMask
testWeightsPruned = 1×4
10.4000 0 0.8000 0
В итеративном сокращении вы создаете бинарную маску для каждой итерации, которая содержит информацию о сокращении. Применение маски к массиву весов не изменяет или размер массива или структуру нейронной сети. Поэтому шаг сокращения непосредственно не приводит ни к какому ускорению во время вывода или сжатия сетевого размера на диске.
Инициализируйте график, который сравнивает точность сокращенной сети к исходной сети.
figure plot(100*iterationScheme([1,end]),100*accuracyOriginalNet*[1 1],'*-b','LineWidth',2,"Color","b") ylim([0 100]) xlim(100*iterationScheme([1,end])) xlabel("Sparsity (%)") ylabel("Accuracy (%)") legend("Original Accuracy","Location","southwest") title("Pruning Accuracy") grid on
Величина, сокращающая [1] присвоения счет к каждому параметру, равняется своему абсолютному значению. Это принято, что абсолютное значение параметра соответствует своей относительной важности для точности обучившего сеть.
Инициализируйте маску. Для первой итерации вы не сокращаете параметров, и разреженность составляет 0%.
pruningMaskMagnitude = cell(1,numIterations); pruningMaskMagnitude{1} = dlupdate(@(p)true(size(p)), dlnet.Learnables);
Ниже реализация сокращения величины. Сеть сокращена к всевозможной целевой разреженности в цикле, чтобы обеспечить гибкость, чтобы выбрать сокращенную сеть на основе ее точности.
lineAccuracyPruningMagnitude = animatedline('Color','g','Marker','o','LineWidth',1.5); legend("Original Accuracy","Magnitude Pruning Accuracy","Location","southwest") % Compute magnitude scores scoresMagnitude = calculateMagnitudeScore(dlnet); for idx = 1:numel(iterationScheme) prunedNetMagnitude = dlnet; % Update the pruning mask pruningMaskMagnitude{idx} = calculateMask(scoresMagnitude,iterationScheme(idx)); % Check the number of zero entries in the pruning mask numPrunedParams = sum(cellfun(@(m)nnz(~extractdata(m)),pruningMaskMagnitude{idx}.Value)); sparsity = numPrunedParams/numTotalParams; % Apply pruning mask to network parameters prunedNetMagnitude.Learnables = dlupdate(@(W,M)W.*M, prunedNetMagnitude.Learnables, pruningMaskMagnitude{idx}); % Compute validation accuracy on pruned network accuracyMagnitude = evaluateAccuracy(prunedNetMagnitude,mbqValidation,classes,trueLabels); % Display the pruning progress addpoints(lineAccuracyPruningMagnitude,100*sparsity,100*accuracyMagnitude) drawnow end
Синаптическое сохранение потока (SynFlow) [2] баллы используется для сокращения. Можно использовать этот метод, чтобы сократить сети, которые используют линейные функции активации, такие как ReLU.
Инициализируйте маску. Для первой итерации не сокращены никакие параметры, и разреженность составляет 0%.
pruningMaskSynFlow = cell(1,numIterations); pruningMaskSynFlow{1} = dlupdate(@(p)true(size(p)),dlnet.Learnables);
Входные данные вы используетесь для расчета баллов, являются одним изображением, содержащим единицы. Если вы используете графический процессор, преобразуете данные в gpuArray
.
dlX = dlarray(ones(net.Layers(1).InputSize),'SSC'); if (executionEnvironment == "auto" && canUseGPU) || executionEnvironment == "gpu" dlX = gpuArray(dlX); end
Ниже цикла реализует итеративный синаптический счет потока к сокращению [2], где пользовательская функция стоимости оценивает счет SynFlow к каждому параметру, используемому для сетевого сокращения.
lineAccuracyPruningSynflow = animatedline('Color','r','Marker','o','LineWidth',1.5); legend("Original Accuracy","Magnitude Pruning Accuracy","Synaptic Flow Accuracy","Location","southwest") prunedNetSynFlow = dlnet; % Iteratively increase sparsity for idx = 1:numel(iterationScheme) % Compute SynFlow scores scoresSynFlow = calculateSynFlowScore(prunedNetSynFlow,dlX); % Update the pruning mask pruningMaskSynFlow{idx} = calculateMask(scoresSynFlow,iterationScheme(idx)); % Check the number of zero entries in the pruning mask numPrunedParams = sum(cellfun(@(m)nnz(~extractdata(m)),pruningMaskSynFlow{idx}.Value)); sparsity = numPrunedParams/numTotalParams; % Apply pruning mask to network parameters prunedNetSynFlow.Learnables = dlupdate(@(W,M)W.*M, prunedNetSynFlow.Learnables, pruningMaskSynFlow{idx}); % Compute validation accuracy on pruned network accuracySynFlow = evaluateAccuracy(prunedNetSynFlow,mbqValidation,classes,trueLabels); % Display the pruning progress addpoints(lineAccuracyPruningSynflow,100*sparsity,100*accuracySynFlow) drawnow end
Выбор, сколько сократить сеть, является компромиссом между точностью и разреженностью. Используйте разреженность по сравнению с графиком точности выбрать итерацию с желаемым уровнем разреженности и приемлемой точностью.
pruningMethod = "SynFlow"; selectedIteration = 8; prunedDLNet = createPrunedNet (dlnet, selectedIteration, pruningMaskSynFlow, pruningMaskMagnitude, pruningMethod); [sparsityPerLayer, prunedChannelsPerLayer, numOutChannelsPerLayer, layerNames] = pruningStatistics (prunedDLNet);
Более ранние слои свертки обычно сокращаются меньше, поскольку они содержат более релевантную информацию о базовой низкоуровневой структуре изображения (e.g. ребра и углы), которые важны для интерпретации изображения.
Постройте разреженность на слой для выбранного метода сокращения и итерации.
figure bar(sparsityPerLayer*100) title("Sparsity per layer") xlabel("Layer") ylabel("Sparsity (%)") xticks(1:numel(sparsityPerLayer)) xticklabels(layerNames) xtickangle(45) set(gca,'TickLabelInterpreter','none')
Алгоритм сокращения сокращает одну связи, когда вы задаете низкую целевую разреженность. Когда вы задаете высокую целевую разреженность, алгоритм сокращения может сократить целые фильтры и нейроны в сверточном или полносвязных слоях, позволив вам значительно уменьшать структурный размер сети.
figure bar([prunedChannelsPerLayer,numOutChannelsPerLayer-prunedChannelsPerLayer],"stacked") xlabel("Layer") ylabel("Number of filters") title("Number of filters per layer") xticks(1:(numel(layerNames))) xticklabels(layerNames) xtickangle(45) legend("Pruned number of channels/neurons" , "Original number of channels/neurons","Location","southoutside") set(gca,'TickLabelInterpreter','none')
Сравните точность сети до и после сокращения.
YPredOriginal = modelPredictions(dlnet,mbqValidation,classes); accOriginal = mean(YPredOriginal == trueLabels)
accOriginal = 0.9900
YPredPruned = modelPredictions(prunedDLNet,mbqValidation,classes); accPruned = mean(YPredPruned == trueLabels)
accPruned = 0.9400
Создайте матричный график беспорядка, чтобы исследовать истинные метки класса к предсказанным меткам класса для исходной и сокращенной сети.
figure
confusionchart(trueLabels,YPredOriginal);
title("Original Network")
Набор валидации данных о цифрах содержит 250 изображений для каждого класса, поэтому если сеть предсказывает класс каждого изображения отлично, все баллы на диагонали равняются 250, и никакие значения не находятся вне диагонали.
confusionchart(trueLabels,YPredPruned);
title("Pruned Network")
При сокращении сети сравните график беспорядка исходной сети и сокращенной сети, чтобы проверять, как точность для каждого класса помечает изменения для выбранного уровня разреженности. Если все числа на диагональном уменьшении примерно одинаково, никакое смещение не присутствует. Однако, если уменьшения не равны, вы можете должны быть выбрать сокращенную сеть из более ранней итерации путем сокращения значения переменной selectedIteration.
Глубокие нейронные сети, обученные в MATLAB, используют типы данных с плавающей запятой с одинарной точностью. Даже сети, которые малы в размере, требуют, чтобы значительный объем памяти и оборудование выполнили арифметические операции с плавающей точкой. Эти ограничения могут запретить развертывание моделей глубокого обучения, которые имеют низкую вычислительную силу и меньше ресурсов памяти. При помощи более низкой точности, чтобы сохранить веса и активации, можно уменьшать требования к памяти сети. Можно использовать Deep Learning Toolbox в тандеме с пакетом поддержки Библиотеки Квантования Модели Глубокого обучения, чтобы уменьшать объем потребляемой памяти глубокой нейронной сети путем квантования весов, смещений и активаций слоев свертки к 8-битным масштабированным целочисленным типам данных.
Сокращение сети влияет на статистику области значений параметров и активаций на каждом слое, таким образом, точность квантованной сети может измениться. Чтобы исследовать это различие, квантуйте сокращенную сеть и используйте квантованную сеть, чтобы выполнить вывод.
Разделите данные в наборы данных калибровки и валидации.
calibrationDataStore = splitEachLabel(imdsTrain,0.1,'randomize');
validationDataStore = imdsValidation;
Создайте dlquantizer
возразите и задайте сокращенную сеть как сеть, чтобы квантовать.
prunedNet = assembleNetwork([prunedDLNet.Layers ; net.Layers(end-1:end)]); quantObjPrunedNetwork = dlquantizer(prunedNet,'ExecutionEnvironment','GPU');
Используйте calibrate
функционируйте, чтобы осуществить сеть с калибровочными данными и собрать статистические данные области значений для весов, смещений и активаций на каждом слое.
calResults = calibrate(quantObjPrunedNetwork, calibrationDataStore)
calResults=18×5 table
Optimized Layer Name Network Layer Name Learnables / Activations MinValue MaxValue
____________________________ __________________ ________________________ _________ ________
{'conv_1_relu_1_Weights' } {'relu_1' } "Weights" -0.45058 1.1127
{'conv_1_relu_1_Bias' } {'relu_1' } "Bias" -0.025525 0.071941
{'conv_2_relu_2_Weights' } {'relu_2' } "Weights" -0.50476 0.56775
{'conv_2_relu_2_Bias' } {'relu_2' } "Bias" -0.080291 0.27024
{'conv_3_relu_3_Weights' } {'relu_3' } "Weights" -0.42412 0.46518
{'conv_3_relu_3_Bias' } {'relu_3' } "Bias" -0.20088 0.20468
{'fc_Weights' } {'fc' } "Weights" -0.30704 0.28832
{'fc_Bias' } {'fc' } "Bias" -0.23249 0.1647
{'imageinput' } {'imageinput'} "Activations" 0 255
{'imageinput_normalization'} {'imageinput'} "Activations" 0 1
{'conv_1_relu_1' } {'relu_1' } "Activations" 0 7.3311
{'maxpool_1' } {'maxpool_1' } "Activations" 0 7.3311
{'conv_2_relu_2' } {'relu_2' } "Activations" 0 27.143
{'maxpool_2' } {'maxpool_2' } "Activations" 0 27.143
{'conv_3_relu_3' } {'relu_3' } "Activations" 0 35.807
{'fc' } {'fc' } "Activations" -52.839 56.185
⋮
Используйте validate
функция, чтобы сравнить результаты сети до и после квантования с помощью набора данных валидации.
valResults = validate(quantObjPrunedNetwork, validationDataStore);
Исследуйте MetricResults.Result
поле валидации выход, чтобы видеть точность квантованного network.
valResults.MetricResults.Result
ans=2×2 table
NetworkImplementation MetricOutput
_____________________ ____________
{'Floating-Point'} 0.94
{'Quantized' } 0.9396
valResults.Statistics
ans=2×2 table
NetworkImplementation LearnableParameterMemory(bytes)
_____________________ _______________________________
{'Floating-Point'} 86320
{'Quantized' } 68824
preprocessMiniBatch
функция предварительно обрабатывает мини-пакет предикторов путем извлечения данных изображения из входного массива ячеек, и конкатенируйте в числовой массив. Для полутонового входа, конкатенируя данные по четвертой размерности добавляет третью размерность в каждое изображение, чтобы использовать в качестве одноэлементной размерности канала.
function X = preprocessMiniBatch(XCell) % Extract image data from cell and concatenate. X = cat(4,XCell{:}); end
Оцените точность классификации dlnetwork
. Точность является процентом меток, правильно классифицированных сетью.
function accuracy = evaluateAccuracy(dlnet,mbqValidation,classes,trueLabels) YPred = modelPredictions(dlnet,mbqValidation,classes); accuracy = mean(YPred == trueLabels); end
calculateSynFlowScore
функция вычисляет Синаптический Поток (SynFlow) баллы. Синаптический выступ [2] описан как класс основанных на градиенте баллов, заданных продуктом градиента потери, умноженной на значение параметров:
*
Счет SynFlow является синаптическим счетом выступа, который использует сумму всех сетевых выходных параметров как функция потерь:
функция, представленная нейронной сетью
параметры сети
входной массив к сети
Чтобы вычислить градиенты параметра относительно этой функции потерь, используйте dlfeval
и функция градиентов модели.
function score = calculateSynFlowScore(dlnet,dlX) dlnet.Learnables = dlupdate(@abs, dlnet.Learnables); gradients = dlfeval(@modelGradients,dlnet,dlX); score = dlupdate(@(g,w)g.*w, gradients, dlnet.Learnables); end
function gradients = modelGradients(dlNet,inputArray) % Evaluate the gradients on a given input to the dlnetwork dlYPred = predict(dlNet,inputArray); pseudoloss = sum(dlYPred,'all'); gradients = dlgradient(pseudoloss,dlNet.Learnables); end
calculateMagnitudeScore
функция возвращает счет величины, заданный как поэлементное абсолютное значение параметров.
function score = calculateMagnitudeScore(dlnet) score = dlupdate(@abs, dlnet.Learnables); end
calculateMask
функция возвращает бинарную маску для сетевых параметров на основе данных баллов и целевой разреженности.
function mask = calculateMask(scoresMagnitude,sparsity) % Compute a binary mask based on the parameter-wise scores such that the mask contains a percentage of zeros as specified by sparsity. % Flatten the cell array of scores into one long score vector flattenedScores = cell2mat(cellfun(@(S)extractdata(gather(S(:))),scoresMagnitude.Value,'UniformOutput',false)); % Rank the scores and determine the threshold for removing connections for the % given sparsity flattenedScores = sort(flattenedScores); k = round(sparsity*numel(flattenedScores)); if k==0 thresh = 0; else thresh = flattenedScores(k); end % Create a binary mask mask = dlupdate( @(S)S>thresh, scoresMagnitude); end
modelPredictions
функционируйте берет в качестве входа dlnetwor
k возражают dlnet, minibatchqueue
из входных данных mbq
, и сетевые классы, и вычисляют предсказания модели путем итерации по всем данным в объекте minibatchqueue. Функция использует onehotdecode
функционируйте, чтобы найти предсказанный класс с самым высоким счетом.
function predictions = modelPredictions(dlnet,mbq,classes) predictions = []; while hasdata(mbq) dlXTest = next(mbq); dlYPred = softmax(predict(dlnet,dlXTest)); YPred = onehotdecode(dlYPred,classes,1)'; predictions = [predictions; YPred]; end reset(mbq) end
createPrunedNet
функция возвращает сокращенный dlnetwork для заданного алгоритма сокращения и итерации.
function prunedNet = createPrunedNet(dlnet,selectedIteration,pruningMaskSynFlow,pruningMaskMagnitude,pruningMethod) switch pruningMethod case "Magnitude" prunedNet = dlupdate(@(W,M)W.*M, dlnet, pruningMaskMagnitude{selectedIteration}); case "SynFlow" prunedNet = dlupdate(@(W,M)W.*M, dlnet, pruningMaskSynFlow{selectedIteration}); end end
pruningStatistics
функционируйте извлечения подробная статистика сокращения уровня слоя такой разреженность уровня слоя и количество фильтров или сокращаемых нейронов.
sparsityPerLayer - percentage of parameters pruned in each layer
prunedChannelsPerLayer - количество каналов/нейронов в каждом слое, который может быть удален в результате сокращения
numOutChannelsPerLayer - количество каналов/нейронов в каждом слое
function [sparsityPerLayer,prunedChannelsPerLayer,numOutChannelsPerLayer,layerNames] = pruningStatistics(dlnet) layerNames = unique(dlnet.Learnables.Layer,'stable'); numLayers = numel(layerNames); layerIDs = zeros(numLayers,1); for idx = 1:numel(layerNames) layerIDs(idx) = find(layerNames(idx)=={dlnet.Layers.Name}); end sparsityPerLayer = zeros(numLayers,1); prunedChannelsPerLayer = zeros(numLayers,1); numOutChannelsPerLayer = zeros(numLayers,1); numParams = zeros(numLayers,1); numPrunedParams = zeros(numLayers,1); for idx = 1:numLayers layer = dlnet.Layers(layerIDs(idx)); % Calculate the sparsity paramIDs = strcmp(dlnet.Learnables.Layer,layerNames(idx)); paramValue = dlnet.Learnables.Value(paramIDs); for p = 1:numel(paramValue) numParams(idx) = numParams(idx) + numel(paramValue{p}); numPrunedParams(idx) = numPrunedParams(idx) + nnz(extractdata(paramValue{p})==0); end % Calculate channel statistics sparsityPerLayer(idx) = numPrunedParams(idx)/numParams(idx); switch class(layer) case "nnet.cnn.layer.FullyConnectedLayer" numOutChannelsPerLayer(idx) = layer.OutputSize; prunedChannelsPerLayer(idx) = nnz(all(layer.Weights==0,2)&layer.Bias(:)==0); case "nnet.cnn.layer.Convolution2DLayer" numOutChannelsPerLayer(idx) = layer.NumFilters; prunedChannelsPerLayer(idx) = nnz(reshape(all(layer.Weights==0,[1,2,3]),[],1)&layer.Bias(:)==0); case "nnet.cnn.layer.GroupedConvolution2DLayer" numOutChannelsPerLayer(idx) = layer.NumGroups*layer.NumFiltersPerGroup; prunedChannelsPerLayer(idx) = nnz(reshape(all(layer.Weights==0,[1,2,3]),[],1)&layer.Bias(:)==0); otherwise error("Unknown layer: "+class(layer)) end end end
loadDigitDataset
функционируйте загружает набор данных Цифр и разделяет данные в данные об обучении и валидации.
function [imdsTrain, imdsValidation] = loadDigitDataset() digitDatasetPath = fullfile(matlabroot,'toolbox','nnet','nndemos', ... 'nndatasets','DigitDataset'); imds = imageDatastore(digitDatasetPath, ... 'IncludeSubfolders',true,'LabelSource','foldernames'); [imdsTrain, imdsValidation] = splitEachLabel(imds,0.75,"randomized"); end
trainDigitDataNetwork
функция обучает сверточную нейронную сеть классифицировать цифры на полутоновые изображения.
function net = trainDigitDataNetwork(imdsTrain,imdsValidation) layers = [ imageInputLayer([28 28 1],"Normalization","rescale-zero-one") convolution2dLayer(3,8,'Padding','same') reluLayer maxPooling2dLayer(2,'Stride',2) convolution2dLayer(3,16,'Padding','same') reluLayer maxPooling2dLayer(2,'Stride',2) convolution2dLayer(3,32,'Padding','same') reluLayer fullyConnectedLayer(10) softmaxLayer classificationLayer]; % Specify the training options options = trainingOptions('sgdm', ... 'InitialLearnRate',0.01, ... 'MaxEpochs',10, ... 'Shuffle','every-epoch', ... 'ValidationData',imdsValidation, ... 'ValidationFrequency',30, ... 'Verbose',false, ... 'Plots','none',"ExecutionEnvironment","auto"); % Train network net = trainNetwork(imdsTrain,layers,options); end
[1] Ханьцы песни, пул Джеффа, Джон Трэн и Уильям Дж. Развлечься. 2015. "Учась и веса и связи для эффективных нейронных сетей". Усовершенствования в нейронных системах обработки информации 28 (NIPS 2015): 1135–1143.
[2] Иденори Танака, Дэниел Кунин, Дэниел Л. К. Яминс и Сурья Гэнгули 2020. "Сокращая нейронные сети без любых данных путем итеративного сохранения синаптического потока". 34-я конференция по нейронным системам обработки информации (NeurlPS 2020)