Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- # Import necessary libraries
- import torch
- from torch.utils.data import DataLoader
- from transformers import CLIPVisionModel, AutoModelForCausalLM
- from peft import LoraConfig, get_peft_model
- # 1. Model Definition
- class CarLLaVA(torch.nn.Module):
- def __init__(self):
- super().__init__()
- # 1. Vision Encoder (CLIPViT)
- self.vision_encoder = CLIPVisionModel.from_pretrained("llava-hf/CLIP-ViT-L-336px")
- for param in self.vision_encoder.parameters():
- param.requires_grad = False # Freeze the vision encoder
- # 2. Decoder (LLaMA with LoRA)
- self.llm = AutoModelForCausalLM.from_pretrained("TinyLlama/TinyLlama-1.1B")
- lora_config = LoraConfig(
- r=8, lora_alpha=32,
- target_modules=["q_proj", "v_proj"], # Layers to adapt
- modules_to_save=["waypoint_head", "path_head"] # Output heads
- )
- self.llm = get_peft_model(self.llm, lora_config)
- # 3. Projection layers
- self.vision_proj = torch.nn.Linear(768, 4096) # Project visual features
- # 4. Two-layer MLPs for speed and target points
- self.speed_mlp = torch.nn.Sequential(
- torch.nn.Linear(1, 4096), # First layer
- torch.nn.ReLU(), # Non-linear activation
- torch.nn.Linear(4096, 4096) # Second layer
- )
- self.target_mlp = torch.nn.Sequential(
- torch.nn.Linear(2, 4096), # First layer
- torch.nn.ReLU(), # Non-linear activation
- torch.nn.Linear(4096, 4096) # Second layer
- )
- # 5. Learnable queries
- self.path_queries = torch.nn.Parameter(torch.randn(10, 4096)) # 10 path queries
- self.wp_queries = torch.nn.Parameter(torch.randn(20, 4096)) # 20 waypoint queries
- # 6. Output heads
- self.waypoint_head = torch.nn.Linear(4096, 2) # Predicts Δx, Δy for waypoints
- self.path_head = torch.nn.Linear(4096, 2) # Predicts path coordinates
- def forward(self, image, speed, target_points):
- # 1. Extract visual features
- vision_features = self.vision_encoder(image).last_hidden_state # (B, N, 768)
- vision_features = self.vision_proj(vision_features) # (B, N, 4096)
- # 2. Encode speed and target points using two-layer MLPs
- speed_emb = self.speed_mlp(speed.unsqueeze(-1)) # (B, 1, 4096)
- target_emb = self.target_mlp(target_points.view(-1, 2))# (B*2, 4096)
- target_emb = target_emb.view(-1, 2, 4096) # (B, 2, 4096)
- # 3. Concatenate everything + queries
- inputs = torch.cat([
- vision_features,
- speed_emb,
- target_emb,
- self.path_queries.unsqueeze(0).repeat(image.size(0), 1, 1), # (B, 10, 4096)
- self.wp_queries.unsqueeze(0).repeat(image.size(0), 1, 1) # (B, 20, 4096)
- ], dim=1) # (B, N+1+2+10+20, 4096)
- # 4. Pass through the LLM
- outputs = self.llm(inputs_embeds=inputs).last_hidden_state # (B, seq_len, 4096)
- # 5. Extract outputs
- path_pred = self.path_head(outputs[:, -30:-20, :]) # (B, 10, 2)
- wp_pred = self.waypoint_head(outputs[:, -20:, :]) # (B, 20, 2)
- return path_pred, wp_pred
- # 2. Training
- def train(model, train_loader, optimizer, scheduler, num_epochs=30):
- model.train() # Set the model to training mode
- for epoch in range(num_epochs):
- for batch in train_loader:
- # 1. Retrieve batch data
- image, speed, target_points, gt_path, gt_waypoints = batch
- # 2. Forward pass: Predict path and waypoints
- path_pred, wp_pred = model(image, speed, target_points)
- # 3. Calculate loss
- loss_path = torch.nn.MSELoss()(path_pred, gt_path) # Loss for path
- loss_wp = torch.nn.MSELoss()(wp_pred, gt_waypoints) # Loss for waypoints
- total_loss = 0.7 * loss_wp + 0.3 * loss_path # Combine losses
- # 4. Backpropagation
- optimizer.zero_grad() # Reset gradients
- total_loss.backward() # Compute gradients
- optimizer.step() # Update weights
- # 5. Update learning rate
- scheduler.step()
- # 6. Logs and evaluation
- print(f"Epoch {epoch+1}/{num_epochs}, Loss: {total_loss.item()}")
- # 3. Inference
- def infer(model, image, speed, target_points):
- model.eval() # Set the model to evaluation mode
- with torch.no_grad(): # Disable gradient calculation
- # 1. Predict path and waypoints
- path_pred, wp_pred = model(image, speed, target_points)
- # 2. Return predictions
- return path_pred, wp_pred
- # Training configuration
- def main():
- # 1. Load dataset (simulated)
- # Replace with your actual dataset using DataLoader
- train_loader = DataLoader(dataset, batch_size=16, shuffle=True)
- # 2. Initialize model
- model = CarLLaVA()
- # 3. Configure optimizer and scheduler
- optimizer = torch.optim.AdamW(model.parameters(), lr=3e-5, weight_decay=0.1)
- scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=30)
- # 4. Train the model
- train(model, train_loader, optimizer, scheduler, num_epochs=30)
- # 5. Save the trained model
- torch.save(model.state_dict(), "carllava.pth")
- # 6. Inference example
- image = torch.randn(1, 3, 336, 336) # Example image (1 image, 3 channels, 336x336)
- speed = torch.tensor([50.0]) # Example speed (50 km/h)
- target_points = torch.tensor([[10.0, 20.0], [30.0, 40.0]]) # Example target points
- path_pred, wp_pred = infer(model, image, speed, target_points)
- print("Predicted path:", path_pred)
- print("Predicted waypoints:", wp_pred)
- # Execute training and inference
- if __name__ == "__main__":
- main()
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement