Сокращение параметра и квантование сети классификации изображений

В этом примере показано, как сократить параметры обученной нейронной сети с помощью двух метрик счета параметра: счет величины [1] и Синаптический счет Потока [2].

Во многих приложениях, где передача обучения используется, чтобы переобучить сеть классификации изображений для новой задачи или когда новая сеть обучена с нуля, оптимальная сетевая архитектура не известна и сетевая сила, которая будет сверхпараметрирована. Сверхпараметрированная сеть имеет избыточные связи.. Структурированное сокращение, также известное sparsification, является методом сжатия, который стремится идентифицировать избыточные, ненужные связи, которые можно удалить, не влияя на сетевую точность. Когда вы используете сокращение в сочетании с сетевым квантованием, можно уменьшить скорость вывода и объем потребляемой памяти сети, облегчающей развертываться.

В этом примере показано, как к:

  • Выполните постобучение, итеративное, неструктурированное сокращение без потребности в обучающих данных

  • Оцените эффективность двух различных алгоритмов сокращения

  • Исследуйте мудрую слоем разреженность, вызванную после сокращения

  • Оцените удар сокращения на точности классификации

  • Оцените удар квантования на точности классификации сокращенной сети

Этот пример использует простую сверточную нейронную сеть, чтобы классифицировать рукописные цифры от 0 до 9. Для получения дополнительной информации о подготовке данных, используемых для обучения и валидации, смотрите, Создают Простую сеть глубокого обучения для Классификации.

Загрузите предварительно обученную сеть и данные

Загрузите данные об обучении и валидации. Обучите сверточную нейронную сеть задаче классификации.

[imdsTrain, imdsValidation] = loadDigitDataset;
net = trainDigitDataNetwork(imdsTrain, imdsValidation);
trueLabels = imdsValidation.Labels;
classes = categories(trueLabels);

Создайте minibatchqueue объект, содержащий данные о валидации. SetexecutionEnvironment к автоматическому, чтобы оценить сеть на графическом процессоре, если вы доступны. По умолчанию, minibatchqueue объект преобразует каждый выход в gpuArray если графический процессор доступен. Используя графический процессор требует Parallel Computing Toolbox™ и поддерживаемого устройства графического процессора. Для получения информации о поддерживаемых устройствах смотрите Поддержку графического процессора Релизом (Parallel Computing Toolbox).

executionEnvironment = "auto";
miniBatchSize = 128;
imdsValidation.ReadSize = miniBatchSize;
mbqValidation = minibatchqueue (imdsValidation, 1,...
    'MiniBatchSize', miniBatchSize,...
    'MiniBatchFormat','SSCB',...
    'MiniBatchFcn',@preprocessMiniBatch,...
    'OutputEnvironment'Среда выполнения;

Сокращение нейронной сети

Цель сокращения нейронной сети состоит в том, чтобы идентифицировать и удалить неважные связи, чтобы уменьшать размер сети, не влияя на сетевую точность. В следующем рисунке, слева, сеть имеет связи, которые сопоставляют каждый нейрон с нейроном следующего слоя. После сокращения сеть имеет меньше связей, чем исходная сеть.

Алгоритм сокращения присваивает счет каждому параметру в сети. Счет оценивает важность каждой связи в сети. Можно использовать один из двух подходов сокращения, чтобы достигнуть целевой разреженности:

  • Сокращение одного выстрела - Удаляет заданный процент связей на основе их счета за один шаг. Этот метод подвержен коллапсу слоя, когда вы задаете высокое значение разреженности.

  • Итеративное сокращение - Достигает целевой разреженности в серии итеративных шагов. Можно использовать этот метод, когда оцененные баллы чувствительны к структуре сети. Баллы переоценены в каждой итерации, так использование серии шагов позволяет сети перемещаться к разреженности инкрементно.

Этот пример использует итеративный метод сокращения, чтобы достигнуть целевой разреженности.

Итеративное сокращение

Преобразуйте в dlnetwork Object

В этом примере вы используете Синаптический алгоритм Потока, который требует, чтобы вы создали пользовательскую функцию стоимости и оценили градиенты относительно функции стоимости, чтобы вычислить счет параметра. Чтобы создать пользовательскую функцию стоимости, сначала преобразуйте предварительно обученную сеть в dlnetwork.

Преобразуйте сеть в график слоев и удалите слои, используемые для классификации с помощью removeLayers.

lgraph = layerGraph(net.Layers);
lgraph = removeLayers(lgraph,["softmax","classoutput"]);
dlnet = dlnetwork(lgraph);

Используйте analyzeNetwork анализировать сетевую архитектуру и настраиваемые параметры.

analyzeNetwork(dlnet)

Оцените точность сети перед сокращением.

accuracyOriginalNet = evaluateAccuracy(dlnet,mbqValidation,classes,trueLabels)
accuracyOriginalNet = 0.9900

Слои с настраиваемыми параметрами являются 3 слоями свертки и одним полносвязным слоем. Сеть первоначально состоит из общих 21 578 настраиваемых параметров.

numTotalParams = sum(cellfun(@numel,dlnet.Learnables.Value))
numTotalParams = 21578
numNonZeroPerParam = cellfun(@(w)nnz(extractdata(w)),dlnet.Learnables.Value)
numNonZeroPerParam = 8×1

          72
           8
        1152
          16
        4608
          32
       15680
          10

Разреженность задана как процент параметров в сети со значением нуля. Проверяйте разреженность сети.

initialSparsity = 1-sum(numNonZeroPerParam)/numTotalParams
initialSparsity = 0

Перед сокращением сеть имеет разреженность нуля.

Создайте схему итерации

Чтобы задать итеративную схему сокращения, задайте целевую разреженность и количество итераций. В данном примере использование линейно расположило итерации с интервалами, чтобы достигнуть целевой разреженности.

numIterations = 10; 
targetSparsity = 0.90;
iterationScheme = linspace(0,targetSparsity,numIterations); 

Сокращение цикла

Для каждой итерации пользовательский цикл сокращения в этом примере выполняет следующие шаги:

  • Вычислите счет к каждой связи.

  • Оцените музыку ко всем связям в сети на основе выбранного алгоритма сокращения.

  • Определите порог для удаления связей с самыми низкими баллами.

  • Создайте маску сокращения с помощью порога.

  • Примените маску сокращения к настраиваемым параметрам сети.

Сетевая маска

Вместо того, чтобы установить записи в массивах весов непосредственно обнулять, алгоритм сокращения создает бинарную маску для каждого настраиваемого параметра, который задает, сокращена ли связь. Маска позволяет вам исследовать поведение сокращенной сети и пробовать различные схемы сокращения, не изменяя базовую структуру сети.

Например, рассмотрите следующие веса.

testWeight = [10.4 5.6 0.8 9];

Создайте бинарную маску для каждого параметра в testWeight.

testMask = [1 0 1 0];

Примените маску к testWeight получить сокращенные веса.

testWeightsPruned = testWeight.*testMask
testWeightsPruned = 1×4

   10.4000         0    0.8000         0

В итеративном сокращении вы создаете бинарную маску для каждой итерации, которая содержит информацию о сокращении. Применение маски к массиву весов не изменяет или размер массива или структуру нейронной сети. Поэтому шаг сокращения непосредственно не приводит ни к какому ускорению во время вывода или сжатия сетевого размера на диске.

Инициализируйте график, который сравнивает точность сокращенной сети к исходной сети.

figure
plot(100*iterationScheme([1,end]),100*accuracyOriginalNet*[1 1],'*-b','LineWidth',2,"Color","b")
ylim([0 100])
xlim(100*iterationScheme([1,end]))
xlabel("Sparsity (%)")
ylabel("Accuracy (%)")
legend("Original Accuracy","Location","southwest")
title("Pruning Accuracy")    
grid on

Сокращение величины

Величина, сокращающая [1] присвоения счет к каждому параметру, равняется своему абсолютному значению. Это принято, что абсолютное значение параметра соответствует своей относительной важности для точности обучившего сеть.

Инициализируйте маску. Для первой итерации вы не сокращаете параметров, и разреженность составляет 0%.

pruningMaskMagnitude = cell(1,numIterations); 
pruningMaskMagnitude{1} = dlupdate(@(p)true(size(p)), dlnet.Learnables);

Ниже реализация сокращения величины. Сеть сокращена к всевозможной целевой разреженности в цикле, чтобы обеспечить гибкость, чтобы выбрать сокращенную сеть на основе ее точности.

lineAccuracyPruningMagnitude = animatedline('Color','g','Marker','o','LineWidth',1.5);
legend("Original Accuracy","Magnitude Pruning Accuracy","Location","southwest")

% Compute magnitude scores
scoresMagnitude = calculateMagnitudeScore(dlnet);

for idx = 1:numel(iterationScheme)

    prunedNetMagnitude = dlnet;
    
    % Update the pruning mask
    pruningMaskMagnitude{idx} = calculateMask(scoresMagnitude,iterationScheme(idx));
    
    % Check the number of zero entries in the pruning mask
    numPrunedParams = sum(cellfun(@(m)nnz(~extractdata(m)),pruningMaskMagnitude{idx}.Value));
    sparsity = numPrunedParams/numTotalParams;
    
    % Apply pruning mask to network parameters
    prunedNetMagnitude.Learnables = dlupdate(@(W,M)W.*M, prunedNetMagnitude.Learnables, pruningMaskMagnitude{idx});
    
    % Compute validation accuracy on pruned network
    accuracyMagnitude = evaluateAccuracy(prunedNetMagnitude,mbqValidation,classes,trueLabels);
    
    % Display the pruning progress
    addpoints(lineAccuracyPruningMagnitude,100*sparsity,100*accuracyMagnitude)
    drawnow
end

Сокращение SynFlow

Синаптическое сохранение потока (SynFlow) [2] баллы используется для сокращения. Можно использовать этот метод, чтобы сократить сети, которые используют линейные функции активации, такие как ReLU.

Инициализируйте маску. Для первой итерации не сокращены никакие параметры, и разреженность составляет 0%.

pruningMaskSynFlow = cell(1,numIterations); 
pruningMaskSynFlow{1} = dlupdate(@(p)true(size(p)),dlnet.Learnables);

Входные данные вы используетесь для расчета баллов, являются одним изображением, содержащим единицы. Если вы используете графический процессор, преобразуете данные в gpuArray.

dlX = dlarray(ones(net.Layers(1).InputSize),'SSC');
if (executionEnvironment == "auto" && canUseGPU) || executionEnvironment == "gpu"
    dlX = gpuArray(dlX);
end

Ниже цикла реализует итеративный синаптический счет потока к сокращению [2], где пользовательская функция стоимости оценивает счет SynFlow к каждому параметру, используемому для сетевого сокращения.

lineAccuracyPruningSynflow = animatedline('Color','r','Marker','o','LineWidth',1.5);
legend("Original Accuracy","Magnitude Pruning Accuracy","Synaptic Flow Accuracy","Location","southwest")

prunedNetSynFlow = dlnet;

% Iteratively increase sparsity
for idx = 1:numel(iterationScheme)
    % Compute SynFlow scores
    scoresSynFlow = calculateSynFlowScore(prunedNetSynFlow,dlX);
    
    % Update the pruning mask
    pruningMaskSynFlow{idx} = calculateMask(scoresSynFlow,iterationScheme(idx));
    
    % Check the number of zero entries in the pruning mask
    numPrunedParams = sum(cellfun(@(m)nnz(~extractdata(m)),pruningMaskSynFlow{idx}.Value));
    sparsity = numPrunedParams/numTotalParams;
    
    % Apply pruning mask to network parameters
    prunedNetSynFlow.Learnables = dlupdate(@(W,M)W.*M, prunedNetSynFlow.Learnables, pruningMaskSynFlow{idx});
    
    % Compute validation accuracy on pruned network
    accuracySynFlow = evaluateAccuracy(prunedNetSynFlow,mbqValidation,classes,trueLabels);
     
    % Display the pruning progress
    addpoints(lineAccuracyPruningSynflow,100*sparsity,100*accuracySynFlow)
    drawnow
end

Исследуйте структуру сокращенной сети

Выбор, сколько сократить сеть, является компромиссом между точностью и разреженностью. Используйте разреженность по сравнению с графиком точности выбрать итерацию с желаемым уровнем разреженности и приемлемой точностью.

pruningMethod = "SynFlow";
selectedIteration = 8;

prunedDLNet = createPrunedNet (dlnet, selectedIteration, pruningMaskSynFlow, pruningMaskMagnitude, pruningMethod);

[sparsityPerLayer, prunedChannelsPerLayer, numOutChannelsPerLayer, layerNames] = pruningStatistics (prunedDLNet);

Более ранние слои свертки обычно сокращаются меньше, поскольку они содержат более релевантную информацию о базовой низкоуровневой структуре изображения (e.g. ребра и углы), которые важны для интерпретации изображения.

Постройте разреженность на слой для выбранного метода сокращения и итерации.

figure
bar(sparsityPerLayer*100)
title("Sparsity per layer")
xlabel("Layer")
ylabel("Sparsity (%)")
xticks(1:numel(sparsityPerLayer))
xticklabels(layerNames)
xtickangle(45)
set(gca,'TickLabelInterpreter','none')

Алгоритм сокращения сокращает одну связи, когда вы задаете низкую целевую разреженность. Когда вы задаете высокую целевую разреженность, алгоритм сокращения может сократить целые фильтры и нейроны в сверточном или полносвязных слоях, позволив вам значительно уменьшать структурный размер сети.

figure
bar([prunedChannelsPerLayer,numOutChannelsPerLayer-prunedChannelsPerLayer],"stacked")
xlabel("Layer")
ylabel("Number of filters")
title("Number of filters per layer")
xticks(1:(numel(layerNames)))
xticklabels(layerNames)
xtickangle(45)
legend("Pruned number of channels/neurons" , "Original number of channels/neurons","Location","southoutside")
set(gca,'TickLabelInterpreter','none')

Оцените сетевую точность

Сравните точность сети до и после сокращения.

YPredOriginal = modelPredictions(dlnet,mbqValidation,classes);
accOriginal = mean(YPredOriginal == trueLabels)
accOriginal = 0.9900
YPredPruned = modelPredictions(prunedDLNet,mbqValidation,classes);
accPruned = mean(YPredPruned == trueLabels)
accPruned = 0.9400

Создайте матричный график беспорядка, чтобы исследовать истинные метки класса к предсказанным меткам класса для исходной и сокращенной сети.

figure
confusionchart(trueLabels,YPredOriginal);
title("Original Network")

Набор валидации данных о цифрах содержит 250 изображений для каждого класса, поэтому если сеть предсказывает класс каждого изображения отлично, все баллы на диагонали равняются 250, и никакие значения не находятся вне диагонали.

confusionchart(trueLabels,YPredPruned);
title("Pruned Network")

При сокращении сети сравните график беспорядка исходной сети и сокращенной сети, чтобы проверять, как точность для каждого класса помечает изменения для выбранного уровня разреженности. Если все числа на диагональном уменьшении примерно одинаково, никакое смещение не присутствует. Однако, если уменьшения не равны, вы можете должны быть выбрать сокращенную сеть из более ранней итерации путем сокращения значения переменной selectedIteration.

Квантуйте сокращенную сеть

Глубокие нейронные сети, обученные в MATLAB, используют типы данных с плавающей запятой с одинарной точностью. Даже сети, которые малы в размере, требуют, чтобы значительный объем памяти и оборудование выполнили арифметические операции с плавающей точкой. Эти ограничения могут запретить развертывание моделей глубокого обучения, которые имеют низкую вычислительную силу и меньше ресурсов памяти. При помощи более низкой точности, чтобы сохранить веса и активации, можно уменьшать требования к памяти сети. Можно использовать Deep Learning Toolbox в тандеме с пакетом поддержки Библиотеки Квантования Модели Глубокого обучения, чтобы уменьшать объем потребляемой памяти глубокой нейронной сети путем квантования весов, смещений и активаций слоев свертки к 8-битным масштабированным целочисленным типам данных.

Сокращение сети влияет на статистику области значений параметров и активаций на каждом слое, таким образом, точность квантованной сети может измениться. Чтобы исследовать это различие, квантуйте сокращенную сеть и используйте квантованную сеть, чтобы выполнить вывод.

Разделите данные в наборы данных калибровки и валидации.

calibrationDataStore = splitEachLabel(imdsTrain,0.1,'randomize');
validationDataStore = imdsValidation;

Создайте dlquantizer возразите и задайте сокращенную сеть как сеть, чтобы квантовать.

prunedNet  = assembleNetwork([prunedDLNet.Layers ; net.Layers(end-1:end)]);

quantObjPrunedNetwork = dlquantizer(prunedNet,'ExecutionEnvironment','GPU'); 

Используйте calibrate функционируйте, чтобы осуществить сеть с калибровочными данными и собрать статистические данные области значений для весов, смещений и активаций на каждом слое.

calResults = calibrate(quantObjPrunedNetwork, calibrationDataStore)
calResults=18×5 table
        Optimized Layer Name        Network Layer Name    Learnables / Activations    MinValue     MaxValue
    ____________________________    __________________    ________________________    _________    ________

    {'conv_1_relu_1_Weights'   }      {'relu_1'    }           "Weights"               -0.45058      1.1127
    {'conv_1_relu_1_Bias'      }      {'relu_1'    }           "Bias"                 -0.025525    0.071941
    {'conv_2_relu_2_Weights'   }      {'relu_2'    }           "Weights"               -0.50476     0.56775
    {'conv_2_relu_2_Bias'      }      {'relu_2'    }           "Bias"                 -0.080291     0.27024
    {'conv_3_relu_3_Weights'   }      {'relu_3'    }           "Weights"               -0.42412     0.46518
    {'conv_3_relu_3_Bias'      }      {'relu_3'    }           "Bias"                  -0.20088     0.20468
    {'fc_Weights'              }      {'fc'        }           "Weights"               -0.30704     0.28832
    {'fc_Bias'                 }      {'fc'        }           "Bias"                  -0.23249      0.1647
    {'imageinput'              }      {'imageinput'}           "Activations"                  0         255
    {'imageinput_normalization'}      {'imageinput'}           "Activations"                  0           1
    {'conv_1_relu_1'           }      {'relu_1'    }           "Activations"                  0      7.3311
    {'maxpool_1'               }      {'maxpool_1' }           "Activations"                  0      7.3311
    {'conv_2_relu_2'           }      {'relu_2'    }           "Activations"                  0      27.143
    {'maxpool_2'               }      {'maxpool_2' }           "Activations"                  0      27.143
    {'conv_3_relu_3'           }      {'relu_3'    }           "Activations"                  0      35.807
    {'fc'                      }      {'fc'        }           "Activations"            -52.839      56.185
      ⋮

Используйте validate функция, чтобы сравнить результаты сети до и после квантования с помощью набора данных валидации.

valResults = validate(quantObjPrunedNetwork, validationDataStore);

Исследуйте MetricResults.Result поле валидации выход, чтобы видеть точность квантованного network.

valResults.MetricResults.Result
ans=2×2 table
    NetworkImplementation    MetricOutput
    _____________________    ____________

     {'Floating-Point'}           0.94   
     {'Quantized'     }         0.9396   

valResults.Statistics
ans=2×2 table
    NetworkImplementation    LearnableParameterMemory(bytes)
    _____________________    _______________________________

     {'Floating-Point'}                   86320             
     {'Quantized'     }                   68824             

Мини-функция предварительной обработки пакета

preprocessMiniBatch функция предварительно обрабатывает мини-пакет предикторов путем извлечения данных изображения из входного массива ячеек, и конкатенируйте в числовой массив. Для полутонового входа, конкатенируя данные по четвертой размерности добавляет третью размерность в каждое изображение, чтобы использовать в качестве одноэлементной размерности канала.

function X = preprocessMiniBatch(XCell)
% Extract image data from cell and concatenate.
X = cat(4,XCell{:});
end

Функция точности модели

Оцените точность классификации dlnetwork. Точность является процентом меток, правильно классифицированных сетью.

function accuracy = evaluateAccuracy(dlnet,mbqValidation,classes,trueLabels)
YPred = modelPredictions(dlnet,mbqValidation,classes);
accuracy = mean(YPred == trueLabels);
end

Функция счета SynFlow

calculateSynFlowScore функция вычисляет Синаптический Поток (SynFlow) баллы. Синаптический выступ [2] описан как класс основанных на градиенте баллов, заданных продуктом градиента потери, умноженной на значение параметров:

synFlowScore=d(loss)dθ*θ

Счет SynFlow является синаптическим счетом выступа, который использует сумму всех сетевых выходных параметров как функция потерь:

loss=f(abs(θ),X)

f функция, представленная нейронной сетью

θ параметры сети

X входной массив к сети

Чтобы вычислить градиенты параметра относительно этой функции потерь, используйте dlfeval и функция градиентов модели.

function score = calculateSynFlowScore(dlnet,dlX)
dlnet.Learnables = dlupdate(@abs, dlnet.Learnables);
gradients = dlfeval(@modelGradients,dlnet,dlX);
score = dlupdate(@(g,w)g.*w, gradients, dlnet.Learnables);
end

Градиенты модели для счета SynFlow

function gradients = modelGradients(dlNet,inputArray)
% Evaluate the gradients on a given input to the dlnetwork
dlYPred = predict(dlNet,inputArray);
pseudoloss = sum(dlYPred,'all');
gradients = dlgradient(pseudoloss,dlNet.Learnables);
end

Функция счета величины

calculateMagnitudeScore функция возвращает счет величины, заданный как поэлементное абсолютное значение параметров.

function score = calculateMagnitudeScore(dlnet)
score = dlupdate(@abs, dlnet.Learnables);
end

Функция генерации маски

calculateMask функция возвращает бинарную маску для сетевых параметров на основе данных баллов и целевой разреженности.

function mask = calculateMask(scoresMagnitude,sparsity)
% Compute a binary mask based on the parameter-wise scores such that the mask contains a percentage of zeros as specified by sparsity.

% Flatten the cell array of scores into one long score vector
flattenedScores = cell2mat(cellfun(@(S)extractdata(gather(S(:))),scoresMagnitude.Value,'UniformOutput',false));
% Rank the scores and determine the threshold for removing connections for the
% given sparsity
flattenedScores = sort(flattenedScores);
k = round(sparsity*numel(flattenedScores));
if k==0
    thresh = 0;
else
    thresh = flattenedScores(k);
end
% Create a binary mask 
mask = dlupdate( @(S)S>thresh, scoresMagnitude);
end

Функция предсказаний модели

modelPredictions функционируйте берет в качестве входа dlnetwork возражают dlnet, minibatchqueue из входных данных mbq, и сетевые классы, и вычисляют предсказания модели путем итерации по всем данным в объекте minibatchqueue. Функция использует onehotdecode функционируйте, чтобы найти предсказанный класс с самым высоким счетом.

function predictions = modelPredictions(dlnet,mbq,classes)
predictions = [];
while hasdata(mbq)
    dlXTest = next(mbq);
    dlYPred = softmax(predict(dlnet,dlXTest));
    YPred = onehotdecode(dlYPred,classes,1)';
    predictions = [predictions; YPred];
end
reset(mbq)
end

Примените функцию сокращения

createPrunedNet функция возвращает сокращенный dlnetwork для заданного алгоритма сокращения и итерации.

function prunedNet = createPrunedNet(dlnet,selectedIteration,pruningMaskSynFlow,pruningMaskMagnitude,pruningMethod)
switch pruningMethod
    case "Magnitude"
        prunedNet = dlupdate(@(W,M)W.*M, dlnet, pruningMaskMagnitude{selectedIteration});
    case "SynFlow"
        prunedNet = dlupdate(@(W,M)W.*M, dlnet, pruningMaskSynFlow{selectedIteration});
end
end

Сокращение функции статистики

pruningStatistics функционируйте извлечения подробная статистика сокращения уровня слоя такой разреженность уровня слоя и количество фильтров или сокращаемых нейронов.

sparsityPerLayer - percentage of parameters pruned in each layer

prunedChannelsPerLayer - количество каналов/нейронов в каждом слое, который может быть удален в результате сокращения

numOutChannelsPerLayer - количество каналов/нейронов в каждом слое

function [sparsityPerLayer,prunedChannelsPerLayer,numOutChannelsPerLayer,layerNames] = pruningStatistics(dlnet)

layerNames = unique(dlnet.Learnables.Layer,'stable');
numLayers = numel(layerNames);
layerIDs = zeros(numLayers,1);
for idx = 1:numel(layerNames)
    layerIDs(idx) = find(layerNames(idx)=={dlnet.Layers.Name});
end

sparsityPerLayer = zeros(numLayers,1);
prunedChannelsPerLayer = zeros(numLayers,1);
numOutChannelsPerLayer = zeros(numLayers,1);

numParams = zeros(numLayers,1);
numPrunedParams = zeros(numLayers,1);
for idx = 1:numLayers
    layer = dlnet.Layers(layerIDs(idx));
    
    % Calculate the sparsity
    paramIDs = strcmp(dlnet.Learnables.Layer,layerNames(idx));
    paramValue = dlnet.Learnables.Value(paramIDs);
    for p = 1:numel(paramValue)
        numParams(idx) = numParams(idx) + numel(paramValue{p});
        numPrunedParams(idx) = numPrunedParams(idx) + nnz(extractdata(paramValue{p})==0);
    end

    % Calculate channel statistics
    sparsityPerLayer(idx) = numPrunedParams(idx)/numParams(idx);
    switch class(layer)
        case "nnet.cnn.layer.FullyConnectedLayer"
            numOutChannelsPerLayer(idx) = layer.OutputSize;
            prunedChannelsPerLayer(idx) = nnz(all(layer.Weights==0,2)&layer.Bias(:)==0);
        case "nnet.cnn.layer.Convolution2DLayer"
            numOutChannelsPerLayer(idx) = layer.NumFilters;
            prunedChannelsPerLayer(idx) = nnz(reshape(all(layer.Weights==0,[1,2,3]),[],1)&layer.Bias(:)==0);
        case "nnet.cnn.layer.GroupedConvolution2DLayer"
            numOutChannelsPerLayer(idx) = layer.NumGroups*layer.NumFiltersPerGroup;
            prunedChannelsPerLayer(idx) = nnz(reshape(all(layer.Weights==0,[1,2,3]),[],1)&layer.Bias(:)==0);
        otherwise
            error("Unknown layer: "+class(layer))
    end
end
end

Загрузите Функцию Набора данных Цифр

loadDigitDataset функционируйте загружает набор данных Цифр и разделяет данные в данные об обучении и валидации.

function [imdsTrain, imdsValidation] = loadDigitDataset()
digitDatasetPath = fullfile(matlabroot,'toolbox','nnet','nndemos', ...
    'nndatasets','DigitDataset');
imds = imageDatastore(digitDatasetPath, ...
    'IncludeSubfolders',true,'LabelSource','foldernames');
[imdsTrain, imdsValidation] = splitEachLabel(imds,0.75,"randomized");
end

Обучите функцию сети распознавания цифры

trainDigitDataNetwork функция обучает сверточную нейронную сеть классифицировать цифры на полутоновые изображения.

function net = trainDigitDataNetwork(imdsTrain,imdsValidation)
layers = [
    imageInputLayer([28 28 1],"Normalization","rescale-zero-one")
    convolution2dLayer(3,8,'Padding','same')
    reluLayer
    maxPooling2dLayer(2,'Stride',2)
    
    convolution2dLayer(3,16,'Padding','same')
    reluLayer
    maxPooling2dLayer(2,'Stride',2)
    
    convolution2dLayer(3,32,'Padding','same')
    reluLayer
    
    fullyConnectedLayer(10)
    softmaxLayer
    classificationLayer];

% Specify the training options
options = trainingOptions('sgdm', ...
    'InitialLearnRate',0.01, ...
    'MaxEpochs',10, ...
    'Shuffle','every-epoch', ...
    'ValidationData',imdsValidation, ...
    'ValidationFrequency',30, ...
    'Verbose',false, ...
    'Plots','none',"ExecutionEnvironment","auto");

% Train network
net = trainNetwork(imdsTrain,layers,options);
end

Ссылки

[1] Ханьцы песни, пул Джеффа, Джон Трэн и Уильям Дж. Развлечься. 2015. "Учась и веса и связи для эффективных нейронных сетей". Усовершенствования в нейронных системах обработки информации 28 (NIPS 2015): 1135–1143.

[2] Иденори Танака, Дэниел Кунин, Дэниел Л. К. Яминс и Сурья Гэнгули 2020. "Сокращая нейронные сети без любых данных путем итеративного сохранения синаптического потока". 34-я конференция по нейронным системам обработки информации (NeurlPS 2020)