Wat zijn beslissingsbomen?
Decision Trees zijn een veelzijdig Machine Learning-algoritme dat zowel classificatie- als regressietaken kan uitvoeren. Het zijn zeer krachtige algoritmen die in staat zijn om complexe datasets aan te passen. Bovendien zijn beslissingsbomen fundamentele componenten van willekeurige bossen, die tot de krachtigste Machine Learning-algoritmen behoren die momenteel beschikbaar zijn.
Trainen en visualiseren van beslissingsbomen
Om uw eerste beslissingsboom in het R-voorbeeld te bouwen, gaan we als volgt te werk in deze zelfstudie over de beslissingsboom:
- Stap 1: Importeer de gegevens
- Stap 2: Maak de dataset schoon
- Stap 3: Maak een trein- / testset
- Stap 4: Bouw het model
- Stap 5: maak een voorspelling
- Stap 6: meet de prestaties
- Stap 7: Stem de hyperparameters af
Stap 1) Importeer de gegevens
Als je nieuwsgierig bent naar het lot van de titanic, kun je deze video op Youtube bekijken. Het doel van deze dataset is om te voorspellen welke mensen meer kans hebben om te overleven na de botsing met de ijsberg. De dataset bevat 13 variabelen en 1309 waarnemingen. De dataset is geordend op de variabele X.
set.seed(678)path <- 'https://raw.githubusercontent.com/guru99-edu/R-Programming/master/titanic_data.csv'titanic <-read.csv(path)head(titanic)
Uitgang:
## X pclass survived name sex## 1 1 1 1 Allen, Miss. Elisabeth Walton female## 2 2 1 1 Allison, Master. Hudson Trevor male## 3 3 1 0 Allison, Miss. Helen Loraine female## 4 4 1 0 Allison, Mr. Hudson Joshua Creighton male## 5 5 1 0 Allison, Mrs. Hudson J C (Bessie Waldo Daniels) female## 6 6 1 1 Anderson, Mr. Harry male## age sibsp parch ticket fare cabin embarked## 1 29.0000 0 0 24160 211.3375 B5 S## 2 0.9167 1 2 113781 151.5500 C22 C26 S## 3 2.0000 1 2 113781 151.5500 C22 C26 S## 4 30.0000 1 2 113781 151.5500 C22 C26 S## 5 25.0000 1 2 113781 151.5500 C22 C26 S## 6 48.0000 0 0 19952 26.5500 E12 S## home.dest## 1 St Louis, MO## 2 Montreal, PQ / Chesterville, ON## 3 Montreal, PQ / Chesterville, ON## 4 Montreal, PQ / Chesterville, ON## 5 Montreal, PQ / Chesterville, ON## 6 New York, NY
tail(titanic)
Uitgang:
## X pclass survived name sex age sibsp## 1304 1304 3 0 Yousseff, Mr. Gerious male NA 0## 1305 1305 3 0 Zabour, Miss. Hileni female 14.5 1## 1306 1306 3 0 Zabour, Miss. Thamine female NA 1## 1307 1307 3 0 Zakarian, Mr. Mapriededer male 26.5 0## 1308 1308 3 0 Zakarian, Mr. Ortin male 27.0 0## 1309 1309 3 0 Zimmerman, Mr. Leo male 29.0 0## parch ticket fare cabin embarked home.dest## 1304 0 2627 14.4583 C## 1305 0 2665 14.4542 C## 1306 0 2665 14.4542 C## 1307 0 2656 7.2250 C## 1308 0 2670 7.2250 C## 1309 0 315082 7.8750 S
Aan de hand van de kop- en staartuitvoer kunt u zien dat de gegevens niet door elkaar worden gehaald. Dit is een groot probleem! Wanneer u uw gegevens splitst tussen een treinset en een testset, selecteert u alleen de passagier uit klasse 1 en 2 (geen passagier uit klasse 3 staat in de top 80 procent van de waarnemingen), wat betekent dat het algoritme nooit de kenmerken van passagier van klasse 3. Deze fout leidt tot een slechte voorspelling.
Om dit probleem op te lossen, kunt u de functie sample () gebruiken.
shuffle_index <- sample(1:nrow(titanic))head(shuffle_index)
Beslisboom R-code Toelichting
- sample (1: nrow (titanic)): Genereer een willekeurige indexlijst van 1 tot 1309 (dwz het maximale aantal rijen).
Uitgang:
## [1] 288 874 1078 633 887 992
U zult deze index gebruiken om de titanic dataset door elkaar te halen.
titanic <- titanic[shuffle_index, ]head(titanic)
Uitgang:
## X pclass survived## 288 288 1 0## 874 874 3 0## 1078 1078 3 1## 633 633 3 0## 887 887 3 1## 992 992 3 1## name sex age## 288 Sutton, Mr. Frederick male 61## 874 Humblen, Mr. Adolf Mathias Nicolai Olsen male 42## 1078 O'Driscoll, Miss. Bridget female NA## 633 Andersson, Mrs. Anders Johan (Alfrida Konstantia Brogren) female 39## 887 Jermyn, Miss. Annie female NA## 992 Mamee, Mr. Hanna male NA## sibsp parch ticket fare cabin embarked home.dest## 288 0 0 36963 32.3208 D50 S Haddenfield, NJ## 874 0 0 348121 7.6500 F G63 S## 1078 0 0 14311 7.7500 Q## 633 1 5 347082 31.2750 S Sweden Winnipeg, MN## 887 0 0 14313 7.7500 Q## 992 0 0 2677 7.2292 C
Stap 2) Maak de dataset schoon
De structuur van de gegevens laat zien dat sommige variabelen NA's hebben. Het opschonen van gegevens moet als volgt worden gedaan
- Drop variabelen home.dest, hut, naam, X en ticket
- Maak factorvariabelen voor pclass en overleefde
- Laat de NA vallen
library(dplyr)# Drop variablesclean_titanic <- titanic % > %select(-c(home.dest, cabin, name, X, ticket)) % > %#Convert to factor levelmutate(pclass = factor(pclass, levels = c(1, 2, 3), labels = c('Upper', 'Middle', 'Lower')),survived = factor(survived, levels = c(0, 1), labels = c('No', 'Yes'))) % > %na.omit()glimpse(clean_titanic)
Code Verklaring
- select (-c (home.dest, cabin, name, X, ticket)): laat onnodige variabelen vallen
- pclass = factor (pclass, levels = c (1,2,3), labels = c ('Upper', 'Middle', 'Lower')): Voeg label toe aan de variabele pclass. 1 wordt Upper, 2 wordt MIddle en 3 wordt lager
- factor (overleefd, niveaus = c (0,1), labels = c ('Nee', 'Ja')): Voeg een label toe aan de variabele overleefd. 1 wordt Nee en 2 wordt Ja
- na.omit (): Verwijder de NA-waarnemingen
Uitgang:
## Observations: 1,045## Variables: 8## $ pclassUpper, Lower, Lower, Upper, Middle, Upper, Middle, U… ## $ survived No, No, No, Yes, No, Yes, Yes, No, No, No, No, No, Y… ## $ sex male, male, female, female, male, male, female, male… ## $ age 61.0, 42.0, 39.0, 49.0, 29.0, 37.0, 20.0, 54.0, 2.0,… ## $ sibsp 0, 0, 1, 0, 0, 1, 0, 0, 4, 0, 0, 1, 1, 0, 0, 0, 1, 1,… ## $ parch 0, 0, 5, 0, 0, 1, 0, 1, 1, 0, 0, 1, 1, 0, 2, 0, 4, 0,… ## $ fare 32.3208, 7.6500, 31.2750, 25.9292, 10.5000, 52.5542,… ## $ embarked S, S, S, S, S, S, S, S, S, C, S, S, S, Q, C, S, S, C…
Stap 3) Maak een trein- / testset
Voordat u uw model gaat trainen, moet u twee stappen uitvoeren:
- Maak een trein en testset: je traint het model op de treinset en test de voorspelling op de testset (dwz ongeziene gegevens)
- Installeer rpart.plot vanaf de console
Het is gebruikelijk om de gegevens 80/20 te splitsen, 80 procent van de gegevens dient om het model te trainen en 20 procent om voorspellingen te doen. U moet twee afzonderlijke dataframes maken. U wilt de testset pas aanraken als u klaar bent met het bouwen van uw model. U kunt een functienaam create_train_test () maken waaraan drie argumenten moeten doorgegeven worden.
create_train_test(df, size = 0.8, train = TRUE)arguments:-df: Dataset used to train the model.-size: Size of the split. By default, 0.8. Numerical value-train: If set to `TRUE`, the function creates the train set, otherwise the test set. Default value sets to `TRUE`. Boolean value.You need to add a Boolean parameter because R does not allow to return two data frames simultaneously.
create_train_test <- function(data, size = 0.8, train = TRUE) {n_row = nrow(data)total_row = size * n_rowtrain_sample < - 1: total_rowif (train == TRUE) {return (data[train_sample, ])} else {return (data[-train_sample, ])}}
Code Verklaring
- function (data, size = 0.8, train = TRUE): Voeg de argumenten in de functie toe
- n_row = nrow (data): tel het aantal rijen in de dataset
- total_row = size * n_row: Retourneer de nde rij om de treinset te construeren
- train_sample <- 1: total_row: Selecteer de eerste rij tot de n-de rijen
- if (train == TRUE) {} else {}: Als voorwaarde op true wordt ingesteld, retourneert u de treinset, anders de testset.
U kunt uw functie testen en de afmeting controleren.
data_train <- create_train_test(clean_titanic, 0.8, train = TRUE)data_test <- create_train_test(clean_titanic, 0.8, train = FALSE)dim(data_train)
Uitgang:
## [1] 836 8
dim(data_test)
Uitgang:
## [1] 209 8
De treindataset heeft 1046 rijen, terwijl de testdataset 262 rijen heeft.
U gebruikt de functie prop.table () in combinatie met table () om te controleren of het randomiseringsproces correct is.
prop.table(table(data_train$survived))
Uitgang:
#### No Yes## 0.5944976 0.4055024
prop.table(table(data_test$survived))
Uitgang:
#### No Yes## 0.5789474 0.4210526
In beide datasets is het aantal overlevenden hetzelfde, ongeveer 40 procent.
Installeer rpart.plot
rpart.plot is niet beschikbaar in conda-bibliotheken. U kunt het vanaf de console installeren:
install.packages("rpart.plot")
Stap 4) Bouw het model
U bent klaar om het model te bouwen. De syntaxis voor de Rpart-beslissingsboomfunctie is:
rpart(formula, data=, method='')arguments:- formula: The function to predict- data: Specifies the data frame- method:- "class" for a classification tree- "anova" for a regression tree
U gebruikt de klassemethode omdat u een klas voorspelt.
library(rpart)library(rpart.plot)fit <- rpart(survived~., data = data_train, method = 'class')rpart.plot(fit, extra = 106
Code Verklaring
- rpart (): functie die bij het model past. De argumenten zijn:
- overleefde ~ .: Formule van de beslissingsbomen
- data = data_train: Dataset
- method = 'class': Fit een binair model
- rpart.plot (fit, extra = 106): plot de boom. De extra functies zijn ingesteld op 101 om de waarschijnlijkheid van de 2e klas weer te geven (handig voor binaire antwoorden). U kunt het vignet raadplegen voor meer informatie over de andere keuzes.
Uitgang:
Je begint bij het root-knooppunt (diepte 0 over 3, de bovenkant van de grafiek):
- Bovenaan is het de algemene overlevingskans. Het toont het aandeel passagiers dat de crash heeft overleefd. 41 procent van de passagiers heeft het overleefd.
- Dit knooppunt vraagt of het geslacht van de passagier mannelijk is. Zo ja, dan ga je naar het linker onderliggende knooppunt van de root (diepte 2). 63 procent zijn mannen met een overlevingskans van 21 procent.
- In het tweede knooppunt vraag je of de mannelijke passagier ouder is dan 3,5 jaar. Zo ja, dan is de overlevingskans 19 procent.
- Je blijft zo doorgaan om te begrijpen welke kenmerken van invloed zijn op de overlevingskans.
Merk op dat een van de vele kwaliteiten van Decision Trees is dat ze zeer weinig gegevensvoorbereiding vereisen. Ze hebben met name geen schaalvergroting of centrering van functies nodig.
Standaard gebruikt de functie rpart () de Gini- onzuiverheidsmaat om de noot te splitsen. Hoe hoger de Gini-coëfficiënt, hoe meer verschillende instanties binnen het knooppunt.
Stap 5) Maak een voorspelling
U kunt uw testdataset voorspellen. Om een voorspelling te doen, kunt u de predict () functie gebruiken. De basissyntaxis van voorspelling voor R-beslissingsboom is:
predict(fitted_model, df, type = 'class')arguments:- fitted_model: This is the object stored after model estimation.- df: Data frame used to make the prediction- type: Type of prediction- 'class': for classification- 'prob': to compute the probability of each class- 'vector': Predict the mean response at the node level
Met de testset wil je voorspellen welke passagiers het waarschijnlijkst zullen overleven na de botsing. Het betekent dat u onder die 209 passagiers weet welke het zal overleven of niet.
predict_unseen <-predict(fit, data_test, type = 'class')
Code Verklaring
- voorspellen (fit, data_test, type = 'class'): Voorspel de klasse (0/1) van de testset
Testen van de passagier die het niet heeft gehaald en degenen die het wel hebben gehaald.
table_mat <- table(data_test$survived, predict_unseen)table_mat
Code Verklaring
- table (data_test $ survived, predict_unseen): maak een tabel om te tellen hoeveel passagiers zijn geclassificeerd als overlevenden en overleden, vergelijk met de juiste beslissingsboomclassificatie in R
Uitgang:
## predict_unseen## No Yes## No 106 15## Yes 30 58
Het model voorspelde correct 106 dode passagiers, maar classificeerde 15 overlevenden als dood. Naar analogie classificeerde het model 30 passagiers ten onrechte als overlevenden terwijl ze dood bleken te zijn.
Stap 6) Meet de prestaties
U kunt een nauwkeurigheidsmeting voor een classificatietaak berekenen met de verwarringmatrix :
De verwarringmatrix is een betere keuze om de classificatieprestaties te evalueren. Het algemene idee is om te tellen hoe vaak True-instanties worden geclassificeerd als False.
Elke rij in een verwarringmatrix vertegenwoordigt een werkelijk doel, terwijl elke kolom een voorspeld doel vertegenwoordigt. De eerste rij van deze matrix houdt rekening met dode passagiers (de False-klasse): 106 werden correct geclassificeerd als dood ( True negatief ), terwijl de overgebleven ten onrechte werd geclassificeerd als een overlevende ( False-positief ). De tweede rij beschouwt de overlevenden, de positieve klasse was 58 ( True-positief ), terwijl het True-negatief 30 was.
U kunt de nauwkeurigheidstest uit de verwarringmatrix berekenen :
Het is de verhouding van echt positief en echt negatief over de som van de matrix. Met R kunt u als volgt coderen:
accuracy_Test <- sum(diag(table_mat)) / sum(table_mat)
Code Verklaring
- sum (diag (table_mat)): Som van de diagonaal
- sum (table_mat): Som van de matrix.
U kunt de nauwkeurigheid van de testset afdrukken:
print(paste('Accuracy for test', accuracy_Test))
Uitgang:
## [1] "Accuracy for test 0.784688995215311"
Je hebt een score van 78 procent voor de testset. U kunt dezelfde oefening repliceren met de trainingsgegevensset.
Stap 7) Stem de hyperparameters af
De beslissingsboom in R heeft verschillende parameters die aspecten van de pasvorm bepalen. In de rpart-beslissingsboombibliotheek kunt u de parameters besturen met de functie rpart.control (). In de volgende code introduceert u de parameters die u gaat stemmen. U kunt naar het vignet verwijzen voor andere parameters.
rpart.control(minsplit = 20, minbucket = round(minsplit/3), maxdepth = 30)Arguments:-minsplit: Set the minimum number of observations in the node before the algorithm perform a split-minbucket: Set the minimum number of observations in the final note i.e. the leaf-maxdepth: Set the maximum depth of any node of the final tree. The root node is treated a depth 0
We gaan als volgt te werk:
- Construct functie om nauwkeurigheid te retourneren
- Stem de maximale diepte af
- Stem het minimumaantal monsters af dat een knooppunt moet hebben voordat het kan worden gesplitst
- Stem het minimumaantal monsters af dat een bladknooppunt moet hebben
U kunt een functie schrijven om de nauwkeurigheid weer te geven. U verpakt gewoon de code die u eerder gebruikte:
- voorspellen: predict_unseen <- voorspellen (fit, data_test, type = 'class')
- Produce table: table_mat <- table (data_test $ survived, predict_unseen)
- Bereken nauwkeurigheid: nauwkeurigheid_Test <- som (diag (tabel_mat)) / som (tabel_mat)
accuracy_tune <- function(fit) {predict_unseen <- predict(fit, data_test, type = 'class')table_mat <- table(data_test$survived, predict_unseen)accuracy_Test <- sum(diag(table_mat)) / sum(table_mat)accuracy_Test}
U kunt proberen de parameters af te stemmen en kijken of u het model kunt verbeteren ten opzichte van de standaardwaarde. Ter herinnering: u moet een nauwkeurigheid krijgen die hoger is dan 0,78
control <- rpart.control(minsplit = 4,minbucket = round(5 / 3),maxdepth = 3,cp = 0)tune_fit <- rpart(survived~., data = data_train, method = 'class', control = control)accuracy_tune(tune_fit)
Uitgang:
## [1] 0.7990431
Met de volgende parameter:
minsplit = 4minbucket= round(5/3)maxdepth = 3cp=0
Je krijgt betere prestaties dan het vorige model. Gefeliciteerd!
Overzicht
We kunnen de functies samenvatten om een beslissingsboomalgoritme te trainen in R
Bibliotheek |
Objectief |
functie |
klasse |
parameters |
details |
---|---|---|---|---|---|
rpart |
Treinclassificatieboom in R |
rpart () |
klasse |
formule, df, methode | |
rpart |
Train de regressieboom |
rpart () |
anova |
formule, df, methode | |
rpart |
Zet de bomen uit |
rpart.plot () |
getailleerd model | ||
baseren |
voorspellen |
voorspellen() |
klasse |
getailleerd model, type | |
baseren |
voorspellen |
voorspellen() |
waarschijnlijk |
getailleerd model, type | |
baseren |
voorspellen |
voorspellen() |
vector |
getailleerd model, type | |
rpart |
Controleparameters |
rpart.control () |
min. split |
Stel het minimum aantal waarnemingen in het knooppunt in voordat het algoritme een splitsing uitvoert |
|
minbucket |
Stel het minimum aantal waarnemingen in de laatste noot, dwz het blad |
||||
maximale diepte |
Stel de maximale diepte in van elk knooppunt van de uiteindelijke boom. De root node wordt behandeld op een diepte van 0 |
||||
rpart |
Train model met besturingsparameter |
rpart () |
formule, df, methode, controle |
Opmerking: train het model op basis van trainingsgegevens en test de prestaties op een ongeziene dataset, dat wil zeggen een testset.