Создайте модель регрессионной нейронной сети с прямой связью с полносвязными слоями с помощью fitrnet
. Используйте данные валидации для раннего остановки процесса обучения, чтобы предотвратить сверхподбор кривой модели. Затем используйте функции объекта модели, чтобы оценить ее эффективность по тестовым данным.
Загрузите carbig
набор данных, содержащий измерения автомобилей 1970-х и начала 1980-х годов.
load carbig
Преобразуйте Origin
переменная - категориальная переменная. Затем создайте таблицу, содержащую переменные предиктора Acceleration
, Displacement
, и так далее, а также переменная отклика MPG
. Каждая строка содержит измерения для одного автомобиля.
Origin = categorical(cellstr(Origin));
Tbl = table(Acceleration,Displacement,Horsepower, ...
Model_Year,Origin,Weight,MPG);
Разделите данные на наборы для обучения, валидации и тестирования. Во-первых, резервируйте приблизительно одну треть наблюдений для тестового набора. Затем разделите оставшиеся данные пополам, чтобы создать наборы обучения и валидации.
rng("default") % For reproducibility of the data partitions cvp1 = cvpartition(size(Tbl,1),"Holdout",1/3); testTbl = Tbl(test(cvp1),:); remainingTbl = Tbl(training(cvp1),:); cvp2 = cvpartition(size(remainingTbl,1),"Holdout",1/2); validationTbl = remainingTbl(test(cvp2),:); trainTbl = remainingTbl(training(cvp2),:);
Обучите регрессионную модель нейронной сети при помощи набора обучающих данных. Задайте MPG
столбец tblTrain
как переменная отклика и стандартизируйте числовые предикторы. Оцените модель при каждой итерации с помощью набора валидации. Задайте, чтобы отобразить информацию о обучении при каждой итерации с помощью Verbose
аргумент имя-значение. По умолчанию процесс обучения заканчивается раньше, если потеря валидации больше или равна минимальной потере валидации, вычисленной до сих пор, шесть раз подряд. Чтобы изменить количество раз, в течение которого потеря валидации может быть больше или равной минимуму, задайте ValidationPatience
аргумент имя-значение.
Mdl = fitrnet(trainTbl,"MPG","Standardize",true, ... "ValidationData",validationTbl, ... "Verbose",1);
|==========================================================================================| | Iteration | Train Loss | Gradient | Step | Iteration | Validation | Validation | | | | | | Time (sec) | Loss | Checks | |==========================================================================================| | 1| 71.063537| 22.623354| 6.466959| 0.001272| 72.648960| 0| | 2| 48.608700| 22.384995| 1.022929| 0.001808| 43.435698| 0| | 3| 30.584887| 13.433471| 0.537190| 0.000903| 29.134447| 0| | 4| 17.781636| 11.159801| 1.401355| 0.000461| 16.542207| 0| | 5| 13.075804| 4.605991| 0.419875| 0.000387| 12.946670| 0| | 6| 11.697936| 3.197944| 0.226945| 0.000543| 12.025502| 0| | 7| 9.494801| 2.269831| 0.751711| 0.000452| 12.596499| 1| | 8| 8.390979| 1.970589| 0.337301| 0.000398| 11.490990| 0| | 9| 6.853097| 1.029078| 0.866974| 0.000378| 9.449945| 0| | 10| 6.531678| 0.924820| 0.306913| 0.000429| 9.350721| 0| |==========================================================================================| | Iteration | Train Loss | Gradient | Step | Iteration | Validation | Validation | | | | | | Time (sec) | Loss | Checks | |==========================================================================================| | 11| 6.152995| 1.872684| 0.457744| 0.000403| 9.223829| 0| | 12| 5.924852| 0.718386| 0.447879| 0.000402| 9.656166| 1| | 13| 5.792836| 0.500170| 0.216351| 0.000387| 9.733226| 2| | 14| 5.613473| 1.151197| 0.316828| 0.000531| 9.788646| 3| | 15| 5.415889| 1.513493| 0.327937| 0.000485| 9.607953| 4| | 16| 5.008195| 1.398069| 1.085660| 0.000430| 9.251971| 5| | 17| 5.004176| 2.070041| 0.890201| 0.000383| 8.719334| 0| | 18| 4.738386| 0.483667| 0.338897| 0.000374| 8.523728| 0| | 19| 4.680213| 0.437918| 0.107667| 0.000371| 8.369271| 0| | 20| 4.587350| 0.510639| 0.146276| 0.000385| 8.100236| 0| |==========================================================================================| | Iteration | Train Loss | Gradient | Step | Iteration | Validation | Validation | | | | | | Time (sec) | Loss | Checks | |==========================================================================================| | 21| 4.479929| 0.565635| 0.228198| 0.000381| 8.062927| 0| | 22| 4.380618| 0.892717| 0.377776| 0.000554| 7.843234| 0| | 23| 4.189344| 0.403227| 0.362307| 0.000434| 7.834582| 0| | 24| 4.182775| 1.150234| 1.908768| 0.000408| 9.436226| 1| | 25| 3.985939| 0.908479| 0.518217| 0.000570| 8.973756| 2| | 26| 3.873835| 0.826655| 0.477740| 0.000505| 8.863599| 3| | 27| 3.830830| 0.331936| 0.220000| 0.000539| 8.574682| 4| | 28| 3.796605| 0.232756| 0.075643| 0.000492| 8.591758| 5| | 29| 3.706326| 0.470116| 0.249292| 0.000396| 8.517317| 6| |==========================================================================================|
Используйте информацию внутри TrainingHistory
свойство объекта Mdl
проверить итерацию, которая соответствует минимальной средней квадратичной невязке валидации (MSE). Конечная возвращенная модель Mdl
- модель, обученная на этой итерации.
iteration = Mdl.TrainingHistory.Iteration; valLosses = Mdl.TrainingHistory.ValidationLoss; [~,minIdx] = min(valLosses); iteration(minIdx)
ans = 23
Оцените эффективность обученной модели Mdl
на тестовом аппарате testTbl
при помощи loss
и predict
функции объекта.
Вычислите среднюю квадратичную невязку тестового набора (MSE). Меньшие значения MSE указывают на лучшую эффективность.
mse = loss(Mdl,testTbl,"MPG")
mse = 25.4145
Сравните предсказанные значения отклика набора тестов с истинными значениями отклика. Постройте график прогнозируемых миль на галлон (MPG) вдоль вертикальной оси и истинного MPG вдоль горизонтальной оси. Точки на опорной линии указывают на правильные предсказания. Хорошая модель производит предсказания, которые разбросаны рядом с линией.
predictedY = predict(Mdl,testTbl); plot(testTbl.MPG,predictedY,".") hold on plot(testTbl.MPG,testTbl.MPG) hold off xlabel("True Miles Per Gallon (MPG)") ylabel("Predicted Miles Per Gallon (MPG)")
Используйте прямоугольные графики, чтобы сравнить распределение предсказанных и истинных значений MPG по странам источника. Создайте прямоугольные графики при помощи boxchart
функция. На каждом прямоугольном графике отображаются медиана, нижний и верхний квартили, любые выбросы (вычисленные с помощью межквартильной области значений) и минимальное и максимальное значения, которые не являются выбросами. В частности, линия внутри каждой коробки является медианой выборки, а круговые маркеры указывают на выбросы.
Для каждой страны источника сравните красный прямоугольный график (показывающий распределение прогнозируемых значений MPG) с синим прямоугольным графиком (показывающим распределение истинных значений MPG). Подобные распределения для предсказанных и истинных значений MPG указывают на хорошие предсказания.
boxchart(testTbl.Origin,testTbl.MPG) hold on boxchart(testTbl.Origin,predictedY) hold off legend(["True MPG","Predicted MPG"]) xlabel("Country of Origin") ylabel("Miles Per Gallon (MPG)")
Для большинства стран прогнозируемые и истинные значения MPG имеют сходные распределения. Однако модель нейронной сети имеет тенденцию занижать значения MPG для автомобилей, произведенных во Франции. Это расхождение, возможно, связано с небольшим количеством французских автомобилей в обучающих и тестовых наборах.
Сравните область значений значений MPG для французских автомобилей в обучающих и тестовых наборах.
trainSummary = grpstats(trainTbl(:,["MPG","Origin"]),"Origin", ... ["min","max"])
trainSummary=6×4 table
Origin GroupCount min_MPG max_MPG
_______ __________ _______ _______
France France 3 16.2 27
Germany Germany 11 20 44.3
Italy Italy 1 37.3 37.3
Japan Japan 24 20 40.8
Sweden Sweden 3 19 21.6
USA USA 94 9 39
testSummary = grpstats(testTbl(:,["MPG","Origin"]),"Origin", ... ["min","max"])
testSummary=6×4 table
Origin GroupCount min_MPG max_MPG
_______ __________ _______ _______
France France 3 28.1 40.9
Germany Germany 12 21.5 44
Italy Italy 3 28 30
Japan Japan 32 18 46.6
Sweden Sweden 3 17 24
USA USA 82 10 36.1
В наборе обучающих данных значения MPG для автомобилей, произведенных во Франции, варьируются от 16,2 до 27. Однако в тестовом наборе значения MPG для автомобилей, произведенных во Франции, варьируются от 28,1 до 40,9.
Постройте график невязок тестового набора. Хорошая модель обычно имеет невязки, рассеянные примерно симметрично около 0. Очистка шаблонов в невязках является признаком того, что вы можете улучшить свою модель.
residuals = testTbl.MPG - predictedY; plot(testTbl.MPG,residuals,".") hold on yline(0) hold off xlabel("True Miles Per Gallon (MPG)") ylabel("MPG Residuals")
График предполагает, что некоторые остаточные значения являются выбросами. Узнайте больше информации о наблюдении с наибольшей невязкой.
[outlierResidual,outlierIdx] = max(residuals)
outlierResidual = 37.8727
outlierIdx = 113
testTbl(outlierIdx,:)
ans=1×7 table
Acceleration Displacement Horsepower Model_Year Origin Weight MPG
____________ ____________ __________ __________ ______ ______ ____
17.3 85 NaN 80 France 1835 40.9
Наблюдение соответствует автомобилю, чья Horsepower
значение отсутствует и чья страна источника - Франция, категория с небольшим количеством наблюдений.
boxchart
| fitrnet
| loss
| predict
| RegressionNeuralNetwork