Advertisement
mgordon

A bugfixed version of the predict.lm function

Feb 14th, 2012
465
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
R 7.59 KB | None | 0 0
  1. predict.lm2 <- function (object, newdata, se.fit = FALSE, scale = NULL, df = Inf,
  2.             interval = c("none", "confidence", "prediction"), level = 0.95,
  3.             type = c("response", "terms"), terms = NULL, na.action = na.pass,
  4.             pred.var = res.var/weights, weights = 1, ...)
  5. {
  6.     tt <- terms(object)
  7.     if (!inherits(object, "lm"))
  8.         warning("calling predict.lm(<fake-lm-object>) ...")
  9.     if (missing(newdata) || is.null(newdata)) {
  10.         mm <- X <- model.matrix(object)
  11.         mmDone <- TRUE
  12.         offset <- object$offset
  13.     }
  14.     else {
  15.         Terms <- delete.response(tt)
  16.         m <- model.frame(Terms, newdata, na.action = na.action,
  17.                             xlev = object$xlevels)
  18.         if (!is.null(cl <- attr(Terms, "dataClasses")))
  19.             .checkMFClasses(cl, m)
  20.         X <- model.matrix(Terms, m, contrasts.arg = object$contrasts)
  21.         offset <- rep(0, nrow(X))
  22.         if (!is.null(off.num <- attr(tt, "offset")))
  23.             for (i in off.num) offset <- offset + eval(attr(tt,
  24.                                                 "variables")[[i + 1]], newdata)
  25.         if (!is.null(object$call$offset))
  26.             offset <- offset + eval(object$call$offset, newdata)
  27.         mmDone <- FALSE
  28.     }
  29.     n <- length(object$residuals)
  30.     p <- object$rank
  31.     p1 <- seq_len(p)
  32.     piv <- if (p)
  33.         stats:::qr.lm(object)$pivot[p1]
  34.     if (p < ncol(X) && !(missing(newdata) || is.null(newdata)))
  35.         warning("prediction from a rank-deficient fit may be misleading")
  36.     beta <- object$coefficients
  37.     predictor <- drop(X[, piv, drop = FALSE] %*% beta[piv])
  38.     if (!is.null(offset))
  39.         predictor <- predictor + offset
  40.     interval <- match.arg(interval)
  41.     if (interval == "prediction") {
  42.         if (missing(newdata))
  43.             warning("Predictions on current data refer to _future_ responses\n")
  44.         if (missing(newdata) && missing(weights)) {
  45.             w <- weights.default(object)
  46.             if (!is.null(w)) {
  47.                 weights <- w
  48.                 warning("Assuming prediction variance inversely proportional to weights used for fitting\n")
  49.             }
  50.         }
  51.         if (!missing(newdata) && missing(weights) && !is.null(object$weights) &&
  52.                             missing(pred.var))
  53.             warning("Assuming constant prediction variance even though model fit is weighted\n")
  54.         if (inherits(weights, "formula")) {
  55.             if (length(weights) != 2L)
  56.                 stop("'weights' as formula should be one-sided")
  57.             d <- if (missing(newdata) || is.null(newdata))
  58.                                         model.frame(object)
  59.                                 else newdata
  60.             weights <- eval(weights[[2L]], d, environment(weights))
  61.         }
  62.     }
  63.     type <- match.arg(type)
  64.     if (se.fit || interval != "none") {
  65.         res.var <- if (is.null(scale)) {
  66.                                 r <- object$residuals
  67.                                 w <- object$weights
  68.                                 rss <- sum(if (is.null(w)) r^2 else r^2 * w)
  69.                                 df <- object$df.residual
  70.                                 rss/df
  71.                         }
  72.                         else scale^2
  73.         if (type != "terms") {
  74.             if (p > 0) {
  75.                 XRinv <- if (missing(newdata) && is.null(w))
  76.                                               qr.Q(stats:::qr.lm(object))[, p1, drop = FALSE]
  77.                                         else X[, piv] %*% qr.solve(qr.R(stats:::qr.lm(object))[p1,
  78.                                                               p1])
  79.                 ip <- drop(XRinv^2 %*% rep(res.var, p))
  80.             }
  81.             else ip <- rep(0, n)
  82.         }
  83.     }
  84.     if (type == "terms") {
  85.         if (!mmDone) {
  86.             mm <- model.matrix(object)
  87.             mmDone <- TRUE
  88.         }
  89.         aa <- attr(mm, "assign")
  90.         ll <- attr(tt, "term.labels")
  91.         hasintercept <- attr(tt, "intercept") > 0L
  92.         if (hasintercept)
  93.             ll <- c("(Intercept)", ll)
  94.         aaa <- factor(aa, labels = ll)
  95.         asgn <- split(order(aa), aaa)
  96.         if (hasintercept) {
  97.             asgn$"(Intercept)" <- NULL
  98.             if (!mmDone) {
  99.                 mm <- model.matrix(object)
  100.                 mmDone <- TRUE
  101.             }
  102.             avx <- colMeans(mm)
  103.             termsconst <- sum(avx[piv] * beta[piv])
  104.         }
  105.         nterms <- length(asgn)
  106.         if (nterms > 0) {
  107.             predictor <- matrix(ncol = nterms, nrow = NROW(X))
  108.             dimnames(predictor) <- list(rownames(X), names(asgn))
  109.             if (se.fit || interval != "none") {
  110.                 ip <- matrix(ncol = nterms, nrow = NROW(X))
  111.                 dimnames(ip) <- list(rownames(X), names(asgn))
  112.                 Rinv <- qr.solve(qr.R(stats:::qr.lm(object))[p1, p1])
  113.             }
  114.             if (hasintercept)
  115.                 X <- sweep(X, 2L, avx, check.margin = FALSE)
  116.             unpiv <- rep.int(0L, NCOL(X))
  117.             unpiv[piv] <- p1
  118.             for (i in seq.int(1L, nterms, length.out = nterms)) {
  119.                 iipiv <- asgn[[i]]
  120.                 ii <- unpiv[iipiv]
  121.                 iipiv[ii == 0L] <- 0L
  122.                 predictor[, i] <- if (any(iipiv > 0L))
  123.                                               X[, iipiv, drop = FALSE] %*% beta[iipiv]
  124.                                         else 0
  125.                 if (se.fit || interval != "none")
  126.                                       ip[, i] <- if (any(iipiv > 0L))
  127.                                                     as.matrix(X[, iipiv, drop = FALSE] %*% Rinv[ii,
  128.                                                                               , drop = FALSE])^2 %*% rep.int(res.var,
  129.                                                                       p)
  130.                                               else 0
  131.             }
  132.             if (!is.null(terms)) {
  133.                 predictor <- predictor[, terms, drop = FALSE]
  134.                 if (se.fit)
  135.                     ip <- ip[, terms, drop = FALSE]
  136.             }
  137.         }
  138.         else {
  139.             predictor <- ip <- matrix(0, n, 0L)
  140.         }
  141.         attr(predictor, "constant") <- if (hasintercept)
  142.                                 termsconst
  143.                         else 0
  144.     }
  145.     if (interval != "none") {
  146.         tfrac <- qt((1 - level)/2, df)
  147.         hwid <- tfrac * switch(interval, confidence = sqrt(ip),
  148.                             prediction = sqrt(ip + pred.var))
  149.         if (type != "terms") {
  150.             predictor <- cbind(predictor, predictor + hwid %o%
  151.                                             c(1, -1))
  152.             colnames(predictor) <- c("fit", "lwr", "upr")
  153.         }
  154.         else {
  155.             if (!is.null(terms))
  156.                 hwid <- hwid[, terms, drop = FALSE]
  157.             lwr <- predictor + hwid
  158.             upr <- predictor - hwid
  159.         }
  160.     }
  161.     if (se.fit || interval != "none") {
  162.         se <- sqrt(ip)
  163.     }
  164.     if (missing(newdata) && !is.null(na.act <- object$na.action)) {
  165.         predictor <- napredict(na.act, predictor)
  166.         if (se.fit)
  167.             se <- napredict(na.act, se)
  168.     }
  169.     if (type == "terms" && interval != "none") {
  170.         if (missing(newdata) && !is.null(na.act)) {
  171.             lwr <- napredict(na.act, lwr)
  172.             upr <- napredict(na.act, upr)
  173.         }
  174.         list(fit = predictor, se.fit = se, lwr = lwr, upr = upr,
  175.                             df = df, residual.scale = sqrt(res.var))
  176.     }
  177.     else if (se.fit)
  178.         list(fit = predictor, se.fit = se, df = df, residual.scale = sqrt(res.var))
  179.     else predictor
  180. }
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement