1import torch 2from torch import nn 3 4 5 6class Swish(nn.Module): 7 def __init__(self): 8 super().__init__() 9 self.sigmoid = nn.Sigmoid() 10 11 def forward(self, x: torch.Tensor) -> torch.Tensor: 12 return x * self.sigmoid(x) 13 14