ICLR’22:
import pytorch_lightning as pl
import torch
torch.random.manual_seed(0)
class ToyModel(pl.LightningModule):
def __init__(self):
super().__init__()
self.rnn = torch.nn.RNN(1,1,batch_first=True)
self.lossf = torch.nn.MSELoss()
def test_step(self,batch,batchidx):
print("IN",batch.shape)
y=self.rnn(batch)[0]
print("OUT",y.shape)
# create a gold value=1 for every timestep
gld = torch.full(y.size(),1.)
loss = self.lossf(y,gld)
print("LOSS",loss.item())
def configure_optimizers(self):
opt = torch.optim.SGD(self.parameters(),lr=0.01)
return opt
class ToyDataset(torch.utils.data.Dataset):
def __init__(self):
super().__init__()
self.data = []
self.data.append(torch.tensor([0.,1.,2.]).view(-1,1))
self.data.append(torch.tensor([0.,1.,2.,3.]).view(-1,1))
def __len__(self):
return len(self.data)
def __getitem__(self,i):
x = self.data[i]
return x
def custom_collate(data):
x = torch.nn.utils.rnn.pad_sequence(data,batch_first=True)
return x
mod=ToyModel()
dd=ToyDataset()
ds = torch.utils.data.DataLoader(dd, batch_size=2, collate_fn=custom_collate)
trainer = pl.Trainer(max_epochs=1)
trainer.test(mod,ds)
x = torch.nn.utils.rnn.pack_padded_sequence(x, [3,4], batch_first=True, enforce_sorted=False)
y,yfin = self.rnn(x)
y,_ = torch.nn.utils.rnn.pad_packed_sequence(y, batch_first=True)
mask = torch.full(y.size(),1.)
if mask.shape[0]>1: mask[0,3,0]=0.
non_zero_elements = mask.sum()
loss = self.lossf(y,gld)
loss = (loss * mask.float()).sum() # zero masked elts
loss = loss / non_zero_elements