Page MenuHomec4science

example.py
No OneTemporary

File Metadata

Created
Wed, May 1, 19:39

example.py

import torch
from torchvision import datasets
import torchvision.transforms as transforms
import hdtorch
class HDParams():
HDFlavor = 'binary' # 'binary', 'bipol' #binary 0,1, bipolar -1,1
D = 10000 # dimension of hypervectors
numFeat = 784
numClasses = 10
device = 'cpu' # device to use (cpu, cuda)
packed = False
numSegmentationLevels = 20
similarityType = 'hamming' # 'hamming','cosine' #similarity measure used for comparing HD vectors
levelVecType = 'scaleNoRand1' # 'random','sandwich','scaleNoRand1','scaleNoRand2','scaleRand1', ,'scaleRand2'... #defines how HD vectors are initialized
IDVecType = 'random'
bindingStrat = 'IDLevelEncoding' # 'FeatAppend' 'FeatPermute' 'IDLevelEncoding' #defines how HD vectors encoded
hdParams = HDParams()
batchSize = 1000
t = transforms.Compose([transforms.ToTensor(), transforms.ConvertImageDtype(torch.int8), transforms.ConvertImageDtype(int)])
dataTrain = datasets.MNIST(root = './data', train = True, transform = t, download = True)
dataTest = datasets.MNIST(root = './data', train = False, transform = t, download = True)
trainLoader = torch.utils.data.DataLoader(dataset=dataTrain, batch_size=batchSize, shuffle=True)
testLoader = torch.utils.data.DataLoader(dataset=dataTest, batch_size=batchSize, shuffle=False)
HDModel = hdtorch.HD_classifier(HDParams)
minFeat = trainLoader.dataset.data.view(-1,784).min(0)[0]
maxFeat = trainLoader.dataset.data.view(-1,784).max(0)[0]
for x,(data,labels) in enumerate(trainLoader):
data = data.view(-1,784)
data = hdtorch.util.normalizeAndDiscretizeData(data,minFeat, maxFeat, HDParams.numSegmentationLevels)
HDModel.trainModelVecOnData(data,labels)
if x==2:
break
for(data,labels) in testLoader:
data = data.view(-1,784)
data = hdtorch.util.normalizeAndDiscretizeData(data,minFeat, maxFeat, HDParams.numSegmentationLevels)
(testPredictions,testDistances) = HDModel.givePrediction(data)
acc_test = (testPredictions == labels).sum().item()/len(labels)
print(f'Acc Test: {acc_test}')

Event Timeline