Guest User

Reactant Example

a guest
Sep 25th, 2025
28
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Julia 8.31 KB | Software | 0 0
  1. using Lux, Reactant, Random, Statistics, LinearAlgebra
  2. # MODEL DEFINITION
  3.  
  4. mutable struct MLLayer{StateType, T<:AbstractFloat} <: AbstractLuxLayer
  5.     dt::T                     # Time step
  6.     epsilon::T                # Perturbation parameter
  7.     N::Int                    # Hidden dimension
  8.     N_input::Int              # Input dimension
  9.     batch_size::Int           # Batch size
  10. end
  11.  
  12. statetype(::MLLayer{StateType, T}) where {StateType, T} = StateType
  13. paramtype(::MLLayer{StateType, T}) where {StateType, T} = T
  14.  
  15. # Default to Float32 if dtype not specified
  16. function MLLayer(T::Type = Float32, N::Int = 32, N_input::Int = 80;
  17.                     epsilon::Real=1e-4,
  18.                     dt::Real=0.1,
  19.                     dtype::Type=Float32,
  20.                     batch_size::Int=256)
  21.     return MLLayer{dtype, T}(dt, epsilon, N, N_input, batch_size)
  22. end
  23.  
  24. function Lux.initialparameters(rng::AbstractRNG, layer::MLLayer{T}) where {T}
  25.     B = layer.batch_size    
  26.     N = layer.N
  27.     N_input = layer.N_input
  28.     init_scale = one(T)
  29.     ϵ = layer.epsilon
  30.  
  31.     # W: Positive semi-definite matrix (N, N) - matches Python W
  32.     W_temp = rand(rng, T, N, N) * (init_scale * (T(0.5))) .+ (init_scale * (T(0.5)))  # Between 0.5 and 1.0
  33.     W = (T(0.1) * init_scale / sqrt(T(N))) * (W_temp' * W_temp)
  34.  
  35.    # W_in: Input weight matrix (N, N_input)
  36.    W_in = (rand(rng, T, N, N_input) * T(2) .- T(1)) * init_scale / sqrt(T(N_input))
  37.  
  38.    # b: Bias vector (N,)
  39.    b = ones(T, N) * (init_scale / sqrt(T(N)))
  40.  
  41.    # alpha: Per-neuron alpha (N,)
  42.    α = rand(rng, T, N) * (init_scale * T(0.9)) .+ (init_scale * T(0.1))  # Between 0.1 and 1.0
  43.  
  44.    # beta: Scalar parameter
  45.    β = (one(T) * init_scale) / T(300)
  46.  
  47.    return (;W, W_in, b, α, β, ϵ, B)
  48. end
  49.  
  50. function Lux.initialstates(rng::AbstractRNG, layer::MLLayer)
  51.    T = paramtype(layer)
  52.    N = layer.N
  53.    B = layer.batch_size
  54.  
  55.    # State = (ϕ, π) for each batch
  56.    ϕ = nothing
  57.    π = nothing
  58.    return (; ϕ, π)
  59. end
  60.  
  61.  
  62. move(states, parameters, device) = states |> device, parameters |> device
  63. export move
  64.  
  65. function (layer::MLLayer)(input_sequence, params, state)
  66.    # input is in (N_in, t, B)
  67.    N_in, DT, B = size(input_sequence)
  68.    alloc = state.alloc
  69.    
  70.    (;W_in) = params
  71.  
  72.    # Apply the input layer to all timesteps and batches
  73.    alloc.W_dot_us .=  reshape(W_in * reshape(input_sequence, N_in, :), layer.N, DT, B)  # (N, DT*B), flattened batches and time for GEMM
  74.  
  75.    return apply_leapfrog!(layer, state, params)
  76. end
  77.  
  78. abstract type OutputMode end
  79. struct Classification <: OutputMode end
  80. struct Regression <: OutputMode end
  81.  
  82. struct MLModel{L, Classification,D,O} <: AbstractLuxContainerLayer{(:ml_layer, :dropout, :output_layer)}
  83.    ml_layer::L
  84.    dropout::D
  85.    output_layer::O
  86.    output_step::Int
  87. end
  88.  
  89.  
  90. function MLModel(T::Type = Float32, hidden_size::Int = 32, N_input::Int = 80, output_dim::Int = 20;
  91.                   classification::Bool=true,
  92.                   output_step::Int=1,
  93.                   epsilon::Real=0.0001,
  94.                   dt::Real=0.1,
  95.                   dropout_rate::Real=0.05,
  96.                   dtype::Type=Float32,
  97.                   batch_size::Int=256)
  98.  
  99.    ml_layer = MLLayer(T, hidden_size, N_input;                          
  100.                            epsilon,
  101.                            dt,
  102.                            dtype,
  103.                            batch_size)
  104.        
  105.    dropout = Lux.Dropout(dropout_rate)
  106.    
  107.    # Output layer takes concatenated input + hidden
  108.    output_layer = Lux.Dense(N_input + hidden_size, output_dim)
  109.    c_mode = classification ? Classification() : Regression()
  110.    return MLModel{typeof(ml_layer), c_mode, typeof(dropout), typeof(output_layer)}(ml_layer, dropout, output_layer, output_step)
  111. end
  112.  
  113. (m::MLModel)(x, p, s) = Lux.apply(m, x, p, s)
  114.  
  115. output_mode(::MLModel{L, C}) where {L, C} = C
  116.  
  117. function Lux.apply(model::MLModel, u_batch, parameters, state)
  118.    (;ml_layer, dropout, output_layer) = model
  119.    layer_parameters = parameters.ml_layer
  120.    layer_state = state.ml_layer
  121.  
  122.    layer_state = ml_layer(u_batch, layer_parameters, layer_state)
  123.    y, dropout_state = dropout(layer_state.ϕ, parameters.dropout, state.dropout)
  124.    y = cat(u_batch, y; dims=1)  # Concatenate input and hidden along feature dim (N_in + N, T, B)
  125.  
  126.    if output_mode(model) isa Classification
  127.        # Classification mode: mean over time, then linear layer
  128.        y = mean(y; dims=2)  # (N_in + N, B)
  129.        y = dropdims(y; dims=2)  # (N_in + N, B) -> (N_in + N, B)
  130.        y = output_layer(y, parameters.output_layer, state.output_layer)  # (output_dim, B)
  131.    else
  132.        # CODE
  133.    end
  134.  
  135.    updated_state = (; state..., dropout = dropout_state)
  136.    return y, updated_state
  137. end
  138.  
  139. function allocate_workspace(model::MLModel, state, input)
  140.    layer = model.ml_layer
  141.    T = paramtype(layer)
  142.    N = layer.N
  143.    B = layer.batch_size
  144.  
  145.    time_size = size(input, 2)
  146.    ϕ = zeros(T, layer.N, time_size, B)
  147.    π = zeros(T, layer.N, B)
  148.    _ml_layer = (; ϕ, π)
  149.    return (; state..., ml_layer = _ml_layer)
  150. end
  151.  
  152. # FORWARD PASS DEFINITION
  153. function apply_no_prealloc(model, u_batch, parameters, state)
  154.    (;ml_layer, dropout, output_layer) = model
  155.    layer_parameters = parameters.ml_layer
  156.    layer_state = state.ml_layer
  157.  
  158.    layer_state = layer_no_prealloc(ml_layer, u_batch, layer_parameters, layer_state)
  159.    y, dropout_state = dropout(layer_state.ϕ, parameters.dropout, state.dropout)
  160.    y = cat(u_batch, y; dims=1)  # Concatenate input and hidden along feature dim (N_in + N, T, B)
  161.  
  162.    if output_mode(model) isa Classification
  163.        # Classification mode: mean over time, then linear layer
  164.        y = mean(y; dims=2)  # (N_in + N, B)
  165.        y = dropdims(y; dims=2)  # (N_in + N, B) -> (N_in + N, B)
  166.        y = output_layer(y, parameters.output_layer, state.output_layer)  # (output_dim, B)
  167.    else
  168.        # CODE
  169.    end
  170.  
  171.    updated_state = (; state..., dropout = dropout_state)
  172.    return y, updated_state
  173. end
  174.  
  175. function layer_no_prealloc(layer, input_sequence, params, state)
  176.    # input is in (N_in, t, B)
  177.    N_in, DT, B = size(input_sequence)
  178.    
  179.    (;W_in) = params
  180.  
  181.    # Apply the input layer to all timesteps and batches
  182.    W_dot_us =  reshape(W_in * reshape(input_sequence, N_in, :), layer.N, DT, B)  # (N, DT*B), flattened batches and time for GEMM
  183.  
  184.    return @trace apply_leapfrog_no_prealloc!(layer, state, params, W_dot_us)
  185. end
  186.  
  187. function apply_leapfrog_no_prealloc!(layer, state, params, W_dot_us)
  188.    _, Tlen, _ = size(W_dot_us)
  189.    for t in 1:Tlen
  190.        @inline leapfrog_time_step_no_prealloc!(layer, state, params, t, W_dot_us)
  191.    end
  192.    return state
  193. end
  194.  
  195. function leapfrog_time_step_no_prealloc!(layer, state, params, t, W_dot_us)
  196.    (;ϕ, π) = state
  197.    (;W, b, α, β, ϵ) = params
  198.    dt = layer.dt
  199.  
  200.    half_dt = dt / 2
  201.  
  202.    W_in_u = @view W_dot_us[:, t, :]
  203.  
  204.    ϕ_t = @view ϕ[:, t, :]
  205.  
  206.    ϕhalf = ϕ_t .+ half_dt .* π
  207.  
  208.    grad = grad_state_potential_hamiltonian_no_prealloc(ϕhalf, W_in_u, W, b, α, β)
  209.    π .-= dt .* grad
  210.  
  211.    ϕ_t .= ϕhalf .+ half_dt .* π
  212.    π .+= ϵ
  213.  
  214.    return state
  215. end
  216.  
  217. function grad_state_potential_hamiltonian_no_prealloc(ϕ::AbstractArray{T}, W_in_u, W, b, alpha, beta) where T
  218.    # Gradient of norm term: ∂/∂ϕ [1/2 * α * ϕ²] = α * ϕ
  219.  
  220.    tβϕ = @. tanh(beta * ϕ)
  221.    state_grad = @. alpha * ϕ
  222.  
  223.    sech2β = @. T(1) - tβϕ^2  # sech²(βϕ)
  224.    
  225.    W_tanh_ϕ = W*tβϕ
  226.    WT_tanh_ϕ = W'*tβϕ
  227.  
  228.     state_grad += @. (T(1) / beta) * T(0.5) * sech2β * (W_tanh_ϕ + WT_tanh_ϕ)
  229.  
  230.     # Gradient of bias term: ∂/∂ϕ [b^T tanh(ϕ)] = b * sech²(ϕ)
  231.     state_grad += @. b * (T(1) - tanh(ϕ)^2)
  232.  
  233.     # Gradient of input term: ∂/∂ϕ [tanh(ϕ)^T W_in_u] = W_in_u * sech²(ϕ)
  234.     state_grad += @. W_in_u * (T(1) - tanh(ϕ)^2)
  235.     return state_grad
  236. end
  237.  
  238. # DATA SETUP AND COMPILE
  239. x_data = rand(Float32, 80,339,256)
  240. model = MLModel()
  241. parameters, state = Lux.setup(MersenneTwister(), model)
  242. state = Lux.testmode(state)
  243. dev = reactant_device()
  244. state = allocate_workspace(model, state, x_data)
  245. parameters = parameters |> dev
  246. state = state |> dev
  247. x_data = x_data |> dev
  248.  
  249.  
  250. init_time = time()
  251. f = @compile apply_no_prealloc(model, x_data, parameters, state)
  252. println("Compiled Reactant apply in $(time() - init_time) seconds.")
Advertisement
Add Comment
Please, Sign In to add comment