Advertisement
Guest User

Untitled

a guest
Jun 8th, 2022
205
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 3.34 KB | None | 0 0
  1. #!python3
  2. from dalle_mini import DalleBart, DalleBartProcessor
  3. from datetime import datetime
  4. from flax.jax_utils import replicate
  5. from flax.training.common_utils import shard_prng_key
  6. from functools import partial
  7. import jax
  8. import jax.numpy as jnp
  9. import numpy as np
  10. import os
  11. from PIL import Image
  12. import random
  13. import sys
  14. from vqgan_jax.modeling_flax_vqgan import VQModel
  15.  
  16. # Parameters: https://huggingface.co/blog/how-to-generate).
  17. inputs = " ".join(sys.argv[1:])
  18. gen_top_k = None
  19. gen_top_p = None
  20. temperature = None
  21. cond_scale = 10.0
  22.  
  23. #DALLE_MODEL = "dalle-mini/dalle-mini/mega-1-fp16:latest" # high precision
  24. DALLE_MODEL = "dalle-mini/dalle-mini/mini-1:latest" # low precision
  25. DALLE_COMMIT_ID = None
  26. VQGAN_REPO = "dalle-mini/vqgan_imagenet_f16_16384"
  27. VQGAN_COMMIT_ID = "e93a26e7707683d349bf5d5c41c5b0ef69b677a9"
  28.  
  29. print(f"Inputs: {inputs}\n")
  30. start = datetime.now()
  31. # Prevents warnings when showing image. Might impact performance, but
  32. # tokenizers appear to only constitute a small portion of runtime.
  33. os.environ["TOKENIZERS_PARALLELISM"] = "false"
  34. jax.local_device_count()
  35.  
  36. print(f"{datetime.now() - start} Loading Dalle model ...", flush=True, end="")
  37. model, params = DalleBart.from_pretrained(
  38.     DALLE_MODEL,
  39.     revision=DALLE_COMMIT_ID,
  40.     dtype=jnp.float16,
  41.     _do_init=False)
  42. params = replicate(params)
  43. print ("done.")
  44.  
  45. print(f"{datetime.now() - start} Loading VQGAN model ...", flush=True, end="")
  46. vqgan, vqgan_params = VQModel.from_pretrained(
  47.     VQGAN_REPO,
  48.     revision=VQGAN_COMMIT_ID,
  49.     _do_init=False)
  50. vqgan_params = replicate(vqgan_params)
  51. print ("done.")
  52.  
  53. print(f"{datetime.now() - start} Processing inputs ... ", flush=True, end="")
  54. processor = DalleBartProcessor.from_pretrained(DALLE_MODEL, revision=DALLE_COMMIT_ID)
  55. tokens = replicate(processor([inputs]))
  56. print ("done.")
  57.  
  58. print(f"{datetime.now() - start} Generating image ... ", flush=True, end="")
  59. @partial(jax.pmap, axis_name="batch", static_broadcasted_argnums=(3, 4, 5, 6))
  60. def p_generate(tokenized_prompt, key, params, top_k, top_p, temperature, condition_scale):
  61.     return model.generate(
  62.         **tokenized_prompt,
  63.         prng_key=key,
  64.         params=params,
  65.         top_k=top_k,
  66.         top_p=top_p,
  67.         temperature=temperature,
  68.         condition_scale=condition_scale)
  69.  
  70. seed = random.randint(0, 2**32 - 1)
  71. key = jax.random.PRNGKey(seed)
  72. _, subkey = jax.random.split(key)
  73. shared_key = shard_prng_key(subkey)
  74. encoded_images = p_generate(
  75.     tokens,
  76.     shared_key,
  77.     params,
  78.     gen_top_k,
  79.     gen_top_p,
  80.     temperature,
  81.     cond_scale)
  82. print(f"done.")
  83.  
  84. print(f"{datetime.now() - start} Encoding image ... ", flush=True, end="")
  85. encoded_images = encoded_images.sequences[..., 1:] # remove BOS
  86. print(f"done.")
  87.  
  88. print(f"{datetime.now() - start} Decoding image ... ", flush=True, end="")
  89. @partial(jax.pmap, axis_name="batch")
  90. def p_decode(indices, params):
  91.     return vqgan.decode_code(indices, params=params)
  92.  
  93. decoded_images = p_decode(encoded_images, vqgan_params)
  94. decoded_images = decoded_images.clip(0.0, 1.0).reshape((-1, 256, 256, 3))
  95. print(f"done.")
  96.  
  97. print(f"{datetime.now() - start} Showing image ... ", flush=True, end="")
  98. for decoded_img in decoded_images:
  99.     img = Image.fromarray(np.asarray(decoded_img * 255, dtype=np.uint8))
  100.     img.save(f"dall-e - {inputs}.png")
  101.     img.show()
  102. print(f"done.")
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement