Безнадзорное медицинское шумоподавление изображений Используя CycleGAN

В этом примере показано, как сгенерировать высококачественные изображения компьютерной томографии (CT) большей дозы от шумных изображений CT низкой дозы с помощью нейронной сети CycleGAN.

Этот пример использует сопоставимую с циклом порождающую соперничающую сеть (CycleGAN), обученный на закрашенных фигурах данных изображения от большой выборки данных. Для аналогичного подхода с помощью МОДУЛЬНОЙ нейронной сети, обученной на полных образах от ограниченной выборки данных, смотрите, что Безнадзорное Медицинское Шумоподавление Изображений Использует CycleGAN.

CT рентгена является популярной модальностью обработки изображений, используемой в клиническом и промышленном применении, потому что он производит высококачественные изображения и предлагает превосходящие диагностические возможности. Чтобы защитить безопасность пациентов, клиницисты рекомендуют низкую дозу излучения. Однако низкая доза излучения приводит к более низкому отношению сигнал-шум (SNR) в изображениях, и поэтому уменьшает диагностическую точность.

Методы глубокого обучения предлагают решения улучшить качество изображения для CT низкой дозы (LDCT) изображения. Используя порождающую соперничающую сеть (GAN) для перевода от изображения к изображению, можно преобразовать шумные изображения LDCT в изображения того же качества как изображения CT регулярной дозы. Для этого приложения исходная область состоит из изображений LDCT, и целевая область состоит из изображений регулярной дозы.

Шумоподавление CT изображений требует GAN, который выполняет безнадзорное обучение, потому что клиницисты обычно не получают соответствие с парами низкой дозы и изображениями CT регулярной дозы того же пациента на том же сеансе. Этот пример использует архитектуру CycleGAN, которая поддерживает безнадзорное обучение. Для получения дополнительной информации смотрите Начало работы с GANs для Перевода От изображения к изображению.

Загрузите набор данных LDCT

Этот пример использует данные из Низкого CT Дозы Главная проблема [2, 3, 4]. Данные включают пары изображений CT регулярной дозы, и симулированные изображения CT низкой дозы для 99 главных сканов (пометил N для neuro), 100 сканов грудной клетки (пометил C для груди), и 100 сканов живота (пометил L для печени). Размер набора данных составляет 1,2 Тбайта.

Установите dataDir как желаемое местоположение набора данных.

dataDir = fullfile(tempdir,"LDCT","LDCT-and-Projection-data");

Чтобы загрузить данные, перейдите к веб-сайту Архива Обработки изображений Рака. Этот пример использует только изображения из груди. Загрузите файлы грудной клетки с "Изображений (DICOM, 952 Гбайт)" набор данных в директорию, заданную dataDir использование Ретривера Данных NBIA. Когда загрузка успешна, dataDir содержит 50 подпапок с именами, такими как "C002" и "C004", заканчивающийся "C296".

Создайте хранилища данных для обучения, валидации и тестирования

Набор данных LDCT обеспечивает пары низкой дозы и изображений CT большей дозы. Однако архитектура CycleGAN требует непарных данных для безнадзорного изучения. Этот пример симулирует непарные данные об обучении и валидации путем разделения изображений, таким образом, что пациенты раньше получали CT низкой дозы, и изображения CT большей дозы не перекрываются. Пример сохраняет пары низкой дозы и изображений регулярной дозы для тестирования.

Разделите данные в обучение, валидацию и наборы тестовых данных с помощью createLDCTFolderList функция помощника. Эта функция присоединена к примеру как вспомогательный файл. Функция помощника разделяет данные, таким образом, что существует примерно хорошее представление двух типов изображений в каждой группе. Приблизительно 80% данных используются для обучения, 15% используется для тестирования, и 5% используются для валидации.

maxDirsForABodyPart = 25;
[filesTrainLD,filesTrainHD,filesTestLD,filesTestHD,filesValLD,filesValHD] = ...
    createLDCTFolderList(dataDir,maxDirsForABodyPart);

Создайте хранилища данных изображений, которые содержат изображения обучения и валидации для обеих областей, а именно, изображения CT низкой дозы и изображения CT большей дозы. Набор данных состоит из изображений DICOM, так используйте пользовательский ReadFcn аргумент значения имени в imageDatastore позволять считать данные.

exts = {'.dcm'};
readFcn = @(x)dicomread(x);
imdsTrainLD = imageDatastore(filesTrainLD,FileExtensions=exts,ReadFcn=readFcn);
imdsTrainHD = imageDatastore(filesTrainHD,FileExtensions=exts,ReadFcn=readFcn);
imdsValLD = imageDatastore(filesValLD,FileExtensions=exts,ReadFcn=readFcn);
imdsValHD = imageDatastore(filesValHD,FileExtensions=exts,ReadFcn=readFcn);
imdsTestLD = imageDatastore(filesTestLD,FileExtensions=exts,ReadFcn=readFcn);
imdsTestHD = imageDatastore(filesTestHD,FileExtensions=exts,ReadFcn=readFcn);

Количество низкой дозы и изображений большей дозы может отличаться. Выберите подмножество файлов, таким образом, что количество изображений равно.

numTrain = min(numel(imdsTrainLD.Files),numel(imdsTrainHD.Files));
imdsTrainLD = subset(imdsTrainLD,1:numTrain);
imdsTrainHD = subset(imdsTrainHD,1:numTrain);

numVal = min(numel(imdsValLD.Files),numel(imdsValHD.Files));
imdsValLD = subset(imdsValLD,1:numVal);
imdsValHD = subset(imdsValHD,1:numVal);

numTest = min(numel(imdsTestLD.Files),numel(imdsTestHD.Files));
imdsTestLD = subset(imdsTestLD,1:numTest);
imdsTestHD = subset(imdsTestHD,1:numTest);

Предварительно обработайте и увеличьте данные

Предварительно обработайте данные при помощи transform функция с пользовательскими операциями предварительной обработки, заданными normalizeCTImages функция помощника. Эта функция присоединена к примеру как вспомогательный файл. normalizeCTImages функция перемасштабирует данные к области значений [-1, 1].

timdsTrainLD = transform(imdsTrainLD,@(x){normalizeCTImages(x)});
timdsTrainHD = transform(imdsTrainHD,@(x){normalizeCTImages(x)});
timdsValLD = transform(imdsValLD,@(x){normalizeCTImages(x)});
timdsValHD  = transform(imdsValHD,@(x){normalizeCTImages(x)});
timdsTestLD = transform(imdsTestLD,@(x){normalizeCTImages(x)});
timdsTestHD  = transform(imdsTestHD,@(x){normalizeCTImages(x)});

Объедините низкую дозу и обучающие данные большей дозы при помощи randomPatchExtractionDatastore. При чтении из этого datastore увеличьте данные с помощью случайного вращения и горизонтального отражения.

inputSize = [128,128,1];
augmenter = imageDataAugmenter(RandRotation=@()90*(randi([0,1],1)),RandXReflection=true);
dsTrain = randomPatchExtractionDatastore(timdsTrainLD,timdsTrainHD, ...
    inputSize(1:2),PatchesPerImage=16,DataAugmentation=augmenter);

Объедините данные о валидации при помощи randomPatchExtractionDatastore. Вы не должны выполнять увеличение при чтении данных о валидации.

dsVal = randomPatchExtractionDatastore(timdsValLD,timdsValHD,inputSize(1:2));

Визуализируйте набор данных

Посмотрите на некоторых низкая доза и пары закрашенной фигуры большей дозы изображений от набора обучающих данных. Заметьте, что пары изображений низкой дозы (слева) и большей дозы (справа) отображают, являются непарными, как они от различных пациентов.

numImagePairs = 6;
imagePairsTrain = [];
for i = 1:numImagePairs
    imLowAndHighDose = read(dsTrain);
    inputImage = imLowAndHighDose.InputImage{1};
    inputImage = rescale(im2single(inputImage));
    responseImage = imLowAndHighDose.ResponseImage{1};
    responseImage = rescale(im2single(responseImage));
    imagePairsTrain = cat(4,imagePairsTrain,inputImage,responseImage);
end
montage(imagePairsTrain,Size=[numImagePairs 2],BorderSize=4,BackgroundColor="w")

Обработайте данные об обучении и валидации в пакетном режиме во время обучения

Этот пример использует пользовательский учебный цикл. minibatchqueue Объект (Deep Learning Toolbox) полезен для управления мини-пакетная обработка наблюдений в пользовательских учебных циклах. minibatchqueue возразите также бросает данные к dlarray объект, который включает автоматическое дифференцирование в применении глубокого обучения.

Обработайте мини-пакеты путем конкатенации закрашенных фигур изображений по пакетному измерению с помощью функции помощника concatenateMiniBatchLD2HDCTЭта функция присоединена к примеру как вспомогательный файл. Задайте мини-пакетный формат экстракции данных как "SSCB" (пространственный, пространственный, канал, пакет). Отбросьте любые частичные мини-пакеты с меньше, чем miniBatchSize наблюдения.

miniBatchSize = 32;

mbqTrain = minibatchqueue(dsTrain, ...
    MiniBatchSize=miniBatchSize, ...
    MiniBatchFcn=@concatenateMiniBatchLD2HDCT, ...
    PartialMiniBatch="discard", ...
    MiniBatchFormat="SSCB");
mbqVal = minibatchqueue(dsVal, ...
    MiniBatchSize=miniBatchSize, ...
    MiniBatchFcn=@concatenateMiniBatchLD2HDCT, ...
    PartialMiniBatch="discard", ...
    MiniBatchFormat="SSCB");

Создайте сети генератора и различителя

CycleGAN состоит из двух генераторов и двух различителей. Генераторы выполняют перевод от изображения к изображению от низкой дозы до большей дозы и наоборот. Различители являются сетями PatchGAN, которые возвращают мудрую закрашенной фигурой вероятность, что входные данные действительны или сгенерированы. Один различитель различает действительные и сгенерированные изображения низкой дозы, и другой различитель различает действительные и сгенерированные изображения большей дозы.

Создайте каждую сеть генератора использование cycleGANGenerator функция. Для входного размера 256 256 пикселей задайте NumResidualBlocks аргумент как 9. По умолчанию функция имеет 3 модуля энкодера и использует 64, просачивается первый сверточный слой.

numResiduals = 6; 
genHD2LD = cycleGANGenerator(inputSize,NumResidualBlocks=numResiduals,NumOutputChannels=1);
genLD2HD = cycleGANGenerator(inputSize,NumResidualBlocks=numResiduals,NumOutputChannels=1);

Создайте каждую сеть различителя использование patchGANDiscriminator функция. Используйте настройки по умолчанию для количества субдискретизации блоков, и количество просачивается первый сверточный слой в различителях.

discLD = patchGANDiscriminator(inputSize);
discHD = patchGANDiscriminator(inputSize);

Задайте функции потерь и баллы

modelGradients функция помощника вычисляет градиенты и потери для различителей и генераторов. Эта функция задана в разделе Supporting Functions этого примера.

Цель генератора состоит в том, чтобы сгенерировать переведенные изображения, которые различители классифицируют как действительные. Потеря генератора является взвешенной суммой трех типов потерь: соперничающая потеря, потеря непротиворечивости цикла и потеря точности. Потеря точности основана на структурном подобии (SSIM) потеря.

LTotal=LAdversarial+λ*LCycleconsistency+LFidelity

Задайте фактор взвешивания λ это управляет относительной значимостью потери непротиворечивости цикла с соперничающими потерями и потерями точности.

lambda = 10;

Цель каждого различителя состоит в том, чтобы правильно различать действительные изображения (1) и переведенные изображения (0) для изображений в его области. Каждый различитель имеет одну функцию потерь, которая использует среднеквадратическую ошибку (MSE) между ожидаемым и предсказанным выходом.

Задайте опции обучения

Обучайтесь с мини-пакетным размером 32 в течение 3 эпох.

numEpochs = 3;
miniBatchSize = 32;

Задайте опции для оптимизации Адама. И для генератора и для сетей различителя, используйте:

  • Скорость обучения 0,0002

  • Фактор затухания градиента 0,5

  • Градиент в квадрате затухает фактор 0,999

learnRate = 0.0002;
gradientDecay = 0.5;
squaredGradientDecayFactor = 0.999;

Инициализируйте параметры Адама для генераторов и различителей.

avgGradGenLD2HD = [];
avgSqGradGenLD2HD = [];
avgGradGenHD2LD = [];
avgSqGradGenHD2LD = [];
avgGradDiscLD = [];
avgSqGradDiscLD = [];
avgGradDiscHD = [];
avgSqGradDiscHD = [];

Отобразитесь сгенерированная валидация отображает каждые 100 итераций.

validationFrequency = 100;

Обучите или загрузите модель

По умолчанию пример загружает предварительно обученную версию генератора CycleGAN для низкой дозы к CT большей дозы при помощи функции помощника downloadTrainedLD2HDCTCycleGANNet. Функция помощника присоединена к примеру как к вспомогательному файлу. Предварительно обученная сеть позволяет вам запустить целый пример, не ожидая обучения завершиться.

Чтобы обучить сеть, установите doTraining переменная в следующем коде к true. Обучите модель в пользовательском учебном цикле. Для каждой итерации:

  • Считайте данные для текущего мини-пакета с помощью next (Deep Learning Toolbox) функция.

  • Оцените градиенты модели с помощью dlfeval (Deep Learning Toolbox) функция и modelGradients функция помощника.

  • Обновите сетевые параметры с помощью adamupdate (Deep Learning Toolbox) функция.

  • Отобразите вход и переведенные изображения для обоих входные и выходные области после каждой эпохи.

Обучайтесь на графическом процессоре, если вы доступны. Используя графический процессор требует Parallel Computing Toolbox™, и CUDA® включил NVIDIA® графический процессор. Для получения дополнительной информации смотрите Поддержку графического процессора Релизом (Parallel Computing Toolbox). Обучение занимает приблизительно 30 часов на Титане NVIDIA™ X с 24 Гбайт памяти графического процессора.

doTraining = false;
if doTraining

    iteration = 0;
    start = tic;

    % Create a directory to store checkpoints
    checkpointDir = fullfile(dataDir,"checkpoints");
    if ~exist(checkpointDir,"dir")
        mkdir(checkpointDir);
    end

    % Initialize plots for training progress
    [figureHandle,tileHandle,imageAxes,scoreAxesX,scoreAxesY, ...
        lineScoreGenLD2HD,lineScoreGenD2LD, ...
        lineScoreDiscHD,lineScoreDiscLD] = initializeTrainingPlotLD2HDCT;

    for epoch = 1:numEpochs

        shuffle(mbqTrain);

        % Loop over mini-batches
        while hasdata(mbqTrain)
            iteration = iteration + 1;

            % Read mini-batch of data
            [imageLD,imageHD] = next(mbqTrain);

            % Convert mini-batch of data to dlarray and specify the dimension labels
            % "SSCB" (spatial, spatial, channel, batch)
            imageLD = dlarray(imageLD,"SSCB");
            imageHD = dlarray(imageHD,"SSCB");

            % If training on a GPU, then convert data to gpuArray
            if canUseGPU
                imageLD = gpuArray(imageLD);
                imageHD = gpuArray(imageHD);
            end

            % Calculate the loss and gradients
            [genHD2LDGrad,genLD2HDGrad,discrXGrad,discYGrad, ...
                genHD2LDState,genLD2HDState,scores,imagesOutLD2HD,imagesOutHD2LD] = ...
                dlfeval(@modelGradients,genLD2HD,genHD2LD, ...
                discLD,discHD,imageHD,imageLD,lambda);
            genHD2LD.State = genHD2LDState;
            genLD2HD.State = genLD2HDState;

            % Update parameters of discLD, which distinguishes
            % the generated low-dose CT images from real low-dose CT images
            [discLD.Learnables,avgGradDiscLD,avgSqGradDiscLD] = ...
                adamupdate(discLD.Learnables,discrXGrad,avgGradDiscLD, ...
                avgSqGradDiscLD,iteration,learnRate,gradientDecay,squaredGradientDecayFactor);

            % Update parameters of discHD, which distinguishes
            % the generated high-dose CT images from real high-dose CT images
            [discHD.Learnables,avgGradDiscHD,avgSqGradDiscHD] = ...
                adamupdate(discHD.Learnables,discYGrad,avgGradDiscHD, ...
                avgSqGradDiscHD,iteration,learnRate,gradientDecay,squaredGradientDecayFactor);

            % Update parameters of genHD2LD, which
            % generates low-dose CT images from high-dose CT images
            [genHD2LD.Learnables,avgGradGenHD2LD,avgSqGradGenHD2LD] = ...
                adamupdate(genHD2LD.Learnables,genHD2LDGrad,avgGradGenHD2LD, ...
                avgSqGradGenHD2LD,iteration,learnRate,gradientDecay,squaredGradientDecayFactor);
                        
            % Update parameters of genLD2HD, which
            % generates high-dose CT images from low-dose CT images
            [genLD2HD.Learnables,avgGradGenLD2HD,avgSqGradGenLD2HD] = ...
                adamupdate(genLD2HD.Learnables,genLD2HDGrad,avgGradGenLD2HD, ...
                avgSqGradGenLD2HD,iteration,learnRate,gradientDecay,squaredGradientDecayFactor);
                        
            % Update the plots of network scores
            updateTrainingPlotLD2HDCT(scores,iteration,epoch,start,scoreAxesX,scoreAxesY,...
                lineScoreGenLD2HD,lineScoreGenD2LD, ...
                lineScoreDiscHD,lineScoreDiscLD)

            %  Every validationFrequency iterations, display a batch of
            %  generated images using the held-out generator input
            if mod(iteration,validationFrequency) == 0 || iteration == 1
                 displayGeneratedLD2HDCTImages(mbqVal,imageAxes,genLD2HD,genHD2LD);
            end
        end

        % Save the model after each epoch
        if canUseGPU
            [genLD2HD,genHD2LD,discLD,discHD] = ...
                gather(genLD2HD,genHD2LD,discLD,discHD);
        end
        generatorHighDoseToLowDose = genHD2LD;
        generatorLowDoseToHighDose = genLD2HD;
        discriminatorLowDose = discLD;
        discriminatorHighDose = discHD;    
        modelDateTime = string(datetime("now",Format="yyyy-MM-dd-HH-mm-ss"));
        save(checkpointDir+filesep+"LD2HDCTCycleGAN-"+modelDateTime+"-Epoch-"+epoch+".mat", ...
            'generatorLowDoseToHighDose','generatorHighDoseToLowDose', ...
            'discriminatorLowDose','discriminatorHighDose');
    end
    
    % Save the final model
    modelDateTime = string(datetime('now','Format',"yyyy-MM-dd-HH-mm-ss"));
    save(checkpointDir+filesep+"trainedLD2HDCTCycleGANNet-"+modelDateTime+".mat", ...
        'generatorLowDoseToHighDose','generatorHighDoseToLowDose', ...
        'discriminatorLowDose','discriminatorHighDose');

else
    trainedCycleGANNetURL = "https://www.mathworks.com/supportfiles/vision/data/trainedLD2HDCTCycleGANNet.mat";
    netDir = fullfile(tempdir,"LD2HDCT");
    downloadTrainedLD2HDCTCycleGANNet(trainedCycleGANNetURL,netDir);
    load(fullfile(netDir,"trainedLD2HDCTCycleGANNet.mat"));
end

Сгенерируйте новые изображения Используя тестовые данные

Задайте количество тестовых изображений, чтобы использовать для вычисления метрик качества. Случайным образом выберите два тестовых изображения, чтобы отобразиться.

numTest = timdsTestLD.numpartitions;
numImagesToDisplay = 2;
idxImagesToDisplay = randi(numTest,1,numImagesToDisplay);

Инициализируйте переменные, чтобы вычислить PSNR и SSIM.

origPSNR = zeros(numTest,1);
generatedPSNR = zeros(numTest,1);
origSSIM = zeros(numTest,1);
generatedSSIM = zeros(numTest,1);

Чтобы сгенерировать новые переведенные изображения, используйте predict (Deep Learning Toolbox) функция. Считайте изображения из набора тестовых данных и используйте обученные генераторы, чтобы сгенерировать новые изображения.

for idx = 1:numTest
    imageTestLD = read(timdsTestLD);
    imageTestHD = read(timdsTestHD);
    
    imageTestLD = cat(4,imageTestLD{1});
    imageTestHD = cat(4,imageTestHD{1});

    % Convert mini-batch of data to dlarray and specify the dimension labels
    % 'SSCB' (spatial, spatial, channel, batch)
    imageTestLD = dlarray(imageTestLD,'SSCB');
    imageTestHD = dlarray(imageTestHD,'SSCB');

    % If running on a GPU, then convert data to gpuArray
    if canUseGPU
        imageTestLD = gpuArray(imageTestLD);
        imageTestHD = gpuArray(imageTestHD);
    end

    % Generate translated images
    generatedImageHD = predict(generatorLowDoseToHighDose,imageTestLD);
    generatedImageLD = predict(generatorHighDoseToLowDose,imageTestHD);

    % Display a few images to visualize the network responses
     if ismember(idx,idxImagesToDisplay)
         figure
         origImLD = rescale(extractdata(imageTestLD));
         genImHD = rescale(extractdata(generatedImageHD));
         montage({origImLD,genImHD},Size=[1 2],BorderSize=5)
         title("Original LDCT Test Image "+idx+" (Left), Generated HDCT Image (Right)")
    end
    
    origPSNR(idx) = psnr(imageTestLD,imageTestHD);
    generatedPSNR(idx) = psnr(generatedImageHD,imageTestHD);
    
    origSSIM(idx) = multissim(imageTestLD,imageTestHD);
    generatedSSIM(idx) = multissim(generatedImageHD,imageTestHD);
end

Вычислите средний PSNR исходных и сгенерированных изображений. Большее значение PSNR указывает на лучшее качество изображения.

disp("Average PSNR of original images: "+mean(origPSNR,"all"));
Average PSNR of original images: 20.4045
disp("Average PSNR of generated images: "+mean(generatedPSNR,"all"));
Average PSNR of generated images: 27.9155

Вычислите средний SSIM исходных и сгенерированных изображений. Значение SSIM ближе к 1 указывает на лучшее качество изображения.

disp("Average SSIM of original images: "+mean(origSSIM,"all"));
Average SSIM of original images: 0.76651
disp("Average SSIM of generated images: "+mean(generatedSSIM,"all"));
Average SSIM of generated images: 0.90194

Вспомогательные Функции

Функция градиентов модели

Функциональный modelGradients берет в качестве входа два генератора и различитель dlnetwork объекты и мини-пакет входных данных. Функция возвращает градиенты потери относительно настраиваемых параметров в сетях и множестве этих четырех сетей. Поскольку различитель выходные параметры не находится в области значений [0, 1], modelGradients функция применяет сигмоидальную функцию, чтобы преобразовать различитель выходные параметры в баллы вероятности.

function [genHD2LDGrad,genLD2HDGrad,discLDGrad,discHDGrad, ...
    genHD2LDState,genLD2HDState,scores,imagesOutLDAndHDGenerated,imagesOutHDAndLDGenerated] = ...
    modelGradients(genLD2HD,genHD2LD,discLD,discHD,imageHD,imageLD,lambda)

% Translate images from one domain to another: low-dose to high-dose and
% vice versa
[imageLDGenerated,genHD2LDState] = forward(genHD2LD,imageHD);
[imageHDGenerated,genLD2HDState] = forward(genLD2HD,imageLD);

% Calculate predictions for real images in each domain by the corresponding
% discriminator networks
predRealLD = forward(discLD,imageLD);
predRealHD = forward(discHD,imageHD);

% Calculate predictions for generated images in each domain by the
% corresponding discriminator networks
predGeneratedLD = forward(discLD,imageLDGenerated);
predGeneratedHD = forward(discHD,imageHDGenerated);

% Calculate discriminator losses for real images
discLDLossReal = lossReal(predRealLD);
discHDLossReal = lossReal(predRealHD);

% Calculate discriminator losses for generated images
discLDLossGenerated = lossGenerated(predGeneratedLD);
discHDLossGenerated = lossGenerated(predGeneratedHD);

% Calculate total discriminator loss for each discriminator network
discLDLossTotal = 0.5*(discLDLossReal + discLDLossGenerated);
discHDLossTotal = 0.5*(discHDLossReal + discHDLossGenerated);

% Calculate generator loss for generated images
genLossHD2LD = lossReal(predGeneratedLD);
genLossLD2HD = lossReal(predGeneratedHD);

% Complete the round-trip (cycle consistency) outputs by applying the
% generator to each generated image to get the images in the corresponding
% original domains
cycleImageLD2HD2LD = forward(genHD2LD,imageHDGenerated);
cycleImageHD2LD2HD = forward(genLD2HD,imageLDGenerated);

% Calculate cycle consistency loss between real and generated images
cycleLossLD2HD2LD = cycleConsistencyLoss(imageLD,cycleImageLD2HD2LD,lambda);
cycleLossHD2LD2HD = cycleConsistencyLoss(imageHD,cycleImageHD2LD2HD,lambda);

% Calculate identity outputs
identityImageLD = forward(genHD2LD,imageLD);
identityImageHD = forward(genLD2HD,imageHD);
 
% Calculate fidelity loss (SSIM) between the identity outputs
fidelityLossLD = mean(1-multissim(identityImageLD,imageLD),"all");
fidelityLossHD = mean(1-multissim(identityImageHD,imageHD),"all");

% Calculate total generator loss
genLossTotal = genLossHD2LD + cycleLossHD2LD2HD + ...
    genLossLD2HD + cycleLossLD2HD2LD + fidelityLossLD + fidelityLossHD;

% Calculate scores of generators
genHD2LDScore = mean(sigmoid(predGeneratedLD),"all");
genLD2HDScore = mean(sigmoid(predGeneratedHD),"all");

% Calculate scores of discriminators
discLDScore = 0.5*mean(sigmoid(predRealLD),"all") + ...
    0.5*mean(1-sigmoid(predGeneratedLD),"all");
discHDScore = 0.5*mean(sigmoid(predRealHD),"all") + ...
    0.5*mean(1-sigmoid(predGeneratedHD),"all");

% Combine scores into cell array
scores = {genHD2LDScore,genLD2HDScore,discLDScore,discHDScore};

% Calculate gradients of generators
genLD2HDGrad = dlgradient(genLossTotal,genLD2HD.Learnables,'RetainData',true);
genHD2LDGrad = dlgradient(genLossTotal,genHD2LD.Learnables,'RetainData',true);

% Calculate gradients of discriminators
discLDGrad = dlgradient(discLDLossTotal,discLD.Learnables,'RetainData',true);
discHDGrad = dlgradient(discHDLossTotal,discHD.Learnables);

% Return mini-batch of images transforming low-dose CT into high-dose CT
imagesOutLDAndHDGenerated = {imageLD,imageHDGenerated};

% Return mini-batch of images transforming high-dose CT into low-dose CT
imagesOutHDAndLDGenerated = {imageHD,imageLDGenerated};
end

Функции потерь

Задайте функции потерь MSE для действительных и сгенерированных изображений.

function loss = lossReal(predictions)
    loss = mean((1-predictions).^2,"all");
end

function loss = lossGenerated(predictions)
    loss = mean((predictions).^2,"all");
end

Задайте функции потерь непротиворечивости цикла для действительных и сгенерированных изображений.

function loss = cycleConsistencyLoss(imageReal,imageGenerated,lambda)
    loss = mean(abs(imageReal-imageGenerated),"all") * lambda;
end

Ссылки

[1] Чжу, июнь-Yan, парк Taesung, Филип Изола и Алексей А. Эфрос. “Непарный Перевод От изображения к изображению Используя Сопоставимые с циклом Соперничающие Сети”. На 2 017 Международных конференциях IEEE по вопросам Компьютерного зрения (ICCV), 2242–51. Венеция: IEEE, 2017. https://doi.org/10.1109/ICCV.2017.244.

[2] МакКалоу, Синтия, Бэйю Чен, Дэвид Р Холмс III, Синьхой Дуань, Жиконг Ю, Лифэн Юй, Шуай Лэн и Джоэл Флетчер. “Низкие Данные об Изображении и Проекции CT Дозы (LDCT и Данные проекции)”. Архив Обработки изображений Рака, 2020. https://doi.org/10.7937/9NPB-2637.

[3] Предоставления EB017095 и EB017185 (Синтия МакКалоу, PI) от национального института биомедицинской обработки изображений и биоинженерии.

[4] Кларк, Кеннет, Брюс Вендт, Кирк Смит, Джон Фреиманн, Джастин Кирби, Пол Коппель, Стивен Мур, и др. “Архив обработки изображений рака (TCIA): Поддержание и Работа Репозиторием Общедоступной информации”. Журнал Цифровой Обработки изображений 26, № 6 (декабрь 2013): 1045–57. https://doi.org/10.1007/s10278-013-9622-7.

[5] Вы, Chenyu, Кингсонг Янг, Хонгминг Шань, Ларс Гджестеби, Гуан Ли, Шенгонг Джу, Цзхойян Чжан, и др. “Структурно чувствительная Многошкальная Глубокая нейронная сеть для Шумоподавления CT Низкой Дозы”. IEEE доступ 6 (2018): 41839–55. https://doi.org/10.1109/ACCESS.2018.2858196.

Смотрите также

| | | | (Deep Learning Toolbox) | (Deep Learning Toolbox) | (Deep Learning Toolbox) | (Deep Learning Toolbox)

Связанные примеры

Больше о