This is a PyTorch implementation of the paper PROGRESSIVE LEARNING AND DISENTANGLEMENT OF HIERARCHICAL REPRESENTATIONS by Zhiyuan et al, ICLR 2020. The official code for proVLAE, implemented in TensorFlow, is available here.
☝️ Visualization of results when traversing the latent space (-1.5 to +1.5) of pytorch-proVLAE trained on 3D Shapes.
We recommend using mamba (via miniforge) for faster installation of dependencies, but you can also use conda.
git clone https://github.com/suzuki-2001/pytorch-proVLAE.git
cd pytorch-proVLAE
mamba env create -f env.yaml # or conda
mamba activate torch-provlae
You can train pytorch-proVLAE with the following command. Sample hyperparameters and train configuration are provided in scripts directory. If you have a checkpoint file from a pythorch-proVLAE training, setting the mode argument to "traverse" allows you to inspect the latent traversal. Please ensure that the parameter settings match those used for the checkpoint file when running this mode.
# training with distributed data parallel
# we tested NVIDIA V100 PCIE 16GB+32GB, NVIDIA A6000 48GB x2
torchrun --nproc_per_node=2 --master_port=29501 src/train.py \
--distributed \
--mode seq_train \
--dataset shapes3d \
--optim adamw \
--num_ladders 3 \
--batch_size 128 \
--num_epochs 15 \
--learning_rate 5e-4 \
--beta 8 \
--z_dim 3 \
--coff 0.5 \
--pre_kl \
--hidden_dim 32 \
--fade_in_duration 5000 \
--output_dir ./output/shapes3d/ \
--data_path ./data
This repository is licensed under the MIT License - see the LICENSE file for details. This follows the licensing of the original implementation license by Zhiyuan.
*This repository is a contribution to AIST (National Institute of Advanced Industrial Science and Technology) project.
Human Informatics and Interaction Research Institute, Neuronrehabilitation Research Group
Shosuke Suzuki, Ryusuke Hayashi