Создайте модель нейронной сети регрессии прямого распространения с полносвязными слоями с помощью fitrnet
. Используйте данные о валидации для ранней остановки учебного процесса, чтобы предотвратить сверхподбор кривой модели. Затем используйте объектные функции модели, чтобы оценить ее эффективность на тестовых данных.
Загрузите carbig
набор данных, который содержит измерения автомобилей, сделанных в 1970-х и в начале 1980-х.
load carbig
Преобразуйте Origin
переменная к категориальной переменной. Затем составьте таблицу, содержащую переменные предикторы Acceleration
, Displacement
, и так далее, а также переменная отклика MPG
. Каждая строка содержит измерения для одного автомобиля.
Origin = categorical(cellstr(Origin));
Tbl = table(Acceleration,Displacement,Horsepower, ...
Model_Year,Origin,Weight,MPG);
Разделите данные в обучение, валидацию и наборы тестов. Во-первых, зарезервируйте приблизительно одну треть наблюдений для набора тестов. Затем разделите остающиеся данные в половине, чтобы создать наборы обучения и валидации.
rng("default") % For reproducibility of the data partitions cvp1 = cvpartition(size(Tbl,1),"Holdout",1/3); testTbl = Tbl(test(cvp1),:); remainingTbl = Tbl(training(cvp1),:); cvp2 = cvpartition(size(remainingTbl,1),"Holdout",1/2); validationTbl = remainingTbl(test(cvp2),:); trainTbl = remainingTbl(training(cvp2),:);
Обучите модель нейронной сети регрессии при помощи набора обучающих данных. Задайте MPG
столбец tblTrain
как переменная отклика, и стандартизируют числовые предикторы. Оцените модель в каждой итерации при помощи набора валидации. Задайте, чтобы отобразить учебную информацию в каждой итерации при помощи Verbose
аргумент значения имени. По умолчанию учебный процесс заканчивается рано, если потеря валидации больше или равна минимальной потере валидации, вычисленной до сих пор, шесть раз подряд. Чтобы изменить число раз, потере валидации позволяют быть больше или быть равной минимуму, задать ValidationPatience
аргумент значения имени.
Mdl = fitrnet(trainTbl,"MPG","Standardize",true, ... "ValidationData",validationTbl, ... "Verbose",1);
|==========================================================================================| | Iteration | Train Loss | Gradient | Step | Iteration | Validation | Validation | | | | | | Time (sec) | Loss | Checks | |==========================================================================================| | 1| 71.063537| 22.623354| 6.466959| 0.001272| 72.648960| 0| | 2| 48.608700| 22.384995| 1.022929| 0.001808| 43.435698| 0| | 3| 30.584887| 13.433471| 0.537190| 0.000903| 29.134447| 0| | 4| 17.781636| 11.159801| 1.401355| 0.000461| 16.542207| 0| | 5| 13.075804| 4.605991| 0.419875| 0.000387| 12.946670| 0| | 6| 11.697936| 3.197944| 0.226945| 0.000543| 12.025502| 0| | 7| 9.494801| 2.269831| 0.751711| 0.000452| 12.596499| 1| | 8| 8.390979| 1.970589| 0.337301| 0.000398| 11.490990| 0| | 9| 6.853097| 1.029078| 0.866974| 0.000378| 9.449945| 0| | 10| 6.531678| 0.924820| 0.306913| 0.000429| 9.350721| 0| |==========================================================================================| | Iteration | Train Loss | Gradient | Step | Iteration | Validation | Validation | | | | | | Time (sec) | Loss | Checks | |==========================================================================================| | 11| 6.152995| 1.872684| 0.457744| 0.000403| 9.223829| 0| | 12| 5.924852| 0.718386| 0.447879| 0.000402| 9.656166| 1| | 13| 5.792836| 0.500170| 0.216351| 0.000387| 9.733226| 2| | 14| 5.613473| 1.151197| 0.316828| 0.000531| 9.788646| 3| | 15| 5.415889| 1.513493| 0.327937| 0.000485| 9.607953| 4| | 16| 5.008195| 1.398069| 1.085660| 0.000430| 9.251971| 5| | 17| 5.004176| 2.070041| 0.890201| 0.000383| 8.719334| 0| | 18| 4.738386| 0.483667| 0.338897| 0.000374| 8.523728| 0| | 19| 4.680213| 0.437918| 0.107667| 0.000371| 8.369271| 0| | 20| 4.587350| 0.510639| 0.146276| 0.000385| 8.100236| 0| |==========================================================================================| | Iteration | Train Loss | Gradient | Step | Iteration | Validation | Validation | | | | | | Time (sec) | Loss | Checks | |==========================================================================================| | 21| 4.479929| 0.565635| 0.228198| 0.000381| 8.062927| 0| | 22| 4.380618| 0.892717| 0.377776| 0.000554| 7.843234| 0| | 23| 4.189344| 0.403227| 0.362307| 0.000434| 7.834582| 0| | 24| 4.182775| 1.150234| 1.908768| 0.000408| 9.436226| 1| | 25| 3.985939| 0.908479| 0.518217| 0.000570| 8.973756| 2| | 26| 3.873835| 0.826655| 0.477740| 0.000505| 8.863599| 3| | 27| 3.830830| 0.331936| 0.220000| 0.000539| 8.574682| 4| | 28| 3.796605| 0.232756| 0.075643| 0.000492| 8.591758| 5| | 29| 3.706326| 0.470116| 0.249292| 0.000396| 8.517317| 6| |==========================================================================================|
Используйте информацию в TrainingHistory
свойство объекта Mdl
проверять итерацию, которая соответствует минимальной среднеквадратической ошибке (MSE) валидации. Финал возвратил модель Mdl
модель, обученная в этой итерации.
iteration = Mdl.TrainingHistory.Iteration; valLosses = Mdl.TrainingHistory.ValidationLoss; [~,minIdx] = min(valLosses); iteration(minIdx)
ans = 23
Оцените эффективность обученной модели Mdl
на наборе тестов testTbl
при помощи loss
и predict
функции объекта.
Вычислите среднеквадратическую ошибку (MSE) набора тестов. Меньшие значения MSE указывают на лучшую эффективность.
mse = loss(Mdl,testTbl,"MPG")
mse = 25.4145
Сравните предсказанные значения отклика набора тестов с истинными значениями отклика. Постройте предсказанные мили на галлон (MPG) вдоль вертикальной оси и истинный MPG вдоль горизонтальной оси. Точки на ссылочной линии указывают на правильные предсказания. Хорошая модель производит предсказания, которые рассеиваются около линии.
predictedY = predict(Mdl,testTbl); plot(testTbl.MPG,predictedY,".") hold on plot(testTbl.MPG,testTbl.MPG) hold off xlabel("True Miles Per Gallon (MPG)") ylabel("Predicted Miles Per Gallon (MPG)")
Используйте диаграммы сравнить распределение предсказанных и истинных значений MPG страной происхождения. Создайте диаграммы при помощи boxchart
функция. Каждая диаграмма отображает медиану, более низкие и верхние квартили, любые выбросы (вычисленное использование межквартильного размаха), и минимальные и максимальные значения, которые не являются выбросами. В частности, линия в каждом поле является демонстрационной медианой, и круговые маркеры указывают на выбросы.
Для каждой страны происхождения сравните красную диаграмму (показав распределение предсказанных значений MPG) к синей диаграмме (показав распределение истинных значений MPG). Подобные распределения для предсказанных и истинных значений MPG указывают на хорошие предсказания.
boxchart(testTbl.Origin,testTbl.MPG) hold on boxchart(testTbl.Origin,predictedY) hold off legend(["True MPG","Predicted MPG"]) xlabel("Country of Origin") ylabel("Miles Per Gallon (MPG)")
Для большинства стран предсказанные и истинные значения MPG имеют подобные распределения. Однако модель нейронной сети имеет тенденцию недооценивать значения MPG для автомобилей, сделанных во Франции. Это несоответствие происходит возможно из-за небольшого количества французских автомобилей в наборах обучающих данных и наборах тестов.
Сравните область значений значений MPG для французских автомобилей в наборах обучающих данных и наборах тестов.
trainSummary = grpstats(trainTbl(:,["MPG","Origin"]),"Origin", ... ["min","max"])
trainSummary=6×4 table
Origin GroupCount min_MPG max_MPG
_______ __________ _______ _______
France France 3 16.2 27
Germany Germany 11 20 44.3
Italy Italy 1 37.3 37.3
Japan Japan 24 20 40.8
Sweden Sweden 3 19 21.6
USA USA 94 9 39
testSummary = grpstats(testTbl(:,["MPG","Origin"]),"Origin", ... ["min","max"])
testSummary=6×4 table
Origin GroupCount min_MPG max_MPG
_______ __________ _______ _______
France France 3 28.1 40.9
Germany Germany 12 21.5 44
Italy Italy 3 28 30
Japan Japan 32 18 46.6
Sweden Sweden 3 17 24
USA USA 82 10 36.1
В наборе обучающих данных значения MPG для автомобилей, сделанных во Франции, лежат в диапазоне от 16,2 до 27. Однако в наборе тестов, значения MPG для автомобилей, сделанных во Франции, лежат в диапазоне от 28,1 до 40,9.
Постройте остаточные значения набора тестов. Хорошей модели обычно рассеивали остаточные значения примерно симметрично приблизительно 0. Очиститесь шаблоны в остаточных значениях являются знаком, что можно улучшить модель.
residuals = testTbl.MPG - predictedY; plot(testTbl.MPG,residuals,".") hold on yline(0) hold off xlabel("True Miles Per Gallon (MPG)") ylabel("MPG Residuals")
График предполагает, что некоторые остаточные значения являются выбросами. Узнайте больше информации о наблюдении с самой большой невязкой.
[outlierResidual,outlierIdx] = max(residuals)
outlierResidual = 37.8727
outlierIdx = 113
testTbl(outlierIdx,:)
ans=1×7 table
Acceleration Displacement Horsepower Model_Year Origin Weight MPG
____________ ____________ __________ __________ ______ ______ ____
17.3 85 NaN 80 France 1835 40.9
Наблюдение соответствует автомобилю чей Horsepower
значение отсутствует и чьей страной происхождения является Франция, категория с немногими наблюдениями.
boxchart
| fitrnet
| loss
| predict
| RegressionNeuralNetwork