diff --git a/env.yml b/env.yml index a6d14d0..1ca7a3c 100644 --- a/env.yml +++ b/env.yml @@ -1,72 +1,70 @@ name: weightmatrices channels: - pytorch - conda-forge - brian-team - defaults dependencies: - blas=1.0=mkl - ca-certificates=2020.6.20=hecda079_0 - certifi=2020.6.20=py37hc8dfbb8_0 - cloudpickle=1.4.1=py_0 - cycler=0.10.0=py_2 - cytoolz=0.10.1=py37h0b31af3_0 - dask-core=2.19.0=py_0 - decorator=4.4.2=py_0 - freetype=2.10.2=ha233b18_0 - imagecodecs-lite=2019.12.3=py37h10e2902_1 - imageio=2.8.0=py_0 - intel-openmp=2019.4=233 - joblib=0.15.1=py_0 - jpeg=9b=he5867d9_2 - kiwisolver=1.2.0=py37ha1cc60f_0 - libcxx=10.0.0=1 - libedit=3.1.20191231=haf1e3a3_0 - libffi=3.3=h0a44026_1 - libgfortran=3.0.1=h93005f0_2 - libpng=1.6.37=ha441bb4_0 - libtiff=4.1.0=hcb84e12_1 - llvm-openmp=10.0.0=h28b9765_0 - lz4-c=1.9.2=h0a44026_0 - matplotlib=3.2.1=0 - matplotlib-base=3.2.1=py37hddda452_0 - mkl=2019.4=233 - mkl-service=2.3.0=py37hfbe908c_0 - mkl_fft=1.1.0=py37hc64f4ea_0 - mkl_random=1.1.1=py37h959d312_0 - ncurses=6.2=h0a44026_1 - networkx=2.4=py_1 - ninja=1.9.0=py37h04f5b5a_0 - numpy=1.18.1=py37h7241aed_0 - numpy-base=1.18.1=py37h3304bdc_1 - olefile=0.46=py37_0 - openssl=1.1.1g=h0b31af3_0 - pillow=7.1.2=py37h4655f20_0 - pip=20.1.1=py37_1 - pyparsing=2.4.7=pyh9f0ad1d_0 - python=3.7.7=hf48f09d_4 - python-dateutil=2.8.1=py_0 - python_abi=3.7=1_cp37m - pytorch=1.5.1=py3.7_0 - pywavelets=1.1.1=py37h10e2902_1 - pyyaml=5.3.1=py37h9bfed18_0 - readline=8.0=h1de35cc_0 - scikit-image=0.17.2=py37h94625e5_1 - scikit-learn=0.22.1=py37h27c97d8_0 - scipy=1.4.1=py37h9fa6033_0 - setuptools=47.3.0=py37_0 - six=1.15.0=py_0 - sqlite=3.32.2=hffcf06c_0 - tifffile=2020.6.3=py_0 - tk=8.6.10=hb0a8c7a_0 - toolz=0.10.0=py_0 - torchvision=0.6.1=py37_cpu - tornado=6.0.4=py37h9bfed18_1 - tqdm=4.46.1=py_0 - wheel=0.34.2=py37_0 - xz=5.2.5=h1de35cc_0 - yaml=0.2.5=h0b31af3_0 - zlib=1.2.11=h1de35cc_3 - zstd=1.4.4=h1990bb4_3 -prefix: /Users/Bernd/anaconda3/envs/weightmatrices - diff --git a/main.py b/main.py index 2103d74..bb739e8 100644 --- a/main.py +++ b/main.py @@ -1,79 +1,82 @@ import torch import numpy as np import random -import matplotlib -matplotlib.use('TkAgg') -import matplotlib.pyplot as plt import argparse # own package from weightmatrices.utils import utils # read command line args and kwargs parser = argparse.ArgumentParser() parser.add_argument("--nhidden", nargs="*", type=int, help="number of hidden neurons", default=100) parser.add_argument("--methods", nargs="*", help="methods to be applied to create weight matrix", default='all') parser.add_argument("--save", type=bool, help="whether to save results or not", default=True) args = parser.parse_args() # data import print("loading data") data_loader = utils.load_data() data_matrix = utils.getbigdatamatrix(data_loader) n_in_features = data_matrix.shape[1] -# weight matrix creation - -if 'pca' in args.methods or args.methods == 'all': +# weight matrix creation with different methods +if 'pca' in args.methods or 'all' in args.methods: from weightmatrices.algos import pca print("creating weight matrix using PCA") for n_h in args.nhidden: if n_h <= n_in_features: # Number of requested components <= input dimensionality W_pca = pca.get_weightmatrices_pca(data_loader, n_h) if args.save: utils.saveweightmatrix('pca'+str(n_h), W_pca) -if 'ica' in args.methods or args.methods == 'all': +if 'ica' in args.methods or 'all' in args.methods: from weightmatrices.algos import ica print("creating weight matrix using ICA") for n_h in args.nhidden: if n_h <= n_in_features: # Number of requested components <= input dimensionality W_ica = ica.get_weightmatrices_ica(data_matrix, n_h) if args.save: utils.saveweightmatrix('ica'+str(n_h), W_ica) -if 'sc' in args.methods or args.methods == 'all': +if 'sc' in args.methods or 'all' in args.methods: from weightmatrices.algos import sc print("creating weight matrix using SC") for n_h in args.nhidden: W_sc = sc.get_weightmatrices_sc(data_matrix, n_h, getsparsity=True) if args.save: utils.saveweightmatrix('sc'+str(n_h), W_sc) -if 'rg' in args.methods or args.methods == 'all': +if 'rg' in args.methods or 'all' in args.methods: from weightmatrices.algos import rg print("creating weight matrix using RG") for n_h in args.nhidden: W_rg = rg.get_weightmatrices_rg(data_matrix, n_h) if args.save: utils.saveweightmatrix('rg'+str(n_h), W_rg) -if 'rp' in args.methods or args.methods == 'all': +if 'rp' in args.methods or 'all' in args.methods: from weightmatrices.algos import rp print("creating weight matrix using RP") for n_h in args.nhidden: W_rp = rp.get_weightmatrices_rp(data_matrix, n_h) if args.save: utils.saveweightmatrix('rp'+str(n_h), W_rp) -# plotting -#plt.ion() -#plt.imshow(W_sc[random.sample(range(0, args.nhidden), 1)[0], :].reshape(28, 28), cmap = 'gray') -#plt.show() # jump to interactive mode import code code.interact(local=locals()) + + +# sample plotting +#import matplotlib +#matplotlib.use('TkAgg') +#import matplotlib.pyplot as plt + +#W = W_pca +#plt.ion() +#plt.imshow(W[random.sample(range(0, args.nhidden[-1]), 1)[0], :].reshape(28, 28), cmap = 'gray') +#plt.show()