Advertisement
Guest User

Untitled

a guest
Dec 10th, 2019
115
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 1.24 KB | None | 0 0
  1. using Flux, DiffEqFlux, DifferentialEquations, Plots
  2.  
  3. u0 = Float32[2.; 0.]
  4. datasize = 30
  5. tspan = (0.0f0,1.5f0)
  6.  
  7. function trueODEfunc(du,u,p,t)
  8. true_A = [-0.1 2.0; -2.0 -0.1]
  9. du .= ((u.^3)'true_A)'
  10. end
  11. t = range(tspan[1],tspan[2],length=datasize)
  12. prob = ODEProblem(trueODEfunc,u0,tspan)
  13. ode_data = Array(solve(prob,Tsit5(),saveat=t))
  14.  
  15. sol = solve(prob, Tsit5())
  16. plot(sol)
  17.  
  18. l1 = Chain(
  19. x -> x.^3,
  20. Dense(2, 10, tanh)
  21. )
  22.  
  23. dudt = Chain(
  24. Dense(10, 10, tanh),
  25. Dense(10, 10, tanh),
  26. Dense(10, 10, tanh)
  27. )
  28.  
  29. l3 = Chain(
  30. Dense(10, 2)
  31. )
  32.  
  33. nn_ode(x) = neural_ode(dudt,x,tspan,Tsit5(),saveat=t,reltol=1e-7,abstol=1e-9)
  34.  
  35. m = Chain(
  36. l1,
  37. nn_ode,
  38. l3
  39. )
  40.  
  41. function predict_n_ode()
  42. m(u0)
  43. end
  44. loss_n_ode() = sum(abs2,ode_data .- predict_n_ode())
  45.  
  46. data = Iterators.repeated((), 1000)
  47. opt = ADAM(0.1)
  48. cb = function () #callback function to observe training
  49. display(loss_n_ode())
  50. # plot current prediction against data
  51. cur_pred = Flux.data(predict_n_ode())
  52. pl = scatter(t,ode_data[1,:],label="data")
  53. scatter!(pl,t,cur_pred[1,:],label="prediction")
  54. display(plot(pl))
  55. end
  56.  
  57. # Display the ODE with the initial parameter values.
  58. cb()
  59.  
  60. Flux.train!(loss_n_ode, params(l1,dudt,l3), data, opt, cb = cb)
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement