notebooks/text2im.ipynb (251 lines of code) (raw):
{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Run this line in Colab to install the package if it is\n",
"# not already installed.\n",
"!pip install git+https://github.com/openai/glide-text2im"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from PIL import Image\n",
"from IPython.display import display\n",
"import torch as th\n",
"\n",
"from glide_text2im.download import load_checkpoint\n",
"from glide_text2im.model_creation import (\n",
" create_model_and_diffusion,\n",
" model_and_diffusion_defaults,\n",
" model_and_diffusion_defaults_upsampler\n",
")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# This notebook supports both CPU and GPU.\n",
"# On CPU, generating one sample may take on the order of 20 minutes.\n",
"# On a GPU, it should be under a minute.\n",
"\n",
"has_cuda = th.cuda.is_available()\n",
"device = th.device('cpu' if not has_cuda else 'cuda')"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Create base model.\n",
"options = model_and_diffusion_defaults()\n",
"options['use_fp16'] = has_cuda\n",
"options['timestep_respacing'] = '100' # use 100 diffusion steps for fast sampling\n",
"model, diffusion = create_model_and_diffusion(**options)\n",
"model.eval()\n",
"if has_cuda:\n",
" model.convert_to_fp16()\n",
"model.to(device)\n",
"model.load_state_dict(load_checkpoint('base', device))\n",
"print('total base parameters', sum(x.numel() for x in model.parameters()))"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Create upsampler model.\n",
"options_up = model_and_diffusion_defaults_upsampler()\n",
"options_up['use_fp16'] = has_cuda\n",
"options_up['timestep_respacing'] = 'fast27' # use 27 diffusion steps for very fast sampling\n",
"model_up, diffusion_up = create_model_and_diffusion(**options_up)\n",
"model_up.eval()\n",
"if has_cuda:\n",
" model_up.convert_to_fp16()\n",
"model_up.to(device)\n",
"model_up.load_state_dict(load_checkpoint('upsample', device))\n",
"print('total upsampler parameters', sum(x.numel() for x in model_up.parameters()))"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def show_images(batch: th.Tensor):\n",
" \"\"\" Display a batch of images inline. \"\"\"\n",
" scaled = ((batch + 1)*127.5).round().clamp(0,255).to(th.uint8).cpu()\n",
" reshaped = scaled.permute(2, 0, 3, 1).reshape([batch.shape[2], -1, 3])\n",
" display(Image.fromarray(reshaped.numpy()))"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Sampling parameters\n",
"prompt = \"an oil painting of a corgi\"\n",
"batch_size = 1\n",
"guidance_scale = 3.0\n",
"\n",
"# Tune this parameter to control the sharpness of 256x256 images.\n",
"# A value of 1.0 is sharper, but sometimes results in grainy artifacts.\n",
"upsample_temp = 0.997"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"##############################\n",
"# Sample from the base model #\n",
"##############################\n",
"\n",
"# Create the text tokens to feed to the model.\n",
"tokens = model.tokenizer.encode(prompt)\n",
"tokens, mask = model.tokenizer.padded_tokens_and_mask(\n",
" tokens, options['text_ctx']\n",
")\n",
"\n",
"# Create the classifier-free guidance tokens (empty)\n",
"full_batch_size = batch_size * 2\n",
"uncond_tokens, uncond_mask = model.tokenizer.padded_tokens_and_mask(\n",
" [], options['text_ctx']\n",
")\n",
"\n",
"# Pack the tokens together into model kwargs.\n",
"model_kwargs = dict(\n",
" tokens=th.tensor(\n",
" [tokens] * batch_size + [uncond_tokens] * batch_size, device=device\n",
" ),\n",
" mask=th.tensor(\n",
" [mask] * batch_size + [uncond_mask] * batch_size,\n",
" dtype=th.bool,\n",
" device=device,\n",
" ),\n",
")\n",
"\n",
"# Create a classifier-free guidance sampling function\n",
"def model_fn(x_t, ts, **kwargs):\n",
" half = x_t[: len(x_t) // 2]\n",
" combined = th.cat([half, half], dim=0)\n",
" model_out = model(combined, ts, **kwargs)\n",
" eps, rest = model_out[:, :3], model_out[:, 3:]\n",
" cond_eps, uncond_eps = th.split(eps, len(eps) // 2, dim=0)\n",
" half_eps = uncond_eps + guidance_scale * (cond_eps - uncond_eps)\n",
" eps = th.cat([half_eps, half_eps], dim=0)\n",
" return th.cat([eps, rest], dim=1)\n",
"\n",
"# Sample from the base model.\n",
"model.del_cache()\n",
"samples = diffusion.p_sample_loop(\n",
" model_fn,\n",
" (full_batch_size, 3, options[\"image_size\"], options[\"image_size\"]),\n",
" device=device,\n",
" clip_denoised=True,\n",
" progress=True,\n",
" model_kwargs=model_kwargs,\n",
" cond_fn=None,\n",
")[:batch_size]\n",
"model.del_cache()\n",
"\n",
"# Show the output\n",
"show_images(samples)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"##############################\n",
"# Upsample the 64x64 samples #\n",
"##############################\n",
"\n",
"tokens = model_up.tokenizer.encode(prompt)\n",
"tokens, mask = model_up.tokenizer.padded_tokens_and_mask(\n",
" tokens, options_up['text_ctx']\n",
")\n",
"\n",
"# Create the model conditioning dict.\n",
"model_kwargs = dict(\n",
" # Low-res image to upsample.\n",
" low_res=((samples+1)*127.5).round()/127.5 - 1,\n",
"\n",
" # Text tokens\n",
" tokens=th.tensor(\n",
" [tokens] * batch_size, device=device\n",
" ),\n",
" mask=th.tensor(\n",
" [mask] * batch_size,\n",
" dtype=th.bool,\n",
" device=device,\n",
" ),\n",
")\n",
"\n",
"# Sample from the base model.\n",
"model_up.del_cache()\n",
"up_shape = (batch_size, 3, options_up[\"image_size\"], options_up[\"image_size\"])\n",
"up_samples = diffusion_up.ddim_sample_loop(\n",
" model_up,\n",
" up_shape,\n",
" noise=th.randn(up_shape, device=device) * upsample_temp,\n",
" device=device,\n",
" clip_denoised=True,\n",
" progress=True,\n",
" model_kwargs=model_kwargs,\n",
" cond_fn=None,\n",
")[:batch_size]\n",
"model_up.del_cache()\n",
"\n",
"# Show the output\n",
"show_images(up_samples)"
]
}
],
"metadata": {
"interpreter": {
"hash": "e7d6e62d90e7e85f9a0faa7f0b1d576302d7ae6108e9fe361594f8e1c8b05781"
},
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.3"
},
"accelerator": "GPU"
},
"nbformat": 4,
"nbformat_minor": 2
}