Создание модели нейронной сети прямой регрессии с полностью связанными слоями с помощью fitrnet. Используйте данные проверки для ранней остановки процесса обучения, чтобы предотвратить переоснащение модели. Затем используйте объектные функции модели для оценки ее производительности на тестовых данных.
Загрузить carbig набор данных, содержащий замеры автомобилей, сделанные в 1970-х и начале 1980-х годов.
load carbigПреобразовать Origin переменной к категориальной переменной. Затем создайте таблицу, содержащую переменные предиктора Acceleration, Displacementи так далее, а также переменная ответа MPG. Каждая строка содержит измерения для одного автомобиля.
Origin = categorical(cellstr(Origin));
Tbl = table(Acceleration,Displacement,Horsepower, ...
Model_Year,Origin,Weight,MPG);Разбейте данные на обучающие, проверочные и тестовые наборы. Во-первых, зарезервировать примерно одну треть наблюдений для тестового набора. Затем разделите оставшиеся данные пополам, чтобы создать наборы обучения и проверки.
rng("default") % For reproducibility of the data partitions cvp1 = cvpartition(size(Tbl,1),"Holdout",1/3); testTbl = Tbl(test(cvp1),:); remainingTbl = Tbl(training(cvp1),:); cvp2 = cvpartition(size(remainingTbl,1),"Holdout",1/2); validationTbl = remainingTbl(test(cvp2),:); trainTbl = remainingTbl(training(cvp2),:);
Обучение модели регрессионной нейронной сети с использованием обучающего набора. Укажите MPG столбец tblTrain в качестве переменной ответа и стандартизировать числовые предикторы. Оцените модель в каждой итерации с помощью набора проверки. Укажите отображение информации об обучении в каждой итерации с помощью Verbose аргумент «имя-значение». По умолчанию учебный процесс завершается раньше, если потеря проверки больше или равна минимальной потере проверки, рассчитанной на данный момент, шесть раз подряд. Чтобы изменить количество раз, когда допустимо, чтобы потеря проверки была больше или равна минимуму, укажите ValidationPatience аргумент «имя-значение».
Mdl = fitrnet(trainTbl,"MPG","Standardize",true, ... "ValidationData",validationTbl, ... "Verbose",1);
|==========================================================================================| | Iteration | Train Loss | Gradient | Step | Iteration | Validation | Validation | | | | | | Time (sec) | Loss | Checks | |==========================================================================================| | 1| 71.063537| 22.623354| 6.466959| 0.001272| 72.648960| 0| | 2| 48.608700| 22.384995| 1.022929| 0.001808| 43.435698| 0| | 3| 30.584887| 13.433471| 0.537190| 0.000903| 29.134447| 0| | 4| 17.781636| 11.159801| 1.401355| 0.000461| 16.542207| 0| | 5| 13.075804| 4.605991| 0.419875| 0.000387| 12.946670| 0| | 6| 11.697936| 3.197944| 0.226945| 0.000543| 12.025502| 0| | 7| 9.494801| 2.269831| 0.751711| 0.000452| 12.596499| 1| | 8| 8.390979| 1.970589| 0.337301| 0.000398| 11.490990| 0| | 9| 6.853097| 1.029078| 0.866974| 0.000378| 9.449945| 0| | 10| 6.531678| 0.924820| 0.306913| 0.000429| 9.350721| 0| |==========================================================================================| | Iteration | Train Loss | Gradient | Step | Iteration | Validation | Validation | | | | | | Time (sec) | Loss | Checks | |==========================================================================================| | 11| 6.152995| 1.872684| 0.457744| 0.000403| 9.223829| 0| | 12| 5.924852| 0.718386| 0.447879| 0.000402| 9.656166| 1| | 13| 5.792836| 0.500170| 0.216351| 0.000387| 9.733226| 2| | 14| 5.613473| 1.151197| 0.316828| 0.000531| 9.788646| 3| | 15| 5.415889| 1.513493| 0.327937| 0.000485| 9.607953| 4| | 16| 5.008195| 1.398069| 1.085660| 0.000430| 9.251971| 5| | 17| 5.004176| 2.070041| 0.890201| 0.000383| 8.719334| 0| | 18| 4.738386| 0.483667| 0.338897| 0.000374| 8.523728| 0| | 19| 4.680213| 0.437918| 0.107667| 0.000371| 8.369271| 0| | 20| 4.587350| 0.510639| 0.146276| 0.000385| 8.100236| 0| |==========================================================================================| | Iteration | Train Loss | Gradient | Step | Iteration | Validation | Validation | | | | | | Time (sec) | Loss | Checks | |==========================================================================================| | 21| 4.479929| 0.565635| 0.228198| 0.000381| 8.062927| 0| | 22| 4.380618| 0.892717| 0.377776| 0.000554| 7.843234| 0| | 23| 4.189344| 0.403227| 0.362307| 0.000434| 7.834582| 0| | 24| 4.182775| 1.150234| 1.908768| 0.000408| 9.436226| 1| | 25| 3.985939| 0.908479| 0.518217| 0.000570| 8.973756| 2| | 26| 3.873835| 0.826655| 0.477740| 0.000505| 8.863599| 3| | 27| 3.830830| 0.331936| 0.220000| 0.000539| 8.574682| 4| | 28| 3.796605| 0.232756| 0.075643| 0.000492| 8.591758| 5| | 29| 3.706326| 0.470116| 0.249292| 0.000396| 8.517317| 6| |==========================================================================================|
Используйте информацию внутри TrainingHistory свойство объекта Mdl для проверки итерации, которая соответствует минимальному среднеквадратичному значению ошибки (MSE). Последняя возвращенная модель Mdl является моделью, обученной этой итерации.
iteration = Mdl.TrainingHistory.Iteration; valLosses = Mdl.TrainingHistory.ValidationLoss; [~,minIdx] = min(valLosses); iteration(minIdx)
ans = 23
Оценка эффективности обучаемой модели Mdl на контрольном аппарате testTbl с помощью loss и predict функции объекта.
Вычислите среднеквадратичную ошибку (MSE) тестового набора. Меньшие значения MSE указывают на лучшую производительность.
mse = loss(Mdl,testTbl,"MPG")mse = 25.4145
Сравните прогнозируемые значения отклика тестового набора с истинными значениями отклика. Постройте график прогнозируемых миль на галлон (MPG) вдоль вертикальной оси и истинного MPG вдоль горизонтальной оси. Точки на опорной линии указывают правильные прогнозы. Хорошая модель производит прогнозы, которые разбросаны вблизи линии.
predictedY = predict(Mdl,testTbl); plot(testTbl.MPG,predictedY,".") hold on plot(testTbl.MPG,testTbl.MPG) hold off xlabel("True Miles Per Gallon (MPG)") ylabel("Predicted Miles Per Gallon (MPG)")

Покадровые графики используются для сравнения распределения прогнозируемых и истинных значений MPG по странам происхождения. Создайте прямоугольные графики с помощью boxchart функция. Каждый прямоугольный график отображает медиану, нижний и верхний квартили, любые отклонения (вычисленные с использованием межквартильного диапазона), а также минимальное и максимальное значения, которые не являются отклонениями. В частности, линия внутри каждого бокса представляет собой медиану образца, а круглые маркеры указывают на отклонения.
Для каждой страны происхождения сравните график красного ящика (показывающий распределение прогнозируемых значений MPG) с графиком синего ящика (показывающим распределение истинных значений MPG). Аналогичные распределения для прогнозируемых и истинных значений MPG указывают на хорошие прогнозы.
boxchart(testTbl.Origin,testTbl.MPG) hold on boxchart(testTbl.Origin,predictedY) hold off legend(["True MPG","Predicted MPG"]) xlabel("Country of Origin") ylabel("Miles Per Gallon (MPG)")

Для большинства стран прогнозируемые и истинные значения MPG имеют сходные распределения. Однако модель нейронной сети имеет тенденцию недооценивать значения MPG для автомобилей, изготовленных во Франции. Это несоответствие, возможно, связано с небольшим количеством французских автомобилей в учебно-испытательных комплектах.
Сравните диапазон значений MPG для французских автомобилей в учебных и тестовых наборах.
trainSummary = grpstats(trainTbl(:,["MPG","Origin"]),"Origin", ... ["min","max"])
trainSummary=6×4 table
Origin GroupCount min_MPG max_MPG
_______ __________ _______ _______
France France 3 16.2 27
Germany Germany 11 20 44.3
Italy Italy 1 37.3 37.3
Japan Japan 24 20 40.8
Sweden Sweden 3 19 21.6
USA USA 94 9 39
testSummary = grpstats(testTbl(:,["MPG","Origin"]),"Origin", ... ["min","max"])
testSummary=6×4 table
Origin GroupCount min_MPG max_MPG
_______ __________ _______ _______
France France 3 28.1 40.9
Germany Germany 12 21.5 44
Italy Italy 3 28 30
Japan Japan 32 18 46.6
Sweden Sweden 3 17 24
USA USA 82 10 36.1
В обучающем наборе значения MPG для автомобилей, изготовленных во Франции, варьируются от 16,2 до 27. Однако в испытательном наборе значения MPG для автомобилей, изготовленных во Франции, варьируются от 28,1 до 40,9.
Постройте график остатков тестового набора. Хорошая модель обычно имеет остатки, разбросанные приблизительно симметрично вокруг 0. Четкие узоры в остатках являются признаком того, что модель можно улучшить.
residuals = testTbl.MPG - predictedY; plot(testTbl.MPG,residuals,".") hold on yline(0) hold off xlabel("True Miles Per Gallon (MPG)") ylabel("MPG Residuals")

Сюжет предполагает, что некоторые остаточные значения являются отклонениями. Узнайте больше о наблюдении с наибольшим остатком.
[outlierResidual,outlierIdx] = max(residuals)
outlierResidual = 37.8727
outlierIdx = 113
testTbl(outlierIdx,:)
ans=1×7 table
Acceleration Displacement Horsepower Model_Year Origin Weight MPG
____________ ____________ __________ __________ ______ ______ ____
17.3 85 NaN 80 France 1835 40.9
Наблюдение соответствует автомобилю, чей Horsepower значение отсутствует, и страной происхождения является Франция, категория с небольшим количеством наблюдений.
boxchart | fitrnet | loss | predict | RegressionNeuralNetwork