Skip to content

suzuki-2001/pytorch-proVLAE

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

24 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

pytorch-proVLAE

Format Code Validate Mamba Env


This is a PyTorch implementation of the paper PROGRESSIVE LEARNING AND DISENTANGLEMENT OF HIERARCHICAL REPRESENTATIONS by Zhiyuan et al, ICLR 2020. The official code for proVLAE, implemented in TensorFlow, is available here.


☝️ Visualization of results when traversing the latent space (-1.5 to +1.5) of pytorch-proVLAE trained on 3D Shapes.

 

Installation

We recommend using mamba (via miniforge) for faster installation of dependencies, but you can also use conda.

git clone https://github.com/suzuki-2001/pytorch-proVLAE.git
cd pytorch-proVLAE

mamba env create -f env.yaml # or conda
mamba activate torch-provlae

 

Usage

You can train pytorch-proVLAE with the following command. Sample hyperparameters and train configuration are provided in scripts directory. If you have a checkpoint file from a pythorch-proVLAE training, setting the mode argument to "traverse" allows you to inspect the latent traversal. Please ensure that the parameter settings match those used for the checkpoint file when running this mode.


# training with distributed data parallel
# we tested NVIDIA V100 PCIE 16GB+32GB, NVIDIA A6000 48GB x2
torchrun --nproc_per_node=2 --master_port=29501 src/train.py \
    --distributed \
    --mode seq_train \
    --dataset shapes3d \
    --optim adamw \
    --num_ladders 3 \
    --batch_size 128 \
    --num_epochs 15 \
    --learning_rate 5e-4 \
    --beta 8 \
    --z_dim 3 \
    --coff 0.5 \
    --pre_kl \
    --hidden_dim 32 \
    --fade_in_duration 5000 \
    --output_dir ./output/shapes3d/ \
    --data_path ./data

 

License

This repository is licensed under the MIT License - see the LICENSE file for details. This follows the licensing of the original implementation license by Zhiyuan.

 


*This repository is a contribution to AIST (National Institute of Advanced Industrial Science and Technology) project.

Human Informatics and Interaction Research Institute, Neuronrehabilitation Research Group
Shosuke Suzuki, Ryusuke Hayashi

About

Pytorch Implementation of proVLAE: progressive learning of variational ladder auto encoder

Topics

Resources

License

Stars

Watchers

Forks

Contributors 3

  •  
  •  
  •