Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- using Flux, DiffEqFlux, DifferentialEquations, Plots
- u0 = Float32[2.; 0.]
- datasize = 30
- tspan = (0.0f0,1.5f0)
- function trueODEfunc(du,u,p,t)
- true_A = [-0.1 2.0; -2.0 -0.1]
- du .= ((u.^3)'true_A)'
- end
- t = range(tspan[1],tspan[2],length=datasize)
- prob = ODEProblem(trueODEfunc,u0,tspan)
- ode_data = Array(solve(prob,Tsit5(),saveat=t))
- sol = solve(prob, Tsit5())
- plot(sol)
- l1 = Chain(
- x -> x.^3,
- Dense(2, 10, tanh)
- )
- dudt = Chain(
- Dense(10, 10, tanh),
- Dense(10, 10, tanh),
- Dense(10, 10, tanh)
- )
- l3 = Chain(
- Dense(10, 2)
- )
- nn_ode(x) = neural_ode(dudt,x,tspan,Tsit5(),saveat=t,reltol=1e-7,abstol=1e-9)
- m = Chain(
- l1,
- nn_ode,
- l3
- )
- function predict_n_ode()
- m(u0)
- end
- loss_n_ode() = sum(abs2,ode_data .- predict_n_ode())
- data = Iterators.repeated((), 1000)
- opt = ADAM(0.1)
- cb = function () #callback function to observe training
- display(loss_n_ode())
- # plot current prediction against data
- cur_pred = Flux.data(predict_n_ode())
- pl = scatter(t,ode_data[1,:],label="data")
- scatter!(pl,t,cur_pred[1,:],label="prediction")
- display(plot(pl))
- end
- # Display the ODE with the initial parameter values.
- cb()
- Flux.train!(loss_n_ode, params(l1,dudt,l3), data, opt, cb = cb)
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement