В этом примере показано, как сгенерировать высококачественные изображения компьютерной томографии (CT) большей дозы от шумных изображений CT низкой дозы с помощью нейронной сети CycleGAN.
Этот пример использует сопоставимую с циклом порождающую соперничающую сеть (CycleGAN), обученный на закрашенных фигурах данных изображения от большой выборки данных. Для аналогичного подхода с помощью МОДУЛЬНОЙ нейронной сети, обученной на полных образах от ограниченной выборки данных, смотрите, что Безнадзорное Медицинское Шумоподавление Изображений Использует CycleGAN.
CT рентгена является популярной модальностью обработки изображений, используемой в клиническом и промышленном применении, потому что он производит высококачественные изображения и предлагает превосходящие диагностические возможности. Чтобы защитить безопасность пациентов, клиницисты рекомендуют низкую дозу излучения. Однако низкая доза излучения приводит к более низкому отношению сигнал-шум (SNR) в изображениях, и поэтому уменьшает диагностическую точность.
Методы глубокого обучения предлагают решения улучшить качество изображения для CT низкой дозы (LDCT) изображения. Используя порождающую соперничающую сеть (GAN) для перевода от изображения к изображению, можно преобразовать шумные изображения LDCT в изображения того же качества как изображения CT регулярной дозы. Для этого приложения исходная область состоит из изображений LDCT, и целевая область состоит из изображений регулярной дозы.
Шумоподавление CT изображений требует GAN, который выполняет безнадзорное обучение, потому что клиницисты обычно не получают соответствие с парами низкой дозы и изображениями CT регулярной дозы того же пациента на том же сеансе. Этот пример использует архитектуру CycleGAN, которая поддерживает безнадзорное обучение. Для получения дополнительной информации смотрите Начало работы с GANs для Перевода От изображения к изображению.
Этот пример использует данные из Низкого CT Дозы Главная проблема [2, 3, 4]. Данные включают пары изображений CT регулярной дозы, и симулированные изображения CT низкой дозы для 99 главных сканов (пометил N для neuro), 100 сканов грудной клетки (пометил C для груди), и 100 сканов живота (пометил L для печени). Размер набора данных составляет 1,2 Тбайта.
Установите dataDir
как желаемое местоположение набора данных.
dataDir = fullfile(tempdir,"LDCT","LDCT-and-Projection-data");
Чтобы загрузить данные, перейдите к веб-сайту Архива Обработки изображений Рака. Этот пример использует только изображения из груди. Загрузите файлы грудной клетки с "Изображений (DICOM, 952 Гбайт)" набор данных в директорию, заданную dataDir
использование Ретривера Данных NBIA. Когда загрузка успешна, dataDir
содержит 50 подпапок с именами, такими как "C002" и "C004", заканчивающийся "C296".
Набор данных LDCT обеспечивает пары низкой дозы и изображений CT большей дозы. Однако архитектура CycleGAN требует непарных данных для безнадзорного изучения. Этот пример симулирует непарные данные об обучении и валидации путем разделения изображений, таким образом, что пациенты раньше получали CT низкой дозы, и изображения CT большей дозы не перекрываются. Пример сохраняет пары низкой дозы и изображений регулярной дозы для тестирования.
Разделите данные в обучение, валидацию и наборы тестовых данных с помощью createLDCTFolderList
функция помощника. Эта функция присоединена к примеру как вспомогательный файл. Функция помощника разделяет данные, таким образом, что существует примерно хорошее представление двух типов изображений в каждой группе. Приблизительно 80% данных используются для обучения, 15% используется для тестирования, и 5% используются для валидации.
maxDirsForABodyPart = 25;
[filesTrainLD,filesTrainHD,filesTestLD,filesTestHD,filesValLD,filesValHD] = ...
createLDCTFolderList(dataDir,maxDirsForABodyPart);
Создайте хранилища данных изображений, которые содержат изображения обучения и валидации для обеих областей, а именно, изображения CT низкой дозы и изображения CT большей дозы. Набор данных состоит из изображений DICOM, так используйте пользовательский ReadFcn
аргумент значения имени в imageDatastore
позволять считать данные.
exts = {'.dcm'};
readFcn = @(x)dicomread(x);
imdsTrainLD = imageDatastore(filesTrainLD,FileExtensions=exts,ReadFcn=readFcn);
imdsTrainHD = imageDatastore(filesTrainHD,FileExtensions=exts,ReadFcn=readFcn);
imdsValLD = imageDatastore(filesValLD,FileExtensions=exts,ReadFcn=readFcn);
imdsValHD = imageDatastore(filesValHD,FileExtensions=exts,ReadFcn=readFcn);
imdsTestLD = imageDatastore(filesTestLD,FileExtensions=exts,ReadFcn=readFcn);
imdsTestHD = imageDatastore(filesTestHD,FileExtensions=exts,ReadFcn=readFcn);
Количество низкой дозы и изображений большей дозы может отличаться. Выберите подмножество файлов, таким образом, что количество изображений равно.
numTrain = min(numel(imdsTrainLD.Files),numel(imdsTrainHD.Files)); imdsTrainLD = subset(imdsTrainLD,1:numTrain); imdsTrainHD = subset(imdsTrainHD,1:numTrain); numVal = min(numel(imdsValLD.Files),numel(imdsValHD.Files)); imdsValLD = subset(imdsValLD,1:numVal); imdsValHD = subset(imdsValHD,1:numVal); numTest = min(numel(imdsTestLD.Files),numel(imdsTestHD.Files)); imdsTestLD = subset(imdsTestLD,1:numTest); imdsTestHD = subset(imdsTestHD,1:numTest);
Предварительно обработайте данные при помощи transform
функция с пользовательскими операциями предварительной обработки, заданными normalizeCTImages
функция помощника. Эта функция присоединена к примеру как вспомогательный файл. normalizeCTImages
функция перемасштабирует данные к области значений [-1, 1].
timdsTrainLD = transform(imdsTrainLD,@(x){normalizeCTImages(x)}); timdsTrainHD = transform(imdsTrainHD,@(x){normalizeCTImages(x)}); timdsValLD = transform(imdsValLD,@(x){normalizeCTImages(x)}); timdsValHD = transform(imdsValHD,@(x){normalizeCTImages(x)}); timdsTestLD = transform(imdsTestLD,@(x){normalizeCTImages(x)}); timdsTestHD = transform(imdsTestHD,@(x){normalizeCTImages(x)});
Объедините низкую дозу и обучающие данные большей дозы при помощи randomPatchExtractionDatastore
. При чтении из этого datastore увеличьте данные с помощью случайного вращения и горизонтального отражения.
inputSize = [128,128,1];
augmenter = imageDataAugmenter(RandRotation=@()90*(randi([0,1],1)),RandXReflection=true);
dsTrain = randomPatchExtractionDatastore(timdsTrainLD,timdsTrainHD, ...
inputSize(1:2),PatchesPerImage=16,DataAugmentation=augmenter);
Объедините данные о валидации при помощи randomPatchExtractionDatastore
. Вы не должны выполнять увеличение при чтении данных о валидации.
dsVal = randomPatchExtractionDatastore(timdsValLD,timdsValHD,inputSize(1:2));
Посмотрите на некоторых низкая доза и пары закрашенной фигуры большей дозы изображений от набора обучающих данных. Заметьте, что пары изображений низкой дозы (слева) и большей дозы (справа) отображают, являются непарными, как они от различных пациентов.
numImagePairs = 6; imagePairsTrain = []; for i = 1:numImagePairs imLowAndHighDose = read(dsTrain); inputImage = imLowAndHighDose.InputImage{1}; inputImage = rescale(im2single(inputImage)); responseImage = imLowAndHighDose.ResponseImage{1}; responseImage = rescale(im2single(responseImage)); imagePairsTrain = cat(4,imagePairsTrain,inputImage,responseImage); end montage(imagePairsTrain,Size=[numImagePairs 2],BorderSize=4,BackgroundColor="w")
Этот пример использует пользовательский учебный цикл. minibatchqueue
Объект (Deep Learning Toolbox) полезен для управления мини-пакетная обработка наблюдений в пользовательских учебных циклах. minibatchqueue
возразите также бросает данные к dlarray
объект, который включает автоматическое дифференцирование в применении глубокого обучения.
Обработайте мини-пакеты путем конкатенации закрашенных фигур изображений по пакетному измерению с помощью функции помощника concatenateMiniBatchLD2HDCT
Эта функция присоединена к примеру как вспомогательный файл. Задайте мини-пакетный формат экстракции данных как "SSCB"
(пространственный, пространственный, канал, пакет). Отбросьте любые частичные мини-пакеты с меньше, чем miniBatchSize
наблюдения.
miniBatchSize = 32; mbqTrain = minibatchqueue(dsTrain, ... MiniBatchSize=miniBatchSize, ... MiniBatchFcn=@concatenateMiniBatchLD2HDCT, ... PartialMiniBatch="discard", ... MiniBatchFormat="SSCB"); mbqVal = minibatchqueue(dsVal, ... MiniBatchSize=miniBatchSize, ... MiniBatchFcn=@concatenateMiniBatchLD2HDCT, ... PartialMiniBatch="discard", ... MiniBatchFormat="SSCB");
CycleGAN состоит из двух генераторов и двух различителей. Генераторы выполняют перевод от изображения к изображению от низкой дозы до большей дозы и наоборот. Различители являются сетями PatchGAN, которые возвращают мудрую закрашенной фигурой вероятность, что входные данные действительны или сгенерированы. Один различитель различает действительные и сгенерированные изображения низкой дозы, и другой различитель различает действительные и сгенерированные изображения большей дозы.
Создайте каждую сеть генератора использование cycleGANGenerator
функция. Для входного размера 256 256 пикселей задайте NumResidualBlocks
аргумент как 9
. По умолчанию функция имеет 3 модуля энкодера и использует 64, просачивается первый сверточный слой.
numResiduals = 6; genHD2LD = cycleGANGenerator(inputSize,NumResidualBlocks=numResiduals,NumOutputChannels=1); genLD2HD = cycleGANGenerator(inputSize,NumResidualBlocks=numResiduals,NumOutputChannels=1);
Создайте каждую сеть различителя использование patchGANDiscriminator
функция. Используйте настройки по умолчанию для количества субдискретизации блоков, и количество просачивается первый сверточный слой в различителях.
discLD = patchGANDiscriminator(inputSize); discHD = patchGANDiscriminator(inputSize);
modelGradients
функция помощника вычисляет градиенты и потери для различителей и генераторов. Эта функция задана в разделе Supporting Functions этого примера.
Цель генератора состоит в том, чтобы сгенерировать переведенные изображения, которые различители классифицируют как действительные. Потеря генератора является взвешенной суммой трех типов потерь: соперничающая потеря, потеря непротиворечивости цикла и потеря точности. Потеря точности основана на структурном подобии (SSIM) потеря.
Задайте фактор взвешивания это управляет относительной значимостью потери непротиворечивости цикла с соперничающими потерями и потерями точности.
lambda = 10;
Цель каждого различителя состоит в том, чтобы правильно различать действительные изображения (1) и переведенные изображения (0) для изображений в его области. Каждый различитель имеет одну функцию потерь, которая использует среднеквадратическую ошибку (MSE) между ожидаемым и предсказанным выходом.
Обучайтесь с мини-пакетным размером 32 в течение 3 эпох.
numEpochs = 3; miniBatchSize = 32;
Задайте опции для оптимизации Адама. И для генератора и для сетей различителя, используйте:
Скорость обучения 0,0002
Фактор затухания градиента 0,5
Градиент в квадрате затухает фактор 0,999
learnRate = 0.0002; gradientDecay = 0.5; squaredGradientDecayFactor = 0.999;
Инициализируйте параметры Адама для генераторов и различителей.
avgGradGenLD2HD = []; avgSqGradGenLD2HD = []; avgGradGenHD2LD = []; avgSqGradGenHD2LD = []; avgGradDiscLD = []; avgSqGradDiscLD = []; avgGradDiscHD = []; avgSqGradDiscHD = [];
Отобразитесь сгенерированная валидация отображает каждые 100 итераций.
validationFrequency = 100;
По умолчанию пример загружает предварительно обученную версию генератора CycleGAN для низкой дозы к CT большей дозы при помощи функции помощника downloadTrainedLD2HDCTCycleGANNet
. Функция помощника присоединена к примеру как к вспомогательному файлу. Предварительно обученная сеть позволяет вам запустить целый пример, не ожидая обучения завершиться.
Чтобы обучить сеть, установите doTraining
переменная в следующем коде к true
. Обучите модель в пользовательском учебном цикле. Для каждой итерации:
Считайте данные для текущего мини-пакета с помощью next
(Deep Learning Toolbox) функция.
Оцените градиенты модели с помощью dlfeval
(Deep Learning Toolbox) функция и modelGradients
функция помощника.
Обновите сетевые параметры с помощью adamupdate
(Deep Learning Toolbox) функция.
Отобразите вход и переведенные изображения для обоих входные и выходные области после каждой эпохи.
Обучайтесь на графическом процессоре, если вы доступны. Используя графический процессор требует Parallel Computing Toolbox™, и CUDA® включил NVIDIA® графический процессор. Для получения дополнительной информации смотрите Поддержку графического процессора Релизом (Parallel Computing Toolbox). Обучение занимает приблизительно 30 часов на Титане NVIDIA™ X с 24 Гбайт памяти графического процессора.
doTraining = false; if doTraining iteration = 0; start = tic; % Create a directory to store checkpoints checkpointDir = fullfile(dataDir,"checkpoints"); if ~exist(checkpointDir,"dir") mkdir(checkpointDir); end % Initialize plots for training progress [figureHandle,tileHandle,imageAxes,scoreAxesX,scoreAxesY, ... lineScoreGenLD2HD,lineScoreGenD2LD, ... lineScoreDiscHD,lineScoreDiscLD] = initializeTrainingPlotLD2HDCT; for epoch = 1:numEpochs shuffle(mbqTrain); % Loop over mini-batches while hasdata(mbqTrain) iteration = iteration + 1; % Read mini-batch of data [imageLD,imageHD] = next(mbqTrain); % Convert mini-batch of data to dlarray and specify the dimension labels % "SSCB" (spatial, spatial, channel, batch) imageLD = dlarray(imageLD,"SSCB"); imageHD = dlarray(imageHD,"SSCB"); % If training on a GPU, then convert data to gpuArray if canUseGPU imageLD = gpuArray(imageLD); imageHD = gpuArray(imageHD); end % Calculate the loss and gradients [genHD2LDGrad,genLD2HDGrad,discrXGrad,discYGrad, ... genHD2LDState,genLD2HDState,scores,imagesOutLD2HD,imagesOutHD2LD] = ... dlfeval(@modelGradients,genLD2HD,genHD2LD, ... discLD,discHD,imageHD,imageLD,lambda); genHD2LD.State = genHD2LDState; genLD2HD.State = genLD2HDState; % Update parameters of discLD, which distinguishes % the generated low-dose CT images from real low-dose CT images [discLD.Learnables,avgGradDiscLD,avgSqGradDiscLD] = ... adamupdate(discLD.Learnables,discrXGrad,avgGradDiscLD, ... avgSqGradDiscLD,iteration,learnRate,gradientDecay,squaredGradientDecayFactor); % Update parameters of discHD, which distinguishes % the generated high-dose CT images from real high-dose CT images [discHD.Learnables,avgGradDiscHD,avgSqGradDiscHD] = ... adamupdate(discHD.Learnables,discYGrad,avgGradDiscHD, ... avgSqGradDiscHD,iteration,learnRate,gradientDecay,squaredGradientDecayFactor); % Update parameters of genHD2LD, which % generates low-dose CT images from high-dose CT images [genHD2LD.Learnables,avgGradGenHD2LD,avgSqGradGenHD2LD] = ... adamupdate(genHD2LD.Learnables,genHD2LDGrad,avgGradGenHD2LD, ... avgSqGradGenHD2LD,iteration,learnRate,gradientDecay,squaredGradientDecayFactor); % Update parameters of genLD2HD, which % generates high-dose CT images from low-dose CT images [genLD2HD.Learnables,avgGradGenLD2HD,avgSqGradGenLD2HD] = ... adamupdate(genLD2HD.Learnables,genLD2HDGrad,avgGradGenLD2HD, ... avgSqGradGenLD2HD,iteration,learnRate,gradientDecay,squaredGradientDecayFactor); % Update the plots of network scores updateTrainingPlotLD2HDCT(scores,iteration,epoch,start,scoreAxesX,scoreAxesY,... lineScoreGenLD2HD,lineScoreGenD2LD, ... lineScoreDiscHD,lineScoreDiscLD) % Every validationFrequency iterations, display a batch of % generated images using the held-out generator input if mod(iteration,validationFrequency) == 0 || iteration == 1 displayGeneratedLD2HDCTImages(mbqVal,imageAxes,genLD2HD,genHD2LD); end end % Save the model after each epoch if canUseGPU [genLD2HD,genHD2LD,discLD,discHD] = ... gather(genLD2HD,genHD2LD,discLD,discHD); end generatorHighDoseToLowDose = genHD2LD; generatorLowDoseToHighDose = genLD2HD; discriminatorLowDose = discLD; discriminatorHighDose = discHD; modelDateTime = string(datetime("now",Format="yyyy-MM-dd-HH-mm-ss")); save(checkpointDir+filesep+"LD2HDCTCycleGAN-"+modelDateTime+"-Epoch-"+epoch+".mat", ... 'generatorLowDoseToHighDose','generatorHighDoseToLowDose', ... 'discriminatorLowDose','discriminatorHighDose'); end % Save the final model modelDateTime = string(datetime('now','Format',"yyyy-MM-dd-HH-mm-ss")); save(checkpointDir+filesep+"trainedLD2HDCTCycleGANNet-"+modelDateTime+".mat", ... 'generatorLowDoseToHighDose','generatorHighDoseToLowDose', ... 'discriminatorLowDose','discriminatorHighDose'); else trainedCycleGANNetURL = "https://www.mathworks.com/supportfiles/vision/data/trainedLD2HDCTCycleGANNet.mat"; netDir = fullfile(tempdir,"LD2HDCT"); downloadTrainedLD2HDCTCycleGANNet(trainedCycleGANNetURL,netDir); load(fullfile(netDir,"trainedLD2HDCTCycleGANNet.mat")); end
Задайте количество тестовых изображений, чтобы использовать для вычисления метрик качества. Случайным образом выберите два тестовых изображения, чтобы отобразиться.
numTest = timdsTestLD.numpartitions; numImagesToDisplay = 2; idxImagesToDisplay = randi(numTest,1,numImagesToDisplay);
Инициализируйте переменные, чтобы вычислить PSNR и SSIM.
origPSNR = zeros(numTest,1); generatedPSNR = zeros(numTest,1); origSSIM = zeros(numTest,1); generatedSSIM = zeros(numTest,1);
Чтобы сгенерировать новые переведенные изображения, используйте predict
(Deep Learning Toolbox) функция. Считайте изображения из набора тестовых данных и используйте обученные генераторы, чтобы сгенерировать новые изображения.
for idx = 1:numTest imageTestLD = read(timdsTestLD); imageTestHD = read(timdsTestHD); imageTestLD = cat(4,imageTestLD{1}); imageTestHD = cat(4,imageTestHD{1}); % Convert mini-batch of data to dlarray and specify the dimension labels % 'SSCB' (spatial, spatial, channel, batch) imageTestLD = dlarray(imageTestLD,'SSCB'); imageTestHD = dlarray(imageTestHD,'SSCB'); % If running on a GPU, then convert data to gpuArray if canUseGPU imageTestLD = gpuArray(imageTestLD); imageTestHD = gpuArray(imageTestHD); end % Generate translated images generatedImageHD = predict(generatorLowDoseToHighDose,imageTestLD); generatedImageLD = predict(generatorHighDoseToLowDose,imageTestHD); % Display a few images to visualize the network responses if ismember(idx,idxImagesToDisplay) figure origImLD = rescale(extractdata(imageTestLD)); genImHD = rescale(extractdata(generatedImageHD)); montage({origImLD,genImHD},Size=[1 2],BorderSize=5) title("Original LDCT Test Image "+idx+" (Left), Generated HDCT Image (Right)") end origPSNR(idx) = psnr(imageTestLD,imageTestHD); generatedPSNR(idx) = psnr(generatedImageHD,imageTestHD); origSSIM(idx) = multissim(imageTestLD,imageTestHD); generatedSSIM(idx) = multissim(generatedImageHD,imageTestHD); end
Вычислите средний PSNR исходных и сгенерированных изображений. Большее значение PSNR указывает на лучшее качество изображения.
disp("Average PSNR of original images: "+mean(origPSNR,"all"));
Average PSNR of original images: 20.4045
disp("Average PSNR of generated images: "+mean(generatedPSNR,"all"));
Average PSNR of generated images: 27.9155
Вычислите средний SSIM исходных и сгенерированных изображений. Значение SSIM ближе к 1 указывает на лучшее качество изображения.
disp("Average SSIM of original images: "+mean(origSSIM,"all"));
Average SSIM of original images: 0.76651
disp("Average SSIM of generated images: "+mean(generatedSSIM,"all"));
Average SSIM of generated images: 0.90194
Функциональный modelGradients
берет в качестве входа два генератора и различитель dlnetwork
объекты и мини-пакет входных данных. Функция возвращает градиенты потери относительно настраиваемых параметров в сетях и множестве этих четырех сетей. Поскольку различитель выходные параметры не находится в области значений [0, 1], modelGradients
функция применяет сигмоидальную функцию, чтобы преобразовать различитель выходные параметры в баллы вероятности.
function [genHD2LDGrad,genLD2HDGrad,discLDGrad,discHDGrad, ... genHD2LDState,genLD2HDState,scores,imagesOutLDAndHDGenerated,imagesOutHDAndLDGenerated] = ... modelGradients(genLD2HD,genHD2LD,discLD,discHD,imageHD,imageLD,lambda) % Translate images from one domain to another: low-dose to high-dose and % vice versa [imageLDGenerated,genHD2LDState] = forward(genHD2LD,imageHD); [imageHDGenerated,genLD2HDState] = forward(genLD2HD,imageLD); % Calculate predictions for real images in each domain by the corresponding % discriminator networks predRealLD = forward(discLD,imageLD); predRealHD = forward(discHD,imageHD); % Calculate predictions for generated images in each domain by the % corresponding discriminator networks predGeneratedLD = forward(discLD,imageLDGenerated); predGeneratedHD = forward(discHD,imageHDGenerated); % Calculate discriminator losses for real images discLDLossReal = lossReal(predRealLD); discHDLossReal = lossReal(predRealHD); % Calculate discriminator losses for generated images discLDLossGenerated = lossGenerated(predGeneratedLD); discHDLossGenerated = lossGenerated(predGeneratedHD); % Calculate total discriminator loss for each discriminator network discLDLossTotal = 0.5*(discLDLossReal + discLDLossGenerated); discHDLossTotal = 0.5*(discHDLossReal + discHDLossGenerated); % Calculate generator loss for generated images genLossHD2LD = lossReal(predGeneratedLD); genLossLD2HD = lossReal(predGeneratedHD); % Complete the round-trip (cycle consistency) outputs by applying the % generator to each generated image to get the images in the corresponding % original domains cycleImageLD2HD2LD = forward(genHD2LD,imageHDGenerated); cycleImageHD2LD2HD = forward(genLD2HD,imageLDGenerated); % Calculate cycle consistency loss between real and generated images cycleLossLD2HD2LD = cycleConsistencyLoss(imageLD,cycleImageLD2HD2LD,lambda); cycleLossHD2LD2HD = cycleConsistencyLoss(imageHD,cycleImageHD2LD2HD,lambda); % Calculate identity outputs identityImageLD = forward(genHD2LD,imageLD); identityImageHD = forward(genLD2HD,imageHD); % Calculate fidelity loss (SSIM) between the identity outputs fidelityLossLD = mean(1-multissim(identityImageLD,imageLD),"all"); fidelityLossHD = mean(1-multissim(identityImageHD,imageHD),"all"); % Calculate total generator loss genLossTotal = genLossHD2LD + cycleLossHD2LD2HD + ... genLossLD2HD + cycleLossLD2HD2LD + fidelityLossLD + fidelityLossHD; % Calculate scores of generators genHD2LDScore = mean(sigmoid(predGeneratedLD),"all"); genLD2HDScore = mean(sigmoid(predGeneratedHD),"all"); % Calculate scores of discriminators discLDScore = 0.5*mean(sigmoid(predRealLD),"all") + ... 0.5*mean(1-sigmoid(predGeneratedLD),"all"); discHDScore = 0.5*mean(sigmoid(predRealHD),"all") + ... 0.5*mean(1-sigmoid(predGeneratedHD),"all"); % Combine scores into cell array scores = {genHD2LDScore,genLD2HDScore,discLDScore,discHDScore}; % Calculate gradients of generators genLD2HDGrad = dlgradient(genLossTotal,genLD2HD.Learnables,'RetainData',true); genHD2LDGrad = dlgradient(genLossTotal,genHD2LD.Learnables,'RetainData',true); % Calculate gradients of discriminators discLDGrad = dlgradient(discLDLossTotal,discLD.Learnables,'RetainData',true); discHDGrad = dlgradient(discHDLossTotal,discHD.Learnables); % Return mini-batch of images transforming low-dose CT into high-dose CT imagesOutLDAndHDGenerated = {imageLD,imageHDGenerated}; % Return mini-batch of images transforming high-dose CT into low-dose CT imagesOutHDAndLDGenerated = {imageHD,imageLDGenerated}; end
Задайте функции потерь MSE для действительных и сгенерированных изображений.
function loss = lossReal(predictions) loss = mean((1-predictions).^2,"all"); end function loss = lossGenerated(predictions) loss = mean((predictions).^2,"all"); end
Задайте функции потерь непротиворечивости цикла для действительных и сгенерированных изображений.
function loss = cycleConsistencyLoss(imageReal,imageGenerated,lambda) loss = mean(abs(imageReal-imageGenerated),"all") * lambda; end
[1] Чжу, июнь-Yan, парк Taesung, Филип Изола и Алексей А. Эфрос. “Непарный Перевод От изображения к изображению Используя Сопоставимые с циклом Соперничающие Сети”. На 2 017 Международных конференциях IEEE по вопросам Компьютерного зрения (ICCV), 2242–51. Венеция: IEEE, 2017. https://doi.org/10.1109/ICCV.2017.244.
[2] МакКалоу, Синтия, Бэйю Чен, Дэвид Р Холмс III, Синьхой Дуань, Жиконг Ю, Лифэн Юй, Шуай Лэн и Джоэл Флетчер. “Низкие Данные об Изображении и Проекции CT Дозы (LDCT и Данные проекции)”. Архив Обработки изображений Рака, 2020. https://doi.org/10.7937/9NPB-2637.
[3] Предоставления EB017095 и EB017185 (Синтия МакКалоу, PI) от национального института биомедицинской обработки изображений и биоинженерии.
[4] Кларк, Кеннет, Брюс Вендт, Кирк Смит, Джон Фреиманн, Джастин Кирби, Пол Коппель, Стивен Мур, и др. “Архив обработки изображений рака (TCIA): Поддержание и Работа Репозиторием Общедоступной информации”. Журнал Цифровой Обработки изображений 26, № 6 (декабрь 2013): 1045–57. https://doi.org/10.1007/s10278-013-9622-7.
[5] Вы, Chenyu, Кингсонг Янг, Хонгминг Шань, Ларс Гджестеби, Гуан Ли, Шенгонг Джу, Цзхойян Чжан, и др. “Структурно чувствительная Многошкальная Глубокая нейронная сеть для Шумоподавления CT Низкой Дозы”. IEEE доступ 6 (2018): 41839–55. https://doi.org/10.1109/ACCESS.2018.2858196.
cycleGANGenerator
| patchGANDiscriminator
| transform
| combine
| minibatchqueue
(Deep Learning Toolbox) | dlarray
(Deep Learning Toolbox) | dlfeval
(Deep Learning Toolbox) | adamupdate
(Deep Learning Toolbox)