Advertisement
karstenw

AMR - adaptive mesh refinement

Sep 12th, 2016
320
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
R 9.23 KB | None | 0 0
  1. amr <- function(x, ...) UseMethod("amr", x)
  2. is.amr <- function(x) inherits(x, "amr")
  3.  
  4. #temp
  5. amr_vals <- function(node) if(length(node[["refinement"]])>0) {
  6.     # browser()
  7.     unlist(lapply(node[["refinement"]], amr_vals))
  8.     } else node[["value"]]
  9.  
  10. # x must be a matrix with two columns
  11. amr.matrix <- function(x, weights=NULL, xlim=NULL, ylim=NULL, ...) {
  12.     if(is.null(xlim)) xlim <- c(min(x[,1]), max(x[,1]))
  13.     if(is.null(ylim)) ylim <- c(min(x[,2]), max(x[,2]))
  14.     idx <- which(!is.na(x[,1]) & !is.na(x[,2]))
  15.     amr.default(x[idx,], weights=weights, type="data", xlim=xlim, ylim=ylim, ...)
  16. }
  17.  
  18. # x must be a function that supports vector arguments
  19. amr.function <- function(x, ...) amr.default(x, type="function", ...)
  20.  
  21. # the workhorse, since the difference between function and dataset is small
  22. amr.default <- function(x, xlim, ylim, weights=NULL, minlevel=4, maxlevel=8, reltol=0.05, log="", type, ...) {
  23.     pts <- matrix(c(c(xlim,rev(xlim)), rep(ylim, each=2), rep(NA, 4)), byrow=FALSE, ncol=3)
  24.     if(type=="function") pts[1:4, 3] <- x(pts[1:4,1], pts[1:4, 2])
  25.     loglist <- strsplit(log,"")[[1]]
  26.     logx <- "x" %in% loglist
  27.     logy <- "y" %in% loglist
  28.     if(type=="data") zlim <- c(0,1)
  29.    
  30.     pgrad.function <- function(node) {
  31.         vals <- pts[node[["pts_idx"]],3]
  32.         vals <- vals[is.finite(vals)]
  33.         ret <- sd(vals)/(mean(vals)+1)
  34.         # browser()
  35.         return(ret)
  36.     }
  37.     pgrad.data <- function(node) {
  38.         # prob. mass of this rectangle
  39.         ll <- pts[node[["pts_idx"]][1],1:2]
  40.         ur <- pts[node[["pts_idx"]][3],1:2]
  41.         ret <- node[["value"]] * (ur[[1]]-ll[[1]]) * (ur[[2]]-ll[[2]])
  42.         return(ret)
  43.     }
  44.     pgrad <- switch(type, "function"=pgrad.function, "data"=pgrad.data)
  45.    
  46.     rectval.function <- function(node) {
  47.         # mean of the function value at the corners
  48.         vals <- pts[node[["pts_idx"]],3]
  49.         return(mean(vals[is.finite(vals)]))
  50.     }
  51.     rectval.data <- function(node) {
  52.         # probability mass per area unit. should integrate to 1.
  53.         ll <- pts[node[["pts_idx"]][1],1:2]
  54.         ur <- pts[node[["pts_idx"]][3],1:2]
  55.         idx <- (x[,1]>=ll[[1]]) & (x[,1]<ur[[1]]) & (x[,2]>=ll[[2]]) & (x[,2]<ur[[2]])
  56.         xlen <- (ur[[1]]-ll[[1]])
  57.         ylen <- (ur[[2]]-ll[[2]])
  58.         # if(logx) xlen <- ur[[1]]/ll[[1]] else xlen <- (ur[[1]]-ll[[1]])
  59.         # if(logy) ylen <- ur[[2]]/ll[[2]] else ylen <- (ur[[2]]-ll[[2]])
  60.        
  61.         ret <- if(is.null(weights)) sum(idx)/(nrow(x)*xlen*ylen) else sum(weights[idx]) / (sum(weights)*xlen*ylen)
  62.         # if(ret>10000) browser()
  63.         # ret <- (sum(idx)*total_area) / (nrow(x) * (ur[[1]]-ll[[1]]) * (ur[[2]]-ll[[2]]))
  64.         zlim[[2]] <<- max(c(zlim[[2]], ret))
  65.         return(ret)
  66.     }      
  67.     rectval <- switch(type, "function"=rectval.function, "data"=rectval.data)
  68.        
  69.     refinement <- function(node) {
  70.         ll <- pts[node[["pts_idx"]][1],1:2]
  71.         ur <- pts[node[["pts_idx"]][3],1:2]
  72.         lvl <- node[["level"]]
  73.         midx <- if(logx) sqrt(ll[1]*ur[1]) else (ll[1]+ur[1])/2
  74.         midy <- if(logy) sqrt(ll[2]*ur[2]) else (ll[2]+ur[2])/2
  75.        
  76.         add_pts <- matrix(c(
  77.             midx, ll[2], NA,  # lower middle
  78.             ll[1], midy, NA,  # left middle
  79.             midx, midy, NA, # middle
  80.             ur[1], midy, NA,  # right middle
  81.             midx, ur[2], NA), # upper middle
  82.             byrow=TRUE, ncol=3
  83.         )
  84.         if(type=="function") add_pts[,3] <- x(add_pts[,1], add_pts[,2])
  85.         old_nrow <- nrow(pts)
  86.         pts <<- rbind(pts, add_pts)
  87.        
  88.         node1 <- list(pts_idx=c(node[["pts_idx"]][1], old_nrow+c(1,3,2)), refinement=list(), level=lvl+1)
  89.         node1[["value"]] <- rectval(node1)
  90.         node1[["pgrad"]] <- pgrad(node1)
  91.         node2 <- list(pts_idx=c(old_nrow+1, node[["pts_idx"]][2], old_nrow+c(4,3)), refinement=list(), level=lvl+1)
  92.         node2[["value"]] <- rectval(node2)
  93.         node2[["pgrad"]] <- pgrad(node2)
  94.         node3 <- list(pts_idx=c(old_nrow+c(3,4), node[["pts_idx"]][3], old_nrow+5), refinement=list(), level=lvl+1)
  95.         node3[["value"]] <- rectval(node3)
  96.         node3[["pgrad"]] <- pgrad(node3)
  97.         node4 <- list(pts_idx=c(old_nrow+c(2,3,5), node[["pts_idx"]][4]), refinement=list(), level=lvl+1)
  98.         node4[["value"]] <- rectval(node4)
  99.         node4[["pgrad"]] <- pgrad(node4)
  100.        
  101.         return(list(node1, node2, node3, node4))
  102.     }
  103.    
  104.     needs_refinement <- function(node) {
  105.         if(node[["level"]]<minlevel) return(TRUE)
  106.         if(node[["level"]]>=maxlevel) return(FALSE)
  107.         if(!is.numeric(node[["pgrad"]]) | is.nan(node[["pgrad"]])) return(TRUE)
  108.         if(abs(node[["pgrad"]]) > reltol) return(TRUE)
  109.         # if(is.numeric(node[["pgrad"]])) browser()
  110.         # ans <- try(if(node[["pgrad"]] > reltol) return(TRUE))
  111.         # if (inherits(ans, "try-error")) browser()
  112.        
  113.         return(FALSE)
  114.     }
  115.    
  116.     refine_recursively <- function(node) {
  117.         if(needs_refinement(node)) {
  118.             node[["refinement"]] <- refinement(node)
  119.             # browser()
  120.             for (i in seq_len(length(node[["refinement"]])))
  121.                 node[["refinement"]][[i]] <- refine_recursively(node[["refinement"]][[i]])
  122.         }
  123.         return(node)
  124.     }
  125.    
  126.     root <- list(pts_idx=seq(1,4), refinement=list(), level=0)
  127.     root[["value"]] <- rectval(root)
  128.     root[["pgrad"]] <- pgrad(root)
  129.     root <- refine_recursively(root)
  130.     if(type=="function")  {
  131.         zlim <- c(min(pts[,3]), max(pts[,3]))
  132.         # browser()
  133.     } else if(type=="data") {
  134.         # browser()
  135.         vals <- amr_vals(root)
  136.         zlim <- c(min(vals), max(vals))
  137.     }
  138.    
  139.     res <- list(pts=pts, node=root, zlim=zlim)
  140.     class(res) <- "amr"
  141.     return(res)
  142. }
  143.  
  144. # use classInt? quantile method not helpful due to irregular sampling
  145. # image has a breaks argument that looks cool
  146. plot.amr <- function(x, col=heat.colors(50), border=NA, xlab="x", ylab="y", zmag=1, ...) {
  147.     zmin <- if(is.finite(x[["zlim"]][[1]])) x[["zlim"]][[1]] else -zmag
  148.     zmax <- if(is.finite(x[["zlim"]][[2]])) x[["zlim"]][[2]] else zmag
  149.     colidx <- seq(zmin, zmax, length.out=length(col))
  150.      
  151.     plot_node <- function(node) {
  152.         ll <- x[["pts"]][node[["pts_idx"]][1],1:2]
  153.         ur <- x[["pts"]][node[["pts_idx"]][3],1:2]
  154.         rect(ll[1], ll[2], ur[1], ur[2], border=border, col=col[findInterval(node[["value"]], colidx)])
  155.         if(length(node[["refinement"]])>0) lapply(node[["refinement"]], plot_node)
  156.     }
  157.    
  158.     root <- x[["node"]]
  159.     ll <- x[["pts"]][root[["pts_idx"]][1], 1:2]
  160.     ur <- x[["pts"]][root[["pts_idx"]][3], 1:2]
  161.     plot_arg <- list(...)
  162.     plot_arg[["x"]] <- 1; plot_arg[["col"]] <- "white"; plot_arg[["type"]] <- "p"; plot_arg[["xlim"]] <- c(ll[1], ur[1])
  163.     plot_arg[["ylim"]] <- c(ll[2], ur[2]); plot_arg[["xlab"]] <- xlab; plot_arg[["ylab"]] <- ylab
  164.     do.call(plot, plot_arg)
  165.     plot_node(root)
  166. }
  167.  
  168.  
  169. # may need a different method (weighted mean instead of sum) for function
  170. integrate_amr <- function(x, xlim, ylim) {
  171.     eval_rect <- function(node) {
  172.         ll <- x[["pts"]][node[["pts_idx"]][1],1:2]
  173.         ur <- x[["pts"]][node[["pts_idx"]][3],1:2]
  174.         if(xlim[1]>ur[1] || xlim[2]<ll[1] || ylim[1]>ur[2] || ylim[2]<ll[2]) return(NULL)
  175.         if(length(node[["refinement"]])>0) return(unlist(lapply(node[["refinement"]], eval_rect)))
  176.         area <- (min(c(ur[1], xlim[2])) - max(c(ll[1], xlim[1]))) * (min(c(ur[2], ylim[2]))-max(c(ll[2], ylim[1])))
  177.         tot_area <- (ur[1]-ll[1]) * (ur[2]-ll[2])
  178.         return(node[["value"]] * area / tot_area) # this is integreation
  179.     }
  180.     sum(eval_rect(x[["node"]])) # this is integration
  181. }
  182.  
  183. # guess xlab and ylab
  184. # add breaks argument
  185. image.amr <- function(x, col=rev(gray(0:12/12)), border=NA, xlim=NULL, ylim=NULL, log="", xlab="x", ylab="y", zmag=1, length.out=100, breaks=NULL, ...) {
  186.     if(is.null(xlim)) xlim <- c(min(x[["pts"]][,1]), max(x[["pts"]][,1]))
  187.     if(is.null(ylim)) ylim <- c(min(x[["pts"]][,2]), max(x[["pts"]][,2]))
  188.    
  189.     zmin <- if(is.finite(x[["zlim"]][[1]])) x[["zlim"]][[1]] else -zmag
  190.     zmax <- if(is.finite(x[["zlim"]][[2]])) x[["zlim"]][[2]] else zmag
  191.        
  192.     if(length(length.out)==1) length.out <- rep(length.out,2)
  193.    
  194.     loglist <- strsplit(log,"")[[1]]
  195.     logx <- "x" %in% loglist
  196.     logy <- "y" %in% loglist
  197.    
  198.     xgrid <- if(logx) exp(seq(log(xlim[1]), log(xlim[2]), length.out=length.out[1])) else seq(xlim[1], xlim[2], length.out=length.out[1])
  199.     ygrid <- if(logy) exp(seq(log(ylim[1]), log(ylim[2]), length.out=length.out[2])) else seq(ylim[1], ylim[2], length.out=length.out[2])
  200.     z <- matrix(0, nrow=length.out[1], ncol=length.out[2])
  201.    
  202.     for (i in seq.int(2,length.out[1]))
  203.         for (j in seq.int(2,length.out[2]))
  204.             z[i,j] <- integrate_amr(x, xlim=xgrid[seq.int(i-1,i)], ylim=ygrid[seq.int(j-1,j)])
  205.            
  206.     if(!is.null(breaks) && length(breaks)==1)  {
  207.         breaks <- classIntervals(as.vector(z), n=breaks, style="fisher")$brks
  208.         # browser()
  209.         # breaks <- quantile(as.vector(z), seq(0,1, length.out=(breaks+1)))
  210.     }
  211.     # browser()
  212.     image(x=xgrid, y=ygrid, z=z, zlim=c(zmin, zmax), xlim=xlim, ylim=ylim, col=col, xlab=xlab, ylab=ylab, log=log, breaks=breaks, ...)
  213.    
  214.  
  215. }
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement