Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- import torch
- import torch.nn as nn
- class Model(nn.Module):
- def __init__(self):
- super(Model, self).__init__()
- self.linear = nn.Linear(1,1)
- def forward(self, x):
- y = self.linear(x)
- return y
- x = torch.zeros(1)
- model = Model()
- trace = torch.jit.trace(x)(model)
- trace.save('trace.pt')
Add Comment
Please, Sign In to add comment