Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- simple.logistic = function(x, y, w, iters=30, l1_penalty=0) {
- d = ncol(x)
- n = nrow(x)
- x_c = colMeans(x)
- x = (x - matrix(rep(x_c, n), nrow=n, byrow = T))
- x_s = sqrt(colMeans(x^2))
- x = (x / matrix(rep(x_s, n), nrow=n, byrow = T))
- beta = rep(0,d)
- p = weighted.mean(y, w = w)
- beta_0 = log(p) - log(1-p)
- for (iter_ in 1:iters) {
- pred = 1 / (1 + exp(-x %*% beta - beta_0))
- weights = w * pred * (1 - pred)
- rw = matrix(rep(weights, d), ncol=d)
- ob = beta
- XTX = t(x) %*% (rw * x)
- XTX = XTX + sum(w) * l1_penalty * diag(1/(abs(beta)+1e-5))
- beta = beta + qr.solve(XTX, t(x) %*% (w * (y - pred)))
- if (mean((beta - ob)^2)<1e-6) {
- break
- }
- }
- obj = list(beta=beta, beta_0=beta_0, x_c = x_c, x_s = x_s)
- class(obj) = "simple.logistic"
- return(obj)
- }
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement