Page Menu
Home
c4science
Search
Configure Global Search
Log In
Files
F60903105
pplm_classification_head.py
No One
Temporary
Actions
Download File
Edit File
Delete File
View Transforms
Subscribe
Mute Notifications
Award Token
Subscribers
None
File Metadata
Details
File Info
Storage
Attached
Created
Fri, May 3, 06:53
Size
655 B
Mime Type
text/x-python
Expires
Sun, May 5, 06:53 (2 d)
Engine
blob
Format
Raw Data
Handle
17437518
Attached To
R11484 ADDI
pplm_classification_head.py
View Options
import
torch
class
ClassificationHead
(
torch
.
nn
.
Module
):
"""Classification Head for transformer encoders"""
def
__init__
(
self
,
class_size
,
embed_size
):
super
()
.
__init__
()
self
.
class_size
=
class_size
self
.
embed_size
=
embed_size
# self.mlp1 = torch.nn.Linear(embed_size, embed_size)
# self.mlp2 = (torch.nn.Linear(embed_size, class_size))
self
.
mlp
=
torch
.
nn
.
Linear
(
embed_size
,
class_size
)
def
forward
(
self
,
hidden_state
):
# hidden_state = F.relu(self.mlp1(hidden_state))
# hidden_state = self.mlp2(hidden_state)
logits
=
self
.
mlp
(
hidden_state
)
return
logits
Event Timeline
Log In to Comment