Page MenuHomec4science

sociallstm.py
No OneTemporary

File Metadata

Created
Wed, May 29, 11:37

sociallstm.py

import torch
import torch.nn as nn
from torch.autograd import Variable
import project.data_utils as du
class SocialLSTM(nn.Module):
'''
Class representing the Social LSTM model
'''
def __init__(self):
super(SocialLSTM,self).__init__()
# Store required sizes
self.hidden_size = 128
#self.grid_size = args.grid_size
self.embedding_size = 64
#self.pooling_size = args.pooling_size #pooling window
self.input_size = 2
self.output_size = 2 #parameters of bivariate distribution
#self.neighborhood_size = args.neighborhood_size
# The LSTM cell. (Social LSTM) embedding size = 64
self.lstm= nn.LSTM(2*self.embedding_size, self.hidden_size, dropout = 0.2)
self.lstm2 = nn.LSTM(self.embedding_size, self.hidden_size, dropout = 0.2)
# Linear layer to embed the input position into LSTM
self.input_embedding_layer = nn.Linear(self.input_size, self.embedding_size)
# Linear layer to embed the social tensor
#self.tensor_embedding_layer = nn.Linear(self.neighborhood_size*self.neighborhood_size*self.hidden_size, self.embedding_size)
# Linear layer to map the hidden state of LSTM to output
self.output_layer = nn.Linear(self.hidden_size, self.output_size)
# ReLU and dropout unit
self.relu = nn.ReLU()
#self.dropout = nn.Dropout(0.5)
def forward(self, peds, social_tensor, future = 0):
'''
Forward pass for the model
params:
peds: pedestrian coords
'''
outputs = []
hidden_states = Variable(torch.zeros(1,1,self.hidden_size))
cell_states = Variable(torch.zeros(1,1,self.hidden_size))
input = self.relu(self.input_embedding_layer(peds))
if social_tensor is None:
h_peds, c_peds = self.lstm2(input, (hidden_states, cell_states))
else:
social_embed = self.relu(self.input_embedding_layer(social_tensor))
concat_embed = torch.cat((input,social_embed),2)
h_peds, c_peds = self.lstm(concat_embed, (hidden_states, cell_states))
output = self.output_layer(h_peds)
outputs += [output]
for i in range(future): #predict future
new_out = self.relu(self.input_embedding_layer(output))
#concat_embed = torch.cat((new_out,social_embed),0)
h_peds, c_peds = self.lstm2(new_out, (hidden_states, cell_states))
output = self.output_layer(h_peds)
outputs += [output]
outputs = torch.cat(outputs, 0)
return outputs

Event Timeline