Оценка эффективности регрессионной нейронной сети

Создайте модель регрессионной нейронной сети с прямой связью с полносвязными слоями с помощью fitrnet. Используйте данные валидации для раннего остановки процесса обучения, чтобы предотвратить сверхподбор кривой модели. Затем используйте функции объекта модели, чтобы оценить ее эффективность по тестовым данным.

Загрузка выборочных данных

Загрузите carbig набор данных, содержащий измерения автомобилей 1970-х и начала 1980-х годов.

load carbig

Преобразуйте Origin переменная - категориальная переменная. Затем создайте таблицу, содержащую переменные предиктора Acceleration, Displacement, и так далее, а также переменная отклика MPG. Каждая строка содержит измерения для одного автомобиля.

Origin = categorical(cellstr(Origin));
Tbl = table(Acceleration,Displacement,Horsepower, ...
    Model_Year,Origin,Weight,MPG);

Данные о разделах

Разделите данные на наборы для обучения, валидации и тестирования. Во-первых, резервируйте приблизительно одну треть наблюдений для тестового набора. Затем разделите оставшиеся данные пополам, чтобы создать наборы обучения и валидации.

rng("default") % For reproducibility of the data partitions
cvp1 = cvpartition(size(Tbl,1),"Holdout",1/3);
testTbl = Tbl(test(cvp1),:);
remainingTbl = Tbl(training(cvp1),:);

cvp2 = cvpartition(size(remainingTbl,1),"Holdout",1/2);
validationTbl = remainingTbl(test(cvp2),:);
trainTbl = remainingTbl(training(cvp2),:);

Обучите нейронную сеть

Обучите регрессионную модель нейронной сети при помощи набора обучающих данных. Задайте MPG столбец tblTrain как переменная отклика и стандартизируйте числовые предикторы. Оцените модель при каждой итерации с помощью набора валидации. Задайте, чтобы отобразить информацию о обучении при каждой итерации с помощью Verbose аргумент имя-значение. По умолчанию процесс обучения заканчивается раньше, если потеря валидации больше или равна минимальной потере валидации, вычисленной до сих пор, шесть раз подряд. Чтобы изменить количество раз, в течение которого потеря валидации может быть больше или равной минимуму, задайте ValidationPatience аргумент имя-значение.

Mdl = fitrnet(trainTbl,"MPG","Standardize",true, ...
    "ValidationData",validationTbl, ...
    "Verbose",1);
|==========================================================================================|
| Iteration  | Train Loss | Gradient   | Step       | Iteration  | Validation | Validation |
|            |            |            |            | Time (sec) | Loss       | Checks     |
|==========================================================================================|
|           1|   71.063537|   22.623354|    6.466959|    0.001272|   72.648960|           0|
|           2|   48.608700|   22.384995|    1.022929|    0.001808|   43.435698|           0|
|           3|   30.584887|   13.433471|    0.537190|    0.000903|   29.134447|           0|
|           4|   17.781636|   11.159801|    1.401355|    0.000461|   16.542207|           0|
|           5|   13.075804|    4.605991|    0.419875|    0.000387|   12.946670|           0|
|           6|   11.697936|    3.197944|    0.226945|    0.000543|   12.025502|           0|
|           7|    9.494801|    2.269831|    0.751711|    0.000452|   12.596499|           1|
|           8|    8.390979|    1.970589|    0.337301|    0.000398|   11.490990|           0|
|           9|    6.853097|    1.029078|    0.866974|    0.000378|    9.449945|           0|
|          10|    6.531678|    0.924820|    0.306913|    0.000429|    9.350721|           0|
|==========================================================================================|
| Iteration  | Train Loss | Gradient   | Step       | Iteration  | Validation | Validation |
|            |            |            |            | Time (sec) | Loss       | Checks     |
|==========================================================================================|
|          11|    6.152995|    1.872684|    0.457744|    0.000403|    9.223829|           0|
|          12|    5.924852|    0.718386|    0.447879|    0.000402|    9.656166|           1|
|          13|    5.792836|    0.500170|    0.216351|    0.000387|    9.733226|           2|
|          14|    5.613473|    1.151197|    0.316828|    0.000531|    9.788646|           3|
|          15|    5.415889|    1.513493|    0.327937|    0.000485|    9.607953|           4|
|          16|    5.008195|    1.398069|    1.085660|    0.000430|    9.251971|           5|
|          17|    5.004176|    2.070041|    0.890201|    0.000383|    8.719334|           0|
|          18|    4.738386|    0.483667|    0.338897|    0.000374|    8.523728|           0|
|          19|    4.680213|    0.437918|    0.107667|    0.000371|    8.369271|           0|
|          20|    4.587350|    0.510639|    0.146276|    0.000385|    8.100236|           0|
|==========================================================================================|
| Iteration  | Train Loss | Gradient   | Step       | Iteration  | Validation | Validation |
|            |            |            |            | Time (sec) | Loss       | Checks     |
|==========================================================================================|
|          21|    4.479929|    0.565635|    0.228198|    0.000381|    8.062927|           0|
|          22|    4.380618|    0.892717|    0.377776|    0.000554|    7.843234|           0|
|          23|    4.189344|    0.403227|    0.362307|    0.000434|    7.834582|           0|
|          24|    4.182775|    1.150234|    1.908768|    0.000408|    9.436226|           1|
|          25|    3.985939|    0.908479|    0.518217|    0.000570|    8.973756|           2|
|          26|    3.873835|    0.826655|    0.477740|    0.000505|    8.863599|           3|
|          27|    3.830830|    0.331936|    0.220000|    0.000539|    8.574682|           4|
|          28|    3.796605|    0.232756|    0.075643|    0.000492|    8.591758|           5|
|          29|    3.706326|    0.470116|    0.249292|    0.000396|    8.517317|           6|
|==========================================================================================|

Используйте информацию внутри TrainingHistory свойство объекта Mdl проверить итерацию, которая соответствует минимальной средней квадратичной невязке валидации (MSE). Конечная возвращенная модель Mdl - модель, обученная на этой итерации.

iteration = Mdl.TrainingHistory.Iteration;
valLosses = Mdl.TrainingHistory.ValidationLoss;
[~,minIdx] = min(valLosses);
iteration(minIdx)
ans = 23

Оценка эффективности тестового набора

Оцените эффективность обученной модели Mdl на тестовом аппарате testTbl при помощи loss и predict функции объекта.

Вычислите среднюю квадратичную невязку тестового набора (MSE). Меньшие значения MSE указывают на лучшую эффективность.

mse = loss(Mdl,testTbl,"MPG")
mse = 25.4145

Сравните предсказанные значения отклика набора тестов с истинными значениями отклика. Постройте график прогнозируемых миль на галлон (MPG) вдоль вертикальной оси и истинного MPG вдоль горизонтальной оси. Точки на опорной линии указывают на правильные предсказания. Хорошая модель производит предсказания, которые разбросаны рядом с линией.

predictedY = predict(Mdl,testTbl);

plot(testTbl.MPG,predictedY,".")
hold on
plot(testTbl.MPG,testTbl.MPG)
hold off
xlabel("True Miles Per Gallon (MPG)")
ylabel("Predicted Miles Per Gallon (MPG)")

Используйте прямоугольные графики, чтобы сравнить распределение предсказанных и истинных значений MPG по странам источника. Создайте прямоугольные графики при помощи boxchart функция. На каждом прямоугольном графике отображаются медиана, нижний и верхний квартили, любые выбросы (вычисленные с помощью межквартильной области значений) и минимальное и максимальное значения, которые не являются выбросами. В частности, линия внутри каждой коробки является медианой выборки, а круговые маркеры указывают на выбросы.

Для каждой страны источника сравните красный прямоугольный график (показывающий распределение прогнозируемых значений MPG) с синим прямоугольным графиком (показывающим распределение истинных значений MPG). Подобные распределения для предсказанных и истинных значений MPG указывают на хорошие предсказания.

boxchart(testTbl.Origin,testTbl.MPG)
hold on
boxchart(testTbl.Origin,predictedY)
hold off
legend(["True MPG","Predicted MPG"])
xlabel("Country of Origin")
ylabel("Miles Per Gallon (MPG)")

Для большинства стран прогнозируемые и истинные значения MPG имеют сходные распределения. Однако модель нейронной сети имеет тенденцию занижать значения MPG для автомобилей, произведенных во Франции. Это расхождение, возможно, связано с небольшим количеством французских автомобилей в обучающих и тестовых наборах.

Сравните область значений значений MPG для французских автомобилей в обучающих и тестовых наборах.

trainSummary = grpstats(trainTbl(:,["MPG","Origin"]),"Origin", ...
    ["min","max"])
trainSummary=6×4 table
               Origin     GroupCount    min_MPG    max_MPG
               _______    __________    _______    _______

    France     France          3         16.2         27  
    Germany    Germany        11           20       44.3  
    Italy      Italy           1         37.3       37.3  
    Japan      Japan          24           20       40.8  
    Sweden     Sweden          3           19       21.6  
    USA        USA            94            9         39  

testSummary = grpstats(testTbl(:,["MPG","Origin"]),"Origin", ...
    ["min","max"])
testSummary=6×4 table
               Origin     GroupCount    min_MPG    max_MPG
               _______    __________    _______    _______

    France     France          3         28.1       40.9  
    Germany    Germany        12         21.5         44  
    Italy      Italy           3           28         30  
    Japan      Japan          32           18       46.6  
    Sweden     Sweden          3           17         24  
    USA        USA            82           10       36.1  

В наборе обучающих данных значения MPG для автомобилей, произведенных во Франции, варьируются от 16,2 до 27. Однако в тестовом наборе значения MPG для автомобилей, произведенных во Франции, варьируются от 28,1 до 40,9.

Постройте график невязок тестового набора. Хорошая модель обычно имеет невязки, рассеянные примерно симметрично около 0. Очистка шаблонов в невязках является признаком того, что вы можете улучшить свою модель.

residuals = testTbl.MPG - predictedY;
plot(testTbl.MPG,residuals,".")
hold on
yline(0)
hold off
xlabel("True Miles Per Gallon (MPG)")
ylabel("MPG Residuals")

График предполагает, что некоторые остаточные значения являются выбросами. Узнайте больше информации о наблюдении с наибольшей невязкой.

[outlierResidual,outlierIdx] = max(residuals)
outlierResidual = 37.8727
outlierIdx = 113
testTbl(outlierIdx,:)
ans=1×7 table
    Acceleration    Displacement    Horsepower    Model_Year    Origin    Weight    MPG 
    ____________    ____________    __________    __________    ______    ______    ____

        17.3             85            NaN            80        France     1835     40.9

Наблюдение соответствует автомобилю, чья Horsepower значение отсутствует и чья страна источника - Франция, категория с небольшим количеством наблюдений.

См. также

| | | |