exponenta event banner

Оценка эффективности регрессионной нейронной сети

Создание модели нейронной сети прямой регрессии с полностью связанными слоями с помощью fitrnet. Используйте данные проверки для ранней остановки процесса обучения, чтобы предотвратить переоснащение модели. Затем используйте объектные функции модели для оценки ее производительности на тестовых данных.

Загрузить данные образца

Загрузить carbig набор данных, содержащий замеры автомобилей, сделанные в 1970-х и начале 1980-х годов.

load carbig

Преобразовать Origin переменной к категориальной переменной. Затем создайте таблицу, содержащую переменные предиктора Acceleration, Displacementи так далее, а также переменная ответа MPG. Каждая строка содержит измерения для одного автомобиля.

Origin = categorical(cellstr(Origin));
Tbl = table(Acceleration,Displacement,Horsepower, ...
    Model_Year,Origin,Weight,MPG);

Данные раздела

Разбейте данные на обучающие, проверочные и тестовые наборы. Во-первых, зарезервировать примерно одну треть наблюдений для тестового набора. Затем разделите оставшиеся данные пополам, чтобы создать наборы обучения и проверки.

rng("default") % For reproducibility of the data partitions
cvp1 = cvpartition(size(Tbl,1),"Holdout",1/3);
testTbl = Tbl(test(cvp1),:);
remainingTbl = Tbl(training(cvp1),:);

cvp2 = cvpartition(size(remainingTbl,1),"Holdout",1/2);
validationTbl = remainingTbl(test(cvp2),:);
trainTbl = remainingTbl(training(cvp2),:);

Нейронная сеть поезда

Обучение модели регрессионной нейронной сети с использованием обучающего набора. Укажите MPG столбец tblTrain в качестве переменной ответа и стандартизировать числовые предикторы. Оцените модель в каждой итерации с помощью набора проверки. Укажите отображение информации об обучении в каждой итерации с помощью Verbose аргумент «имя-значение». По умолчанию учебный процесс завершается раньше, если потеря проверки больше или равна минимальной потере проверки, рассчитанной на данный момент, шесть раз подряд. Чтобы изменить количество раз, когда допустимо, чтобы потеря проверки была больше или равна минимуму, укажите ValidationPatience аргумент «имя-значение».

Mdl = fitrnet(trainTbl,"MPG","Standardize",true, ...
    "ValidationData",validationTbl, ...
    "Verbose",1);
|==========================================================================================|
| Iteration  | Train Loss | Gradient   | Step       | Iteration  | Validation | Validation |
|            |            |            |            | Time (sec) | Loss       | Checks     |
|==========================================================================================|
|           1|   71.063537|   22.623354|    6.466959|    0.001272|   72.648960|           0|
|           2|   48.608700|   22.384995|    1.022929|    0.001808|   43.435698|           0|
|           3|   30.584887|   13.433471|    0.537190|    0.000903|   29.134447|           0|
|           4|   17.781636|   11.159801|    1.401355|    0.000461|   16.542207|           0|
|           5|   13.075804|    4.605991|    0.419875|    0.000387|   12.946670|           0|
|           6|   11.697936|    3.197944|    0.226945|    0.000543|   12.025502|           0|
|           7|    9.494801|    2.269831|    0.751711|    0.000452|   12.596499|           1|
|           8|    8.390979|    1.970589|    0.337301|    0.000398|   11.490990|           0|
|           9|    6.853097|    1.029078|    0.866974|    0.000378|    9.449945|           0|
|          10|    6.531678|    0.924820|    0.306913|    0.000429|    9.350721|           0|
|==========================================================================================|
| Iteration  | Train Loss | Gradient   | Step       | Iteration  | Validation | Validation |
|            |            |            |            | Time (sec) | Loss       | Checks     |
|==========================================================================================|
|          11|    6.152995|    1.872684|    0.457744|    0.000403|    9.223829|           0|
|          12|    5.924852|    0.718386|    0.447879|    0.000402|    9.656166|           1|
|          13|    5.792836|    0.500170|    0.216351|    0.000387|    9.733226|           2|
|          14|    5.613473|    1.151197|    0.316828|    0.000531|    9.788646|           3|
|          15|    5.415889|    1.513493|    0.327937|    0.000485|    9.607953|           4|
|          16|    5.008195|    1.398069|    1.085660|    0.000430|    9.251971|           5|
|          17|    5.004176|    2.070041|    0.890201|    0.000383|    8.719334|           0|
|          18|    4.738386|    0.483667|    0.338897|    0.000374|    8.523728|           0|
|          19|    4.680213|    0.437918|    0.107667|    0.000371|    8.369271|           0|
|          20|    4.587350|    0.510639|    0.146276|    0.000385|    8.100236|           0|
|==========================================================================================|
| Iteration  | Train Loss | Gradient   | Step       | Iteration  | Validation | Validation |
|            |            |            |            | Time (sec) | Loss       | Checks     |
|==========================================================================================|
|          21|    4.479929|    0.565635|    0.228198|    0.000381|    8.062927|           0|
|          22|    4.380618|    0.892717|    0.377776|    0.000554|    7.843234|           0|
|          23|    4.189344|    0.403227|    0.362307|    0.000434|    7.834582|           0|
|          24|    4.182775|    1.150234|    1.908768|    0.000408|    9.436226|           1|
|          25|    3.985939|    0.908479|    0.518217|    0.000570|    8.973756|           2|
|          26|    3.873835|    0.826655|    0.477740|    0.000505|    8.863599|           3|
|          27|    3.830830|    0.331936|    0.220000|    0.000539|    8.574682|           4|
|          28|    3.796605|    0.232756|    0.075643|    0.000492|    8.591758|           5|
|          29|    3.706326|    0.470116|    0.249292|    0.000396|    8.517317|           6|
|==========================================================================================|

Используйте информацию внутри TrainingHistory свойство объекта Mdl для проверки итерации, которая соответствует минимальному среднеквадратичному значению ошибки (MSE). Последняя возвращенная модель Mdl является моделью, обученной этой итерации.

iteration = Mdl.TrainingHistory.Iteration;
valLosses = Mdl.TrainingHistory.ValidationLoss;
[~,minIdx] = min(valLosses);
iteration(minIdx)
ans = 23

Оценка производительности тестового набора

Оценка эффективности обучаемой модели Mdl на контрольном аппарате testTbl с помощью loss и predict функции объекта.

Вычислите среднеквадратичную ошибку (MSE) тестового набора. Меньшие значения MSE указывают на лучшую производительность.

mse = loss(Mdl,testTbl,"MPG")
mse = 25.4145

Сравните прогнозируемые значения отклика тестового набора с истинными значениями отклика. Постройте график прогнозируемых миль на галлон (MPG) вдоль вертикальной оси и истинного MPG вдоль горизонтальной оси. Точки на опорной линии указывают правильные прогнозы. Хорошая модель производит прогнозы, которые разбросаны вблизи линии.

predictedY = predict(Mdl,testTbl);

plot(testTbl.MPG,predictedY,".")
hold on
plot(testTbl.MPG,testTbl.MPG)
hold off
xlabel("True Miles Per Gallon (MPG)")
ylabel("Predicted Miles Per Gallon (MPG)")

Покадровые графики используются для сравнения распределения прогнозируемых и истинных значений MPG по странам происхождения. Создайте прямоугольные графики с помощью boxchart функция. Каждый прямоугольный график отображает медиану, нижний и верхний квартили, любые отклонения (вычисленные с использованием межквартильного диапазона), а также минимальное и максимальное значения, которые не являются отклонениями. В частности, линия внутри каждого бокса представляет собой медиану образца, а круглые маркеры указывают на отклонения.

Для каждой страны происхождения сравните график красного ящика (показывающий распределение прогнозируемых значений MPG) с графиком синего ящика (показывающим распределение истинных значений MPG). Аналогичные распределения для прогнозируемых и истинных значений MPG указывают на хорошие прогнозы.

boxchart(testTbl.Origin,testTbl.MPG)
hold on
boxchart(testTbl.Origin,predictedY)
hold off
legend(["True MPG","Predicted MPG"])
xlabel("Country of Origin")
ylabel("Miles Per Gallon (MPG)")

Для большинства стран прогнозируемые и истинные значения MPG имеют сходные распределения. Однако модель нейронной сети имеет тенденцию недооценивать значения MPG для автомобилей, изготовленных во Франции. Это несоответствие, возможно, связано с небольшим количеством французских автомобилей в учебно-испытательных комплектах.

Сравните диапазон значений MPG для французских автомобилей в учебных и тестовых наборах.

trainSummary = grpstats(trainTbl(:,["MPG","Origin"]),"Origin", ...
    ["min","max"])
trainSummary=6×4 table
               Origin     GroupCount    min_MPG    max_MPG
               _______    __________    _______    _______

    France     France          3         16.2         27  
    Germany    Germany        11           20       44.3  
    Italy      Italy           1         37.3       37.3  
    Japan      Japan          24           20       40.8  
    Sweden     Sweden          3           19       21.6  
    USA        USA            94            9         39  

testSummary = grpstats(testTbl(:,["MPG","Origin"]),"Origin", ...
    ["min","max"])
testSummary=6×4 table
               Origin     GroupCount    min_MPG    max_MPG
               _______    __________    _______    _______

    France     France          3         28.1       40.9  
    Germany    Germany        12         21.5         44  
    Italy      Italy           3           28         30  
    Japan      Japan          32           18       46.6  
    Sweden     Sweden          3           17         24  
    USA        USA            82           10       36.1  

В обучающем наборе значения MPG для автомобилей, изготовленных во Франции, варьируются от 16,2 до 27. Однако в испытательном наборе значения MPG для автомобилей, изготовленных во Франции, варьируются от 28,1 до 40,9.

Постройте график остатков тестового набора. Хорошая модель обычно имеет остатки, разбросанные приблизительно симметрично вокруг 0. Четкие узоры в остатках являются признаком того, что модель можно улучшить.

residuals = testTbl.MPG - predictedY;
plot(testTbl.MPG,residuals,".")
hold on
yline(0)
hold off
xlabel("True Miles Per Gallon (MPG)")
ylabel("MPG Residuals")

Сюжет предполагает, что некоторые остаточные значения являются отклонениями. Узнайте больше о наблюдении с наибольшим остатком.

[outlierResidual,outlierIdx] = max(residuals)
outlierResidual = 37.8727
outlierIdx = 113
testTbl(outlierIdx,:)
ans=1×7 table
    Acceleration    Displacement    Horsepower    Model_Year    Origin    Weight    MPG 
    ____________    ____________    __________    __________    ______    ______    ____

        17.3             85            NaN            80        France     1835     40.9

Наблюдение соответствует автомобилю, чей Horsepower значение отсутствует, и страной происхождения является Франция, категория с небольшим количеством наблюдений.

См. также

| | | |