notebooks/usage.ipynb (116 lines of code) (raw):
{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import io\n",
"import os, sys\n",
"import requests\n",
"import PIL\n",
"\n",
"import torch\n",
"import torchvision.transforms as T\n",
"import torchvision.transforms.functional as TF\n",
"\n",
"from dall_e import map_pixels, unmap_pixels, load_model\n",
"from IPython.display import display, display_markdown\n",
"\n",
"target_image_size = 256\n",
"\n",
"def download_image(url):\n",
" resp = requests.get(url)\n",
" resp.raise_for_status()\n",
" return PIL.Image.open(io.BytesIO(resp.content))\n",
"\n",
"def preprocess(img):\n",
" s = min(img.size)\n",
" \n",
" if s < target_image_size:\n",
" raise ValueError(f'min dim for image {s} < {target_image_size}')\n",
" \n",
" r = target_image_size / s\n",
" s = (round(r * img.size[1]), round(r * img.size[0]))\n",
" img = TF.resize(img, s, interpolation=PIL.Image.LANCZOS)\n",
" img = TF.center_crop(img, output_size=2 * [target_image_size])\n",
" img = torch.unsqueeze(T.ToTensor()(img), 0)\n",
" return map_pixels(img)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# This can be changed to a GPU, e.g. 'cuda:0'.\n",
"dev = torch.device('cpu')\n",
"\n",
"# For faster load times, download these files locally and use the local paths instead.\n",
"enc = load_model(\"https://cdn.openai.com/dall-e/encoder.pkl\", dev)\n",
"dec = load_model(\"https://cdn.openai.com/dall-e/decoder.pkl\", dev)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"x = preprocess(download_image('https://assets.bwbx.io/images/users/iqjWHBFdfxIU/iKIWgaiJUtss/v2/1000x-1.jpg'))\n",
"display_markdown('Original image:')\n",
"display(T.ToPILImage(mode='RGB')(x[0]))"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import torch.nn.functional as F\n",
"\n",
"z_logits = enc(x)\n",
"z = torch.argmax(z_logits, axis=1)\n",
"z = F.one_hot(z, num_classes=enc.vocab_size).permute(0, 3, 1, 2).float()\n",
"\n",
"x_stats = dec(z).float()\n",
"x_rec = unmap_pixels(torch.sigmoid(x_stats[:, :3]))\n",
"x_rec = T.ToPILImage(mode='RGB')(x_rec[0])\n",
"\n",
"display_markdown('Reconstructed image:')\n",
"display(x_rec)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.1"
}
},
"nbformat": 4,
"nbformat_minor": 2
}