# Untitled

a guest
Sep 20th, 2020
21
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
1. using DifferentialEquations, Flux, Optim, DiffEqFlux, Plots
2. function lotka_volterra(du,u,p,t)
3.   x, y = u
4.   α, β, δ, γ = p
5.   du[1] = dx = α*x - β*x*y
6.   du[2] = dy = -δ*y + γ*x*y
7. end
8. u0 = [1.0,1.0]
9. tspan = (0.0,10.0)
10. p = [1.5,1.0,3.0,1.0]
11. prob = ODEProblem(lotka_volterra,u0,tspan,p)
12. sol = solve(prob,Tsit5(),saveat=0.1)
13. using Plots
14. plot(sol)
15. measurements = Array(sol)
16.
17. function predict_adjoint(p) # Our 1-layer neural network
18.   Array(concrete_solve(prob,Tsit5(),u0,p,saveat=0.0:0.1:10.0))
19. end
20.
21. function loss_adjoint(p)
22.   prediction = predict_adjoint(p)
23.   loss = sum(abs2,prediction-measurements)
24.   loss,prediction
25. end
26.
27. function negative_loss_adjoint(p)
28.   prediction = predict_adjoint(p)
29.   loss = sum(abs2,prediction-measurements)
30.   loss,prediction
31. end
32.
33. cb = function (p,l,pred) #callback function to observe training
34.   display(l)
35.   # using `remake` to re-create our `prob` with current parameters `p`
36.   a = plot(solve(remake(prob,p=p),Tsit5()),ylim=(0,6))
37.   plot!(a,(0.0:0.1:10.0,measurements[1,:]),label="rabbit data")
38.   display(plot!(a,(0.0:0.1:10.0,measurements[2,:]),label="wolves data"))
39.   return false # Tell it to not halt the optimization. If return true, then optimization stops
40. end
41.
42. guess = [1.15814, 1.898547, 3.26339, 0.443906]
43. # Display the ODE with the initial parameter values.
44. cb(guess,loss_adjoint(guess)...)
45.
46. prescan = DiffEqFlux.sciml_train(loss_adjoint, guess, Optim.ParticleSwarmState(), cb = cb, maxiters = 1000)
47.
48. refined_guess = prescan.minimizer
49. cb(refined_guess,loss_adjoint(refined_guess)...)
50. res = DiffEqFlux.sciml_train(negative_loss_adjoint, refined_guess, ADAM(0.1), cb = cb, maxiters = 1000)
51.
52. cb(res.minimizer,loss_adjoint(res.minimizer)...)
53. plot(solve(remake(prob,p=res.minimizer),Tsit5(),saveat=0.0:0.1:10.0),ylim=(0,6))
RAW Paste Data