Безнадзорное медицинское шумоподавление изображений Используя CycleGAN

В этом примере показано, как сгенерировать высококачественные изображения компьютерной томографии (CT) большей дозы от шумных изображений CT низкой дозы с помощью нейронной сети CycleGAN.

Этот пример использует сопоставимую с циклом порождающую соперничающую сеть (CycleGAN), обученный на закрашенных фигурах данных изображения от большой выборки данных. Для аналогичного подхода с помощью МОДУЛЬНОЙ нейронной сети, обученной на полных образах от ограниченной выборки данных, смотрите, что Безнадзорное Медицинское Шумоподавление Изображений Использует CycleGAN (Image Processing Toolbox).

CT рентгена является популярной модальностью обработки изображений, используемой в клиническом и промышленном применении, потому что он производит высококачественные изображения и предлагает превосходящие диагностические возможности. Чтобы защитить безопасность пациентов, клиницисты рекомендуют низкую дозу излучения. Однако низкая доза излучения приводит к более низкому отношению сигнал-шум (SNR) в изображениях, и поэтому уменьшает диагностическую точность.

Методы глубокого обучения предлагают решения улучшить качество изображения для CT низкой дозы (LDCT) изображения. Используя порождающую соперничающую сеть (GAN) для перевода от изображения к изображению, можно преобразовать шумные изображения LDCT в изображения того же качества как изображения CT регулярной дозы. Для этого приложения исходная область состоит из изображений LDCT, и целевая область состоит из изображений регулярной дозы.

Шумоподавление CT изображений требует GAN, который выполняет безнадзорное обучение, потому что клиницисты обычно не получают соответствие с парами низкой дозы и изображениями CT регулярной дозы того же пациента на том же сеансе. Этот пример использует архитектуру CycleGAN, которая поддерживает безнадзорное обучение. Для получения дополнительной информации смотрите Начало работы с GANs для Перевода От изображения к изображению (Image Processing Toolbox).

Загрузите набор данных LDCT

Этот пример использует данные из Низкого CT Дозы Главная проблема [2, 3, 4]. Данные включают пары изображений CT регулярной дозы, и симулированные изображения CT низкой дозы для 99 главных сканов (пометил N для neuro), 100 сканов грудной клетки (пометил C для груди), и 100 сканов живота (пометил L для печени). Размер набора данных составляет 1,2 Тбайта.

Установите dataDir как желаемое местоположение набора данных.

dataDir = fullfile(tempdir,"LDCT","LDCT-and-Projection-data");

Чтобы загрузить данные, перейдите к веб-сайту Архива Обработки изображений Рака. Этот пример использует только изображения из груди. Загрузите файлы грудной клетки с "Изображений (DICOM, 952 Гбайт)" набор данных в директорию, заданную dataDir использование Ретривера Данных NBIA. Когда загрузка успешна, dataDir содержит 50 подпапок с именами, такими как "C002" и "C004", заканчивающийся "C296".

Создайте хранилища данных для обучения, валидации и тестирования

Набор данных LDCT обеспечивает пары низкой дозы и изображений CT большей дозы. Однако архитектура CycleGAN требует непарных данных для безнадзорного изучения. Этот пример симулирует непарные данные об обучении и валидации путем разделения изображений, таким образом, что пациенты раньше получали CT низкой дозы, и изображения CT большей дозы не перекрываются. Пример сохраняет пары низкой дозы и изображений регулярной дозы для тестирования.

Разделите данные в обучение, валидацию и наборы тестовых данных с помощью createLDCTFolderList функция помощника. Эта функция присоединена к примеру как вспомогательный файл. Функция помощника разделяет данные, таким образом, что существует примерно хорошее представление двух типов изображений в каждой группе. Приблизительно 80% данных используются для обучения, 15% используется для тестирования, и 5% используются для валидации.

maxDirsForABodyPart = 25;
[filesTrainLD,filesTrainHD,filesTestLD,filesTestHD,filesValLD,filesValHD] = ...
    createLDCTFolderList(dataDir,maxDirsForABodyPart);

Создайте хранилища данных изображений, которые содержат изображения обучения и валидации для обеих областей, а именно, изображения CT низкой дозы и изображения CT большей дозы. Набор данных состоит из изображений DICOM, так используйте пользовательский ReadFcn аргумент значения имени в imageDatastore позволять считать данные.

exts = {'.dcm'};
readFcn = @(x)dicomread(x);
imdsTrainLD = imageDatastore(filesTrainLD,FileExtensions=exts,ReadFcn=readFcn);
imdsTrainHD = imageDatastore(filesTrainHD,FileExtensions=exts,ReadFcn=readFcn);
imdsValLD = imageDatastore(filesValLD,FileExtensions=exts,ReadFcn=readFcn);
imdsValHD = imageDatastore(filesValHD,FileExtensions=exts,ReadFcn=readFcn);
imdsTestLD = imageDatastore(filesTestLD,FileExtensions=exts,ReadFcn=readFcn);
imdsTestHD = imageDatastore(filesTestHD,FileExtensions=exts,ReadFcn=readFcn);

Количество низкой дозы и изображений большей дозы может отличаться. Выберите подмножество файлов, таким образом, что количество изображений равно.

numTrain = min(numel(imdsTrainLD.Files),numel(imdsTrainHD.Files));
imdsTrainLD = subset(imdsTrainLD,1:numTrain);
imdsTrainHD = subset(imdsTrainHD,1:numTrain);

numVal = min(numel(imdsValLD.Files),numel(imdsValHD.Files));
imdsValLD = subset(imdsValLD,1:numVal);
imdsValHD = subset(imdsValHD,1:numVal);

numTest = min(numel(imdsTestLD.Files),numel(imdsTestHD.Files));
imdsTestLD = subset(imdsTestLD,1:numTest);
imdsTestHD = subset(imdsTestHD,1:numTest);

Предварительно обработайте и увеличьте данные

Предварительно обработайте данные при помощи transform функция с пользовательскими операциями предварительной обработки, заданными normalizeCTImages функция помощника. Эта функция присоединена к примеру как вспомогательный файл. normalizeCTImages функция перемасштабирует данные к области значений [-1, 1].

timdsTrainLD = transform(imdsTrainLD,@(x){normalizeCTImages(x)});
timdsTrainHD = transform(imdsTrainHD,@(x){normalizeCTImages(x)});
timdsValLD = transform(imdsValLD,@(x){normalizeCTImages(x)});
timdsValHD  = transform(imdsValHD,@(x){normalizeCTImages(x)});
timdsTestLD = transform(imdsTestLD,@(x){normalizeCTImages(x)});
timdsTestHD  = transform(imdsTestHD,@(x){normalizeCTImages(x)});

Объедините низкую дозу и обучающие данные большей дозы при помощи randomPatchExtractionDatastore (Image Processing Toolbox). При чтении из этого datastore увеличьте данные с помощью случайного вращения и горизонтального отражения.

inputSize = [128,128,1];
augmenter = imageDataAugmenter(RandRotation=@()90*(randi([0,1],1)),RandXReflection=true);
dsTrain = randomPatchExtractionDatastore(timdsTrainLD,timdsTrainHD, ...
    inputSize(1:2),PatchesPerImage=16,DataAugmentation=augmenter);

Объедините данные о валидации при помощи randomPatchExtractionDatastore. Вы не должны выполнять увеличение при чтении данных о валидации.

dsVal = randomPatchExtractionDatastore(timdsValLD,timdsValHD,inputSize(1:2));

Визуализируйте набор данных

Посмотрите на некоторых низкая доза и пары закрашенной фигуры большей дозы изображений от набора обучающих данных. Заметьте, что пары изображений низкой дозы (слева) и большей дозы (справа) отображают, являются непарными, как они от различных пациентов.

numImagePairs = 6;
imagePairsTrain = [];
for i = 1:numImagePairs
    imLowAndHighDose = read(dsTrain);
    inputImage = imLowAndHighDose.InputImage{1};
    inputImage = rescale(im2single(inputImage));
    responseImage = imLowAndHighDose.ResponseImage{1};
    responseImage = rescale(im2single(responseImage));
    imagePairsTrain = cat(4,imagePairsTrain,inputImage,responseImage);
end
montage(imagePairsTrain,Size=[numImagePairs 2],BorderSize=4,BackgroundColor="w")

Обработайте данные об обучении и валидации в пакетном режиме во время обучения

Этот пример использует пользовательский учебный цикл. minibatchqueue объект полезен для управления мини-пакетная обработка наблюдений в пользовательских учебных циклах. minibatchqueue возразите также бросает данные к dlarray объект, который включает автоматическое дифференцирование в применении глубокого обучения.

Обработайте мини-пакеты путем конкатенации закрашенных фигур изображений по пакетному измерению с помощью функции помощника concatenateMiniBatchLD2HDCTЭта функция присоединена к примеру как вспомогательный файл. Задайте мини-пакетный формат экстракции данных как "SSCB" (пространственный, пространственный, канал, пакет). Отбросьте любые частичные мини-пакеты с меньше, чем miniBatchSize наблюдения.

miniBatchSize = 32;

mbqTrain = minibatchqueue(dsTrain, ...
    MiniBatchSize=miniBatchSize, ...
    MiniBatchFcn=@concatenateMiniBatchLD2HDCT, ...
    PartialMiniBatch="discard", ...
    MiniBatchFormat="SSCB");
mbqVal = minibatchqueue(dsVal, ...
    MiniBatchSize=miniBatchSize, ...
    MiniBatchFcn=@concatenateMiniBatchLD2HDCT, ...
    PartialMiniBatch="discard", ...
    MiniBatchFormat="SSCB");

Создайте сети генератора и различителя

CycleGAN состоит из двух генераторов и двух различителей. Генераторы выполняют перевод от изображения к изображению от низкой дозы до большей дозы и наоборот. Различители являются сетями PatchGAN, которые возвращают мудрую закрашенной фигурой вероятность, что входные данные действительны или сгенерированы. Один различитель различает действительные и сгенерированные изображения низкой дозы, и другой различитель различает действительные и сгенерированные изображения большей дозы.

Создайте каждую сеть генератора использование cycleGANGenerator (Image Processing Toolbox) функция. Для входного размера 256 256 пикселей задайте NumResidualBlocks аргумент как 9. По умолчанию функция имеет 3 модуля энкодера и использует 64, просачивается первый сверточный слой.

numResiduals = 6; 
genHD2LD = cycleGANGenerator(inputSize,NumResidualBlocks=numResiduals,NumOutputChannels=1);
genLD2HD = cycleGANGenerator(inputSize,NumResidualBlocks=numResiduals,NumOutputChannels=1);

Создайте каждую сеть различителя использование patchGANDiscriminator (Image Processing Toolbox) функция. Используйте настройки по умолчанию для количества субдискретизации блоков, и количество просачивается первый сверточный слой в различителях.

discLD = patchGANDiscriminator(inputSize);
discHD = patchGANDiscriminator(inputSize);

Задайте функции потерь и баллы

modelGradients функция помощника вычисляет градиенты и потери для различителей и генераторов. Эта функция задана в разделе Supporting Functions этого примера.

Цель генератора состоит в том, чтобы сгенерировать переведенные изображения, которые различители классифицируют как действительные. Потеря генератора является взвешенной суммой трех типов потерь: соперничающая потеря, потеря непротиворечивости цикла и потеря точности. Потеря точности основана на структурном подобии (SSIM) потеря.

LTotal=LAdversarial+λ*LCycleconsistency+LFidelity

Задайте фактор взвешивания λ это управляет относительной значимостью потери непротиворечивости цикла с соперничающими потерями и потерями точности.

lambda = 10;

Цель каждого различителя состоит в том, чтобы правильно различать действительные изображения (1) и переведенные изображения (0) для изображений в его области. Каждый различитель имеет одну функцию потерь, которая использует среднеквадратическую ошибку (MSE) между ожидаемым и предсказанным выходом.

Задайте опции обучения

Обучайтесь с мини-пакетным размером 32 в течение 3 эпох.

numEpochs = 3;
miniBatchSize = 32;

Задайте опции для оптимизации Адама. И для генератора и для сетей различителя, используйте:

  • Скорость обучения 0,0002

  • Фактор затухания градиента 0,5

  • Градиент в квадрате затухает фактор 0,999

learnRate = 0.0002;
gradientDecay = 0.5;
squaredGradientDecayFactor = 0.999;

Инициализируйте параметры Адама для генераторов и различителей.

avgGradGenLD2HD = [];
avgSqGradGenLD2HD = [];
avgGradGenHD2LD = [];
avgSqGradGenHD2LD = [];
avgGradDiscLD = [];
avgSqGradDiscLD = [];
avgGradDiscHD = [];
avgSqGradDiscHD = [];

Отобразитесь сгенерированная валидация отображает каждые 100 итераций.

validationFrequency = 100;

Обучите или загрузите модель

По умолчанию пример загружает предварительно обученную версию генератора CycleGAN для низкой дозы к CT большей дозы при помощи функции помощника downloadTrainedLD2HDCTCycleGANNet. Функция помощника присоединена к примеру как к вспомогательному файлу. Предварительно обученная сеть позволяет вам запустить целый пример, не ожидая обучения завершиться.

Чтобы обучить сеть, установите doTraining переменная в следующем коде к true. Обучите модель в пользовательском учебном цикле. Для каждой итерации:

  • Считайте данные для текущего мини-пакета с помощью next функция.

  • Оцените градиенты модели с помощью dlfeval функционируйте и modelGradients функция помощника.

  • Обновите сетевые параметры с помощью adamupdate функция.

  • Отобразите вход и переведенные изображения для обоих входные и выходные области после каждой эпохи.

Обучайтесь на графическом процессоре, если вы доступны. Используя графический процессор требует Parallel Computing Toolbox™, и CUDA® включил NVIDIA® графический процессор. Для получения дополнительной информации смотрите Поддержку графического процессора Релизом (Parallel Computing Toolbox). Обучение занимает приблизительно 30 часов на Титане NVIDIA™ X с 24 Гбайт памяти графического процессора.

doTraining = false;
if doTraining

    iteration = 0;
    start = tic;

    % Create a directory to store checkpoints
    checkpointDir = fullfile(dataDir,"checkpoints");
    if ~exist(checkpointDir,"dir")
        mkdir(checkpointDir);
    end

    % Initialize plots for training progress
    [figureHandle,tileHandle,imageAxes,scoreAxesX,scoreAxesY, ...
        lineScoreGenLD2HD,lineScoreGenD2LD, ...
        lineScoreDiscHD,lineScoreDiscLD] = initializeTrainingPlotLD2HDCT;

    for epoch = 1:numEpochs

        shuffle(mbqTrain);

        % Loop over mini-batches
        while hasdata(mbqTrain)
            iteration = iteration + 1;

            % Read mini-batch of data
            [imageLD,imageHD] = next(mbqTrain);

            % Convert mini-batch of data to dlarray and specify the dimension labels
            % "SSCB" (spatial, spatial, channel, batch)
            imageLD = dlarray(imageLD,"SSCB");
            imageHD = dlarray(imageHD,"SSCB");

            % If training on a GPU, then convert data to gpuArray
            if canUseGPU
                imageLD = gpuArray(imageLD);
                imageHD = gpuArray(imageHD);
            end

            % Calculate the loss and gradients
            [genHD2LDGrad,genLD2HDGrad,discrXGrad,discYGrad, ...
                genHD2LDState,genLD2HDState,scores,imagesOutLD2HD,imagesOutHD2LD] = ...
                dlfeval(@modelGradients,genLD2HD,genHD2LD, ...
                discLD,discHD,imageHD,imageLD,lambda);
            genHD2LD.State = genHD2LDState;
            genLD2HD.State = genLD2HDState;

            % Update parameters of discLD, which distinguishes
            % the generated low-dose CT images from real low-dose CT images
            [discLD.Learnables,avgGradDiscLD,avgSqGradDiscLD] = ...
                adamupdate(discLD.Learnables,discrXGrad,avgGradDiscLD, ...
                avgSqGradDiscLD,iteration,learnRate,gradientDecay,squaredGradientDecayFactor);

            % Update parameters of discHD, which distinguishes
            % the generated high-dose CT images from real high-dose CT images
            [discHD.Learnables,avgGradDiscHD,avgSqGradDiscHD] = ...
                adamupdate(discHD.Learnables,discYGrad,avgGradDiscHD, ...
                avgSqGradDiscHD,iteration,learnRate,gradientDecay,squaredGradientDecayFactor);

            % Update parameters of genHD2LD, which
            % generates low-dose CT images from high-dose CT images
            [genHD2LD.Learnables,avgGradGenHD2LD,avgSqGradGenHD2LD] = ...
                adamupdate(genHD2LD.Learnables,genHD2LDGrad,avgGradGenHD2LD, ...
                avgSqGradGenHD2LD,iteration,learnRate,gradientDecay,squaredGradientDecayFactor);
                        
            % Update parameters of genLD2HD, which
            % generates high-dose CT images from low-dose CT images
            [genLD2HD.Learnables,avgGradGenLD2HD,avgSqGradGenLD2HD] = ...
                adamupdate(genLD2HD.Learnables,genLD2HDGrad,avgGradGenLD2HD, ...
                avgSqGradGenLD2HD,iteration,learnRate,gradientDecay,squaredGradientDecayFactor);
                        
            % Update the plots of network scores
            updateTrainingPlotLD2HDCT(scores,iteration,epoch,start,scoreAxesX,scoreAxesY,...
                lineScoreGenLD2HD,lineScoreGenD2LD, ...
                lineScoreDiscHD,lineScoreDiscLD)

            %  Every validationFrequency iterations, display a batch of
            %  generated images using the held-out generator input
            if mod(iteration,validationFrequency) == 0 || iteration == 1
                 displayGeneratedLD2HDCTImages(mbqVal,imageAxes,genLD2HD,genHD2LD);
            end
        end

        % Save the model after each epoch
        if canUseGPU
            [genLD2HD,genHD2LD,discLD,discHD] = ...
                gather(genLD2HD,genHD2LD,discLD,discHD);
        end
        generatorHighDoseToLowDose = genHD2LD;
        generatorLowDoseToHighDose = genLD2HD;
        discriminatorLowDose = discLD;
        discriminatorHighDose = discHD;    
        modelDateTime = string(datetime("now",Format="yyyy-MM-dd-HH-mm-ss"));
        save(checkpointDir+filesep+"LD2HDCTCycleGAN-"+modelDateTime+"-Epoch-"+epoch+".mat", ...
            'generatorLowDoseToHighDose','generatorHighDoseToLowDose', ...
            'discriminatorLowDose','discriminatorHighDose');
    end
    
    % Save the final model
    modelDateTime = string(datetime('now','Format',"yyyy-MM-dd-HH-mm-ss"));
    save(checkpointDir+filesep+"trainedLD2HDCTCycleGANNet-"+modelDateTime+".mat", ...
        'generatorLowDoseToHighDose','generatorHighDoseToLowDose', ...
        'discriminatorLowDose','discriminatorHighDose');

else
    trainedCycleGANNetURL = "https://www.mathworks.com/supportfiles/vision/data/trainedLD2HDCTCycleGANNet.mat";
    netDir = fullfile(tempdir,"LD2HDCT");
    downloadTrainedLD2HDCTCycleGANNet(trainedCycleGANNetURL,netDir);
    load(fullfile(netDir,"trainedLD2HDCTCycleGANNet.mat"));
end

Сгенерируйте новые изображения Используя тестовые данные

Задайте количество тестовых изображений, чтобы использовать для вычисления метрик качества. Случайным образом выберите два тестовых изображения, чтобы отобразиться.

numTest = timdsTestLD.numpartitions;
numImagesToDisplay = 2;
idxImagesToDisplay = randi(numTest,1,numImagesToDisplay);

Инициализируйте переменные, чтобы вычислить PSNR и SSIM.

origPSNR = zeros(numTest,1);
generatedPSNR = zeros(numTest,1);
origSSIM = zeros(numTest,1);
generatedSSIM = zeros(numTest,1);

Чтобы сгенерировать новые переведенные изображения, используйте predict функция. Считайте изображения из набора тестовых данных и используйте обученные генераторы, чтобы сгенерировать новые изображения.

for idx = 1:numTest
    imageTestLD = read(timdsTestLD);
    imageTestHD = read(timdsTestHD);
    
    imageTestLD = cat(4,imageTestLD{1});
    imageTestHD = cat(4,imageTestHD{1});

    % Convert mini-batch of data to dlarray and specify the dimension labels
    % 'SSCB' (spatial, spatial, channel, batch)
    imageTestLD = dlarray(imageTestLD,'SSCB');
    imageTestHD = dlarray(imageTestHD,'SSCB');

    % If running on a GPU, then convert data to gpuArray
    if canUseGPU
        imageTestLD = gpuArray(imageTestLD);
        imageTestHD = gpuArray(imageTestHD);
    end

    % Generate translated images
    generatedImageHD = predict(generatorLowDoseToHighDose,imageTestLD);
    generatedImageLD = predict(generatorHighDoseToLowDose,imageTestHD);

    % Display a few images to visualize the network responses
     if ismember(idx,idxImagesToDisplay)
         figure
         origImLD = rescale(extractdata(imageTestLD));
         genImHD = rescale(extractdata(generatedImageHD));
         montage({origImLD,genImHD},Size=[1 2],BorderSize=5)
         title("Original LDCT Test Image "+idx+" (Left), Generated HDCT Image (Right)")
    end
    
    origPSNR(idx) = psnr(imageTestLD,imageTestHD);
    generatedPSNR(idx) = psnr(generatedImageHD,imageTestHD);
    
    origSSIM(idx) = multissim(imageTestLD,imageTestHD);
    generatedSSIM(idx) = multissim(generatedImageHD,imageTestHD);
end

Вычислите средний PSNR исходных и сгенерированных изображений. Большее значение PSNR указывает на лучшее качество изображения.

disp("Average PSNR of original images: "+mean(origPSNR,"all"));
Average PSNR of original images: 20.4045
disp("Average PSNR of generated images: "+mean(generatedPSNR,"all"));
Average PSNR of generated images: 27.9155

Вычислите средний SSIM исходных и сгенерированных изображений. Значение SSIM ближе к 1 указывает на лучшее качество изображения.

disp("Average SSIM of original images: "+mean(origSSIM,"all"));
Average SSIM of original images: 0.76651
disp("Average SSIM of generated images: "+mean(generatedSSIM,"all"));
Average SSIM of generated images: 0.90194

Вспомогательные Функции

Функция градиентов модели

Функциональный modelGradients берет в качестве входа два генератора и различитель dlnetwork объекты и мини-пакет входных данных. Функция возвращает градиенты потери относительно настраиваемых параметров в сетях и множестве этих четырех сетей. Поскольку различитель выходные параметры не находится в области значений [0, 1], modelGradients функция применяет сигмоидальную функцию, чтобы преобразовать различитель выходные параметры в баллы вероятности.

function [genHD2LDGrad,genLD2HDGrad,discLDGrad,discHDGrad, ...
    genHD2LDState,genLD2HDState,scores,imagesOutLDAndHDGenerated,imagesOutHDAndLDGenerated] = ...
    modelGradients(genLD2HD,genHD2LD,discLD,discHD,imageHD,imageLD,lambda)

% Translate images from one domain to another: low-dose to high-dose and
% vice versa
[imageLDGenerated,genHD2LDState] = forward(genHD2LD,imageHD);
[imageHDGenerated,genLD2HDState] = forward(genLD2HD,imageLD);

% Calculate predictions for real images in each domain by the corresponding
% discriminator networks
predRealLD = forward(discLD,imageLD);
predRealHD = forward(discHD,imageHD);

% Calculate predictions for generated images in each domain by the
% corresponding discriminator networks
predGeneratedLD = forward(discLD,imageLDGenerated);
predGeneratedHD = forward(discHD,imageHDGenerated);

% Calculate discriminator losses for real images
discLDLossReal = lossReal(predRealLD);
discHDLossReal = lossReal(predRealHD);

% Calculate discriminator losses for generated images
discLDLossGenerated = lossGenerated(predGeneratedLD);
discHDLossGenerated = lossGenerated(predGeneratedHD);

% Calculate total discriminator loss for each discriminator network
discLDLossTotal = 0.5*(discLDLossReal + discLDLossGenerated);
discHDLossTotal = 0.5*(discHDLossReal + discHDLossGenerated);

% Calculate generator loss for generated images
genLossHD2LD = lossReal(predGeneratedLD);
genLossLD2HD = lossReal(predGeneratedHD);

% Complete the round-trip (cycle consistency) outputs by applying the
% generator to each generated image to get the images in the corresponding
% original domains
cycleImageLD2HD2LD = forward(genHD2LD,imageHDGenerated);
cycleImageHD2LD2HD = forward(genLD2HD,imageLDGenerated);

% Calculate cycle consistency loss between real and generated images
cycleLossLD2HD2LD = cycleConsistencyLoss(imageLD,cycleImageLD2HD2LD,lambda);
cycleLossHD2LD2HD = cycleConsistencyLoss(imageHD,cycleImageHD2LD2HD,lambda);

% Calculate identity outputs
identityImageLD = forward(genHD2LD,imageLD);
identityImageHD = forward(genLD2HD,imageHD);
 
% Calculate fidelity loss (SSIM) between the identity outputs
fidelityLossLD = mean(1-multissim(identityImageLD,imageLD),"all");
fidelityLossHD = mean(1-multissim(identityImageHD,imageHD),"all");

% Calculate total generator loss
genLossTotal = genLossHD2LD + cycleLossHD2LD2HD + ...
    genLossLD2HD + cycleLossLD2HD2LD + fidelityLossLD + fidelityLossHD;

% Calculate scores of generators
genHD2LDScore = mean(sigmoid(predGeneratedLD),"all");
genLD2HDScore = mean(sigmoid(predGeneratedHD),"all");

% Calculate scores of discriminators
discLDScore = 0.5*mean(sigmoid(predRealLD),"all") + ...
    0.5*mean(1-sigmoid(predGeneratedLD),"all");
discHDScore = 0.5*mean(sigmoid(predRealHD),"all") + ...
    0.5*mean(1-sigmoid(predGeneratedHD),"all");

% Combine scores into cell array
scores = {genHD2LDScore,genLD2HDScore,discLDScore,discHDScore};

% Calculate gradients of generators
genLD2HDGrad = dlgradient(genLossTotal,genLD2HD.Learnables,'RetainData',true);
genHD2LDGrad = dlgradient(genLossTotal,genHD2LD.Learnables,'RetainData',true);

% Calculate gradients of discriminators
discLDGrad = dlgradient(discLDLossTotal,discLD.Learnables,'RetainData',true);
discHDGrad = dlgradient(discHDLossTotal,discHD.Learnables);

% Return mini-batch of images transforming low-dose CT into high-dose CT
imagesOutLDAndHDGenerated = {imageLD,imageHDGenerated};

% Return mini-batch of images transforming high-dose CT into low-dose CT
imagesOutHDAndLDGenerated = {imageHD,imageLDGenerated};
end

Функции потерь

Задайте функции потерь MSE для действительных и сгенерированных изображений.

function loss = lossReal(predictions)
    loss = mean((1-predictions).^2,"all");
end

function loss = lossGenerated(predictions)
    loss = mean((predictions).^2,"all");
end

Задайте функции потерь непротиворечивости цикла для действительных и сгенерированных изображений.

function loss = cycleConsistencyLoss(imageReal,imageGenerated,lambda)
    loss = mean(abs(imageReal-imageGenerated),"all") * lambda;
end

Ссылки

[1] Чжу, июнь-Yan, парк Taesung, Филип Изола и Алексей А. Эфрос. “Непарный Перевод От изображения к изображению Используя Сопоставимые с циклом Соперничающие Сети”. На 2 017 Международных конференциях IEEE по вопросам Компьютерного зрения (ICCV), 2242–51. Венеция: IEEE, 2017. https://doi.org/10.1109/ICCV.2017.244.

[2] МакКалоу, Синтия, Бэйю Чен, Дэвид Р Холмс III, Синьхой Дуань, Жиконг Ю, Лифэн Юй, Шуай Лэн и Джоэл Флетчер. “Низкие Данные об Изображении и Проекции CT Дозы (LDCT и Данные проекции)”. Архив Обработки изображений Рака, 2020. https://doi.org/10.7937/9NPB-2637.

[3] Предоставления EB017095 и EB017185 (Синтия МакКалоу, PI) от национального института биомедицинской обработки изображений и биоинженерии.

[4] Кларк, Кеннет, Брюс Вендт, Кирк Смит, Джон Фреиманн, Джастин Кирби, Пол Коппель, Стивен Мур, и др. “Архив обработки изображений рака (TCIA): Поддержание и Работа Репозиторием Общедоступной информации”. Журнал Цифровой Обработки изображений 26, № 6 (декабрь 2013): 1045–57. https://doi.org/10.1007/s10278-013-9622-7.

[5] Вы, Chenyu, Кингсонг Янг, Хонгминг Шань, Ларс Гджестеби, Гуан Ли, Шенгонг Джу, Цзхойян Чжан, и др. “Структурно чувствительная Многошкальная Глубокая нейронная сеть для Шумоподавления CT Низкой Дозы”. IEEE доступ 6 (2018): 41839–55. https://doi.org/10.1109/ACCESS.2018.2858196.

Смотрите также

(Image Processing Toolbox) | (Image Processing Toolbox) | | | | | |

Связанные примеры

Больше о