Advertisement
Guest User

CarLLAVA TEST

a guest
Jan 27th, 2025
74
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 6.03 KB | Source Code | 0 0
  1. # Import necessary libraries
  2. import torch
  3. from torch.utils.data import DataLoader
  4. from transformers import CLIPVisionModel, AutoModelForCausalLM
  5. from peft import LoraConfig, get_peft_model
  6.  
  7. # 1. Model Definition
  8. class CarLLaVA(torch.nn.Module):
  9.     def __init__(self):
  10.         super().__init__()
  11.        
  12.         # 1. Vision Encoder (CLIPViT)
  13.         self.vision_encoder = CLIPVisionModel.from_pretrained("llava-hf/CLIP-ViT-L-336px")
  14.         for param in self.vision_encoder.parameters():
  15.             param.requires_grad = False  # Freeze the vision encoder
  16.        
  17.         # 2. Decoder (LLaMA with LoRA)
  18.         self.llm = AutoModelForCausalLM.from_pretrained("TinyLlama/TinyLlama-1.1B")
  19.         lora_config = LoraConfig(
  20.             r=8, lora_alpha=32,
  21.             target_modules=["q_proj", "v_proj"],  # Layers to adapt
  22.             modules_to_save=["waypoint_head", "path_head"]  # Output heads
  23.         )
  24.         self.llm = get_peft_model(self.llm, lora_config)
  25.        
  26.         # 3. Projection layers
  27.         self.vision_proj = torch.nn.Linear(768, 4096)  # Project visual features
  28.        
  29.         # 4. Two-layer MLPs for speed and target points
  30.         self.speed_mlp = torch.nn.Sequential(
  31.             torch.nn.Linear(1, 4096),  # First layer
  32.             torch.nn.ReLU(),           # Non-linear activation
  33.             torch.nn.Linear(4096, 4096)  # Second layer
  34.         )
  35.        
  36.         self.target_mlp = torch.nn.Sequential(
  37.             torch.nn.Linear(2, 4096),  # First layer
  38.             torch.nn.ReLU(),           # Non-linear activation
  39.             torch.nn.Linear(4096, 4096)  # Second layer
  40.         )
  41.        
  42.         # 5. Learnable queries
  43.         self.path_queries = torch.nn.Parameter(torch.randn(10, 4096))  # 10 path queries
  44.         self.wp_queries = torch.nn.Parameter(torch.randn(20, 4096))    # 20 waypoint queries
  45.        
  46.         # 6. Output heads
  47.         self.waypoint_head = torch.nn.Linear(4096, 2)  # Predicts Δx, Δy for waypoints
  48.         self.path_head = torch.nn.Linear(4096, 2)      # Predicts path coordinates
  49.  
  50.     def forward(self, image, speed, target_points):
  51.         # 1. Extract visual features
  52.         vision_features = self.vision_encoder(image).last_hidden_state  # (B, N, 768)
  53.         vision_features = self.vision_proj(vision_features)             # (B, N, 4096)
  54.        
  55.         # 2. Encode speed and target points using two-layer MLPs
  56.         speed_emb = self.speed_mlp(speed.unsqueeze(-1))        # (B, 1, 4096)
  57.         target_emb = self.target_mlp(target_points.view(-1, 2))# (B*2, 4096)
  58.         target_emb = target_emb.view(-1, 2, 4096)              # (B, 2, 4096)
  59.        
  60.         # 3. Concatenate everything + queries
  61.         inputs = torch.cat([
  62.             vision_features,
  63.             speed_emb,
  64.             target_emb,
  65.             self.path_queries.unsqueeze(0).repeat(image.size(0), 1, 1),  # (B, 10, 4096)
  66.             self.wp_queries.unsqueeze(0).repeat(image.size(0), 1, 1)     # (B, 20, 4096)
  67.         ], dim=1)  # (B, N+1+2+10+20, 4096)
  68.        
  69.         # 4. Pass through the LLM
  70.         outputs = self.llm(inputs_embeds=inputs).last_hidden_state  # (B, seq_len, 4096)
  71.        
  72.         # 5. Extract outputs
  73.         path_pred = self.path_head(outputs[:, -30:-20, :])  # (B, 10, 2)
  74.         wp_pred = self.waypoint_head(outputs[:, -20:, :])   # (B, 20, 2)
  75.        
  76.         return path_pred, wp_pred
  77.  
  78. # 2. Training
  79. def train(model, train_loader, optimizer, scheduler, num_epochs=30):
  80.     model.train()  # Set the model to training mode
  81.    
  82.     for epoch in range(num_epochs):
  83.         for batch in train_loader:
  84.             # 1. Retrieve batch data
  85.             image, speed, target_points, gt_path, gt_waypoints = batch
  86.            
  87.             # 2. Forward pass: Predict path and waypoints
  88.             path_pred, wp_pred = model(image, speed, target_points)
  89.            
  90.             # 3. Calculate loss
  91.             loss_path = torch.nn.MSELoss()(path_pred, gt_path)      # Loss for path
  92.             loss_wp = torch.nn.MSELoss()(wp_pred, gt_waypoints)     # Loss for waypoints
  93.             total_loss = 0.7 * loss_wp + 0.3 * loss_path            # Combine losses
  94.            
  95.             # 4. Backpropagation
  96.             optimizer.zero_grad()  # Reset gradients
  97.             total_loss.backward()  # Compute gradients
  98.             optimizer.step()       # Update weights
  99.            
  100.             # 5. Update learning rate
  101.             scheduler.step()
  102.        
  103.         # 6. Logs and evaluation
  104.         print(f"Epoch {epoch+1}/{num_epochs}, Loss: {total_loss.item()}")
  105.  
  106. # 3. Inference
  107. def infer(model, image, speed, target_points):
  108.     model.eval()  # Set the model to evaluation mode
  109.    
  110.     with torch.no_grad():  # Disable gradient calculation
  111.         # 1. Predict path and waypoints
  112.         path_pred, wp_pred = model(image, speed, target_points)
  113.        
  114.         # 2. Return predictions
  115.         return path_pred, wp_pred
  116.  
  117. # Training configuration
  118. def main():
  119.     # 1. Load dataset (simulated)
  120.     # Replace with your actual dataset using DataLoader
  121.     train_loader = DataLoader(dataset, batch_size=16, shuffle=True)
  122.    
  123.     # 2. Initialize model
  124.     model = CarLLaVA()
  125.    
  126.     # 3. Configure optimizer and scheduler
  127.     optimizer = torch.optim.AdamW(model.parameters(), lr=3e-5, weight_decay=0.1)
  128.     scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=30)
  129.    
  130.     # 4. Train the model
  131.     train(model, train_loader, optimizer, scheduler, num_epochs=30)
  132.    
  133.     # 5. Save the trained model
  134.     torch.save(model.state_dict(), "carllava.pth")
  135.    
  136.     # 6. Inference example
  137.     image = torch.randn(1, 3, 336, 336)  # Example image (1 image, 3 channels, 336x336)
  138.     speed = torch.tensor([50.0])         # Example speed (50 km/h)
  139.     target_points = torch.tensor([[10.0, 20.0], [30.0, 40.0]])  # Example target points
  140.    
  141.     path_pred, wp_pred = infer(model, image, speed, target_points)
  142.     print("Predicted path:", path_pred)
  143.     print("Predicted waypoints:", wp_pred)
  144.  
  145. # Execute training and inference
  146. if __name__ == "__main__":
  147.     main()
  148.  
Tags: CarLLAVA
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement