Advertisement
Guest User

Untitled

a guest
Jul 30th, 2019
247
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Julia 2.80 KB | None | 0 0
  1. # Univariate
  2.  
  3. function rbf(x::Float64,y::Float64)
  4.     return ℯ^(-(x-y)^2/2)
  5. end
  6.  
  7. function stein_kernel(x::Float64,y::Float64, k)
  8.     return (-x)*k(x,y)*(-y) + (-x)*(x-y)*k(x,y) + (-y)*(y-x)*k(x,y) + k(x,y) + (x-y)*(y-x)*k(x,y)
  9. end
  10.  
  11. function stein_kernel(x::Float64,y::Float64)
  12.     function rbf(x::Float64,y::Float64)
  13.         return ℯ^(-(x-y)^2/2)
  14.     end
  15.     return stein_kernel(x,y,rbf)
  16. end
  17.  
  18. function SteinDiscrepancy(x::Vector{Float64})
  19.     counter = 0.0
  20.     for i in eachindex(x)
  21.         for j in eachindex(x)
  22.             counter += stein_kernel(x[i],x[j])
  23.         end
  24.     end
  25.     return counter/(length(x)^2)
  26. end
  27.  
  28. dxStein2(x,y) = ((-1 * ℯ ^ (-((x - y) ^ 2) / 2) * -y + -x * (ℯ ^ (-((x - y) ^ 2) / 2) * ((-(2 * (x - y)) / 2) * log())) * -y) + (-1 * (x - y) * ℯ ^ (-((x - y) ^ 2) / 2) + -x * ℯ ^ (-((x - y) ^ 2) / 2) + -x * (x - y) * (ℯ ^ (-((x - y) ^ 2) / 2) * ((-(2 * (x - y)) / 2) * log()))) + (-1 * -y * ℯ ^ (-((x - y) ^ 2) / 2) + -y * (y - x) * (ℯ ^ (-((x - y) ^ 2) / 2) * ((-(2 * (x - y)) / 2) * log()))) + ℯ ^ (-((x - y) ^ 2) / 2) * ((-(2 * (x - y)) / 2) * log()) + ((y - x) * ℯ ^ (-((x - y) ^ 2) / 2) + -1 * (x - y) * ℯ ^ (-((x - y) ^ 2) / 2) + (x - y) * (y - x) * (ℯ ^ (-((x - y) ^ 2) / 2) * ((-(2 * (x - y)) / 2) * log()))))
  29.  
  30. function dSteinDiscrepancy(x::Vector{Float64})
  31.     grad = similar(x)
  32.     for i in eachindex(x)
  33.         grad[i] = 0
  34.         for j in eachindex(x)
  35.             grad[i] += dxStein2(x[i],x[j])
  36.         end
  37.     end
  38.     return 2*grad/length(x)^2
  39. end
  40.  
  41. # Multivariate
  42.  
  43. function stein_kernel(x,y)
  44.     function rbf(x,y)
  45.         return ℯ^(-norm(x-y)^2/2)
  46.     end
  47.     return stein_kernel(x,y,rbf)
  48. end
  49.  
  50. function stein_kernel(x,y,k)
  51.     return (5.0*dot(x,y) - 2.0*dot(x,x) - 2.0*dot(y,y) + length(x))*k(x,y)
  52. end
  53.  
  54. function SteinDiscrepancy(x::Array{Float64,2})
  55.     counter = 0.0
  56.     for i in 1:size(x)[2]
  57.         for j in 1:size(x)[2]
  58.             counter += stein_kernel(view(x,:,i),view(x,:,j))
  59.         end
  60.     end
  61.     return counter/size(x)[2]^2
  62. end
  63.  
  64. function dx_stein_kernel(x,y)
  65.     return (5y-4x)ℯ^(-norm(x-y)^2/2) +(y-x)*stein_kernel(x,y)
  66. end
  67.  
  68. function dSteinDiscrepancy(x::Array{Float64,2})
  69.     grad = similar(x)
  70.     for i in 1:size(x)[2]
  71.         grad[:,i] .= 0
  72.         for j in 1:size(x)[2]
  73.             grad[:,i] += dx_stein_kernel(view(x,:,i),view(x,:,j))
  74.         end
  75.     end
  76.     return 2*grad/size(x)[2]^2
  77. end
  78.  
  79.  
  80. # Misc
  81.  
  82. function loss_along_direction(direction, state)
  83.     direction_norm = norm(direction)
  84.     if(direction_norm != 1)
  85.         direction = direction/direction_norm
  86.     end
  87.     return SteinDiscrepancy(state'direction)
  88. end
  89.  
  90. function g!(∇f,x)
  91.    copy!(∇f,dSteinDiscrepancy(x))
  92. end
  93.  
  94.  
  95. function my_test_debug(x)
  96.    for i in 1:10
  97.        println(x)
  98.    end
  99. end
  100.  
  101. @profiler 3 + 5
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement