Page MenuHomec4science

test.py
No OneTemporary

File Metadata

Created
Tue, Feb 25, 01:00
import unittest
import numpy as np
from PCAnet import *
class TestFilterBank(unittest.TestCase):
'''def test_im2col_gen(self):
r = np.random.RandomState(7)
Im1 = np.floor(r.rand(10,10) * 256).reshape(10,10,1)
Im2 = np.floor(r.rand(10,10) * 256).reshape(10,10,1)
InIm = np.array([Im1, Im2])
patches = im2col_gen(Im1, 7)
s = patches.shape
patches_ref = np.loadtxt('./test/im2colgen_nosliding.txt')
patches_ref = patches_ref.reshape((s[0], s[1], s[2]))
np.testing.assert_array_equal(patches, patches_ref)
patches = im2col_gen(Im1, 7, 2)
s = patches.shape
patches_ref = np.loadtxt('./test/im2colgen_sliding2x2.txt')
patches_ref = patches_ref.reshape((s[0], s[1], s[2]))
np.testing.assert_array_equal(patches, patches_ref)
def test_im2col_mean_removal(self):
r = np.random.RandomState(7)
Im1 = np.floor(r.rand(10,10) * 256).reshape(10,10,1)
Im2 = np.floor(r.rand(10,10) * 256).reshape(10,10,1)
InIm = np.array([Im1, Im2])
patches = im2col_mean_removal(Im2, 7)
s = patches.shape
patches_ref = np.loadtxt('./test/im2col_mean_rem_nosliding.txt')
patches_ref = patches_ref.reshape((s[0], s[1], s[2]))
np.testing.assert_array_almost_equal(patches, patches_ref)
patches = im2col_mean_removal(Im2, 7, 3)
s = patches.shape
patches_ref = np.loadtxt('./test/im2col_mean_rem_sliding3x3.txt')
patches_ref = patches_ref.reshape((s[0], s[1], s[2]))
np.testing.assert_array_almost_equal(patches, patches_ref)
def test_filter_bank(self):
r = np.random.RandomState(7)
Im1 = np.floor(r.rand(10,10) * 256).reshape(10,10,1)
Im2 = np.floor(r.rand(10,10) * 256).reshape(10,10,1)
InIm = np.array([Im1, Im2])
pcanet = PCAnet(rdm=r)
f = pcanet.filter_bank(InIm, 7, 8)
f_ref = np.loadtxt('./test/filter_bank.txt')
np.testing.assert_array_almost_equal(f, f_ref, decimal=5)
def test_filter_output(self):
r = np.random.RandomState(7)
Im1 = np.floor(r.rand(10,10) * 256).reshape(10,10,1)
Im2 = np.floor(r.rand(10,10) * 256).reshape(10,10,1)
InIm = np.array([Im1, Im2])
pcanet = PCAnet(rdm=r)
V = pcanet.filter_bank(InIm, 7, 8)
out, out_idx = pcanet.filter_output(InIm, np.array([0,1]), 7, 8, V)'''
def test_extract(self):
r = np.random.RandomState(1)
Im1 = np.floor(r.rand(40,40) * 256).reshape(40,40,1)
Im2 = np.floor(r.rand(40,40) * 256).reshape(40,40,1)
InIm = np.array([Im1, Im2])
params = {
'num_stages': 2,
'patch_dim': np.array([7, 7]),
'num_filters': np.array([8, 8]),
'hist_size': np.array([7, 7]),
'overlap_ratio': 0.5
}
pcanet = PCAnet(rdm=r, params=params)
V = pcanet.train(InIm)
f = pcanet.extract(InIm, V)
if __name__ == '__main__':
unittest.main()

Event Timeline