MIT:The Analytics Edge 笔记03-指数回归

MIT课程 15.071x The Analytics Edge 第三单元的学习记录。


Logistic Regression

第三单元的主题是指数回归。

1.理论

指数回归

指数回归用于因变量y是二进制的情况,也就是说,y的取值只有1或者0。
y=1的概率:
formula=\frac{1}{1+e^{-{(\beta_0 +\beta_1x_1+\beta_2x_2+\ldots+\beta_nx_n+\epsilon)}}})

y=1的概率与y=0的概率的比值:
formula}{P(y=0)})

formula}{1-P(y=1)})

formula

混淆矩阵(confusion matrix)

有阈值t,
如果P(y=1) >=t,则预测y=1。
如果P(y=1) < t,则预测y=0。

对于预测结果,我们得到矩阵

predict y=0 predict y=1
actual y=0 TN (True Nagative) FP (False Positive)
actual y=1 FN (False Nagative) TP (True Positive)

根据矩阵中的值,我们可以计算指数回归的一些指标:

formula
formula
formula

补充概念:
适合率
formula
再现率
formula
F值(F-measure)
formula
F值越高,性能越好

ROC曲线

ROC曲线 (Receiver Operator Characteristic curve)可以指导我们如何选取阈值t。
y轴的指标是 sensitivity,所以也叫 True positive rate。

formula
x轴的指标是 1-specificity,所以也叫 False positive rate。

formula

每取一个阈值t,则计算相对应的 TPR 和 FPR,在坐标里标出这个点,就形成ROC曲线。
ROC Curve

如图所示,

t=0时,我们预测所有的y=1,即TPR=1,FPR=1,对应的坐标是(1,1)   
t=1时,我们预测所有的y=0,即TPR=0,FPR=0,对应的坐标是(0,0)   

这就是曲线的两个端点。

AUC值

AUC(Area Under Curve)被定义为ROC曲线下的面积,显然这个面积的数值不会大于1。又由于ROC曲线一般都处于y=x这条直线的上方,所以AUC的取值范围在0.5和1之间。

2.建立回归模型

# 建立模型
# Top10作为因变量,其他所有的列都作为自变量
SongsLog1 = glm(Top10 ~ ., data=SongsTrain, family=binomial)

# Top10作为因变量,除了loudness以外的所有列都作为自变量
SongsLog2 = glm(Top10 ~ . - loudness, data=SongsTrain, family=binomial)

3.评估

# 预测
testPredict = predict(SongsLog3, newdata=SongsTest, type="response")

# 生成混淆矩阵
table(SongsTest$Top10, testPredict >= 0.45)

# 生成ROC曲线
library(ROCR)
pred = prediction(testPredict, test$violator)
perf = performance(pred, "tpr", "fpr")
plot(perf)

# 加点颜色和坐标点
plot(perf, colorize=TRUE, print.cutoffs.at=seq(0,1,0.1), text.adj=c(-0.2,1.7))

# 计算AUC值
as.numeric(performance(pred, "auc")@y.values)

附录A 分割train和test的方法一

library(caTools)
set.seed(144)

split = sample.split(parole$violator, SplitRatio = 0.7)
train = subset(parole, split == TRUE)
test = subset(parole, split == FALSE)
# 特别注意:每次运行出来的结果是不一样的

也可以这样做:

library(caTools)
set.seed(144)

split = sample(1:nrow(data), size=0.7 * nrow(data))
train = data[split,]
test = data[-split,]

附录B 补充缺失数据

library(mice)
set.seed(144)
vars.for.imputation = setdiff(names(loans), "not.fully.paid")
imputed = complete(mice(loans[vars.for.imputation]))
loans[vars.for.imputation] = imputed