exponenta event banner

предсказать

Прогнозирование ответов с использованием регрессионной нейронной сети

    Описание

    пример

    yfit = predict(Mdl,X) возвращает прогнозируемые значения отклика для данных предиктора в таблице или матрице X использование модели обученной регрессионной нейронной сети Mdl.

    yfit возвращается в виде числового вектора, i-я запись которого соответствует i-му наблюдению в X.

    yfit = predict(Mdl,X,'ObservationsIn',dimension) задает измерение наблюдения данных предиктора, либо 'rows' (по умолчанию) или 'columns'. Например, укажите 'ObservationsIn','columns' чтобы указать, что столбцы в данных предиктора соответствуют наблюдениям.

    Примеры

    свернуть все

    Прогнозировать значения ответа тестового набора с использованием модели обученной регрессионной нейронной сети.

    Загрузить patients набор данных. Создайте таблицу из набора данных. Каждая строка соответствует одному пациенту, и каждый столбец соответствует диагностической переменной. Используйте Systolic переменная в качестве ответной переменной, а остальные переменные в качестве предикторов.

    load patients
    tbl = table(Age,Diastolic,Gender,Height,Smoker,Weight,Systolic);

    Разделение данных на набор обучения tblTrain и набор тестов tblTest с помощью нестратифицированного раздела удержания. Программное обеспечение резервирует приблизительно 30% наблюдений для набора тестовых данных и использует остальные наблюдения для набора обучающих данных.

    rng("default") % For reproducibility of the partition
    c = cvpartition(size(tbl,1),"Holdout",0.30);
    trainingIndices = training(c);
    testIndices = test(c);
    tblTrain = tbl(trainingIndices,:);
    tblTest = tbl(testIndices,:);

    Обучение модели регрессионной нейронной сети с использованием обучающего набора. Укажите Systolic столбец tblTrain в качестве переменной ответа. Укажите, чтобы стандартизировать числовые предикторы. По умолчанию модель нейронной сети имеет один полностью подключенный уровень с 10 выходами, исключая конечный полностью подключенный уровень.

    Mdl = fitrnet(tblTrain,"Systolic", ...
        "Standardize",true);

    Предсказать уровни систолического артериального давления для пациентов в наборе тестов.

    predictedY = predict(Mdl,tblTest);

    Визуализация результатов с помощью графика рассеяния с опорной линией. Постройте график прогнозируемых значений вдоль вертикальной оси и истинных значений отклика вдоль горизонтальной оси. Точки на опорной линии указывают правильные прогнозы.

    plot(tblTest.Systolic,predictedY,".")
    hold on
    plot(tblTest.Systolic,tblTest.Systolic)
    hold off
    xlabel("True Systolic Blood Pressure Levels")
    ylabel("Predicted Systolic Blood Pressure Levels")

    Поскольку многие из точек находятся далеко от опорной линии, модель нейронной сети по умолчанию с полностью соединенным слоем размера 10, по-видимому, не является большим предиктором уровней систолического артериального давления.

    Выбор функций выполняется путем сравнения потерь и прогнозов тестового набора. Сравните метрики тестового набора для модели регрессионной нейронной сети, обученной с использованием всех предикторов, с метриками тестового набора для модели, обученной с использованием только подмножества предикторов.

    Загрузить образец файла fisheriris.csv, которая содержит данные по радужке, включая длину чашелистика, ширину чашелистика, длину лепестка, ширину лепестка и видовой тип. Прочтите файл в таблицу.

    fishertable = readtable('fisheriris.csv');

    Разделение данных на набор обучения trainTbl и набор тестов testTbl с помощью нестратифицированного раздела удержания. Программное обеспечение резервирует приблизительно 30% наблюдений для набора тестовых данных и использует остальные наблюдения для набора обучающих данных.

    rng("default")
    c = cvpartition(size(fishertable,1),"Holdout",0.3);
    trainTbl = fishertable(training(c),:);
    testTbl = fishertable(test(c),:);

    Обучить одну модель регрессионной нейронной сети, используя все предикторы в обучающем наборе, и обучить другой классификатор, используя все предикторы, кроме PetalWidth. Для обеих моделей укажите PetalLength в качестве переменной ответа и стандартизировать предикторы.

    allMdl = fitrnet(trainTbl,"PetalLength","Standardize",true);
    subsetMdl = fitrnet(trainTbl,"PetalLength ~ SepalLength + SepalWidth + Species", ...
        "Standardize",true);

    Сравните значения среднеквадратичной ошибки (MSE) для двух моделей. Меньшие значения MSE указывают на лучшую производительность.

    allMSE = loss(allMdl,testTbl)
    allMSE = 0.0834
    
    subsetMSE = loss(subsetMdl,testTbl)
    subsetMSE = 0.0887
    

    Для каждой модели сравните прогнозируемые длины лепестков тестового набора с истинными длинами лепестков. Постройте график прогнозируемых длин лепестков вдоль вертикальной оси и истинных длин лепестков вдоль горизонтальной оси. Точки на опорной линии указывают правильные прогнозы.

    tiledlayout(2,1)
    
    % Top axes
    ax1 = nexttile;
    allPredictedY = predict(allMdl,testTbl);
    plot(ax1,testTbl.PetalLength,allPredictedY,".")
    hold on
    plot(ax1,testTbl.PetalLength,testTbl.PetalLength)
    hold off
    xlabel(ax1,"True Petal Length")
    ylabel(ax1,"Predicted Petal Length")
    title(ax1,"All Predictors")
    
    % Bottom axes
    ax2 = nexttile;
    subsetPredictedY = predict(subsetMdl,testTbl);
    plot(ax2,testTbl.PetalLength,subsetPredictedY,".")
    hold on
    plot(ax2,testTbl.PetalLength,testTbl.PetalLength)
    hold off
    xlabel(ax2,"True Petal Length")
    ylabel(ax2,"Predicted Petal Length")
    title(ax2,"Subset of Predictors")

    Поскольку обе модели, по-видимому, работают хорошо, с предсказаниями, разбросанными вблизи опорной линии, рассмотрите возможность использования модели, обученной с использованием всех предикторов, кроме PetalWidth.

    Посмотрите, как слои регрессионной модели нейронной сети работают вместе, чтобы предсказать значение отклика для одного наблюдения.

    Загрузить образец файла fisheriris.csv, которая содержит данные по радужке, включая длину чашелистика, ширину чашелистика, длину лепестка, ширину лепестка и видовой тип. Прочтите файл в таблицу и отобразите первые несколько строк таблицы.

    fishertable = readtable('fisheriris.csv');
    head(fishertable)
    ans=8×5 table
        SepalLength    SepalWidth    PetalLength    PetalWidth     Species  
        ___________    __________    ___________    __________    __________
    
            5.1           3.5            1.4           0.2        {'setosa'}
            4.9             3            1.4           0.2        {'setosa'}
            4.7           3.2            1.3           0.2        {'setosa'}
            4.6           3.1            1.5           0.2        {'setosa'}
              5           3.6            1.4           0.2        {'setosa'}
            5.4           3.9            1.7           0.4        {'setosa'}
            4.6           3.4            1.4           0.3        {'setosa'}
              5           3.4            1.5           0.2        {'setosa'}
    
    

    Обучение регрессионной модели нейронной сети с использованием набора данных. Укажите PetalLength переменная в качестве ответа и использовать другие числовые переменные в качестве предикторов.

    Mdl = fitrnet(fishertable,"PetalLength ~ SepalLength + SepalWidth + PetalWidth");

    Выберите пятнадцатое наблюдение из набора данных. Посмотрите, как слои нейронной сети принимают наблюдение и возвращают прогнозируемое значение отклика newPointResponse.

    newPoint = Mdl.X{15,:}
    newPoint = 1×3
    
        5.8000    4.0000    0.2000
    
    
    firstFCStep = (Mdl.LayerWeights{1})*newPoint' + Mdl.LayerBiases{1};
    reluStep = max(firstFCStep,0);
    
    finalFCStep = (Mdl.LayerWeights{end})*reluStep + Mdl.LayerBiases{end};
    
    newPointResponse = finalFCStep
    newPointResponse = 1.6716
    

    Убедитесь, что прогноз соответствует прогнозу, возвращенному predict объектная функция.

    predictedY = predict(Mdl,newPoint)
    predictedY = 1.6716
    
    isequal(newPointResponse,predictedY)
    ans = logical
       1
    
    

    Два результата совпадают.

    Входные аргументы

    свернуть все

    Обученная регрессионная нейронная сеть, указанная как RegressionNeuralNetwork объект модели или CompactRegressionNeuralNetwork объект модели, возвращенный fitrnet или compactсоответственно.

    Данные предиктора, используемые для генерации ответов, заданные как числовая матрица или таблица.

    По умолчанию каждая строка X соответствует одному наблюдению, и каждый столбец соответствует одной переменной.

    • Для числовой матрицы:

      • Переменные в столбцах X должен иметь тот же порядок, что и обучаемые переменные предиктора Mdl.

      • Если вы тренируетесь Mdl использование таблицы (например, Tbl) и Tbl содержит только числовые переменные предиктора, затем X может быть числовой матрицей. Чтобы обработать числовые предикторы в Tbl в качестве категориального во время обучения, определить категориальные предикторы, используя CategoricalPredictors аргумент «имя-значение» fitrnet. Если Tbl содержит разнородные переменные предиктора (например, числовые и категориальные типы данных) и X является числовой матрицей, то predict выдает ошибку.

    • Для таблицы:

      • predict не поддерживает многозначные переменные или массивы ячеек, отличные от массивов ячеек символьных векторов.

      • Если вы тренируетесь Mdl использование таблицы (например, Tbl), затем все переменные предиктора в X должны иметь те же имена переменных и типы данных, что и обучаемые переменные Mdl (хранится в Mdl.PredictorNames). Однако порядок столбцов X не обязательно соответствовать порядку столбцов Tbl. Также, Tbl и X может содержать дополнительные переменные (переменные ответа, веса наблюдения и т.д.), но predict игнорирует их.

      • Если вы тренируетесь Mdl используя числовую матрицу, затем имена предикторов в Mdl.PredictorNames должны совпадать с именами соответствующих переменных предиктора в X. Чтобы указать имена предикторов во время обучения, используйте PredictorNames аргумент «имя-значение» fitrnet. Все переменные предиктора в X должны быть числовыми векторами. X может содержать дополнительные переменные (переменные ответа, веса наблюдения и т.д.), но predict игнорирует их.

    Если установить 'Standardize',true в fitrnet при обучении Mdlзатем программное обеспечение стандартизирует числовые столбцы данных предиктора с использованием соответствующих средств и стандартных отклонений.

    Примечание

    Если вы ориентируете матрицу предиктора так, чтобы наблюдения соответствовали столбцам, и укажите 'ObservationsIn','columns', то вы можете испытать значительное сокращение времени вычислений. Невозможно указать 'ObservationsIn','columns' для данных предиктора в таблице.

    Типы данных: single | double | table

    Измерение наблюдения данных предиктора, указанное как 'rows' или 'columns'.

    Примечание

    Если вы ориентируете матрицу предиктора так, чтобы наблюдения соответствовали столбцам, и укажите 'ObservationsIn','columns', то вы можете испытать значительное сокращение времени вычислений. Невозможно указать 'ObservationsIn','columns' для данных предиктора в таблице.

    Типы данных: char | string

    Представлен в R2021a