Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- # Univariate
- function rbf(x::Float64,y::Float64)
- return ℯ^(-(x-y)^2/2)
- end
- function stein_kernel(x::Float64,y::Float64, k)
- return (-x)*k(x,y)*(-y) + (-x)*(x-y)*k(x,y) + (-y)*(y-x)*k(x,y) + k(x,y) + (x-y)*(y-x)*k(x,y)
- end
- function stein_kernel(x::Float64,y::Float64)
- function rbf(x::Float64,y::Float64)
- return ℯ^(-(x-y)^2/2)
- end
- return stein_kernel(x,y,rbf)
- end
- function SteinDiscrepancy(x::Vector{Float64})
- counter = 0.0
- for i in eachindex(x)
- for j in eachindex(x)
- counter += stein_kernel(x[i],x[j])
- end
- end
- return counter/(length(x)^2)
- end
- dxStein2(x,y) = ((-1 * ℯ ^ (-((x - y) ^ 2) / 2) * -y + -x * (ℯ ^ (-((x - y) ^ 2) / 2) * ((-(2 * (x - y)) / 2) * log(ℯ))) * -y) + (-1 * (x - y) * ℯ ^ (-((x - y) ^ 2) / 2) + -x * ℯ ^ (-((x - y) ^ 2) / 2) + -x * (x - y) * (ℯ ^ (-((x - y) ^ 2) / 2) * ((-(2 * (x - y)) / 2) * log(ℯ)))) + (-1 * -y * ℯ ^ (-((x - y) ^ 2) / 2) + -y * (y - x) * (ℯ ^ (-((x - y) ^ 2) / 2) * ((-(2 * (x - y)) / 2) * log(ℯ)))) + ℯ ^ (-((x - y) ^ 2) / 2) * ((-(2 * (x - y)) / 2) * log(ℯ)) + ((y - x) * ℯ ^ (-((x - y) ^ 2) / 2) + -1 * (x - y) * ℯ ^ (-((x - y) ^ 2) / 2) + (x - y) * (y - x) * (ℯ ^ (-((x - y) ^ 2) / 2) * ((-(2 * (x - y)) / 2) * log(ℯ)))))
- function dSteinDiscrepancy(x::Vector{Float64})
- grad = similar(x)
- for i in eachindex(x)
- grad[i] = 0
- for j in eachindex(x)
- grad[i] += dxStein2(x[i],x[j])
- end
- end
- return 2*grad/length(x)^2
- end
- # Multivariate
- function stein_kernel(x,y)
- function rbf(x,y)
- return ℯ^(-norm(x-y)^2/2)
- end
- return stein_kernel(x,y,rbf)
- end
- function stein_kernel(x,y,k)
- return (5.0*dot(x,y) - 2.0*dot(x,x) - 2.0*dot(y,y) + length(x))*k(x,y)
- end
- function SteinDiscrepancy(x::Array{Float64,2})
- counter = 0.0
- for i in 1:size(x)[2]
- for j in 1:size(x)[2]
- counter += stein_kernel(view(x,:,i),view(x,:,j))
- end
- end
- return counter/size(x)[2]^2
- end
- function dx_stein_kernel(x,y)
- return (5y-4x)ℯ^(-norm(x-y)^2/2) +(y-x)*stein_kernel(x,y)
- end
- function dSteinDiscrepancy(x::Array{Float64,2})
- grad = similar(x)
- for i in 1:size(x)[2]
- grad[:,i] .= 0
- for j in 1:size(x)[2]
- grad[:,i] += dx_stein_kernel(view(x,:,i),view(x,:,j))
- end
- end
- return 2*grad/size(x)[2]^2
- end
- # Misc
- function loss_along_direction(direction, state)
- direction_norm = norm(direction)
- if(direction_norm != 1)
- direction = direction/direction_norm
- end
- return SteinDiscrepancy(state'direction)
- end
- function g!(∇f,x)
- copy!(∇f,dSteinDiscrepancy(x))
- end
- function my_test_debug(x)
- for i in 1:10
- println(x)
- end
- end
- @profiler 3 + 5
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement