PyTorch implementation of TIGraNet model
Recent Commits
Commit | Author | Details | Committed | ||||
---|---|---|---|---|---|---|---|
7553da4562f7 | joelmfonseca | corrected CUDA support + added optimizer in checkpoint + cleaning | Jun 15 2018 | ||||
f743708f9589 | joelmfonseca | Revert "updated broken link" | Jun 8 2018 | ||||
eb6c5a273d03 | joelmfonseca | updated broken link | Jun 8 2018 | ||||
3b7b62b9a8bc | joelmfonseca | updated broken link | Jun 8 2018 | ||||
6060f461b70a | joelmfonseca | final clean + added folder structure | Jun 8 2018 | ||||
5c84265a79ea | joelmfonseca | generated all figures + done some cleaning | Jun 6 2018 | ||||
a9f838708c47 | joelmfonseca | modified debugging structure + found numerical instability example | Jun 6 2018 | ||||
0f99b91b5bba | joelmfonseca | updated README + added loading of Laplacians | Jun 5 2018 | ||||
34aac57fe00f | joelmfonseca | cleaned some files + added mnist_012 model + new debug for mnist_rot | Jun 4 2018 | ||||
daf4a30c0384 | joelmfonseca | need to optimize update_mask func in DP layer | Jun 2 2018 | ||||
4ba138631ab5 | joelmfonseca | selection of datsets in command line + reorganization of repo | Jun 1 2018 | ||||
4a0bae5d7a8b | joelmfonseca | added report | May 24 2018 | ||||
5a5912ccbfaa | joelmfonseca | updated loss and error + added freeze option for spectral conv weights | May 22 2018 | ||||
747259c8df4d | joelmfonseca | cleaned file | May 22 2018 | ||||
81b21fc8e288 | joelmfonseca | finished layer debug + added plot loss/error + done some param tuning | May 21 2018 |
README.md
PyTorch implementation of TIGraNet
This project implements the TIGranet model proposed by Khasanova and Frossard (2017) using the PyTorch framework.
Project structure
The project environment setup is composed of the following parts:
Some folders:
- saved_data: data used from Theano after preprocessing with corresponding pretrained weights.
- data: raw data for each dataset used in the experiment (mnist, eth80).
- saved_models: models saved during training.
- figures: plots of performance and some comparison figures.
- debug_mnist_012: files with all intermediary steps saved individually for mnist_012 dataset (from both PyTorch and Theano frameworks).
- debug_mnist_rot: files with all intermediary steps saved individually for mnist_rot dataset (from both PyTorch and Theano frameworks).
- debug_mnist_eth80: files with all intermediary steps saved individually for eth80 dataset (from both PyTorch and Theano frameworks).
And several python modules about:
- the datasets:
- datasets.py: loads the raw datasets with the specific transformations and splits.
- saved_datasets.py: loads the saved datasets from the Theano implementation.
- custom_transform.py: custom tranformations used for preprocessing on images.
- loader.py: custom loader for mnist_012 dataset.
- the models:
- models.py: model description for each dataset (mnist_012, mnist_rot, mnist_trans, eth80)
- layers.py: layer description for each new layer (spectral convolutional, dynamic pooling and statistical)
- graph.py: main functions applied on the graph (Laplacian, normalized Laplacian, Chebyshev polynomials, ...)
- the debugging process:
- comparison_debug.py: comparison of intermediary steps between PyTorch and Theano implementations.
- layers_debug.py: debugging module for layers.
- the training and testing:
- train.py: training module for the models.
- evaluate.py: inference module for testing set.
- the analysis tools:
- plot.py: plot some metrics like the loss and % of errors.
- the configurations:
- configurations.py: configuration values for the project (batch size, learning rate, ...)
- paths.py: different paths for the project.
- the utilitary functions:
- utils.py: auxiliary functions (train/valid/test split, save/load models, ...)
In addition, all functions come with a oneline description to better understand its purpose.
Instructions to run the code
In order to train or evaluate a specific dataset, the following command line is expected.
terminal >> python3 train.py mnist_012
terminal >> python3 evaluate.py mnist_012
Notice that only one additional argument is allowed. The available datasets you can enter are: mnist_012, mnist_rot, mnist_trans, eth80. If you don't enter the command correctly a message will appear with some instructions and/or information.
The remaining modules must be called without additional argument.
Instructions to install required libraries
Conda
Anaconda is used as the package manager. Thus, all needed libraries are installed through the conda command line. If you don't have it installed, you can easily follow the steps from their download page. The version with Python 3.6 is recommended.
PyTorch
Once Conda is installed, we can proceed with the installation of PyTorch. Run the following command.
terminal >> conda install pytorch torchvision -c pytorch
tqdm
Finally, we install a smart progress meter. As it does not come pre-installed by default, we can simply install it with following command line:
terminal >> conda install tqdm
You are now ready to run the code!
NOTE
Code is runnable with the following configuration:
- conda 4.5.4
- conda-built 3.10.5
- python 3.6.5
- pytorch 0.4.0
- torchvision 0.2.1
- tqdm 4.23.4
CUDA support
The code can run on CUDA flawlessly without any particular modification in the code.