predict

Спрогнозируйте ответы, используя регрессионую нейронную сеть

    Описание

    пример

    yfit = predict(Mdl,X) возвращает предсказанные значения отклика для данных предиктора в таблице или матрице X использование модели обученной регрессионной нейронной сети Mdl.

    yfit возвращается как числовой вектор, i-я запись которого соответствует i-му наблюдению в X.

    yfit = predict(Mdl,X,'ObservationsIn',dimension) задает размерность наблюдения данных предиктора, либо 'rows' (по умолчанию) или 'columns'. Для примера задайте 'ObservationsIn','columns' чтобы указать, что столбцы в данных предиктора соответствуют наблюдениям.

    Примеры

    свернуть все

    Предсказать значения отклика тестового набора при помощи обученной модели регрессионной нейронной сети.

    Загрузите patients набор данных. Составьте таблицу из набора данных. Каждая строка соответствует одному пациенту, и каждый столбец соответствует диагностической переменной. Используйте Systolic переменная как переменная отклика, а остальная часть переменных как предикторы.

    load patients
    tbl = table(Age,Diastolic,Gender,Height,Smoker,Weight,Systolic);

    Разделите данные на набор обучающих данных tblTrain и тестовый набор tblTest при помощи нертифицированного разбиения с ограничением. Программа резервирует приблизительно 30% наблюдений для тестовых данных набора и использует остальную часть наблюдений для обучающего набора данных.

    rng("default") % For reproducibility of the partition
    c = cvpartition(size(tbl,1),"Holdout",0.30);
    trainingIndices = training(c);
    testIndices = test(c);
    tblTrain = tbl(trainingIndices,:);
    tblTest = tbl(testIndices,:);

    Обучите регрессионную модель нейронной сети с помощью набора обучающих данных. Задайте Systolic столбец tblTrain как переменная отклика. Задайте, чтобы стандартизировать числовые предикторы. По умолчанию модель нейронной сети имеет один полностью соединенный слой с 10 выходами, исключая конечный полностью соединенный слой.

    Mdl = fitrnet(tblTrain,"Systolic", ...
        "Standardize",true);

    Спрогнозируйте шишкообразные уровни артериального давления для пациентов в тестовом наборе.

    predictedY = predict(Mdl,tblTest);

    Визуализируйте результаты с помощью графика поля точек с ссылкой линией. Постройте график предсказанных значений вдоль вертикальной оси и истинных значений отклика вдоль горизонтальной оси. Точки на опорной линии указывают на правильные предсказания.

    plot(tblTest.Systolic,predictedY,".")
    hold on
    plot(tblTest.Systolic,tblTest.Systolic)
    hold off
    xlabel("True Systolic Blood Pressure Levels")
    ylabel("Predicted Systolic Blood Pressure Levels")

    Поскольку многие точки находятся далеко от ссылки, модель нейронной сети по умолчанию с полносвязным слоем размера 10, по-видимому, не является большим предиктором сестрических уровней артериального давления.

    Выполните выбор признаков путем сравнения потерь и предсказаний тестового набора. Сравните метрики тестового набора для регрессионной модели нейронной сети, обученной с использованием всех предикторов, с метриками тестового набора для модели, обученной с использованием только подмножества предикторов.

    Загрузите образец файла fisheriris.csv, который содержит данные по радужке, включая длину чашелистика, ширину чашелистика, длину лепестка, ширину лепестка и видовой тип. Считайте файл в таблицу.

    fishertable = readtable('fisheriris.csv');

    Разделите данные на набор обучающих данных trainTbl и тестовый набор testTbl при помощи нертифицированного разбиения с ограничением. Программа резервирует приблизительно 30% наблюдений для тестовых данных набора и использует остальную часть наблюдений для обучающего набора данных.

    rng("default")
    c = cvpartition(size(fishertable,1),"Holdout",0.3);
    trainTbl = fishertable(training(c),:);
    testTbl = fishertable(test(c),:);

    Обучите одну регрессионную модель нейронной сети, используя все предикторы в наборе обучающих данных, и обучите другой классификатор, используя все предикторы, кроме PetalWidth. Для обеих моделей задайте PetalLength как переменная отклика и стандартизируйте предикторы.

    allMdl = fitrnet(trainTbl,"PetalLength","Standardize",true);
    subsetMdl = fitrnet(trainTbl,"PetalLength ~ SepalLength + SepalWidth + Species", ...
        "Standardize",true);

    Сравните среднюю квадратичную невязку (MSE) тестового набора двух моделей. Меньшие значения MSE указывают на лучшую эффективность.

    allMSE = loss(allMdl,testTbl)
    allMSE = 0.0834
    
    subsetMSE = loss(subsetMdl,testTbl)
    subsetMSE = 0.0887
    

    Для каждой модели сравните предсказанные длины лепестков тестового набора с истинными длинами лепестков. Постройте график прогнозируемых длин лепестков вдоль вертикальной оси и истинных длин лепестков вдоль горизонтальной оси. Точки на опорной линии указывают на правильные предсказания.

    tiledlayout(2,1)
    
    % Top axes
    ax1 = nexttile;
    allPredictedY = predict(allMdl,testTbl);
    plot(ax1,testTbl.PetalLength,allPredictedY,".")
    hold on
    plot(ax1,testTbl.PetalLength,testTbl.PetalLength)
    hold off
    xlabel(ax1,"True Petal Length")
    ylabel(ax1,"Predicted Petal Length")
    title(ax1,"All Predictors")
    
    % Bottom axes
    ax2 = nexttile;
    subsetPredictedY = predict(subsetMdl,testTbl);
    plot(ax2,testTbl.PetalLength,subsetPredictedY,".")
    hold on
    plot(ax2,testTbl.PetalLength,testTbl.PetalLength)
    hold off
    xlabel(ax2,"True Petal Length")
    ylabel(ax2,"Predicted Petal Length")
    title(ax2,"Subset of Predictors")

    Поскольку обе модели, по-видимому, работают хорошо, с предсказаниями, разбросанными около ссылки линии, рассмотрите использование модели, обученной с использованием всех предикторов, кроме PetalWidth.

    Посмотрите, как слои регрессионной модели нейронной сети работают вместе, чтобы предсказать значение отклика для одного наблюдения.

    Загрузите образец файла fisheriris.csv, который содержит данные по радужке, включая длину чашелистика, ширину чашелистика, длину лепестка, ширину лепестка и видовой тип. Прочтите файл в таблицу и отобразите первые несколько строк таблицы.

    fishertable = readtable('fisheriris.csv');
    head(fishertable)
    ans=8×5 table
        SepalLength    SepalWidth    PetalLength    PetalWidth     Species  
        ___________    __________    ___________    __________    __________
    
            5.1           3.5            1.4           0.2        {'setosa'}
            4.9             3            1.4           0.2        {'setosa'}
            4.7           3.2            1.3           0.2        {'setosa'}
            4.6           3.1            1.5           0.2        {'setosa'}
              5           3.6            1.4           0.2        {'setosa'}
            5.4           3.9            1.7           0.4        {'setosa'}
            4.6           3.4            1.4           0.3        {'setosa'}
              5           3.4            1.5           0.2        {'setosa'}
    
    

    Обучите регрессионную модель нейронной сети с помощью набора данных. Задайте PetalLength переменная в качестве отклика и используйте другие числовые переменные в качестве предикторов.

    Mdl = fitrnet(fishertable,"PetalLength ~ SepalLength + SepalWidth + PetalWidth");

    Выберите пятнадцатое наблюдение из набора данных. Посмотрите, как слои нейронной сети берут наблюдение и возвращают предсказанное значение отклика newPointResponse.

    newPoint = Mdl.X{15,:}
    newPoint = 1×3
    
        5.8000    4.0000    0.2000
    
    
    firstFCStep = (Mdl.LayerWeights{1})*newPoint' + Mdl.LayerBiases{1};
    reluStep = max(firstFCStep,0);
    
    finalFCStep = (Mdl.LayerWeights{end})*reluStep + Mdl.LayerBiases{end};
    
    newPointResponse = finalFCStep
    newPointResponse = 1.6716
    

    Проверяйте, что предсказание соответствует предсказанию, возвращаемому predict функция объекта.

    predictedY = predict(Mdl,newPoint)
    predictedY = 1.6716
    
    isequal(newPointResponse,predictedY)
    ans = logical
       1
    
    

    Эти два результата совпадают.

    Входные параметры

    свернуть все

    Обученная регрессионная нейронная сеть, заданная как RegressionNeuralNetworkобъект модели объект модели, возвращенный fitrnet или compact, соответственно.

    Данные предиктора, используемые для генерации откликов, заданные как числовая матрица или таблица.

    По умолчанию каждая строка X соответствует одному наблюдению, и каждый столбец соответствует одной переменной.

    • Для числовой матрицы:

      • Переменные в столбцах X должен иметь тот же порядок, что и переменные предиктора, которые обучали Mdl.

      • Если вы обучаете Mdl использование таблицы (для примера, Tbl) и Tbl содержит только числовые переменные предиктора, тогда X может быть числовой матрицей. Для лечения числовых предикторов в Tbl как категориальный во время обучения, идентифицируйте категориальные предикторы с помощью CategoricalPredictors аргумент имя-значение fitrnet. Если Tbl содержит неоднородные переменные предиктора (для примера, числовых и категориальных типов данных) и X является числовой матрицей, тогда predict выдает ошибку.

    • Для таблицы:

      • predict не поддерживает многополюсные переменные или массивы ячеек, отличные от массивов ячеек векторов символов.

      • Если вы обучаете Mdl использование таблицы (для примера, Tbl), затем все переменные предиктора в X должны иметь те же имена переменных и типы данных, что и обученные переменные Mdl (хранится в Mdl.PredictorNames). Однако порядок столбцов X не должен соответствовать порядку столбцов Tbl. Кроме того, Tbl и X может содержать дополнительные переменные (переменные отклика, веса наблюдений и так далее), но predict игнорирует их.

      • Если вы обучаете Mdl используя числовую матрицу, затем имена предикторов в Mdl.PredictorNames должен быть таким же, как и соответствующий предиктор, имена переменных в X. Чтобы задать имена предикторов во время обучения, используйте PredictorNames аргумент имя-значение fitrnet. Все переменные предиктора в X должны быть числовыми векторами. X может содержать дополнительные переменные (переменные отклика, веса наблюдений и так далее), но predict игнорирует их.

    Если вы задаете 'Standardize',true в fitrnet при обучении Mdlзатем программное обеспечение стандартизирует числовые столбцы данных предиктора с помощью соответствующих средств и стандартных отклонений.

    Примечание

    Если вы ориентируете матрицу предиктора так, чтобы наблюдения соответствовали столбцам и задавали 'ObservationsIn','columns', тогда вы можете испытать значительное сокращение времени расчета. Вы не можете задать 'ObservationsIn','columns' для данных предиктора в таблице.

    Типы данных: single | double | table

    Размерность наблюдения данных предиктора, заданная как 'rows' или 'columns'.

    Примечание

    Если вы ориентируете матрицу предиктора так, чтобы наблюдения соответствовали столбцам и задавали 'ObservationsIn','columns', тогда вы можете испытать значительное сокращение времени расчета. Вы не можете задать 'ObservationsIn','columns' для данных предиктора в таблице.

    Типы данных: char | string

    Введенный в R2021a