Advertisement
Guest User

Untitled

a guest
Feb 26th, 2025
37
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 7.18 KB | None | 0 0
  1. import argparse
  2.  
  3. import modelopt.torch.quantization as mtq
  4. import torch
  5. import torch.nn as nn
  6. import torch.nn.functional as F
  7. import torch_tensorrt as torchtrt
  8. import torchvision.datasets as datasets
  9. import torchvision.transforms as transforms
  10. from modelopt.torch.quantization.utils import export_torch_mode
  11.  
  12.  
  13. class VGG(nn.Module):
  14.     def __init__(self, layer_spec, num_classes=1000, init_weights=False):
  15.         super(VGG, self).__init__()
  16.  
  17.         layers = []
  18.         in_channels = 3
  19.         for l in layer_spec:
  20.             if l == "pool":
  21.                 layers.append(nn.MaxPool2d(kernel_size=2, stride=2))
  22.             else:
  23.                 layers += [
  24.                     nn.Conv2d(in_channels, l, kernel_size=3, padding=1),
  25.                     nn.BatchNorm2d(l),
  26.                     nn.ReLU(),
  27.                 ]
  28.                 in_channels = l
  29.  
  30.         self.features = nn.Sequential(*layers)
  31.         self.avgpool = nn.AdaptiveAvgPool2d((7, 7))
  32.         self.classifier = nn.Sequential(
  33.             nn.Linear(512 * 7 * 7, 4096),
  34.             nn.ReLU(),
  35.             nn.Dropout(),
  36.             nn.Linear(4096, 4096),
  37.             nn.ReLU(),
  38.             nn.Dropout(),
  39.             nn.Linear(4096, num_classes),
  40.         )
  41.         if init_weights:
  42.             self._initialize_weights()
  43.  
  44.     def _initialize_weights(self):
  45.         for m in self.modules():
  46.             if isinstance(m, nn.Conv2d):
  47.                 nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
  48.                 if m.bias is not None:
  49.                     nn.init.constant_(m.bias, 0)
  50.             elif isinstance(m, nn.BatchNorm2d):
  51.                 nn.init.constant_(m.weight, 1)
  52.                 nn.init.constant_(m.bias, 0)
  53.             elif isinstance(m, nn.Linear):
  54.                 nn.init.normal_(m.weight, 0, 0.01)
  55.                 nn.init.constant_(m.bias, 0)
  56.  
  57.     def forward(self, x):
  58.         x = self.features(x)
  59.         x = self.avgpool(x)
  60.         x = torch.flatten(x, 1)
  61.         x = self.classifier(x)
  62.         return x
  63.  
  64.  
  65. def vgg16(num_classes=1000, init_weights=False):
  66.     vgg16_cfg = [
  67.         64,
  68.         64,
  69.         "pool",
  70.         128,
  71.         128,
  72.         "pool",
  73.         256,
  74.         256,
  75.         256,
  76.         "pool",
  77.         512,
  78.         512,
  79.         512,
  80.         "pool",
  81.         512,
  82.         512,
  83.         512,
  84.         "pool",
  85.     ]
  86.     return VGG(vgg16_cfg, num_classes, init_weights)
  87.  
  88.  
  89. PARSER = argparse.ArgumentParser(
  90.     description="Load pre-trained VGG model and then tune with FP8 and PTQ. For having a pre-trained VGG model, please refer to https://github.com/pytorch/TensorRT/tree/main/examples/int8/training/vgg16"
  91. )
  92. PARSER.add_argument(
  93.     "--ckpt", type=str, required=True, help="Path to the pre-trained checkpoint"
  94. )
  95. PARSER.add_argument(
  96.     "--batch-size",
  97.     default=128,
  98.     type=int,
  99.     help="Batch size for tuning the model with PTQ and FP8",
  100. )
  101. PARSER.add_argument(
  102.     "--quantize-type",
  103.     default="int8",
  104.     type=str,
  105.     help="quantization type, currently supported int8 or fp8 for PTQ",
  106. )
  107. args = PARSER.parse_args()
  108.  
  109. model = vgg16(num_classes=1000, init_weights=False)
  110. model = model.cuda()
  111.  
  112.  
  113. ckpt = torch.load(args.ckpt)
  114. weights = ckpt
  115.  
  116.  
  117. model.load_state_dict(weights)
  118. # Don't forget to set the model to evaluation mode!
  119. model.eval()
  120.  
  121. training_dataset = datasets.CIFAR10(
  122.     root="./data",
  123.     train=True,
  124.     download=True,
  125.     transform=transforms.Compose(
  126.         [
  127.             transforms.RandomCrop(32, padding=4),
  128.             transforms.RandomHorizontalFlip(),
  129.             transforms.ToTensor(),
  130.             transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
  131.         ]
  132.     ),
  133. )
  134. training_dataloader = torch.utils.data.DataLoader(
  135.     training_dataset,
  136.     batch_size=args.batch_size,
  137.     shuffle=True,
  138.     num_workers=2,
  139.     drop_last=True,
  140. )
  141.  
  142. data = iter(training_dataloader)
  143. images, _ = next(data)
  144.  
  145. crit = nn.CrossEntropyLoss()
  146.  
  147. def calibrate_loop(model):
  148.     # calibrate over the training dataset
  149.     total = 0
  150.     correct = 0
  151.     loss = 0.0
  152.     for data, labels in training_dataloader:
  153.         data, labels = data.cuda(), labels.cuda(non_blocking=True)
  154.         out = model(data)
  155.         loss += crit(out, labels)
  156.         preds = torch.max(out, 1)[1]
  157.         total += labels.size(0)
  158.         correct += (preds == labels).sum().item()
  159.  
  160.     print("PTQ Loss: {:.5f} Acc: {:.2f}%".format(loss / total, 100 * correct / total))
  161.  
  162. if args.quantize_type == "int8":
  163.     quant_cfg = mtq.INT8_DEFAULT_CFG
  164. elif args.quantize_type == "fp8":
  165.     quant_cfg = mtq.FP8_DEFAULT_CFG
  166. # PTQ with in-place replacement to quantized modules
  167. mtq.quantize(model, quant_cfg, forward_loop=calibrate_loop)
  168. # model has FP8 qdq nodes at this point
  169.  
  170.  
  171. # Load the testing dataset
  172. testing_dataset = datasets.CIFAR10(
  173.     root="./data",
  174.     train=False,
  175.     download=True,
  176.     transform=transforms.Compose(
  177.         [
  178.             transforms.ToTensor(),
  179.             transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
  180.         ]
  181.     ),
  182. )
  183.  
  184. testing_dataloader = torch.utils.data.DataLoader(
  185.     testing_dataset,
  186.     batch_size=args.batch_size,
  187.     shuffle=False,
  188.     num_workers=2,
  189.     drop_last=True,
  190. )  # set drop_last=True to drop the last incomplete batch for static shape `torchtrt.dynamo.compile()`
  191.  
  192. with torch.no_grad():
  193.     with export_torch_mode():
  194.         # Compile the model with Torch-TensorRT Dynamo backend
  195.         input_tensor = images.cuda()
  196.         # torch.export.export() failed due to RuntimeError: Attempting to use FunctionalTensor on its own. Instead, please use it with a corresponding FunctionalTensorMode()
  197.         from torch.export._trace import _export
  198.  
  199.         exp_program = _export(model, (input_tensor,))
  200.         if args.quantize_type == "int8":
  201.             enabled_precisions = {torch.int8}
  202.         elif args.quantize_type == "fp8":
  203.             enabled_precisions = {torch.float8_e4m3fn}
  204.         trt_model = torchtrt.dynamo.compile(
  205.             exp_program,
  206.             inputs=[input_tensor],
  207.             enabled_precisions=enabled_precisions,
  208.             min_block_size=1,
  209.             debug=True,
  210.         )
  211.         # You can also use torch compile path to compile the model with Torch-TensorRT:
  212.         # trt_model = torch.compile(model, backend="tensorrt")
  213.  
  214.         # Inference compiled Torch-TensorRT model over the testing dataset
  215.         total = 0
  216.         correct = 0
  217.         loss = 0.0
  218.         class_probs = []
  219.         class_preds = []
  220.         for data, labels in testing_dataloader:
  221.             data, labels = data.cuda(), labels.cuda(non_blocking=True)
  222.             out = trt_model(data)
  223.             loss += crit(out, labels)
  224.             preds = torch.max(out, 1)[1]
  225.             class_probs.append([F.softmax(i, dim=0) for i in out])
  226.             class_preds.append(preds)
  227.             total += labels.size(0)
  228.             correct += (preds == labels).sum().item()
  229.  
  230.         test_probs = torch.cat([torch.stack(batch) for batch in class_probs])
  231.         test_preds = torch.cat(class_preds)
  232.         test_loss = loss / total
  233.         test_acc = correct / total
  234.         print("Test Loss: {:.5f} Test Acc: {:.2f}%".format(test_loss, 100 * test_acc))
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement