Оцените эффективность нейронной сети регрессии

Создайте модель нейронной сети регрессии прямого распространения с полносвязными слоями с помощью fitrnet. Используйте данные о валидации для ранней остановки учебного процесса, чтобы предотвратить сверхподбор кривой модели. Затем используйте объектные функции модели, чтобы оценить ее эффективность на тестовых данных.

Загрузка демонстрационных данных

Загрузите carbig набор данных, который содержит измерения автомобилей, сделанных в 1970-х и в начале 1980-х.

load carbig

Преобразуйте Origin переменная к категориальной переменной. Затем составьте таблицу, содержащую переменные предикторы Acceleration, Displacement, и так далее, а также переменная отклика MPG. Каждая строка содержит измерения для одного автомобиля.

Origin = categorical(cellstr(Origin));
Tbl = table(Acceleration,Displacement,Horsepower, ...
    Model_Year,Origin,Weight,MPG);

Данные о разделе

Разделите данные в обучение, валидацию и наборы тестов. Во-первых, зарезервируйте приблизительно одну треть наблюдений для набора тестов. Затем разделите остающиеся данные в половине, чтобы создать наборы обучения и валидации.

rng("default") % For reproducibility of the data partitions
cvp1 = cvpartition(size(Tbl,1),"Holdout",1/3);
testTbl = Tbl(test(cvp1),:);
remainingTbl = Tbl(training(cvp1),:);

cvp2 = cvpartition(size(remainingTbl,1),"Holdout",1/2);
validationTbl = remainingTbl(test(cvp2),:);
trainTbl = remainingTbl(training(cvp2),:);

Обучите нейронную сеть

Обучите модель нейронной сети регрессии при помощи набора обучающих данных. Задайте MPG столбец tblTrain как переменная отклика, и стандартизируют числовые предикторы. Оцените модель в каждой итерации при помощи набора валидации. Задайте, чтобы отобразить учебную информацию в каждой итерации при помощи Verbose аргумент значения имени. По умолчанию учебный процесс заканчивается рано, если потеря валидации больше или равна минимальной потере валидации, вычисленной до сих пор, шесть раз подряд. Чтобы изменить число раз, потере валидации позволяют быть больше или быть равной минимуму, задать ValidationPatience аргумент значения имени.

Mdl = fitrnet(trainTbl,"MPG","Standardize",true, ...
    "ValidationData",validationTbl, ...
    "Verbose",1);
|==========================================================================================|
| Iteration  | Train Loss | Gradient   | Step       | Iteration  | Validation | Validation |
|            |            |            |            | Time (sec) | Loss       | Checks     |
|==========================================================================================|
|           1|   71.063537|   22.623354|    6.466959|    0.001272|   72.648960|           0|
|           2|   48.608700|   22.384995|    1.022929|    0.001808|   43.435698|           0|
|           3|   30.584887|   13.433471|    0.537190|    0.000903|   29.134447|           0|
|           4|   17.781636|   11.159801|    1.401355|    0.000461|   16.542207|           0|
|           5|   13.075804|    4.605991|    0.419875|    0.000387|   12.946670|           0|
|           6|   11.697936|    3.197944|    0.226945|    0.000543|   12.025502|           0|
|           7|    9.494801|    2.269831|    0.751711|    0.000452|   12.596499|           1|
|           8|    8.390979|    1.970589|    0.337301|    0.000398|   11.490990|           0|
|           9|    6.853097|    1.029078|    0.866974|    0.000378|    9.449945|           0|
|          10|    6.531678|    0.924820|    0.306913|    0.000429|    9.350721|           0|
|==========================================================================================|
| Iteration  | Train Loss | Gradient   | Step       | Iteration  | Validation | Validation |
|            |            |            |            | Time (sec) | Loss       | Checks     |
|==========================================================================================|
|          11|    6.152995|    1.872684|    0.457744|    0.000403|    9.223829|           0|
|          12|    5.924852|    0.718386|    0.447879|    0.000402|    9.656166|           1|
|          13|    5.792836|    0.500170|    0.216351|    0.000387|    9.733226|           2|
|          14|    5.613473|    1.151197|    0.316828|    0.000531|    9.788646|           3|
|          15|    5.415889|    1.513493|    0.327937|    0.000485|    9.607953|           4|
|          16|    5.008195|    1.398069|    1.085660|    0.000430|    9.251971|           5|
|          17|    5.004176|    2.070041|    0.890201|    0.000383|    8.719334|           0|
|          18|    4.738386|    0.483667|    0.338897|    0.000374|    8.523728|           0|
|          19|    4.680213|    0.437918|    0.107667|    0.000371|    8.369271|           0|
|          20|    4.587350|    0.510639|    0.146276|    0.000385|    8.100236|           0|
|==========================================================================================|
| Iteration  | Train Loss | Gradient   | Step       | Iteration  | Validation | Validation |
|            |            |            |            | Time (sec) | Loss       | Checks     |
|==========================================================================================|
|          21|    4.479929|    0.565635|    0.228198|    0.000381|    8.062927|           0|
|          22|    4.380618|    0.892717|    0.377776|    0.000554|    7.843234|           0|
|          23|    4.189344|    0.403227|    0.362307|    0.000434|    7.834582|           0|
|          24|    4.182775|    1.150234|    1.908768|    0.000408|    9.436226|           1|
|          25|    3.985939|    0.908479|    0.518217|    0.000570|    8.973756|           2|
|          26|    3.873835|    0.826655|    0.477740|    0.000505|    8.863599|           3|
|          27|    3.830830|    0.331936|    0.220000|    0.000539|    8.574682|           4|
|          28|    3.796605|    0.232756|    0.075643|    0.000492|    8.591758|           5|
|          29|    3.706326|    0.470116|    0.249292|    0.000396|    8.517317|           6|
|==========================================================================================|

Используйте информацию в TrainingHistory свойство объекта Mdl проверять итерацию, которая соответствует минимальной среднеквадратической ошибке (MSE) валидации. Финал возвратил модель Mdl модель, обученная в этой итерации.

iteration = Mdl.TrainingHistory.Iteration;
valLosses = Mdl.TrainingHistory.ValidationLoss;
[~,minIdx] = min(valLosses);
iteration(minIdx)
ans = 23

Оцените эффективность набора тестов

Оцените эффективность обученной модели Mdl на наборе тестов testTbl при помощи loss и predict функции объекта.

Вычислите среднеквадратическую ошибку (MSE) набора тестов. Меньшие значения MSE указывают на лучшую эффективность.

mse = loss(Mdl,testTbl,"MPG")
mse = 25.4145

Сравните предсказанные значения отклика набора тестов с истинными значениями отклика. Постройте предсказанные мили на галлон (MPG) вдоль вертикальной оси и истинный MPG вдоль горизонтальной оси. Точки на ссылочной линии указывают на правильные предсказания. Хорошая модель производит предсказания, которые рассеиваются около линии.

predictedY = predict(Mdl,testTbl);

plot(testTbl.MPG,predictedY,".")
hold on
plot(testTbl.MPG,testTbl.MPG)
hold off
xlabel("True Miles Per Gallon (MPG)")
ylabel("Predicted Miles Per Gallon (MPG)")

Используйте диаграммы сравнить распределение предсказанных и истинных значений MPG страной происхождения. Создайте диаграммы при помощи boxchart функция. Каждая диаграмма отображает медиану, более низкие и верхние квартили, любые выбросы (вычисленное использование межквартильного размаха), и минимальные и максимальные значения, которые не являются выбросами. В частности, линия в каждом поле является демонстрационной медианой, и круговые маркеры указывают на выбросы.

Для каждой страны происхождения сравните красную диаграмму (показав распределение предсказанных значений MPG) к синей диаграмме (показав распределение истинных значений MPG). Подобные распределения для предсказанных и истинных значений MPG указывают на хорошие предсказания.

boxchart(testTbl.Origin,testTbl.MPG)
hold on
boxchart(testTbl.Origin,predictedY)
hold off
legend(["True MPG","Predicted MPG"])
xlabel("Country of Origin")
ylabel("Miles Per Gallon (MPG)")

Для большинства стран предсказанные и истинные значения MPG имеют подобные распределения. Однако модель нейронной сети имеет тенденцию недооценивать значения MPG для автомобилей, сделанных во Франции. Это несоответствие происходит возможно из-за небольшого количества французских автомобилей в наборах обучающих данных и наборах тестов.

Сравните область значений значений MPG для французских автомобилей в наборах обучающих данных и наборах тестов.

trainSummary = grpstats(trainTbl(:,["MPG","Origin"]),"Origin", ...
    ["min","max"])
trainSummary=6×4 table
               Origin     GroupCount    min_MPG    max_MPG
               _______    __________    _______    _______

    France     France          3         16.2         27  
    Germany    Germany        11           20       44.3  
    Italy      Italy           1         37.3       37.3  
    Japan      Japan          24           20       40.8  
    Sweden     Sweden          3           19       21.6  
    USA        USA            94            9         39  

testSummary = grpstats(testTbl(:,["MPG","Origin"]),"Origin", ...
    ["min","max"])
testSummary=6×4 table
               Origin     GroupCount    min_MPG    max_MPG
               _______    __________    _______    _______

    France     France          3         28.1       40.9  
    Germany    Germany        12         21.5         44  
    Italy      Italy           3           28         30  
    Japan      Japan          32           18       46.6  
    Sweden     Sweden          3           17         24  
    USA        USA            82           10       36.1  

В наборе обучающих данных значения MPG для автомобилей, сделанных во Франции, лежат в диапазоне от 16,2 до 27. Однако в наборе тестов, значения MPG для автомобилей, сделанных во Франции, лежат в диапазоне от 28,1 до 40,9.

Постройте остаточные значения набора тестов. Хорошей модели обычно рассеивали остаточные значения примерно симметрично приблизительно 0. Очиститесь шаблоны в остаточных значениях являются знаком, что можно улучшить модель.

residuals = testTbl.MPG - predictedY;
plot(testTbl.MPG,residuals,".")
hold on
yline(0)
hold off
xlabel("True Miles Per Gallon (MPG)")
ylabel("MPG Residuals")

График предполагает, что некоторые остаточные значения являются выбросами. Узнайте больше информации о наблюдении с самой большой невязкой.

[outlierResidual,outlierIdx] = max(residuals)
outlierResidual = 37.8727
outlierIdx = 113
testTbl(outlierIdx,:)
ans=1×7 table
    Acceleration    Displacement    Horsepower    Model_Year    Origin    Weight    MPG 
    ____________    ____________    __________    __________    ______    ______    ____

        17.3             85            NaN            80        France     1835     40.9

Наблюдение соответствует автомобилю чей Horsepower значение отсутствует и чьей страной происхождения является Франция, категория с немногими наблюдениями.

Смотрите также

| | | |