CoModGAN-PyTorch-Implementation
CoModGAN-PyTorch-Implementation copied to clipboard
CoModGAN Pytorch
- Built on top of StyleGANv2 ADA implementation
Compatibility
- Compatible with old network pickles created using the TensorFlow version.
- New ZIP/PNG based dataset format for maximal interoperability with existing 3rd party tools.
- TFRecords datasets are no longer supported — they need to be converted to the new format.
- New JSON-based format for logs, metrics, and training curves.
- Training curves are also exported in the old TFEvents format if TensorBoard is installed.
- Command line syntax is mostly unchanged, with a few exceptions (e.g.,
dataset_tool.py). - Comparison methods are not supported (
--cmethod,--dcap,--cfg=cifarbaseline,--aug=adarv) - Truncation is now disabled by default.
Data repository
| Path | Description |
|---|---|
| stylegan2-ada-pytorch | Main directory hosted on Amazon S3 |
| ├ ada-paper.pdf | Paper PDF |
| ├ images | Curated example images produced using the pre-trained models |
| ├ videos | Curated example interpolation videos |
| └ pretrained | Pre-trained models |
| ├ ffhq.pkl | FFHQ at 1024x1024, trained using original StyleGAN2 |
| ├ brecahad.pkl | BreCaHAD at 512x512, trained from scratch using ADA |
| ├ paper-fig7c-training-set-sweeps | Models used in Fig.7c (sweep over training set size) |
| ├ paper-fig11a-small-datasets | Models used in Fig.11a (small datasets & transfer learning) |
| ├ paper-fig11b-cifar10 | Models used in Fig.11b (CIFAR-10) |
| ├ transfer-learning-source-nets | Models used as starting point for transfer learning |
| └ metrics | Feature detectors used by the quality metrics |
Getting started
Pre-trained networks are stored as *.pkl files that can be referenced using local filenames or URLs:
# Generate curated MetFaces images without truncation (Fig.10 left)
python generate.py --outdir=out --trunc=1 --seeds=85,265,297,849 \
--network=https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/metfaces.pkl
# Generate uncurated MetFaces images with truncation (Fig.12 upper left)
python generate.py --outdir=out --trunc=0.7 --seeds=600-605 \
--network=https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/metfaces.pkl
# Generate class conditional CIFAR-10 images (Fig.17 left, Car)
python generate.py --outdir=out --seeds=0-35 --class=1 \
--network=https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/cifar10.pkl
# Style mixing example
python style_mixing.py --outdir=out --rows=85,100,75,458,1500 --cols=55,821,1789,293 \
--network=https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/metfaces.pkl
Outputs from the above commands are placed under out/*.png, controlled by --outdir. Downloaded network pickles are cached under $HOME/.cache/dnnlib, which can be overridden by setting the DNNLIB_CACHE_DIR environment variable. The default PyTorch extension build directory is $HOME/.cache/torch_extensions, which can be overridden by setting TORCH_EXTENSIONS_DIR.
Projecting images to latent space
To find the matching latent vector for a given image file, run:
python projector.py --outdir=out --target=~/mytargetimg.png \
--network=https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/ffhq.pkl
For optimal results, the target image should be cropped and aligned similar to the FFHQ dataset. The above command saves the projection target out/target.png, result out/proj.png, latent vector out/projected_w.npz, and progression video out/proj.mp4. You can render the resulting latent vector by specifying --projected_w for generate.py:
python generate.py --outdir=out --projected_w=out/projected_w.npz \
--network=https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/ffhq.pkl
Using networks from Python
You can use pre-trained networks in your own Python code as follows:
with open('ffhq.pkl', 'rb') as f:
G = pickle.load(f)['G_ema'].cuda() # torch.nn.Module
z = torch.randn([1, G.z_dim]).cuda() # latent codes
c = None # class labels (not used in this example)
img = G(z, c) # NCHW, float32, dynamic range [-1, +1]
The above code requires torch_utils and dnnlib to be accessible via PYTHONPATH. It does not need source code for the networks themselves — their class definitions are loaded from the pickle via torch_utils.persistence.
The pickle contains three networks. 'G' and 'D' are instantaneous snapshots taken during training, and 'G_ema' represents a moving average of the generator weights over several training steps. The networks are regular instances of torch.nn.Module, with all of their parameters and buffers placed on the CPU at import and gradient computation disabled by default.
The generator consists of two submodules, G.mapping and G.synthesis, that can be executed separately. They also support various additional options:
w = G.mapping(z, c, truncation_psi=0.5, truncation_cutoff=8)
img = G.synthesis(w, noise_mode='const', force_fp32=True)
Please refer to generate.py, style_mixing.py, and projector.py for further examples.
Preparing datasets
Datasets are stored as uncompressed ZIP archives containing uncompressed PNG files and a metadata file dataset.json for labels.
Custom datasets can be created from a folder containing images; see python dataset_tool.py --help for more information. Alternatively, the folder can also be used directly as a dataset, without running it through dataset_tool.py first, but doing so may lead to suboptimal performance.
Legacy TFRecords datasets are not supported — see below for instructions on how to convert them.
FFHQ:
Step 1: Download the Flickr-Faces-HQ dataset as TFRecords.
Step 2: Extract images from TFRecords using dataset_tool.py from the TensorFlow version of StyleGAN2-ADA:
# Using dataset_tool.py from TensorFlow version at
# https://github.com/NVlabs/stylegan2-ada/
python ../stylegan2-ada/dataset_tool.py unpack \
--tfrecord_dir=~/ffhq-dataset/tfrecords/ffhq --output_dir=/tmp/ffhq-unpacked
Step 3: Create ZIP archive using dataset_tool.py from this repository:
# Original 1024x1024 resolution.
python dataset_tool.py --source=/tmp/ffhq-unpacked --dest=~/datasets/ffhq.zip
# Scaled down 256x256 resolution.
python dataset_tool.py --source=/tmp/ffhq-unpacked --dest=~/datasets/ffhq256x256.zip \
--width=256 --height=256
Training new networks
In its most basic form, training new networks boils down to:
python train.py --outdir=~/training-runs --data=~/mydataset.zip --gpus=1 --dry-run
python train.py --outdir=~/training-runs --data=~/mydataset.zip --gpus=1
The first command is optional; it validates the arguments, prints out the training configuration, and exits. The second command kicks off the actual training.
In this example, the results are saved to a newly created directory ~/training-runs/<ID>-mydataset-auto1, controlled by --outdir. The training exports network pickles (network-snapshot-<INT>.pkl) and example images (fakes<INT>.png) at regular intervals (controlled by --snap). For each pickle, it also evaluates FID (controlled by --metrics) and logs the resulting scores in metric-fid50k_full.jsonl (as well as TFEvents if TensorBoard is installed).
The name of the output directory reflects the training configuration. For example, 00000-mydataset-auto1 indicates that the base configuration was auto1, meaning that the hyperparameters were selected automatically for training on one GPU. The base configuration is controlled by --cfg:
| Base config | Description |
|---|---|
auto (default) |
Automatically select reasonable defaults based on resolution and GPU count. Serves as a good starting point for new datasets but does not necessarily lead to optimal results. |
stylegan2 |
Reproduce results for StyleGAN2 config F at 1024x1024 using 1, 2, 4, or 8 GPUs. |
paper256 |
Reproduce results for FFHQ and LSUN Cat at 256x256 using 1, 2, 4, or 8 GPUs. |
paper512 |
Reproduce results for BreCaHAD and AFHQ at 512x512 using 1, 2, 4, or 8 GPUs. |
paper1024 |
Reproduce results for MetFaces at 1024x1024 using 1, 2, 4, or 8 GPUs. |
cifar |
Reproduce results for CIFAR-10 (tuned configuration) using 1 or 2 GPUs. |
The training configuration can be further customized with additional command line options:
--aug=noaugdisables ADA.--cond=1enables class-conditional training (requires a dataset with labels).--mirror=1amplifies the dataset with x-flips. Often beneficial, even with ADA.--resume=ffhq1024 --snap=10performs transfer learning from FFHQ trained at 1024x1024.--resume=~/training-runs/<NAME>/network-snapshot-<INT>.pklresumes a previous training run.--gamma=10overrides R1 gamma. We recommend trying a couple of different values for each new dataset.--aug=ada --target=0.7adjusts ADA target value (default: 0.6).--augpipe=blitenables pixel blitting but disables all other augmentations.--augpipe=bgcfncenables all available augmentations (blit, geom, color, filter, noise, cutout).
Please refer to python train.py --help for the full list.
References:
- GANs Trained by a Two Time-Scale Update Rule Converge to a Local Nash Equilibrium, Heusel et al. 2017
- Demystifying MMD GANs, Bińkowski et al. 2018
- Improved Precision and Recall Metric for Assessing Generative Models, Kynkäänniemi et al. 2019
- Improved Techniques for Training GANs, Salimans et al. 2016
- A Style-Based Generator Architecture for Generative Adversarial Networks, Karras et al. 2018
Citation
@inproceedings{Karras2020ada,
title = {Training Generative Adversarial Networks with Limited Data},
author = {Tero Karras and Miika Aittala and Janne Hellsten and Samuli Laine and Jaakko Lehtinen and Timo Aila},
booktitle = {Proc. NeurIPS},
year = {2020}
}
@inproceedings{zhao2021comodgan,
title={Large Scale Image Completion via Co-Modulated Generative Adversarial Networks},
author={Zhao, Shengyu and Cui, Jonathan and Sheng, Yilun and Dong, Yue and Liang, Xiao and Chang, Eric I and Xu, Yan},
booktitle={International Conference on Learning Representations (ICLR)},
year={2021}
}