Vad är beslutsträd?
Beslutsträd är mångsidig maskininlärningsalgoritm som kan utföra både klassificerings- och regressionsuppgifter. De är mycket kraftfulla algoritmer som kan passa komplexa datamängder. Dessutom är beslutsträd grundläggande komponenter i slumpmässiga skogar, som är bland de mest potenta maskininlärningsalgoritmerna som finns idag.
Utbildning och visualisering av beslutsträd
För att bygga ditt första beslutsträd i R-exemplet fortsätter vi enligt följande i denna beslutsträdhandledning:
- Steg 1: Importera data
- Steg 2: Rengör datasetet
- Steg 3: Skapa tåg / testuppsättning
- Steg 4: Bygg modellen
- Steg 5: Gör förutsägelser
- Steg 6: Mät prestanda
- Steg 7: Ställ in hyperparametrarna
Steg 1) Importera data
Om du är nyfiken på ödet för titanic kan du titta på den här videon på Youtube. Syftet med denna dataset är att förutsäga vilka människor som är mer benägna att överleva efter kollisionen med isberget. Datauppsättningen innehåller 13 variabler och 1309 observationer. Dataset ordnas av variabeln X.
set.seed(678)path <- 'https://raw.githubusercontent.com/guru99-edu/R-Programming/master/titanic_data.csv'titanic <-read.csv(path)head(titanic)
Produktion:
## X pclass survived name sex## 1 1 1 1 Allen, Miss. Elisabeth Walton female## 2 2 1 1 Allison, Master. Hudson Trevor male## 3 3 1 0 Allison, Miss. Helen Loraine female## 4 4 1 0 Allison, Mr. Hudson Joshua Creighton male## 5 5 1 0 Allison, Mrs. Hudson J C (Bessie Waldo Daniels) female## 6 6 1 1 Anderson, Mr. Harry male## age sibsp parch ticket fare cabin embarked## 1 29.0000 0 0 24160 211.3375 B5 S## 2 0.9167 1 2 113781 151.5500 C22 C26 S## 3 2.0000 1 2 113781 151.5500 C22 C26 S## 4 30.0000 1 2 113781 151.5500 C22 C26 S## 5 25.0000 1 2 113781 151.5500 C22 C26 S## 6 48.0000 0 0 19952 26.5500 E12 S## home.dest## 1 St Louis, MO## 2 Montreal, PQ / Chesterville, ON## 3 Montreal, PQ / Chesterville, ON## 4 Montreal, PQ / Chesterville, ON## 5 Montreal, PQ / Chesterville, ON## 6 New York, NY
tail(titanic)
Produktion:
## X pclass survived name sex age sibsp## 1304 1304 3 0 Yousseff, Mr. Gerious male NA 0## 1305 1305 3 0 Zabour, Miss. Hileni female 14.5 1## 1306 1306 3 0 Zabour, Miss. Thamine female NA 1## 1307 1307 3 0 Zakarian, Mr. Mapriededer male 26.5 0## 1308 1308 3 0 Zakarian, Mr. Ortin male 27.0 0## 1309 1309 3 0 Zimmerman, Mr. Leo male 29.0 0## parch ticket fare cabin embarked home.dest## 1304 0 2627 14.4583 C## 1305 0 2665 14.4542 C## 1306 0 2665 14.4542 C## 1307 0 2656 7.2250 C## 1308 0 2670 7.2250 C## 1309 0 315082 7.8750 S
Från huvud- och svansutgången kan du märka att data inte blandas. Det här är en stor fråga! När du delar dina data mellan en tåguppsättning och testuppsättning, väljer du bara passagerare från klass 1 och 2 (Ingen passagerare från klass 3 är i de översta 80 procenten av observationerna), vilket innebär att algoritmen aldrig kommer att se egenskaper hos passagerare i klass 3. Detta misstag leder till dålig förutsägelse.
För att lösa problemet kan du använda funktionsexemplet ().
shuffle_index <- sample(1:nrow(titanic))head(shuffle_index)
Beslutsträd R-kod Förklaring
- sample (1: nrow (titanic)): Skapa en slumpmässig lista med index från 1 till 1309 (dvs. det maximala antalet rader).
Produktion:
## [1] 288 874 1078 633 887 992
Du kommer att använda detta index för att blanda den titaniska datasetet.
titanic <- titanic[shuffle_index, ]head(titanic)
Produktion:
## X pclass survived## 288 288 1 0## 874 874 3 0## 1078 1078 3 1## 633 633 3 0## 887 887 3 1## 992 992 3 1## name sex age## 288 Sutton, Mr. Frederick male 61## 874 Humblen, Mr. Adolf Mathias Nicolai Olsen male 42## 1078 O'Driscoll, Miss. Bridget female NA## 633 Andersson, Mrs. Anders Johan (Alfrida Konstantia Brogren) female 39## 887 Jermyn, Miss. Annie female NA## 992 Mamee, Mr. Hanna male NA## sibsp parch ticket fare cabin embarked home.dest## 288 0 0 36963 32.3208 D50 S Haddenfield, NJ## 874 0 0 348121 7.6500 F G63 S## 1078 0 0 14311 7.7500 Q## 633 1 5 347082 31.2750 S Sweden Winnipeg, MN## 887 0 0 14313 7.7500 Q## 992 0 0 2677 7.2292 C
Steg 2) Rengör datamängden
Datastrukturen visar att vissa variabler har NA: er. Upprensning av data görs enligt följande
- Släpp variabler hem.dest, stuga, namn, X och biljett
- Skapa faktorvariabler för pclass och överlevde
- Släpp NA
library(dplyr)# Drop variablesclean_titanic <- titanic % > %select(-c(home.dest, cabin, name, X, ticket)) % > %#Convert to factor levelmutate(pclass = factor(pclass, levels = c(1, 2, 3), labels = c('Upper', 'Middle', 'Lower')),survived = factor(survived, levels = c(0, 1), labels = c('No', 'Yes'))) % > %na.omit()glimpse(clean_titanic)
Kodförklaring
- välj (-c (hemdest, stuga, namn, X, biljett)): Släpp onödiga variabler
- pclass = faktor (pclass, levels = c (1,2,3), labels = c ('Upper', 'Middle', 'Lower')): Lägg till etikett till variabeln pclass. 1 blir övre, 2 blir mitten och 3 blir lägre
- faktor (överlevd, nivåer = c (0,1), etiketter = c ('Nej', 'Ja')): Lägg till etikett till variabeln överlevd. 1 blir nej och 2 blir ja
- na.omit (): Ta bort NA-observationerna
Produktion:
## Observations: 1,045## Variables: 8## $ pclassUpper, Lower, Lower, Upper, Middle, Upper, Middle, U… ## $ survived No, No, No, Yes, No, Yes, Yes, No, No, No, No, No, Y… ## $ sex male, male, female, female, male, male, female, male… ## $ age 61.0, 42.0, 39.0, 49.0, 29.0, 37.0, 20.0, 54.0, 2.0,… ## $ sibsp 0, 0, 1, 0, 0, 1, 0, 0, 4, 0, 0, 1, 1, 0, 0, 0, 1, 1,… ## $ parch 0, 0, 5, 0, 0, 1, 0, 1, 1, 0, 0, 1, 1, 0, 2, 0, 4, 0,… ## $ fare 32.3208, 7.6500, 31.2750, 25.9292, 10.5000, 52.5542,… ## $ embarked S, S, S, S, S, S, S, S, S, C, S, S, S, Q, C, S, S, C…
Steg 3) Skapa tåg / testuppsättning
Innan du tränar din modell måste du utföra två steg:
- Skapa ett tåg och testuppsättning: Du tränar modellen på tågset och testar förutsägelsen på testuppsättningen (dvs. osedda data)
- Installera rpart.plot från konsolen
Vanlig praxis är att dela data 80/20, 80 procent av data tjänar till att träna modellen och 20 procent för att göra förutsägelser. Du måste skapa två separata dataramar. Du vill inte röra testuppsättningen förrän du är klar med att bygga din modell. Du kan skapa ett funktionsnamn create_train_test () som tar tre argument.
create_train_test(df, size = 0.8, train = TRUE)arguments:-df: Dataset used to train the model.-size: Size of the split. By default, 0.8. Numerical value-train: If set to `TRUE`, the function creates the train set, otherwise the test set. Default value sets to `TRUE`. Boolean value.You need to add a Boolean parameter because R does not allow to return two data frames simultaneously.
create_train_test <- function(data, size = 0.8, train = TRUE) {n_row = nrow(data)total_row = size * n_rowtrain_sample < - 1: total_rowif (train == TRUE) {return (data[train_sample, ])} else {return (data[-train_sample, ])}}
Kodförklaring
- funktion (data, storlek = 0,8, tåg = SANT): Lägg till argumenten i funktionen
- n_row = nrow (data): Räkna antalet rader i datasetet
- total_row = storlek * n_row: Returnera den nionde raden för att konstruera tågset
- train_sample <- 1: total_row: Välj den första raden till den nionde raden
- if (train == TRUE) {} else {}: Om villkoret är sant, returnera tågset, annars testuppsättningen.
Du kan testa din funktion och kontrollera dimensionen.
data_train <- create_train_test(clean_titanic, 0.8, train = TRUE)data_test <- create_train_test(clean_titanic, 0.8, train = FALSE)dim(data_train)
Produktion:
## [1] 836 8
dim(data_test)
Produktion:
## [1] 209 8
Tågdatasetet har 1046 rader medan testdataset har 262 rader.
Du använder funktionen prop.table () i kombination med tabell () för att verifiera om randomiseringsprocessen är korrekt.
prop.table(table(data_train$survived))
Produktion:
#### No Yes## 0.5944976 0.4055024
prop.table(table(data_test$survived))
Produktion:
#### No Yes## 0.5789474 0.4210526
I båda datauppsättningarna är mängden överlevande densamma, cirka 40 procent.
Installera rpart.plot
rpart.plot är inte tillgängligt från conda-bibliotek. Du kan installera det från konsolen:
install.packages("rpart.plot")
Steg 4) Bygg modellen
Du är redo att bygga modellen. Syntaksen för Rpart-beslutsträdfunktionen är:
rpart(formula, data=, method='')arguments:- formula: The function to predict- data: Specifies the data frame- method:- "class" for a classification tree- "anova" for a regression tree
Du använder klassmetoden eftersom du förutsäger en klass.
library(rpart)library(rpart.plot)fit <- rpart(survived~., data = data_train, method = 'class')rpart.plot(fit, extra = 106
Kodförklaring
- rpart (): Funktion för att passa modellen. Argumenten är:
- överlevde ~: Formel för beslutsträd
- data = data_train: Dataset
- method = 'class': Anpassa en binär modell
- rpart.plot (passform, extra = 106): Plotta trädet. De extra funktionerna är inställda på 101 för att visa sannolikheten för andra klass (användbart för binära svar). Du kan hänvisa till vinjetten för mer information om de andra valen.
Produktion:
Du börjar vid rotnoden (djup 0 över 3, toppen av diagrammet):
- Överst är det den totala sannolikheten för överlevnad. Den visar andelen passagerare som överlevde kraschen. 41 procent av passagerarna överlevde.
- Denna nod frågar om passagerarens kön är manligt. Om ja, går du ner till rotens vänstra barnnod (djup 2). 63 procent är män med en överlevnadssannolikhet på 21 procent.
- I den andra noden frågar du om den manliga passageraren är över 3,5 år. Om ja, är chansen att överleva 19 procent.
- Du fortsätter så för att förstå vilka funktioner som påverkar sannolikheten för överlevnad.
Observera att en av de många egenskaperna hos beslutsträd är att de kräver mycket lite dataförberedelse. I synnerhet kräver de inte funktionsskalning eller centrering.
Som standard använder rpart () -funktionen Gini- orenhetsmåttet för att dela anteckningen. Ju högre Gini-koefficienten är, desto mer olika fall inom noden.
Steg 5) Gör en förutsägelse
Du kan förutsäga din testdataset. För att göra en förutsägelse kan du använda funktionen förutsäga (). Den grundläggande syntaxen för förutsägelse för R-beslutsträdet är:
predict(fitted_model, df, type = 'class')arguments:- fitted_model: This is the object stored after model estimation.- df: Data frame used to make the prediction- type: Type of prediction- 'class': for classification- 'prob': to compute the probability of each class- 'vector': Predict the mean response at the node level
Du vill förutsäga vilka passagerare som är mer benägna att överleva efter kollisionen från testuppsättningen. Det betyder att du kommer att veta bland de 209 passagerarna, vilken som kommer att överleva eller inte.
predict_unseen <-predict(fit, data_test, type = 'class')
Kodförklaring
- förutsäga (passa, data_test, typ = 'klass'): Förutse klassen (0/1) i testuppsättningen
Testa passageraren som inte lyckades och de som gjorde det.
table_mat <- table(data_test$survived, predict_unseen)table_mat
Kodförklaring
- tabell (data_test $ survived, predict_unseen): Skapa en tabell för att räkna hur många passagerare som klassificeras som överlevande och gått bort jämfört med rätt beslutsträdklassificering i R
Produktion:
## predict_unseen## No Yes## No 106 15## Yes 30 58
Modellen förutspådde korrekt 106 döda passagerare men klassificerade 15 överlevande som döda. I analogi klassificerade modellen 30 passagerare felaktigt som överlevande medan de visade sig vara döda.
Steg 6) Mät prestanda
Du kan beräkna ett noggrannhetsmått för klassificeringsuppgift med förvirringsmatrisen :
Den förvirring matris är ett bättre val för att utvärdera klassificeringsprestanda. Den allmänna idén är att räkna antalet gånger Sanna instanser klassificeras som falska.
Varje rad i en förvirringsmatris representerar ett verkligt mål, medan varje kolumn representerar ett förutsagt mål. Den första raden i denna matris tar hänsyn till döda passagerare (falskklassen): 106 klassificerades korrekt som döda ( sant negativt ), medan den återstående felaktigt klassificerades som en överlevande ( falskt positivt ). Den andra raden tar hänsyn till de överlevande, den positiva klassen var 58 ( sant positiv ), medan den sanna negativen var 30.
Du kan beräkna noggrannhetstestet från förvirringsmatrisen:
Det är andelen sant positivt och sant negativt över summan av matrisen. Med R kan du koda enligt följande:
accuracy_Test <- sum(diag(table_mat)) / sum(table_mat)
Kodförklaring
- sum (diag (table_mat)): summan av diagonalen
- sum (table_mat): summan av matrisen.
Du kan skriva ut noggrannheten i testuppsättningen:
print(paste('Accuracy for test', accuracy_Test))
Produktion:
## [1] "Accuracy for test 0.784688995215311"
Du har en poäng på 78 procent för testuppsättningen. Du kan replikera samma övning med träningsdatasetet.
Steg 7) Ställ in hyperparametrarna
Beslutsträd i R har olika parametrar som styr aspekter av passformen. I rpart-beslutsträdbiblioteket kan du styra parametrarna med funktionen rpart.control (). I följande kod introducerar du de parametrar som du kommer att ställa in. Du kan hänvisa till vinjetten för andra parametrar.
rpart.control(minsplit = 20, minbucket = round(minsplit/3), maxdepth = 30)Arguments:-minsplit: Set the minimum number of observations in the node before the algorithm perform a split-minbucket: Set the minimum number of observations in the final note i.e. the leaf-maxdepth: Set the maximum depth of any node of the final tree. The root node is treated a depth 0
Vi fortsätter enligt följande:
- Konstruera funktion för att återge noggrannhet
- Ställ in maximalt djup
- Ställ in det minsta antalet prov som en nod måste ha innan den kan delas
- Ställ in det minsta antalet prov som en bladnod måste ha
Du kan skriva en funktion för att visa noggrannheten. Du slår helt enkelt in koden du använde tidigare:
- förutsäga: förutsäga osynligt <- förutsäga (passa, data_test, typ = 'klass')
- Producera tabell: table_mat <- tabell (data_test $ survived, predict_unseen)
- Beräkningsnoggrannhet: precision_Test <- sum (diag (table_mat)) / sum (table_mat)
accuracy_tune <- function(fit) {predict_unseen <- predict(fit, data_test, type = 'class')table_mat <- table(data_test$survived, predict_unseen)accuracy_Test <- sum(diag(table_mat)) / sum(table_mat)accuracy_Test}
Du kan försöka ställa in parametrarna och se om du kan förbättra modellen jämfört med standardvärdet. Som en påminnelse måste du få en noggrannhet högre än 0,78
control <- rpart.control(minsplit = 4,minbucket = round(5 / 3),maxdepth = 3,cp = 0)tune_fit <- rpart(survived~., data = data_train, method = 'class', control = control)accuracy_tune(tune_fit)
Produktion:
## [1] 0.7990431
Med följande parameter:
minsplit = 4minbucket= round(5/3)maxdepth = 3cp=0
Du får högre prestanda än den tidigare modellen. Grattis!
Sammanfattning
Vi kan sammanfatta funktionerna för att träna en beslutsträdalgoritm i R
Bibliotek |
Mål |
fungera |
klass |
parametrar |
detaljer |
---|---|---|---|---|---|
rpart |
Tågklassificeringsträd i R |
rpart () |
klass |
formel, df, metod | |
rpart |
Trä regressionsträd |
rpart () |
anova |
formel, df, metod | |
rpart |
Plotta träden |
rpart.plot () |
monterad modell | ||
bas |
förutspå |
förutspå() |
klass |
monterad modell, typ | |
bas |
förutspå |
förutspå() |
prob |
monterad modell, typ | |
bas |
förutspå |
förutspå() |
vektor |
monterad modell, typ | |
rpart |
Kontrollparametrar |
rpart.control () |
minsplit |
Ställ in minsta antal observationer i noden innan algoritmen utför en split |
|
minbucket |
Ställ in minsta antal observationer i den sista anteckningen, dvs. bladet |
||||
Max djup |
Ställ in maximalt djup för vilken nod som helst i det sista trädet. Rotnoden behandlas ett djup 0 |
||||
rpart |
Tågmodell med kontrollparameter |
rpart () |
formel, df, metod, kontroll |
Obs: Träna modellen på träningsdata och testa prestanda i en osynlig dataset, dvs. testuppsättning.