Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- using Lux, Reactant, Random, Statistics, LinearAlgebra
- # MODEL DEFINITION
- mutable struct MLLayer{StateType, T<:AbstractFloat} <: AbstractLuxLayer
- dt::T # Time step
- epsilon::T # Perturbation parameter
- N::Int # Hidden dimension
- N_input::Int # Input dimension
- batch_size::Int # Batch size
- end
- statetype(::MLLayer{StateType, T}) where {StateType, T} = StateType
- paramtype(::MLLayer{StateType, T}) where {StateType, T} = T
- # Default to Float32 if dtype not specified
- function MLLayer(T::Type = Float32, N::Int = 32, N_input::Int = 80;
- epsilon::Real=1e-4,
- dt::Real=0.1,
- dtype::Type=Float32,
- batch_size::Int=256)
- return MLLayer{dtype, T}(dt, epsilon, N, N_input, batch_size)
- end
- function Lux.initialparameters(rng::AbstractRNG, layer::MLLayer{T}) where {T}
- B = layer.batch_size
- N = layer.N
- N_input = layer.N_input
- init_scale = one(T)
- ϵ = layer.epsilon
- # W: Positive semi-definite matrix (N, N) - matches Python W
- W_temp = rand(rng, T, N, N) * (init_scale * (T(0.5))) .+ (init_scale * (T(0.5))) # Between 0.5 and 1.0
- W = (T(0.1) * init_scale / sqrt(T(N))) * (W_temp' * W_temp)
- # W_in: Input weight matrix (N, N_input)
- W_in = (rand(rng, T, N, N_input) * T(2) .- T(1)) * init_scale / sqrt(T(N_input))
- # b: Bias vector (N,)
- b = ones(T, N) * (init_scale / sqrt(T(N)))
- # alpha: Per-neuron alpha (N,)
- α = rand(rng, T, N) * (init_scale * T(0.9)) .+ (init_scale * T(0.1)) # Between 0.1 and 1.0
- # beta: Scalar parameter
- β = (one(T) * init_scale) / T(300)
- return (;W, W_in, b, α, β, ϵ, B)
- end
- function Lux.initialstates(rng::AbstractRNG, layer::MLLayer)
- T = paramtype(layer)
- N = layer.N
- B = layer.batch_size
- # State = (ϕ, π) for each batch
- ϕ = nothing
- π = nothing
- return (; ϕ, π)
- end
- move(states, parameters, device) = states |> device, parameters |> device
- export move
- function (layer::MLLayer)(input_sequence, params, state)
- # input is in (N_in, t, B)
- N_in, DT, B = size(input_sequence)
- alloc = state.alloc
- (;W_in) = params
- # Apply the input layer to all timesteps and batches
- alloc.W_dot_us .= reshape(W_in * reshape(input_sequence, N_in, :), layer.N, DT, B) # (N, DT*B), flattened batches and time for GEMM
- return apply_leapfrog!(layer, state, params)
- end
- abstract type OutputMode end
- struct Classification <: OutputMode end
- struct Regression <: OutputMode end
- struct MLModel{L, Classification,D,O} <: AbstractLuxContainerLayer{(:ml_layer, :dropout, :output_layer)}
- ml_layer::L
- dropout::D
- output_layer::O
- output_step::Int
- end
- function MLModel(T::Type = Float32, hidden_size::Int = 32, N_input::Int = 80, output_dim::Int = 20;
- classification::Bool=true,
- output_step::Int=1,
- epsilon::Real=0.0001,
- dt::Real=0.1,
- dropout_rate::Real=0.05,
- dtype::Type=Float32,
- batch_size::Int=256)
- ml_layer = MLLayer(T, hidden_size, N_input;
- epsilon,
- dt,
- dtype,
- batch_size)
- dropout = Lux.Dropout(dropout_rate)
- # Output layer takes concatenated input + hidden
- output_layer = Lux.Dense(N_input + hidden_size, output_dim)
- c_mode = classification ? Classification() : Regression()
- return MLModel{typeof(ml_layer), c_mode, typeof(dropout), typeof(output_layer)}(ml_layer, dropout, output_layer, output_step)
- end
- (m::MLModel)(x, p, s) = Lux.apply(m, x, p, s)
- output_mode(::MLModel{L, C}) where {L, C} = C
- function Lux.apply(model::MLModel, u_batch, parameters, state)
- (;ml_layer, dropout, output_layer) = model
- layer_parameters = parameters.ml_layer
- layer_state = state.ml_layer
- layer_state = ml_layer(u_batch, layer_parameters, layer_state)
- y, dropout_state = dropout(layer_state.ϕ, parameters.dropout, state.dropout)
- y = cat(u_batch, y; dims=1) # Concatenate input and hidden along feature dim (N_in + N, T, B)
- if output_mode(model) isa Classification
- # Classification mode: mean over time, then linear layer
- y = mean(y; dims=2) # (N_in + N, B)
- y = dropdims(y; dims=2) # (N_in + N, B) -> (N_in + N, B)
- y = output_layer(y, parameters.output_layer, state.output_layer) # (output_dim, B)
- else
- # CODE
- end
- updated_state = (; state..., dropout = dropout_state)
- return y, updated_state
- end
- function allocate_workspace(model::MLModel, state, input)
- layer = model.ml_layer
- T = paramtype(layer)
- N = layer.N
- B = layer.batch_size
- time_size = size(input, 2)
- ϕ = zeros(T, layer.N, time_size, B)
- π = zeros(T, layer.N, B)
- _ml_layer = (; ϕ, π)
- return (; state..., ml_layer = _ml_layer)
- end
- # FORWARD PASS DEFINITION
- function apply_no_prealloc(model, u_batch, parameters, state)
- (;ml_layer, dropout, output_layer) = model
- layer_parameters = parameters.ml_layer
- layer_state = state.ml_layer
- layer_state = layer_no_prealloc(ml_layer, u_batch, layer_parameters, layer_state)
- y, dropout_state = dropout(layer_state.ϕ, parameters.dropout, state.dropout)
- y = cat(u_batch, y; dims=1) # Concatenate input and hidden along feature dim (N_in + N, T, B)
- if output_mode(model) isa Classification
- # Classification mode: mean over time, then linear layer
- y = mean(y; dims=2) # (N_in + N, B)
- y = dropdims(y; dims=2) # (N_in + N, B) -> (N_in + N, B)
- y = output_layer(y, parameters.output_layer, state.output_layer) # (output_dim, B)
- else
- # CODE
- end
- updated_state = (; state..., dropout = dropout_state)
- return y, updated_state
- end
- function layer_no_prealloc(layer, input_sequence, params, state)
- # input is in (N_in, t, B)
- N_in, DT, B = size(input_sequence)
- (;W_in) = params
- # Apply the input layer to all timesteps and batches
- W_dot_us = reshape(W_in * reshape(input_sequence, N_in, :), layer.N, DT, B) # (N, DT*B), flattened batches and time for GEMM
- return @trace apply_leapfrog_no_prealloc!(layer, state, params, W_dot_us)
- end
- function apply_leapfrog_no_prealloc!(layer, state, params, W_dot_us)
- _, Tlen, _ = size(W_dot_us)
- for t in 1:Tlen
- @inline leapfrog_time_step_no_prealloc!(layer, state, params, t, W_dot_us)
- end
- return state
- end
- function leapfrog_time_step_no_prealloc!(layer, state, params, t, W_dot_us)
- (;ϕ, π) = state
- (;W, b, α, β, ϵ) = params
- dt = layer.dt
- half_dt = dt / 2
- W_in_u = @view W_dot_us[:, t, :]
- ϕ_t = @view ϕ[:, t, :]
- ϕhalf = ϕ_t .+ half_dt .* π
- grad = grad_state_potential_hamiltonian_no_prealloc(ϕhalf, W_in_u, W, b, α, β)
- π .-= dt .* grad
- ϕ_t .= ϕhalf .+ half_dt .* π
- π .+= ϵ
- return state
- end
- function grad_state_potential_hamiltonian_no_prealloc(ϕ::AbstractArray{T}, W_in_u, W, b, alpha, beta) where T
- # Gradient of norm term: ∂/∂ϕ [1/2 * α * ϕ²] = α * ϕ
- tβϕ = @. tanh(beta * ϕ)
- state_grad = @. alpha * ϕ
- sech2β = @. T(1) - tβϕ^2 # sech²(βϕ)
- W_tanh_ϕ = W*tβϕ
- WT_tanh_ϕ = W'*tβϕ
- state_grad += @. (T(1) / beta) * T(0.5) * sech2β * (W_tanh_ϕ + WT_tanh_ϕ)
- # Gradient of bias term: ∂/∂ϕ [b^T tanh(ϕ)] = b * sech²(ϕ)
- state_grad += @. b * (T(1) - tanh(ϕ)^2)
- # Gradient of input term: ∂/∂ϕ [tanh(ϕ)^T W_in_u] = W_in_u * sech²(ϕ)
- state_grad += @. W_in_u * (T(1) - tanh(ϕ)^2)
- return state_grad
- end
- # DATA SETUP AND COMPILE
- x_data = rand(Float32, 80,339,256)
- model = MLModel()
- parameters, state = Lux.setup(MersenneTwister(), model)
- state = Lux.testmode(state)
- dev = reactant_device()
- state = allocate_workspace(model, state, x_data)
- parameters = parameters |> dev
- state = state |> dev
- x_data = x_data |> dev
- init_time = time()
- f = @compile apply_no_prealloc(model, x_data, parameters, state)
- println("Compiled Reactant apply in $(time() - init_time) seconds.")
Advertisement
Add Comment
Please, Sign In to add comment