Guest User

Untitled

a guest
Sep 20th, 2020
21
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
  1. using DifferentialEquations, Flux, Optim, DiffEqFlux, Plots
  2. function lotka_volterra(du,u,p,t)
  3.   x, y = u
  4.   α, β, δ, γ = p
  5.   du[1] = dx = α*x - β*x*y
  6.   du[2] = dy = -δ*y + γ*x*y
  7. end
  8. u0 = [1.0,1.0]
  9. tspan = (0.0,10.0)
  10. p = [1.5,1.0,3.0,1.0]
  11. prob = ODEProblem(lotka_volterra,u0,tspan,p)
  12. sol = solve(prob,Tsit5(),saveat=0.1)
  13. using Plots
  14. plot(sol)
  15. measurements = Array(sol)
  16.  
  17. function predict_adjoint(p) # Our 1-layer neural network
  18.   Array(concrete_solve(prob,Tsit5(),u0,p,saveat=0.0:0.1:10.0))
  19. end
  20.  
  21. function loss_adjoint(p)
  22.   prediction = predict_adjoint(p)
  23.   loss = sum(abs2,prediction-measurements)
  24.   loss,prediction
  25. end
  26.  
  27. function negative_loss_adjoint(p)
  28.   prediction = predict_adjoint(p)
  29.   loss = sum(abs2,prediction-measurements)
  30.   loss,prediction
  31. end
  32.  
  33. cb = function (p,l,pred) #callback function to observe training
  34.   display(l)
  35.   # using `remake` to re-create our `prob` with current parameters `p`
  36.   a = plot(solve(remake(prob,p=p),Tsit5()),ylim=(0,6))
  37.   plot!(a,(0.0:0.1:10.0,measurements[1,:]),label="rabbit data")
  38.   display(plot!(a,(0.0:0.1:10.0,measurements[2,:]),label="wolves data"))
  39.   return false # Tell it to not halt the optimization. If return true, then optimization stops
  40. end
  41.  
  42. guess = [1.15814, 1.898547, 3.26339, 0.443906]
  43. # Display the ODE with the initial parameter values.
  44. cb(guess,loss_adjoint(guess)...)
  45.  
  46. prescan = DiffEqFlux.sciml_train(loss_adjoint, guess, Optim.ParticleSwarmState(), cb = cb, maxiters = 1000)
  47.  
  48. refined_guess = prescan.minimizer
  49. cb(refined_guess,loss_adjoint(refined_guess)...)
  50. res = DiffEqFlux.sciml_train(negative_loss_adjoint, refined_guess, ADAM(0.1), cb = cb, maxiters = 1000)
  51.  
  52. cb(res.minimizer,loss_adjoint(res.minimizer)...)
  53. plot(solve(remake(prob,p=res.minimizer),Tsit5(),saveat=0.0:0.1:10.0),ylim=(0,6))
RAW Paste Data