Advertisement
binarydepth

Untitled

Mar 12th, 2024
719
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 4.73 KB | Source Code | 0 0
  1. import os
  2. import time
  3.  
  4. import torch
  5. import torch.nn as nn
  6. import torch.optim as optim
  7. import torch.nn.utils.prune as prune
  8.  
  9. from modules import timer
  10. from modules import initialize_util
  11. from modules import initialize
  12.  
  13. startup_timer = timer.startup_timer
  14. startup_timer.record("launcher")
  15.  
  16. initialize.imports()
  17.  
  18. initialize.check_versions()
  19.  
  20. def optimize_model(model):
  21.     # Prune unnecessary connections from the model
  22.     prune.l1_unstructured(model.fc1, name='weight', amount=0.5)
  23.     # Split the model into smaller modules
  24.     model_chunk1 = nn.Sequential(model.fc1, model.relu)
  25.     model_chunk2 = model.fc2
  26.     # Share parameters between layers of the model
  27.     model.fc2.weight = nn.Parameter(model.fc1.weight.clone())
  28.     model.fc2.bias = nn.Parameter(model.fc1.bias.clone())
  29.  
  30. def create_model(input_dim, hidden_dim, output_dim):
  31.     # Define a simple neural network architecture
  32.     model = SimpleNet(input_dim, hidden_dim, output_dim)
  33.     # Optimize the model for memory usage
  34.     optimize_model(model)
  35.     return model
  36.  
  37. def process_prompt(prompt):
  38.     # Process the prompt here
  39.     pass
  40.  
  41. def webui():
  42.     from modules.shared_cmd_options import cmd_opts
  43.  
  44.     launch_api = cmd_opts.api
  45.     initialize.initialize()
  46.  
  47.     from modules import shared, ui_tempdir, script_callbacks, ui, progress, ui_extra_networks
  48.  
  49.     if shared.opts.clean_temp_dir_at_start:
  50.         ui_tempdir.cleanup_tmpdr()
  51.         startup_timer.record("cleanup temp dir")
  52.  
  53.     script_callbacks.before_ui_callback()
  54.     startup_timer.record("scripts before_ui_callback")
  55.  
  56.     if not cmd_opts.no_gradio_queue:
  57.         shared.demo.queue(64)
  58.  
  59.     gradio_auth_creds = list(initialize_util.get_gradio_auth_creds()) or None
  60.  
  61.     auto_launch_browser = False
  62.     if os.getenv('SD_WEBUI_RESTARTING') != '1':
  63.         if shared.opts.auto_launch_browser == "Remote" or cmd_opts.autolaunch:
  64.             auto_launch_browser = True
  65.         elif shared.opts.auto_launch_browser == "Local":
  66.             auto_launch_browser = not cmd_opts.webui_is_non_local
  67.  
  68.     app, local_url, share_url = shared.demo.launch(
  69.         share=cmd_opts.share,
  70.         server_name=initialize_util.gradio_server_name(),
  71.         server_port=cmd_opts.port,
  72.         ssl_keyfile=cmd_opts.tls_keyfile,
  73.         ssl_certfile=cmd_opts.tls_certfile,
  74.         ssl_verify=cmd_opts.disable_tls_verify,
  75.         debug=cmd_opts.gradio_debug,
  76.         auth=gradio_auth_creds,
  77.         inbrowser=auto_launch_browser,
  78.         prevent_thread_lock=True,
  79.         allowed_paths=cmd_opts.gradio_allowed_path,
  80.         app_kwargs={
  81.             "docs_url": "/docs",
  82.             "redoc_url": "/redoc",
  83.         },
  84.         root_path=f"/{cmd_opts.subpath}" if cmd_opts.subpath else "",
  85.     )
  86.  
  87.     startup_timer.record("gradio launch")
  88.  
  89.     # Remove the CORS middleware to enhance security
  90.     app.user_middleware = [x for x in app.user_middleware if x.cls.__name__ != 'CORSMiddleware']
  91.  
  92.     initialize_util.setup_middleware(app)
  93.  
  94.     progress.setup_progress_api(app)
  95.     ui.setup_ui_api(app)
  96.  
  97.     if launch_api:
  98.         create_api(app)
  99.  
  100.     ui_extra_networks.add_pages_to_demo(app)
  101.  
  102.     startup_timer.record("add APIs")
  103.  
  104.     with startup_timer.subcategory("app_started_callback"):
  105.         script_callbacks.app_started_callback(shared.demo, app)
  106.  
  107.     timer.startup_record = startup_timer.dump()
  108.     print(f"Startup time: {startup_timer.summary()}.")
  109.  
  110.     try:
  111.         while True:
  112.             server_command = shared.state.wait_for_server_command(timeout=5)
  113.             if server_command:
  114.                 if server_command in ("stop", "restart"):
  115.                     break
  116.                 else:
  117.                     print(f"Unknown server command: {server_command}")
  118.  
  119.             # Process prompt
  120.             prompt = # Get the prompt data
  121.             process_prompt(prompt)
  122.     except KeyboardInterrupt:
  123.         print('Caught KeyboardInterrupt, stopping...')
  124.         server_command = "stop"
  125.  
  126.     if server_command == "stop":
  127.         print("Stopping server...")
  128.         # If we catch a keyboard interrupt, we want to stop the server and exit.
  129.         shared.demo.close()
  130.  
  131.     # disable auto launch webui in browser for subsequent UI Reload
  132.     os.environ.setdefault('SD_WEBUI_RESTARTING', '1')
  133.  
  134.     print('Restarting UI...')
  135.     shared.demo.close()
  136.     time.sleep(0.5)
  137.     startup_timer.reset()
  138.     script_callbacks.app_reload_callback()
  139.     startup_timer.record("app reload callback")
  140.     script_callbacks.script_unloaded_callback()
  141.     startup_timer.record("scripts unloaded callback")
  142.     initialize.initialize_rest(reload_script_modules=True)
  143.  
  144.  
  145. if __name__ == "__main__":
  146.     from modules.shared_cmd_options import cmd_opts
  147.  
  148.     if cmd_opts.nowebui:
  149.         api_only()
  150.     else:
  151.         webui()
  152.  
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement