Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- #!python3
- from dalle_mini import DalleBart, DalleBartProcessor
- from datetime import datetime
- from flax.jax_utils import replicate
- from flax.training.common_utils import shard_prng_key
- from functools import partial
- import jax
- import jax.numpy as jnp
- import numpy as np
- import os
- from PIL import Image
- import random
- import sys
- from vqgan_jax.modeling_flax_vqgan import VQModel
- # Parameters: https://huggingface.co/blog/how-to-generate).
- inputs = " ".join(sys.argv[1:])
- gen_top_k = None
- gen_top_p = None
- temperature = None
- cond_scale = 10.0
- #DALLE_MODEL = "dalle-mini/dalle-mini/mega-1-fp16:latest" # high precision
- DALLE_MODEL = "dalle-mini/dalle-mini/mini-1:latest" # low precision
- DALLE_COMMIT_ID = None
- VQGAN_REPO = "dalle-mini/vqgan_imagenet_f16_16384"
- VQGAN_COMMIT_ID = "e93a26e7707683d349bf5d5c41c5b0ef69b677a9"
- print(f"Inputs: {inputs}\n")
- start = datetime.now()
- # Prevents warnings when showing image. Might impact performance, but
- # tokenizers appear to only constitute a small portion of runtime.
- os.environ["TOKENIZERS_PARALLELISM"] = "false"
- jax.local_device_count()
- print(f"{datetime.now() - start} Loading Dalle model ...", flush=True, end="")
- model, params = DalleBart.from_pretrained(
- DALLE_MODEL,
- revision=DALLE_COMMIT_ID,
- dtype=jnp.float16,
- _do_init=False)
- params = replicate(params)
- print ("done.")
- print(f"{datetime.now() - start} Loading VQGAN model ...", flush=True, end="")
- vqgan, vqgan_params = VQModel.from_pretrained(
- VQGAN_REPO,
- revision=VQGAN_COMMIT_ID,
- _do_init=False)
- vqgan_params = replicate(vqgan_params)
- print ("done.")
- print(f"{datetime.now() - start} Processing inputs ... ", flush=True, end="")
- processor = DalleBartProcessor.from_pretrained(DALLE_MODEL, revision=DALLE_COMMIT_ID)
- tokens = replicate(processor([inputs]))
- print ("done.")
- print(f"{datetime.now() - start} Generating image ... ", flush=True, end="")
- @partial(jax.pmap, axis_name="batch", static_broadcasted_argnums=(3, 4, 5, 6))
- def p_generate(tokenized_prompt, key, params, top_k, top_p, temperature, condition_scale):
- return model.generate(
- **tokenized_prompt,
- prng_key=key,
- params=params,
- top_k=top_k,
- top_p=top_p,
- temperature=temperature,
- condition_scale=condition_scale)
- seed = random.randint(0, 2**32 - 1)
- key = jax.random.PRNGKey(seed)
- _, subkey = jax.random.split(key)
- shared_key = shard_prng_key(subkey)
- encoded_images = p_generate(
- tokens,
- shared_key,
- params,
- gen_top_k,
- gen_top_p,
- temperature,
- cond_scale)
- print(f"done.")
- print(f"{datetime.now() - start} Encoding image ... ", flush=True, end="")
- encoded_images = encoded_images.sequences[..., 1:] # remove BOS
- print(f"done.")
- print(f"{datetime.now() - start} Decoding image ... ", flush=True, end="")
- @partial(jax.pmap, axis_name="batch")
- def p_decode(indices, params):
- return vqgan.decode_code(indices, params=params)
- decoded_images = p_decode(encoded_images, vqgan_params)
- decoded_images = decoded_images.clip(0.0, 1.0).reshape((-1, 256, 256, 3))
- print(f"done.")
- print(f"{datetime.now() - start} Showing image ... ", flush=True, end="")
- for decoded_img in decoded_images:
- img = Image.fromarray(np.asarray(decoded_img * 255, dtype=np.uint8))
- img.save(f"dall-e - {inputs}.png")
- img.show()
- print(f"done.")
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement