diff --git a/sociallstm.py b/sociallstm.py index 9adba9a..a1d7884 100644 --- a/sociallstm.py +++ b/sociallstm.py @@ -1,69 +1,68 @@ import torch import torch.nn as nn from torch.autograd import Variable import project.data_utils as du class SocialLSTM(nn.Module): ''' Class representing the Social LSTM model ''' def __init__(self): super(SocialLSTM,self).__init__() # Store required sizes self.hidden_size = 128 #self.grid_size = args.grid_size self.embedding_size = 64 #self.pooling_size = args.pooling_size #pooling window self.input_size = 2 self.output_size = 2 #parameters of bivariate distribution #self.neighborhood_size = args.neighborhood_size # The LSTM cell. (Social LSTM) embedding size = 64 self.lstm= nn.LSTM(2*self.embedding_size, self.hidden_size, dropout = 0.2) self.lstm2 = nn.LSTM(self.embedding_size, self.hidden_size, dropout = 0.2) # Linear layer to embed the input position into LSTM self.input_embedding_layer = nn.Linear(self.input_size, self.embedding_size) # Linear layer to embed the social tensor #self.tensor_embedding_layer = nn.Linear(self.neighborhood_size*self.neighborhood_size*self.hidden_size, self.embedding_size) # Linear layer to map the hidden state of LSTM to output self.output_layer = nn.Linear(self.hidden_size, self.output_size) # ReLU and dropout unit self.relu = nn.ReLU() #self.dropout = nn.Dropout(0.5) def forward(self, peds, social_tensor, future = 0): ''' Forward pass for the model params: peds: pedestrian coords ''' outputs = [] hidden_states = Variable(torch.zeros(1,1,self.hidden_size)) cell_states = Variable(torch.zeros(1,1,self.hidden_size)) input = self.relu(self.input_embedding_layer(peds)) if social_tensor is None: h_peds, c_peds = self.lstm2(input, (hidden_states, cell_states)) else: social_embed = self.relu(self.input_embedding_layer(social_tensor)) - concat_embed = torch.cat((input,social_embed),1) - print(concat_embed.size()) + concat_embed = torch.cat((input,social_embed),2) h_peds, c_peds = self.lstm(concat_embed, (hidden_states, cell_states)) output = self.output_layer(h_peds) outputs += [output] for i in range(future): #predict future new_out = self.relu(self.input_embedding_layer(output)) #concat_embed = torch.cat((new_out,social_embed),0) h_peds, c_peds = self.lstm2(new_out, (hidden_states, cell_states)) output = self.output_layer(h_peds) outputs += [output] outputs = torch.cat(outputs, 0) return outputs