Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- function dot2(v)
- sums = fill(0.0, size(v,1)-1)
- for i = 1:size(v,1)-1
- for j = 1:size(v,2)
- sums[i] += v[i,j]*v[i+1,j]
- end
- end
- return sums
- end
- function dot2Δ(v, Δ)
- g = fill(0.0, size(v)...)
- g[1,:] .+= Δ[1].*v[2,:]
- for i = 2:size(v,1)-1
- g[i,:] .+= Δ[i-1]*v[i-1,:] .+ Δ[i]*v[i+1,:]
- end
- g[end,:] .+= Δ[end].*v[end-1,:]
- return g
- end
- dot2(v::TrackedArray) = Flux.Tracker.track(dot2, v)
- Flux.Tracker.@grad function dot2(v)
- dot2(data(v)),
- Δ-> (nobacksies(:dot2, dot2Δ(data.((v, Δ))...)) ,)
- end
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement