Advertisement
Guest User

Untitled

a guest
Feb 25th, 2025
90
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 2.22 KB | Source Code | 0 0
  1. import os
  2.  
  3. class JanusModelLoader:
  4.     def __init__(self):
  5.         pass
  6.        
  7.     @classmethod
  8.     def INPUT_TYPES(s):
  9.         return {
  10.             "required": {
  11.                 "model_name": (["deepseek-ai/Janus-Pro-1B", "deepseek-ai/Janus-Pro-7B","Janus-Pro-1B","Janus-Pro-7B"],),
  12.                 "device": (["cuda","cpu"],),
  13.             },
  14.         }
  15.    
  16.     RETURN_TYPES = ("JANUS_MODEL", "JANUS_PROCESSOR")
  17.     RETURN_NAMES = ("model", "processor")
  18.     FUNCTION = "load_model"
  19.     CATEGORY = "Janus-Pro"
  20.  
  21.     def load_model(self, model_name, device):
  22.         try:
  23.             from janus.models import MultiModalityCausalLM, VLChatProcessor
  24.             from transformers import AutoModelForCausalLM
  25.             import torch
  26.         except ImportError:
  27.             raise ImportError("Please install Janus using 'pip install -r requirements.txt'")
  28.  
  29.         #device = "cuda" if torch.cuda.is_available() else "cpu"
  30.         try:
  31.             if torch.cuda.is_available() and device=="cuda": device = "cuda"
  32.             elif device=="cpu": device = "cpu"
  33.         except ImportError:
  34.             raise ImportError("[Janus model loader] Error selecting device: {device} ")
  35.  
  36.         try:
  37.             dtype = torch.bfloat16
  38.             torch.zeros(1, dtype=dtype, device=device)
  39.         except RuntimeError:
  40.             dtype = torch.float16
  41.  
  42.         # 获取ComfyUI根目录
  43.         comfy_path = os.path.dirname(os.path.dirname(os.path.dirname(os.path.dirname(__file__))))
  44.         # 构建模型路径
  45.         model_dir = os.path.join(comfy_path,
  46.                                "models",
  47.                                "Janus-Pro",
  48.                                os.path.basename(model_name))
  49.         if not os.path.exists(model_dir):
  50.             raise ValueError(f"Local model not found at {model_dir}. Please download the model and place it in the ComfyUI/models/Janus-Pro folder.")
  51.            
  52.         vl_chat_processor = VLChatProcessor.from_pretrained(model_dir)
  53.        
  54.         vl_gpt = AutoModelForCausalLM.from_pretrained(
  55.             model_dir,
  56.             trust_remote_code=True
  57.         )
  58.        
  59.         vl_gpt = vl_gpt.to(dtype).to(device).eval()
  60.        
  61.         return (vl_gpt, vl_chat_processor)
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement