提交 2d471951 编写于 作者: A Aditya Ramesh

initialized

上级
# OS specific
*.DS_Store
# Python
/build
/dist
__pycache__
*.ipynb_checkpoints
*.egg-info
# Vim
*.vim
*.swk
*.swl
*.swm
*.swn
*.swo
*.swp
Modified MIT License
Software Copyright (c) 2021 OpenAI
We don’t claim ownership of the content you create with the DALL-E discrete VAE, so it is yours to
do with as you please. We only ask that you use the model responsibly and clearly indicate that it
was used.
Permission is hereby granted, free of charge, to any person obtaining a copy of this software and
associated documentation files (the "Software"), to deal in the Software without restriction,
including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense,
and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so,
subject to the following conditions:
The above copyright notice and this permission notice shall be included
in all copies or substantial portions of the Software.
The above copyright notice and this permission notice need not be included
with content created by the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED,
INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS
BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE
OR OTHER DEALINGS IN THE SOFTWARE.
# Overview
[[Blog]](https://openai.com/blog/dall-e/) [[Paper]](https://arxiv.org/abs/2102.12092) [[Model Card]](model_card.md) [[Usage]](notebooks/usage.ipynb)
This is the official PyTorch package for the discrete VAE used for DALL·E.
# Installation
Before running [the example notebook](notebooks/usage.ipynb), you will need to install the package using
pip install git+https://github.com/openai/DALL-E.git
import io, requests
import torch
import torch.nn as nn
from dall_e.encoder import Encoder
from dall_e.decoder import Decoder
from dall_e.utils import map_pixels, unmap_pixels
def load_model(path: str, device: torch.device = None) -> nn.Module:
if path.startswith('http://') or path.startswith('https://'):
resp = requests.get(path)
resp.raise_for_status()
with io.BytesIO(resp.content) as buf:
return torch.load(buf, map_location=device)
else:
with open(path, 'rb') as f:
return torch.load(f, map_location=device)
import attr
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from collections import OrderedDict
from functools import partial
from dall_e.utils import Conv2d
@attr.s(eq=False, repr=False)
class DecoderBlock(nn.Module):
n_in: int = attr.ib(validator=lambda i, a, x: x >= 1)
n_out: int = attr.ib(validator=lambda i, a, x: x >= 1 and x % 4 ==0)
n_layers: int = attr.ib(validator=lambda i, a, x: x >= 1)
device: torch.device = attr.ib(default=None)
requires_grad: bool = attr.ib(default=False)
def __attrs_post_init__(self) -> None:
super().__init__()
self.n_hid = self.n_out // 4
self.post_gain = 1 / (self.n_layers ** 2)
make_conv = partial(Conv2d, device=self.device, requires_grad=self.requires_grad)
self.id_path = make_conv(self.n_in, self.n_out, 1) if self.n_in != self.n_out else nn.Identity()
self.res_path = nn.Sequential(OrderedDict([
('relu_1', nn.ReLU()),
('conv_1', make_conv(self.n_in, self.n_hid, 1)),
('relu_2', nn.ReLU()),
('conv_2', make_conv(self.n_hid, self.n_hid, 3)),
('relu_3', nn.ReLU()),
('conv_3', make_conv(self.n_hid, self.n_hid, 3)),
('relu_4', nn.ReLU()),
('conv_4', make_conv(self.n_hid, self.n_out, 3)),]))
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.id_path(x) + self.post_gain * self.res_path(x)
@attr.s(eq=False, repr=False)
class Decoder(nn.Module):
group_count: int = 4
n_init: int = attr.ib(default=128, validator=lambda i, a, x: x >= 8)
n_hid: int = attr.ib(default=256, validator=lambda i, a, x: x >= 64)
n_blk_per_group: int = attr.ib(default=2, validator=lambda i, a, x: x >= 1)
output_channels: int = attr.ib(default=3, validator=lambda i, a, x: x >= 1)
vocab_size: int = attr.ib(default=8192, validator=lambda i, a, x: x >= 512)
device: torch.device = attr.ib(default=torch.device('cpu'))
requires_grad: bool = attr.ib(default=False)
use_mixed_precision: bool = attr.ib(default=True)
def __attrs_post_init__(self) -> None:
super().__init__()
blk_range = range(self.n_blk_per_group)
n_layers = self.group_count * self.n_blk_per_group
make_conv = partial(Conv2d, device=self.device, requires_grad=self.requires_grad)
make_blk = partial(DecoderBlock, n_layers=n_layers, device=self.device,
requires_grad=self.requires_grad)
self.blocks = nn.Sequential(OrderedDict([
('input', make_conv(self.vocab_size, self.n_init, 1, use_float16=False)),
('group_1', nn.Sequential(OrderedDict([
*[(f'block_{i + 1}', make_blk(self.n_init if i == 0 else 8 * self.n_hid, 8 * self.n_hid)) for i in blk_range],
('upsample', nn.Upsample(scale_factor=2, mode='nearest')),
]))),
('group_2', nn.Sequential(OrderedDict([
*[(f'block_{i + 1}', make_blk(8 * self.n_hid if i == 0 else 4 * self.n_hid, 4 * self.n_hid)) for i in blk_range],
('upsample', nn.Upsample(scale_factor=2, mode='nearest')),
]))),
('group_3', nn.Sequential(OrderedDict([
*[(f'block_{i + 1}', make_blk(4 * self.n_hid if i == 0 else 2 * self.n_hid, 2 * self.n_hid)) for i in blk_range],
('upsample', nn.Upsample(scale_factor=2, mode='nearest')),
]))),
('group_4', nn.Sequential(OrderedDict([
*[(f'block_{i + 1}', make_blk(2 * self.n_hid if i == 0 else 1 * self.n_hid, 1 * self.n_hid)) for i in blk_range],
]))),
('output', nn.Sequential(OrderedDict([
('relu', nn.ReLU()),
('conv', make_conv(1 * self.n_hid, 2 * self.output_channels, 1)),
]))),
]))
def forward(self, x: torch.Tensor) -> torch.Tensor:
if len(x.shape) != 4:
raise ValueError(f'input shape {x.shape} is not 4d')
if x.shape[1] != self.vocab_size:
raise ValueError(f'input has {x.shape[1]} channels but model built for {self.vocab_size}')
if x.dtype != torch.float32:
raise ValueError('input must have dtype torch.float32')
return self.blocks(x)
import attr
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from collections import OrderedDict
from functools import partial
from dall_e.utils import Conv2d
@attr.s(eq=False, repr=False)
class EncoderBlock(nn.Module):
n_in: int = attr.ib(validator=lambda i, a, x: x >= 1)
n_out: int = attr.ib(validator=lambda i, a, x: x >= 1 and x % 4 ==0)
n_layers: int = attr.ib(validator=lambda i, a, x: x >= 1)
device: torch.device = attr.ib(default=None)
requires_grad: bool = attr.ib(default=False)
def __attrs_post_init__(self) -> None:
super().__init__()
self.n_hid = self.n_out // 4
self.post_gain = 1 / (self.n_layers ** 2)
make_conv = partial(Conv2d, device=self.device, requires_grad=self.requires_grad)
self.id_path = make_conv(self.n_in, self.n_out, 1) if self.n_in != self.n_out else nn.Identity()
self.res_path = nn.Sequential(OrderedDict([
('relu_1', nn.ReLU()),
('conv_1', make_conv(self.n_in, self.n_hid, 3)),
('relu_2', nn.ReLU()),
('conv_2', make_conv(self.n_hid, self.n_hid, 3)),
('relu_3', nn.ReLU()),
('conv_3', make_conv(self.n_hid, self.n_hid, 3)),
('relu_4', nn.ReLU()),
('conv_4', make_conv(self.n_hid, self.n_out, 1)),]))
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.id_path(x) + self.post_gain * self.res_path(x)
@attr.s(eq=False, repr=False)
class Encoder(nn.Module):
group_count: int = 4
n_hid: int = attr.ib(default=256, validator=lambda i, a, x: x >= 64)
n_blk_per_group: int = attr.ib(default=2, validator=lambda i, a, x: x >= 1)
input_channels: int = attr.ib(default=3, validator=lambda i, a, x: x >= 1)
vocab_size: int = attr.ib(default=8192, validator=lambda i, a, x: x >= 512)
device: torch.device = attr.ib(default=torch.device('cpu'))
requires_grad: bool = attr.ib(default=False)
use_mixed_precision: bool = attr.ib(default=True)
def __attrs_post_init__(self) -> None:
super().__init__()
blk_range = range(self.n_blk_per_group)
n_layers = self.group_count * self.n_blk_per_group
make_conv = partial(Conv2d, device=self.device, requires_grad=self.requires_grad)
make_blk = partial(EncoderBlock, n_layers=n_layers, device=self.device,
requires_grad=self.requires_grad)
self.blocks = nn.Sequential(OrderedDict([
('input', make_conv(self.input_channels, 1 * self.n_hid, 7)),
('group_1', nn.Sequential(OrderedDict([
*[(f'block_{i + 1}', make_blk(1 * self.n_hid, 1 * self.n_hid)) for i in blk_range],
('pool', nn.MaxPool2d(kernel_size=2)),
]))),
('group_2', nn.Sequential(OrderedDict([
*[(f'block_{i + 1}', make_blk(1 * self.n_hid if i == 0 else 2 * self.n_hid, 2 * self.n_hid)) for i in blk_range],
('pool', nn.MaxPool2d(kernel_size=2)),
]))),
('group_3', nn.Sequential(OrderedDict([
*[(f'block_{i + 1}', make_blk(2 * self.n_hid if i == 0 else 4 * self.n_hid, 4 * self.n_hid)) for i in blk_range],
('pool', nn.MaxPool2d(kernel_size=2)),
]))),
('group_4', nn.Sequential(OrderedDict([
*[(f'block_{i + 1}', make_blk(4 * self.n_hid if i == 0 else 8 * self.n_hid, 8 * self.n_hid)) for i in blk_range],
]))),
('output', nn.Sequential(OrderedDict([
('relu', nn.ReLU()),
('conv', make_conv(8 * self.n_hid, self.vocab_size, 1, use_float16=False)),
]))),
]))
def forward(self, x: torch.Tensor) -> torch.Tensor:
if len(x.shape) != 4:
raise ValueError(f'input shape {x.shape} is not 4d')
if x.shape[1] != self.input_channels:
raise ValueError(f'input has {x.shape[1]} channels but model built for {self.input_channels}')
if x.dtype != torch.float32:
raise ValueError('input must have dtype torch.float32')
return self.blocks(x)
import attr
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
logit_laplace_eps: float = 0.1
@attr.s(eq=False)
class Conv2d(nn.Module):
n_in: int = attr.ib(validator=lambda i, a, x: x >= 1)
n_out: int = attr.ib(validator=lambda i, a, x: x >= 1)
kw: int = attr.ib(validator=lambda i, a, x: x >= 1 and x % 2 == 1)
use_float16: bool = attr.ib(default=True)
device: torch.device = attr.ib(default=torch.device('cpu'))
requires_grad: bool = attr.ib(default=False)
def __attrs_post_init__(self) -> None:
super().__init__()
w = torch.empty((self.n_out, self.n_in, self.kw, self.kw), dtype=torch.float32,
device=self.device, requires_grad=self.requires_grad)
w.normal_(std=1 / math.sqrt(self.n_in * self.kw ** 2))
b = torch.zeros((self.n_out,), dtype=torch.float32, device=self.device,
requires_grad=self.requires_grad)
self.w, self.b = nn.Parameter(w), nn.Parameter(b)
def forward(self, x: torch.Tensor) -> torch.Tensor:
if self.use_float16 and 'cuda' in self.w.device.type:
if x.dtype != torch.float16:
x = x.half()
w, b = self.w.half(), self.b.half()
else:
if x.dtype != torch.float32:
x = x.float()
w, b = self.w, self.b
return F.conv2d(x, w, b, padding=(self.kw - 1) // 2)
def map_pixels(x: torch.Tensor) -> torch.Tensor:
if len(x.shape) != 4:
raise ValueError('expected input to be 4d')
if x.dtype != torch.float:
raise ValueError('expected input to have type float')
return (1 - 2 * logit_laplace_eps) * x + logit_laplace_eps
def unmap_pixels(x: torch.Tensor) -> torch.Tensor:
if len(x.shape) != 4:
raise ValueError('expected input to be 4d')
if x.dtype != torch.float:
raise ValueError('expected input to have type float')
return torch.clamp((x - logit_laplace_eps) / (1 - 2 * logit_laplace_eps), 0, 1)
# Model Card: DALL·E dVAE
Following [Model Cards for Model Reporting (Mitchell et al.)](https://arxiv.org/abs/1810.03993) and [Lessons from
Archives (Jo & Gebru)](https://arxiv.org/pdf/1912.10389.pdf), we're providing some information about about the discrete
VAE (dVAE) that was used to train DALL·E.
## Model Details
The dVAE was developed by researchers at OpenAI to reduce the memory footprint of the transformer trained on the
text-to-image generation task. The details involved in training the dVAE are described in [the paper][dalle_paper]. This
model card describes the first version of the model, released in February 2021. The model consists of a convolutional
encoder and decoder whose architectures are described [here](dall_e/encoder.py) and [here](dall_e/decoder.py), respectively.
For questions or comments about the models or the code release, please file a Github issue.
## Model Use
### Intended Use
The model is intended for others to use for training their own generative models.
### Out-of-Scope Use Cases
This model is inappropriate for high-fidelity image processing applications. We also do not recommend its use as a
general-purpose image compressor.
## Training Data
The model was trained on publicly available text-image pairs collected from the internet. This data consists partly of
[Conceptual Captions][cc] and a filtered subset of [YFCC100M][yfcc100m]. We used a subset of the filters described in
[Sharma et al.][cc_paper] to construct this dataset; further details are described in [our paper][dalle_paper]. We will
not be releasing the dataset.
## Performance and Limitations
The heavy compression from the encoding process results in a noticeable loss of detail in the reconstructed images. This
renders it inappropriate for applications that require fine-grained details of the image to be preserved.
[dalle_paper]: https://arxiv.org/abs/2102.12092
[cc]: https://ai.google.com/research/ConceptualCaptions
[cc_paper]: https://www.aclweb.org/anthology/P18-1238/
[yfcc100m]: http://projects.dfki.uni-kl.de/yfcc100m/
{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import io\n",
"import os, sys\n",
"import requests\n",
"import PIL\n",
"\n",
"import torch\n",
"import torchvision.transforms as T\n",
"import torchvision.transforms.functional as TF\n",
"\n",
"from dall_e import map_pixels, unmap_pixels, load_model\n",
"from IPython.display import display, display_markdown\n",
"\n",
"target_image_size = 256\n",
"\n",
"def download_image(url):\n",
" resp = requests.get(url)\n",
" resp.raise_for_status()\n",
" return PIL.Image.open(io.BytesIO(resp.content))\n",
"\n",
"def preprocess(img):\n",
" s = min(img.size)\n",
" \n",
" if s < target_image_size:\n",
" raise ValueError(f'min dim for image {s} < {target_image_size}')\n",
" \n",
" r = target_image_size / s\n",
" s = (round(r * img.size[1]), round(r * img.size[0]))\n",
" img = TF.resize(img, s, interpolation=PIL.Image.LANCZOS)\n",
" img = TF.center_crop(img, output_size=2 * [target_image_size])\n",
" img = torch.unsqueeze(T.ToTensor()(img), 0)\n",
" return map_pixels(img)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# This can be changed to a GPU, e.g. 'cuda:0'.\n",
"dev = torch.device('cpu')\n",
"\n",
"# For faster load times, download these files locally and use the local paths instead.\n",
"enc = load_model(\"https://cdn.openai.com/dall-e/encoder.pkl\", dev)\n",
"dec = load_model(\"https://cdn.openai.com/dall-e/decoder.pkl\", dev)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"x = preprocess(download_image('https://assets.bwbx.io/images/users/iqjWHBFdfxIU/iKIWgaiJUtss/v2/1000x-1.jpg'))\n",
"display_markdown('Original image:')\n",
"display(T.ToPILImage(mode='RGB')(x[0]))"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import torch.nn.functional as F\n",
"\n",
"z_logits = enc(x)\n",
"z = torch.argmax(z_logits, axis=1)\n",
"z = F.one_hot(z, num_classes=enc.vocab_size).permute(0, 3, 1, 2).float()\n",
"\n",
"x_stats = dec(z).float()\n",
"x_rec = unmap_pixels(torch.sigmoid(x_stats[:, :3]))\n",
"x_rec = T.ToPILImage(mode='RGB')(x_rec[0])\n",
"\n",
"display_markdown('Reconstructed image:')\n",
"display(x_rec)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.1"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
from setuptools import setup
def parse_requirements(filename):
lines = (line.strip() for line in open(filename))
return [line for line in lines if line and not line.startswith("#")]
setup(name='DALL-E',
version='0.1',
description='PyTorch package for the discrete VAE used for DALL·E.',
url='http://github.com/openai/DALL-E',
author='Aditya Ramesh',
author_email='aramesh@openai.com',
license='BSD',
packages=['dall_e'],
install_requires=parse_requirements('requirements.txt'),
zip_safe=True)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册