Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- from typing import Callable
- from tinygrad import nn, Tensor, TinyJit
- from tinygrad.nn.datasets import mnist
- class LeNet5:
- def __init__(self, num_classes: int=10):
- self.l1 = nn.Conv2d(1, 6, kernel_size=5, padding=2)
- self.l2 = nn.Conv2d(6, 16, kernel_size=2, stride=2) # 14x14->10x10
- #self.l3 = nn.Linear(16, 120)
- self.l3 = nn.Linear(144, 120)
- self.l4 = nn.Linear(120, 84)
- self.l5 = nn.Linear(84, num_classes)
- def __call__(self, x):
- x = self.l1(x).tanh().avg_pool2d(kernel_size=2, stride=2) # 28x28->14x14
- x = self.l2(x).tanh().avg_pool2d(kernel_size=2, stride=2) # 10x10->5x5
- x = self.l3(x.flatten(1)).tanh()
- x = self.l4(x).tanh()
- x = self.l5(x)
- return x
- @TinyJit
- def jit(self, x):
- return self(x).realize()
- net = LeNet5()
- learning_rate = 3e-4
- batch_size = 64
- opt = nn.optim.SGD(nn.state.get_parameters(net), lr=learning_rate)
- X_train, y_train, X_valid, y_valid = mnist()
- @TinyJit
- @Tensor.train()
- def train_step() -> Tensor:
- samples = Tensor.randint(batch_size, high=X_train.shape[0])
- X, y = X_train[samples], y_train[samples]
- labels = y
- out = net.jit(X)
- opt.zero_grad()
- loss = out.sparse_categorical_crossentropy(labels)
- loss.backward()
- opt.step()
- return loss
Advertisement
Add Comment
Please, Sign In to add comment