Guest User

Untitled

a guest
Oct 15th, 2025
95
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 1.35 KB | None | 0 0
  1. from typing import Callable
  2. from tinygrad import nn, Tensor, TinyJit
  3. from tinygrad.nn.datasets import mnist
  4.  
  5. class LeNet5:
  6.     def __init__(self, num_classes: int=10):
  7.         self.l1 = nn.Conv2d(1, 6, kernel_size=5, padding=2)
  8.         self.l2 = nn.Conv2d(6, 16, kernel_size=2, stride=2) # 14x14->10x10
  9.         #self.l3 = nn.Linear(16, 120)
  10.         self.l3 = nn.Linear(144, 120)
  11.         self.l4 = nn.Linear(120, 84)
  12.         self.l5 = nn.Linear(84, num_classes)
  13.  
  14.     def __call__(self, x):
  15.         x = self.l1(x).tanh().avg_pool2d(kernel_size=2, stride=2) # 28x28->14x14
  16.         x = self.l2(x).tanh().avg_pool2d(kernel_size=2, stride=2) # 10x10->5x5
  17.         x = self.l3(x.flatten(1)).tanh()
  18.         x = self.l4(x).tanh()
  19.         x = self.l5(x)
  20.         return x
  21.  
  22.     @TinyJit
  23.     def jit(self, x):
  24.         return self(x).realize()
  25.  
  26. net = LeNet5()
  27.  
  28. learning_rate = 3e-4
  29. batch_size = 64
  30.  
  31. opt = nn.optim.SGD(nn.state.get_parameters(net), lr=learning_rate)
  32.  
  33. X_train, y_train, X_valid, y_valid = mnist()
  34.  
  35. @TinyJit
  36. @Tensor.train()
  37. def train_step() -> Tensor:
  38.     samples = Tensor.randint(batch_size, high=X_train.shape[0])
  39.     X, y = X_train[samples], y_train[samples]
  40.    
  41.     labels = y
  42.     out = net.jit(X)
  43.     opt.zero_grad()
  44.     loss = out.sparse_categorical_crossentropy(labels)
  45.     loss.backward()
  46.     opt.step()
  47.    
  48.     return loss
Advertisement
Add Comment
Please, Sign In to add comment