Advertisement
Guest User

Untitled

a guest
Dec 2nd, 2023
79
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 4.71 KB | None | 0 0
  1.  
  2. using Zygote, ComponentArrays, Lux, SciMLSensitivity, Optimisers,
  3. OrdinaryDiffEq, Random, Statistics, OneHotArrays, InteractiveUtils
  4. import MLDatasets: MNIST
  5. import MLUtils: DataLoader, splitobs
  6.  
  7. # using LuxCUDA
  8. # CUDA.allowscalar(false)
  9.  
  10. using LuxAMDGPU
  11.  
  12.  
  13. function loadmnist(batchsize, train_split)
  14. # Load MNIST: Only 1500 for demonstration purposes
  15. N = 1500
  16. dataset = MNIST(; split=:train)
  17. imgs = dataset.features[:, :, 1:N]
  18. labels_raw = dataset.targets[1:N]
  19.  
  20. # Process images into (H,W,C,BS) batches
  21. x_data = Float32.(reshape(imgs, size(imgs, 1), size(imgs, 2), 1, size(imgs, 3)))
  22. y_data = onehotbatch(labels_raw, 0:9)
  23. (x_train, y_train), (x_test, y_test) = splitobs((x_data, y_data); at=train_split)
  24.  
  25. return (
  26. # Use DataLoader to automatically minibatch and shuffle the data
  27. DataLoader(collect.((x_train, y_train)); batchsize, shuffle=true),
  28. # Don't shuffle the test data
  29. DataLoader(collect.((x_test, y_test)); batchsize, shuffle=false))
  30. end
  31.  
  32.  
  33. struct NeuralODE{M <: Lux.AbstractExplicitLayer, So, Se, T, K} <:
  34. Lux.AbstractExplicitContainerLayer{(:model,)}
  35. model::M
  36. solver::So
  37. sensealg::Se
  38. tspan::T
  39. kwargs::K
  40. end
  41.  
  42. function NeuralODE(model::Lux.AbstractExplicitLayer; solver=Tsit5(), tspan=(0.0f0, 1.0f0),
  43. sensealg=InterpolatingAdjoint(; autojacvec=ZygoteVJP()), kwargs...)
  44. return NeuralODE(model, solver, sensealg, tspan, kwargs)
  45. end
  46.  
  47.  
  48. function (n::NeuralODE)(x, ps, st)
  49. function dudt(u, p, t)
  50. u_, st = n.model(reshape(u, size(x)), p, st)
  51. return vec(u_)
  52. end
  53. prob = ODEProblem{false}(ODEFunction{false}(dudt), vec(x), n.tspan, ps)
  54. return solve(prob, n.solver; sensealg=n.sensealg, n.kwargs...), st
  55. end
  56.  
  57. @views diffeqsol_to_array(l::Int, x::ODESolution) = reshape(last(x.u), (l, :))
  58. @views diffeqsol_to_array(l::Int, x::AbstractMatrix) = reshape(x[:, end], (l, :))
  59.  
  60.  
  61. function create_model(model_fn=NeuralODE; dev=gpu_device(), use_named_tuple::Bool=false,
  62. sensealg=InterpolatingAdjoint(; autojacvec=ZygoteVJP()))
  63. # Construct the Neural ODE Model
  64. model = Chain(FlattenLayer(),
  65. Dense(784 => 20, tanh),
  66. model_fn(Chain(Dense(20 => 10, tanh), Dense(10 => 10, tanh), Dense(10 => 20, tanh));
  67. save_everystep=false, reltol=1.0f-3, abstol=1.0f-3, save_start=false,
  68. sensealg),
  69. Base.Fix1(diffeqsol_to_array, 20),
  70. Dense(20 => 10))
  71.  
  72. rng = Random.default_rng()
  73. Random.seed!(rng, 0)
  74.  
  75. ps, st = Lux.setup(rng, model)
  76. ps = (use_named_tuple ? ps : ComponentArray(ps)) |> dev
  77. st = st |> dev
  78.  
  79. return model, ps, st
  80. end
  81.  
  82.  
  83.  
  84. logitcrossentropy(y_pred, y) = mean(-sum(y .* logsoftmax(y_pred); dims=1))
  85.  
  86. function loss(x, y, model, ps, st)
  87. y_pred, st = model(x, ps, st)
  88. return logitcrossentropy(y_pred, y), st
  89. end
  90.  
  91. function accuracy(model, ps, st, dataloader; dev=gpu_device())
  92. total_correct, total = 0, 0
  93. st = Lux.testmode(st)
  94. cpu_dev = cpu_device()
  95. for (x, y) in dataloader
  96. target_class = onecold(y)
  97. predicted_class = onecold(cpu_dev(first(model(dev(x), ps, st))))
  98. total_correct += sum(target_class .== predicted_class)
  99. total += length(target_class)
  100. end
  101. return total_correct / total
  102. end
  103.  
  104.  
  105.  
  106. function train(model_function; cpu::Bool=false, kwargs...)
  107. dev = cpu ? cpu_device() : gpu_device()
  108. model, ps, st = create_model(model_function; dev, kwargs...)
  109.  
  110. # Training
  111. train_dataloader, test_dataloader = loadmnist(128, 0.9)
  112.  
  113. opt = Adam(0.001f0)
  114. st_opt = Optimisers.setup(opt, ps)
  115.  
  116.  
  117. ### Warmup the Model
  118. img = dev(train_dataloader.data[1][:, :, :, 1:1])
  119. lab = dev(train_dataloader.data[2][:, 1:1])
  120. loss(img, lab, model, ps, st)
  121. (l, _), back = pullback(p -> loss(img, lab, model, p, st), ps)
  122. back((one(l), nothing))
  123.  
  124. ### Lets train the model
  125. nepochs = 9
  126. for epoch in 1:nepochs
  127. stime = time()
  128. for (x, y) in train_dataloader
  129. x = dev(x)
  130. y = dev(y)
  131. (l, st), back = pullback(p -> loss(x, y, model, p, st), ps)
  132. ### We need to add `nothing`s equal to the number of returned values - 1
  133. gs = back((one(l), nothing))[1]
  134. st_opt, ps = Optimisers.update(st_opt, ps, gs)
  135. # println("c")
  136. end
  137. ttime = time() - stime
  138.  
  139. println("[$epoch/$nepochs] \t Time $(round(ttime; digits=2))s \t Training Accuracy: " *
  140. "$(round(accuracy(model, ps, st, train_dataloader; dev) * 100; digits=2))% \t " *
  141. "Test Accuracy: $(round(accuracy(model, ps, st, test_dataloader; dev) * 100; digits=2))%")
  142. end
  143. end
  144.  
  145. train(NeuralODE)
  146.  
  147.  
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement