Wat is logistieke regressie?
Logistische regressie wordt gebruikt om een klasse, dwz een waarschijnlijkheid, te voorspellen. Logistische regressie kan een binaire uitkomst nauwkeurig voorspellen.
Stel je voor dat je op basis van veel attributen wilt voorspellen of een lening wordt geweigerd / geaccepteerd. De logistische regressie heeft de vorm 0/1. y = 0 als een lening wordt afgewezen, y = 1 als deze wordt geaccepteerd.
Een logistisch regressiemodel verschilt op twee manieren van een lineair regressiemodel.
- Allereerst accepteert de logistische regressie alleen dichotome (binaire) invoer als een afhankelijke variabele (dwz een vector van 0 en 1).
- Ten tweede wordt de uitkomst gemeten door de volgende probabilistische linkfunctie genaamd sigmoid vanwege zijn S-vorm:
De output van de functie is altijd tussen 0 en 1. Zie afbeelding hieronder
De sigmoïde functie retourneert waarden van 0 tot 1. Voor de classificatietaak hebben we een discrete uitvoer van 0 of 1 nodig.
Om een continue stroom om te zetten in een discrete waarde, kunnen we een beslissingsgrens instellen op 0,5. Alle waarden boven deze drempel worden geclassificeerd als 1
In deze tutorial leer je
- Wat is logistieke regressie?
- Een gegeneraliseerd voeringmodel (GLM) maken
- Stap 1) Controleer continue variabelen
- Stap 2) Controleer factorvariabelen
- Stap 3) Feature engineering
- Stap 4) Samenvattende statistiek
- Stap 5) Train / testset
- Stap 6) Bouw het model
- Stap 7) Beoordeel de prestaties van het model
Een gegeneraliseerd voeringmodel (GLM) maken
Laten we de gegevensset voor volwassenen gebruiken om logistische regressie te illustreren. De "volwassene" is een geweldige dataset voor de classificatietaak. Het doel is om te voorspellen of het jaarinkomen in dollar van een individu hoger zal zijn dan 50.000. De dataset bevat 46.033 observaties en tien features:
- leeftijd: leeftijd van het individu. Numeriek
- opleiding: opleidingsniveau van het individu. Factor.
- marital.status: burgerlijke staat van het individu. Factor dwz nooit getrouwd, gehuwd-burger-echtgenoot, ...
- geslacht: geslacht van het individu. Factor, dwz mannelijk of vrouwelijk
- inkomen: doelvariabele. Inkomen boven of onder 50K. Factor dwz> 50K, <= 50K
onder anderen
library(dplyr)data_adult <-read.csv("https://raw.githubusercontent.com/guru99-edu/R-Programming/master/adult.csv")glimpse(data_adult)
Uitgang:
Observations: 48,842Variables: 10$ x1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15,… $ age 25, 38, 28, 44, 18, 34, 29, 63, 24, 55, 65, 36, 26… $ workclass Private, Private, Local-gov, Private, ?, Private,… $ education 11th, HS-grad, Assoc-acdm, Some-college, Some-col… $ educational.num 7, 9, 12, 10, 10, 6, 9, 15, 10, 4, 9, 13, 9, 9, 9,… $ marital.status Never-married, Married-civ-spouse, Married-civ-sp… $ race Black, White, White, Black, White, White, Black,… $ gender Male, Male, Male, Male, Female, Male, Male, Male,… $ hours.per.week 40, 50, 40, 40, 30, 30, 40, 32, 40, 10, 40, 40, 39… $ income <=50K, <=50K, >50K, >50K, <=50K, <=50K, <=50K, >5…
We gaan als volgt te werk:
- Stap 1: Controleer continue variabelen
- Stap 2: Controleer factorvariabelen
- Stap 3: Feature engineering
- Stap 4: Samenvattende statistiek
- Stap 5: Train / testset
- Stap 6: Bouw het model
- Stap 7: Beoordeel de prestaties van het model
- stap 8: Verbeter het model
Het is jouw taak om te voorspellen welke persoon een omzet van meer dan 50K zal hebben.
In deze tutorial wordt elke stap gedetailleerd om een analyse uit te voeren op een echte dataset.
Stap 1) Controleer continue variabelen
In de eerste stap zie je de verdeling van de continue variabelen.
continuous <-select_if(data_adult, is.numeric)summary(continuous)
Code Verklaring
- continu <- select_if (data_adult, is.numeric): Gebruik de functie select_if () uit de dplyr-bibliotheek om alleen de numerieke kolommen te selecteren
- samenvatting (doorlopend): Druk de samenvattende statistiek af
Uitgang:
## X age educational.num hours.per.week## Min. : 1 Min. :17.00 Min. : 1.00 Min. : 1.00## 1st Qu.:11509 1st Qu.:28.00 1st Qu.: 9.00 1st Qu.:40.00## Median :23017 Median :37.00 Median :10.00 Median :40.00## Mean :23017 Mean :38.56 Mean :10.13 Mean :40.95## 3rd Qu.:34525 3rd Qu.:47.00 3rd Qu.:13.00 3rd Qu.:45.00## Max. :46033 Max. :90.00 Max. :16.00 Max. :99.00
Uit de bovenstaande tabel kun je zien dat de gegevens totaal verschillende schalen en uren hebben. Per. Weken heeft grote uitschieters (. Kijk naar het laatste kwartiel en de maximale waarde).
U kunt het afhandelen door twee stappen te volgen:
- 1: Zet de urenverdeling per week uit
- 2: Standaardiseer de continue variabelen
- Teken de verdeling
Laten we eens kijken naar de verdeling van uren.per.week
# Histogram with kernel density curvelibrary(ggplot2)ggplot(continuous, aes(x = hours.per.week)) +geom_density(alpha = .2, fill = "#FF6666")
Uitgang:
De variabele heeft veel uitschieters en een niet goed gedefinieerde distributie. U kunt dit probleem gedeeltelijk aanpakken door de bovenste 0,01 procent van de uren per week te schrappen.
Basissyntaxis van kwantiel:
quantile(variable, percentile)arguments:-variable: Select the variable in the data frame to compute the percentile-percentile: Can be a single value between 0 and 1 or multiple value. If multiple, use this format: `c(A,B,C,… )- `A`,`B`,`C` and `… ` are all integer from 0 to 1.
We berekenen het bovenste 2 procent percentiel
top_one_percent <- quantile(data_adult$hours.per.week, .99)top_one_percent
Code Verklaring
- kwantiel (data_adult $ uren.per.week, .99): Bereken de waarde van de 99 procent van de werktijd
Uitgang:
## 99%## 80
98 procent van de bevolking werkt minder dan 80 uur per week.
U kunt de waarnemingen boven deze drempel laten vallen. U gebruikt het filter uit de dplyr-bibliotheek.
data_adult_drop <-data_adult %>%filter(hours.per.weekUitgang:
## [1] 45537 10
- Standaardiseer de continue variabelen
U kunt elke kolom standaardiseren om de prestaties te verbeteren, omdat uw gegevens niet dezelfde schaal hebben. U kunt de functie mutate_if uit de dplyr-bibliotheek gebruiken. De basissyntaxis is:
mutate_if(df, condition, funs(function))arguments:-`df`: Data frame used to compute the function- `condition`: Statement used. Do not use parenthesis- funs(function): Return the function to apply. Do not use parenthesis for the functionU kunt de numerieke kolommen als volgt standaardiseren:
data_adult_rescale <- data_adult_drop % > %mutate_if(is.numeric, funs(as.numeric(scale(.))))head(data_adult_rescale)Code Verklaring
- mutate_if (is.numeric, funs (scale)): de voorwaarde is alleen een numerieke kolom en de functie is schaal
Uitgang:
## X age workclass education educational.num## 1 -1.732680 -1.02325949 Private 11th -1.22106443## 2 -1.732605 -0.03969284 Private HS-grad -0.43998868## 3 -1.732530 -0.79628257 Local-gov Assoc-acdm 0.73162494## 4 -1.732455 0.41426100 Private Some-college -0.04945081## 5 -1.732379 -0.34232873 Private 10th -1.61160231## 6 -1.732304 1.85178149 Self-emp-not-inc Prof-school 1.90323857## marital.status race gender hours.per.week income## 1 Never-married Black Male -0.03995944 <=50K## 2 Married-civ-spouse White Male 0.86863037 <=50K## 3 Married-civ-spouse White Male -0.03995944 >50K## 4 Married-civ-spouse Black Male -0.03995944 >50K## 5 Never-married White Male -0.94854924 <=50K## 6 Married-civ-spouse White Male -0.76683128 >50KStap 2) Controleer factorvariabelen
Deze stap heeft twee doelstellingen:
- Controleer het niveau in elke categorische kolom
- Definieer nieuwe niveaus
We verdelen deze stap in drie delen:
- Selecteer de categorische kolommen
- Sla het staafdiagram van elke kolom op in een lijst
- Druk de grafieken af
We kunnen de factorkolommen selecteren met de onderstaande code:
# Select categorical columnfactor <- data.frame(select_if(data_adult_rescale, is.factor))ncol(factor)Code Verklaring
- data.frame (select_if (data_adult, is.factor)): We slaan de factorkolommen in factor op in een dataframetype. De bibliotheek ggplot2 vereist een dataframe-object.
Uitgang:
## [1] 6De dataset bevat 6 categorische variabelen
De tweede stap is meer vaardig. U wilt voor elke kolom in de dataframefactor een staafdiagram plotten. Het is handiger om het proces te automatiseren, vooral als er veel kolommen zijn.
library(ggplot2)# Create graph for each columngraph <- lapply(names(factor),function(x)ggplot(factor, aes(get(x))) +geom_bar() +theme(axis.text.x = element_text(angle = 90)))Code Verklaring
- lapply (): Gebruik de functie lapply () om een functie door te geven in alle kolommen van de dataset. Je slaat de output op in een lijst
- functie (x): De functie wordt voor elke x verwerkt. Hier zijn x de kolommen
- ggplot (factor, aes (get (x))) + geom_bar () + theme (axis.text.x = element_text (angle = 90)): maak een staafdiagram voor elk x-element. Let op, om x als een kolom te retourneren, moet u deze opnemen in de get ()
De laatste stap is relatief eenvoudig. U wilt de 6 grafieken afdrukken.
# Print the graphgraphUitgang:
## [[1]]## ## [[2]]## ## [[3]]## ## [[4]]## ## [[5]]## ## [[6]]Opmerking: gebruik de knop Volgende om naar de volgende grafiek te navigeren
Stap 3) Feature engineering
Onderwijs herschikken
Uit bovenstaande grafiek kun je zien dat het variabel onderwijs 16 niveaus heeft. Dit is aanzienlijk, en sommige niveaus hebben een relatief laag aantal waarnemingen. Als u de hoeveelheid informatie die u uit deze variabele kunt halen wilt verbeteren, kunt u deze naar een hoger niveau herschikken. Je creëert namelijk grotere groepen met een vergelijkbaar opleidingsniveau. Zo zal een laag opleidingsniveau worden omgezet in uitval. Hogere opleidingsniveaus worden gewijzigd in master.
Hier is het detail:
Oud niveau
Nieuw level
Peuter
afvaller
10e
Afvaller
11e
Afvaller
12e
Afvaller
1e-4e
Afvaller
5e-6e
Afvaller
7e-8e
Afvaller
9e
Afvaller
HS-Grad
HighGrad
Een of andere universiteit
Gemeenschap
Assoc-acdm
Gemeenschap
Assoc-voc
Gemeenschap
Bachelors
Bachelors
Meesters
Meesters
Prof-school
Meesters
Doctoraat
PhD
recast_data <- data_adult_rescale % > %select(-X) % > %mutate(education = factor(ifelse(education == "Preschool" | education == "10th" | education == "11th" | education == "12th" | education == "1st-4th" | education == "5th-6th" | education == "7th-8th" | education == "9th", "dropout", ifelse(education == "HS-grad", "HighGrad", ifelse(education == "Some-college" | education == "Assoc-acdm" | education == "Assoc-voc", "Community",ifelse(education == "Bachelors", "Bachelors",ifelse(education == "Masters" | education == "Prof-school", "Master", "PhD")))))))Code Verklaring
- We gebruiken het werkwoord muteren uit de dplyr-bibliotheek. We veranderen de waarden van onderwijs met de stelling ifelse
In onderstaande tabel maak je een samenvattende statistiek om te zien hoeveel jaar opleiding (z-waarde) er gemiddeld nodig is om de bachelor, master of PhD te behalen.
recast_data % > %group_by(education) % > %summarize(average_educ_year = mean(educational.num),count = n()) % > %arrange(average_educ_year)Uitgang:
## # A tibble: 6 x 3## education average_educ_year count#### 1 dropout -1.76147258 5712## 2 HighGrad -0.43998868 14803## 3 Community 0.09561361 13407## 4 Bachelors 1.12216282 7720## 5 Master 1.60337381 3338## 6 PhD 2.29377644 557 Herschikking burgerlijke staat
Het is ook mogelijk om lagere niveaus voor de burgerlijke staat te creëren. In de volgende code verander je het niveau als volgt:
Oud niveau
Nieuw level
Nooit getrouwd
Niet getrouwd
Gehuwd-echtgenoot-afwezig
Niet getrouwd
Getrouwd-AF-echtgenoot
Getrouwd
Gehuwd-burger-echtgenoot
Uit elkaar gehaald
Uit elkaar gehaald
Gescheiden
Weduwen
Weduwe
# Change level marryrecast_data <- recast_data % > %mutate(marital.status = factor(ifelse(marital.status == "Never-married" | marital.status == "Married-spouse-absent", "Not_married", ifelse(marital.status == "Married-AF-spouse" | marital.status == "Married-civ-spouse", "Married", ifelse(marital.status == "Separated" | marital.status == "Divorced", "Separated", "Widow")))))U kunt het aantal individuen binnen elke groep controleren.table(recast_data$marital.status)Uitgang:
## ## Married Not_married Separated Widow## 21165 15359 7727 1286Stap 4) Samenvattende statistiek
Het is tijd om wat statistieken over onze doelvariabelen te bekijken. In de onderstaande grafiek tel je het percentage individuen dat meer dan 50.000 verdient, gegeven hun geslacht.
# Plot gender incomeggplot(recast_data, aes(x = gender, fill = income)) +geom_bar(position = "fill") +theme_classic()Uitgang:
Controleer vervolgens of de oorsprong van het individu van invloed is op hun verdiensten.
# Plot origin incomeggplot(recast_data, aes(x = race, fill = income)) +geom_bar(position = "fill") +theme_classic() +theme(axis.text.x = element_text(angle = 90))Uitgang:
Het aantal gewerkte uren naar geslacht.
# box plot gender working timeggplot(recast_data, aes(x = gender, y = hours.per.week)) +geom_boxplot() +stat_summary(fun.y = mean,geom = "point",size = 3,color = "steelblue") +theme_classic()Uitgang:
De boxplot bevestigt dat de verdeling van de werktijd bij verschillende groepen past. In de boxplot hebben beide geslachten geen homogene waarnemingen.
U kunt de dichtheid van de wekelijkse werktijd per opleiding bekijken. De distributies hebben veel verschillende keuzes. Het kan waarschijnlijk worden verklaard door het type contract in de VS.
# Plot distribution working time by educationggplot(recast_data, aes(x = hours.per.week)) +geom_density(aes(color = education), alpha = 0.5) +theme_classic()Code Verklaring
- ggplot (recast_data, aes (x = uren.per.week)): Een dichtheidsplot vereist slechts één variabele
- geom_density (aes (color = education), alpha = 0.5): het geometrische object om de dichtheid te regelen
Uitgang:
Om uw mening te bevestigen, kunt u een eenmalige ANOVA-test uitvoeren:
anova <- aov(hours.per.week~education, recast_data)summary(anova)Uitgang:
## Df Sum Sq Mean Sq F value Pr(>F)## education 5 1552 310.31 321.2 <2e-16 ***## Residuals 45531 43984 0.97## ---## Signif. codes: 0 '***' 0.001 '**' 0.01 '*' 0.05 '.' 0.1 ' ' 1De ANOVA-test bevestigt het verschil in gemiddelde tussen groepen.
Niet-lineariteit
Voordat u het model uitvoert, kunt u zien of het aantal gewerkte uren gerelateerd is aan de leeftijd.
library(ggplot2)ggplot(recast_data, aes(x = age, y = hours.per.week)) +geom_point(aes(color = income),size = 0.5) +stat_smooth(method = 'lm',formula = y~poly(x, 2),se = TRUE,aes(color = income)) +theme_classic()Code Verklaring
- ggplot (recast_data, aes (x = leeftijd, y = uren.per.week)): Stel de esthetiek van de grafiek in
- geom_point (aes (kleur = inkomen), grootte = 0,5): construeer de puntplot
- stat_smooth (): Voeg de trendlijn toe met de volgende argumenten:
- method = 'lm': Plot de aangepaste waarde als de lineaire regressie
- formule = y ~ poly (x, 2): Fit een polynoomregressie
- se = TRUE: voeg de standaardfout toe
- aes (kleur = inkomen): Breek het model naar inkomen
Uitgang:
Kortom, je kunt interactietermen in het model testen om het niet-lineariteitseffect tussen de wekelijkse werktijd en andere kenmerken op te pikken. Het is belangrijk om te bepalen onder welke omstandigheden de werktijd verschilt.
Correlatie
De volgende controle is om de correlatie tussen de variabelen te visualiseren. U converteert het type factor niveau naar numeriek, zodat u een heatmap kunt plotten met de correlatiecoëfficiënt die is berekend met de Spearman-methode.
library(GGally)# Convert data to numericcorr <- data.frame(lapply(recast_data, as.integer))# Plot the graphggcorr(corr,method = c("pairwise", "spearman"),nbreaks = 6,hjust = 0.8,label = TRUE,label_size = 3,color = "grey50")Code Verklaring
- data.frame (lapply (recast_data, as.integer)): converteer gegevens naar numeriek
- ggcorr () zet de heatmap uit met de volgende argumenten:
- methode: methode om de correlatie te berekenen
- nbreaks = 6: aantal pauze
- hjust = 0.8: Controlepositie van de variabelenaam in de plot
- label = TRUE: Voeg labels toe in het midden van de vensters
- label_size = 3: Maatlabels
- color = "grey50"): Kleur van het label
Uitgang:
Stap 5) Train / testset
Elke begeleide machine learning-taak vereist het splitsen van de gegevens tussen een treinset en een testset. Je kunt de "functie" die je hebt aangemaakt in de andere begeleide leerlessen gebruiken om een trein / testset te maken.
set.seed(1234)create_train_test <- function(data, size = 0.8, train = TRUE) {n_row = nrow(data)total_row = size * n_rowtrain_sample <- 1: total_rowif (train == TRUE) {return (data[train_sample, ])} else {return (data[-train_sample, ])}}data_train <- create_train_test(recast_data, 0.8, train = TRUE)data_test <- create_train_test(recast_data, 0.8, train = FALSE)dim(data_train)Uitgang:
## [1] 36429 9dim(data_test)Uitgang:
## [1] 9108 9Stap 6) Bouw het model
Om te zien hoe het algoritme presteert, gebruikt u het pakket glm (). Het gegeneraliseerde lineaire model is een verzameling modellen. De basissyntaxis is:
glm(formula, data=data, family=linkfunction()Argument:- formula: Equation used to fit the model- data: dataset used- Family: - binomial: (link = "logit")- gaussian: (link = "identity")- Gamma: (link = "inverse")- inverse.gaussian: (link = "1/mu^2")- poisson: (link = "log")- quasi: (link = "identity", variance = "constant")- quasibinomial: (link = "logit")- quasipoisson: (link = "log")U bent klaar om het logistieke model te schatten om het inkomensniveau over een reeks functies te verdelen.
formula <- income~.logit <- glm(formula, data = data_train, family = 'binomial')summary(logit)Code Verklaring
- formule <- inkomen ~.: Maak het model dat past
- logit <- glm (formule, data = data_train, family = 'binomial'): Pas een logistiek model (family = 'binomial') aan met de data_train data.
- samenvatting (logit): Druk de samenvatting van het model af
Uitgang:
#### Call:## glm(formula = formula, family = "binomial", data = data_train)## ## Deviance Residuals:## Min 1Q Median 3Q Max## -2.6456 -0.5858 -0.2609 -0.0651 3.1982#### Coefficients:## Estimate Std. Error z value Pr(>|z|)## (Intercept) 0.07882 0.21726 0.363 0.71675## age 0.41119 0.01857 22.146 < 2e-16 ***## workclassLocal-gov -0.64018 0.09396 -6.813 9.54e-12 ***## workclassPrivate -0.53542 0.07886 -6.789 1.13e-11 ***## workclassSelf-emp-inc -0.07733 0.10350 -0.747 0.45499## workclassSelf-emp-not-inc -1.09052 0.09140 -11.931 < 2e-16 ***## workclassState-gov -0.80562 0.10617 -7.588 3.25e-14 ***## workclassWithout-pay -1.09765 0.86787 -1.265 0.20596## educationCommunity -0.44436 0.08267 -5.375 7.66e-08 ***## educationHighGrad -0.67613 0.11827 -5.717 1.08e-08 ***## educationMaster 0.35651 0.06780 5.258 1.46e-07 ***## educationPhD 0.46995 0.15772 2.980 0.00289 **## educationdropout -1.04974 0.21280 -4.933 8.10e-07 ***## educational.num 0.56908 0.07063 8.057 7.84e-16 ***## marital.statusNot_married -2.50346 0.05113 -48.966 < 2e-16 ***## marital.statusSeparated -2.16177 0.05425 -39.846 < 2e-16 ***## marital.statusWidow -2.22707 0.12522 -17.785 < 2e-16 ***## raceAsian-Pac-Islander 0.08359 0.20344 0.411 0.68117## raceBlack 0.07188 0.19330 0.372 0.71001## raceOther 0.01370 0.27695 0.049 0.96054## raceWhite 0.34830 0.18441 1.889 0.05894 .## genderMale 0.08596 0.04289 2.004 0.04506 *## hours.per.week 0.41942 0.01748 23.998 < 2e-16 ***## ---## Signif. codes: 0 '***' 0.001 '**' 0.01 '*' 0.05 '.' 0.1 ' ' 1## ## (Dispersion parameter for binomial family taken to be 1)## ## Null deviance: 40601 on 36428 degrees of freedom## Residual deviance: 27041 on 36406 degrees of freedom## AIC: 27087#### Number of Fisher Scoring iterations: 6De samenvatting van ons model onthult interessante informatie. De prestaties van een logistieke regressie worden geëvalueerd met specifieke belangrijke meetgegevens.
- AIC (Akaike Information Criteria): dit is het equivalent van R2 in logistieke regressie. Het meet de pasvorm wanneer een penalty wordt toegepast op het aantal parameters. Kleinere AIC- waarden geven aan dat het model dichter bij de waarheid is.
- Null deviantie: past alleen in het model met het onderscheppingspunt. De vrijheidsgraad is n-1. We kunnen het interpreteren als een Chi-kwadraatwaarde (aangepaste waarde die verschilt van de werkelijke waardehypothesetest).
- Restafwijking: model met alle variabelen. Het wordt ook geïnterpreteerd als een Chi-kwadraat-hypothesetest.
- Aantal Fisher Scoring-iteraties: aantal iteraties vóór convergentie.
De uitvoer van de functie glm () wordt opgeslagen in een lijst. De onderstaande code toont alle items die beschikbaar zijn in de logit-variabele die we hebben geconstrueerd om de logistische regressie te evalueren.
# De lijst is erg lang, print alleen de eerste drie elementen
lapply(logit, class)[1:3]Uitgang:
## $coefficients## [1] "numeric"#### $residuals## [1] "numeric"#### $fitted.values## [1] "numeric"Elke waarde kan worden geëxtraheerd met het $ -teken gevolgd door de naam van de metrische gegevens. U hebt het model bijvoorbeeld opgeslagen als logit. Om de AIC-criteria te extraheren, gebruikt u:
logit$aicUitgang:
## [1] 27086.65Stap 7) Beoordeel de prestaties van het model
Verwarring Matrix
De verwarringmatrix is een betere keuze om de classificatieprestaties te evalueren in vergelijking met de verschillende statistieken die u eerder zag. Het algemene idee is om te tellen hoe vaak True-instanties worden geclassificeerd als False.
Om de verwarringmatrix te berekenen, moet u eerst een reeks voorspellingen hebben, zodat ze kunnen worden vergeleken met de werkelijke doelen.
predict <- predict(logit, data_test, type = 'response')# confusion matrixtable_mat <- table(data_test$income, predict > 0.5)table_matCode Verklaring
- voorspellen (logit, data_test, type = 'response'): Bereken de voorspelling op de testset. Stel type = 'response' in om de responskans te berekenen.
- tabel (data_test $ inkomen, voorspellen> 0,5): Bereken de verwarringmatrix. voorspellen> 0,5 betekent dat het 1 retourneert als de voorspelde kansen groter zijn dan 0,5, anders 0.
Uitgang:
#### FALSE TRUE## <=50K 6310 495## >50K 1074 1229Elke rij in een verwarringmatrix vertegenwoordigt een werkelijk doel, terwijl elke kolom een voorspeld doel vertegenwoordigt. De eerste rij van deze matrix beschouwt het inkomen lager dan 50k (de False-klasse): 6241 werden correct geclassificeerd als individuen met een inkomen lager dan 50k ( True-negatief ), terwijl de overige ten onrechte werd geclassificeerd als hoger dan 50k ( False-positief ). De tweede rij beschouwt het inkomen boven de 50.000, de positieve klasse was 1229 ( True-positief ), terwijl het True-negatief 1074 was.
U kunt het model te berekenen nauwkeurigheid door het optellen van de ware positieve + true negatief over de totale observatie
accuracy_Test <- sum(diag(table_mat)) / sum(table_mat)accuracy_TestCode Verklaring
- sum (diag (table_mat)): Som van de diagonaal
- sum (table_mat): Som van de matrix.
Uitgang:
## [1] 0.8277339Het model lijkt te kampen met één probleem, het overschat het aantal fout-negatieven. Dit wordt de nauwkeurigheidstestparadox genoemd . We stelden dat de nauwkeurigheid de verhouding is tussen de juiste voorspellingen en het totale aantal gevallen. We kunnen een relatief hoge nauwkeurigheid hebben, maar een nutteloos model. Het gebeurt wanneer er een dominante klasse is. Als je terugkijkt op de verwarringmatrix, kun je zien dat de meeste gevallen als echt negatief worden geclassificeerd. Stel je voor dat het model alle klassen als negatief classificeerde (dwz lager dan 50k). U zou een nauwkeurigheid van 75 procent hebben (6718/6718 + 2257). Uw model presteert beter, maar heeft moeite om het echte positieve van het echte negatieve te onderscheiden.
In een dergelijke situatie verdient het de voorkeur om een beknoptere metriek te hebben. We kunnen kijken naar:
- Precisie = TP / (TP + FP)
- Recall = TP / (TP + FN)
Precisie versus terugroepen
Precisie kijkt naar de nauwkeurigheid van de positieve voorspelling. Recall is de verhouding van positieve instanties die correct worden gedetecteerd door de classificator;
U kunt twee functies construeren om deze twee metrieken te berekenen
- Construeer precisie
precision <- function(matrix) {# True positivetp <- matrix[2, 2]# false positivefp <- matrix[1, 2]return (tp / (tp + fp))}Code Verklaring
- mat [1,1]: Retourneert de eerste cel van de eerste kolom van het dataframe, dwz het ware positieve
- mat [1,2]; Retourneer de eerste cel van de tweede kolom van het dataframe, dwz het vals-positief
recall <- function(matrix) {# true positivetp <- matrix[2, 2]# false positivefn <- matrix[2, 1]return (tp / (tp + fn))}Code Verklaring
- mat [1,1]: Retourneert de eerste cel van de eerste kolom van het dataframe, dwz het ware positieve
- mat [2,1]; Retourneer de tweede cel van de eerste kolom van het dataframe, dwz het vals negatief
U kunt uw functies testen
prec <- precision(table_mat)precrec <- recall(table_mat)recUitgang:
## [1] 0.712877## [2] 0.5336518Wanneer het model zegt dat het een persoon is van meer dan 50.000, is dit in slechts 54 procent van de gevallen correct en kan in 72 procent van de gevallen personen boven de 50.000 worden geclaimd.
U kunt de is een harmonisch gemiddelde van deze twee statistieken, wat betekent dat het meer gewicht geeft aan de lagere waarden.
f1 <- 2 * ((prec * rec) / (prec + rec))f1Uitgang:
## [1] 0.6103799Precisie versus terugroepactie
Het is onmogelijk om zowel een hoge precisie als een hoge terugroepactie te hebben.
Als we de precisie vergroten, wordt de juiste persoon beter voorspeld, maar missen we er veel (lagere herinnering). In sommige situaties geven we de voorkeur aan hogere precisie dan terugroepen. Er is een concave relatie tussen precisie en herinnering.
- Stel je voor, je moet voorspellen of een patiënt een ziekte heeft. U wilt zo nauwkeurig mogelijk zijn.
- Als u potentiële frauduleuze mensen op straat moet detecteren door middel van gezichtsherkenning, is het beter om veel mensen te betrappen die als frauduleus worden bestempeld, ook al is de precisie laag. De politie kan de niet-frauduleuze persoon vrijlaten.
De ROC-curve
De curve voor bedrijfskarakteristieken van de ontvanger is een ander veelgebruikt hulpmiddel dat wordt gebruikt bij binaire classificatie. Het lijkt erg op de precisie / terugroepcurve, maar in plaats van precisie versus terugroepen uit te zetten, toont de ROC-curve de werkelijke positieve snelheid (dat wil zeggen, terugroepen) tegen de vals-positieve snelheid. Het percentage vals-positieven is de verhouding van negatieve gevallen die ten onrechte als positief worden geclassificeerd. Het is gelijk aan één minus het werkelijke negatieve tarief. Het echte negatieve percentage wordt ook wel specificiteit genoemd . Daarom plot de ROC-curve gevoeligheid (recall) versus 1-specificiteit
Om de ROC-curve te plotten, moeten we een bibliotheek met de naam RORC installeren. We kunnen het vinden in de conda-bibliotheek. U kunt de code typen:
conda install -cr r-rocr --yes
We kunnen de ROC plotten met de functies predict () en performance ().
library(ROCR)ROCRpred <- prediction(predict, data_test$income)ROCRperf <- performance(ROCRpred, 'tpr', 'fpr')plot(ROCRperf, colorize = TRUE, text.adj = c(-0.2, 1.7))Code Verklaring
- voorspelling (voorspellen, data_test $ inkomen): de ROCR-bibliotheek moet een voorspellingsobject maken om de invoergegevens te transformeren
- performance (ROCRpred, 'tpr', 'fpr'): Retourneer de twee combinaties die in de grafiek moeten worden geproduceerd. Hier worden tpr en fpr geconstrueerd. Gebruik "prec", "rec" om precisie en recall samen te plotten.
Uitgang:
Stap 8) Verbeter het model
U kunt proberen om non-lineariteit aan het model toe te voegen met de interactie tussen
- leeftijd en uren. per. week
- geslacht en uren. per. week.
U moet de scoretest gebruiken om beide modellen te vergelijken
formula_2 <- income~age: hours.per.week + gender: hours.per.week + .logit_2 <- glm(formula_2, data = data_train, family = 'binomial')predict_2 <- predict(logit_2, data_test, type = 'response')table_mat_2 <- table(data_test$income, predict_2 > 0.5)precision_2 <- precision(table_mat_2)recall_2 <- recall(table_mat_2)f1_2 <- 2 * ((precision_2 * recall_2) / (precision_2 + recall_2))f1_2Uitgang:
## [1] 0.6109181De score is iets hoger dan de vorige. U kunt aan de gegevens blijven werken om de score te verbeteren.
Overzicht
We kunnen de functie om een logistieke regressie te trainen samenvatten in de onderstaande tabel:
Pakket
Objectief
functie
argument
Maak een trein- / testdataset
create_train_set ()
gegevens, grootte, trein
glm
Train een gegeneraliseerd lineair model
glm ()
formule, data, familie *
glm
Vat het model samen
overzicht()
getailleerd model
baseren
Maak voorspelling
voorspellen()
aangepast model, dataset, type = 'respons'
baseren
Creëer een verwarringmatrix
tafel()
y, voorspellen ()
baseren
Maak een nauwkeurigheidsscore
som (diag (table ()) / sum (table ()
ROCR
ROC maken: Stap 1 Voorspelling maken
voorspelling()
voorspellen (), y
ROCR
Creëer ROC: Stap 2 Creëer performance
prestatie()
voorspelling (), 'tpr', 'fpr'
ROCR
ROC maken: Stap 3 Grafiek plotten
verhaal()
prestatie()
De andere modellen van het GLM- type zijn:
- binominaal: (link = "logit")
- gaussian: (link = "identity")
- Gamma: (link = "inverse")
- inverse.gaussian: (link = "1 / mu 2")
- poisson: (link = "log")
- quasi: (link = "identiteit", variantie = "constante")
- quasibinomial: (link = "logit")
- quasipoisson: (link = "log")