
在以前Pytorch只有一种量化的方法,叫做“eager mode qunatization”,在量化我们自定定义模型时经常会产生奇怪的错误,并且很难解决。但是最近,PyTorch发布了一种称为“fx-graph-mode-qunatization”的方方法。在本文中我们将研究这个fx-graph-mode-qunatization”看看它能不能让我们的量化操作更容易,更稳定。

本文将使用CIFAR 10和一个自定义AlexNet模型,我对这个模型进行了小的修改以提高效率,最后就是因为模型和数据集都很小,所以CPU也可以跑起来。

  1. import os
  2. import cv2
  3. import time
  4. import torch
  5. import numpy as np
  6. import torchvision
  7. from PIL import Image
  8. import torch.nn as nn
  9. import matplotlib.pyplot as plt
  10. from torchvision import transforms
  11. from torchvision import datasets, models, transforms
  12. device = "cpu"
  13. print(device)
  14. transform = transforms.Compose([
  15. transforms.Resize(224),
  16. transforms.ToTensor(),
  17. transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
  18. ])
  19. batch_size = 8
  20. trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
  21. download=True, transform=transform)
  22. testset = torchvision.datasets.CIFAR10(root='./data', train=False,
  23. download=True, transform=transform)
  24. trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size,
  25. shuffle=True, num_workers=2)
  26. testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size,
  27. shuffle=False, num_workers=2)
  28. def print_model_size(mdl):
  29. torch.save(mdl.state_dict(), "tmp.pt")
  30. print("%.2f MB" %(os.path.getsize("tmp.pt")/1e6))
  31. os.remove('tmp.pt')


  1. from torch.nn import init
  2. class mAlexNet(nn.Module):
  3. def __init__(self, num_classes=2):
  4. super().__init__()
  5. self.input_channel = 3
  6. self.num_output = num_classes
  7. self.layer1 = nn.Sequential(
  8. nn.Conv2d(in_channels=self.input_channel, out_channels= 16, kernel_size= 11, stride= 4),
  9. nn.ReLU(inplace=True),
  10. nn.MaxPool2d(kernel_size=3, stride=2)
  11. )
  12. init.xavier_uniform_(self.layer1[0].weight,gain= nn.init.calculate_gain('conv2d'))
  13. self.layer2 = nn.Sequential(
  14. nn.Conv2d(in_channels= 16, out_channels= 20, kernel_size= 5, stride= 1),
  15. nn.ReLU(inplace=True),
  16. nn.MaxPool2d(kernel_size=3, stride=2)
  17. )
  18. init.xavier_uniform_(self.layer2[0].weight,gain= nn.init.calculate_gain('conv2d'))
  19. self.layer3 = nn.Sequential(
  20. nn.Conv2d(in_channels= 20, out_channels= 30, kernel_size= 3, stride= 1),
  21. nn.ReLU(inplace=True),
  22. nn.MaxPool2d(kernel_size=3, stride=2)
  23. )
  24. init.xavier_uniform_(self.layer3[0].weight,gain= nn.init.calculate_gain('conv2d'))
  25. self.layer4 = nn.Sequential(
  26. nn.Linear(30*3*3, out_features=48),
  27. nn.ReLU(inplace=True)
  28. )
  29. init.kaiming_normal_(self.layer4[0].weight, mode='fan_in', nonlinearity='relu')
  30. self.layer5 = nn.Sequential(
  31. nn.Linear(in_features=48, out_features=self.num_output)
  32. )
  33. init.kaiming_normal_(self.layer5[0].weight, mode='fan_in', nonlinearity='relu')
  34. def forward(self, x):
  35. x = self.layer1(x)
  36. x = self.layer2(x)
  37. x = self.layer3(x)
  38. # Squeezes or flattens the image, but keeps the batch dimension
  39. x = x.reshape(x.size(0), -1)
  40. x = self.layer4(x)
  41. logits= self.layer5(x)
  42. return logits
  43. model = mAlexNet(num_classes= 10).to(device)


  1. import torch.optim as optim
  2. def train_model(model):
  3. criterion = nn.CrossEntropyLoss()
  4. optimizer = optim.SGD(model.parameters(), lr=0.001, momentum = 0.9)
  5. for epoch in range(2):
  6. running_loss =0.0
  7. for i, data in enumerate(trainloader,0):
  8. inputs, labels = data
  9. inputs, labels = inputs.to(device), labels.to(device)
  10. optimizer.zero_grad()
  11. outputs = model(inputs)
  12. loss = criterion(outputs, labels)
  13. loss.backward()
  14. optimizer.step()
  15. # print statistics
  16. running_loss += loss.item()
  17. if i % 1000 == 999:
  18. print(f'[Ep: {epoch + 1}, Step: {i + 1:5d}] loss: {running_loss / 2000:.3f}')
  19. running_loss = 0.0
  20. return model
  21. model = train_model(model)
  22. PATH = './float_model.pth'
  23. torch.save(model.state_dict(), PATH)



  1. 动态量化 Dynamic qunatization:使权重为整数(训练后)
  2. 静态量化 Static quantization:使权值和激活值为整数(训练后)
  3. 量化感知训练 Quantization aware training:以整数精度对模型进行训练


  1. import torch
  2. from torch.ao.quantization import (
  3. get_default_qconfig_mapping,
  4. get_default_qat_qconfig_mapping,
  5. QConfigMapping,
  6. )
  7. import torch.ao.quantization.quantize_fx as quantize_fx
  8. import copy
  9. # Load float model
  10. model_fp = mAlexNet(num_classes= 10).to(device)
  11. model_fp.load_state_dict(torch.load("./float_model.pth", map_location=device))
  12. # Copy model to qunatize
  13. model_to_quantize = copy.deepcopy(model_fp).to(device)
  14. model_to_quantize.eval()
  15. qconfig_mapping = QConfigMapping().set_global(torch.ao.quantization.default_dynamic_qconfig)
  16. # a tuple of one or more example inputs are needed to trace the model
  17. example_inputs = next(iter(trainloader))[0]
  18. # prepare
  19. model_prepared = quantize_fx.prepare_fx(model_to_quantize, qconfig_mapping,
  20. example_inputs)
  21. # no calibration needed when we only have dynamic/weight_only quantization
  22. # quantize
  23. model_quantized_dynamic = quantize_fx.convert_fx(model_prepared)


  1. print_model_size(model)
  2. print_model_size(model_quantized_dynamic)

可以看到的,减少了0.03 MB或者说模型变为了原来的75%,我们可以通过静态模式量化使其更小:

  1. model_to_quantize = copy.deepcopy(model_fp)
  2. qconfig_mapping = get_default_qconfig_mapping("qnnpack")
  3. model_to_quantize.eval()
  4. # prepare
  5. model_prepared = quantize_fx.prepare_fx(model_to_quantize, qconfig_mapping, example_inputs)
  6. # calibrate
  7. with torch.no_grad():
  8. for i in range(20):
  9. batch = next(iter(trainloader))[0]
  10. output = model_prepared(batch.to(device))




  1. # quantize
  2. model_quantized_static = quantize_fx.convert_fx(model_prepared)



  1. print_model_size(model)
  2. print_model_size(model_quantized_dynamic)
  3. print_model_size(model_quantized_static)



  1. model_to_quantize = mAlexNet(num_classes= 10).to(device)
  2. qconfig_mapping = get_default_qat_qconfig_mapping("qnnpack")
  3. model_to_quantize.train()
  4. # prepare
  5. model_prepared = quantize_fx.prepare_qat_fx(model_to_quantize, qconfig_mapping, example_inputs)
  6. # training loop
  7. model_trained_prepared = train_model(model_prepared)
  8. # quantize
  9. model_quantized_trained = quantize_fx.convert_fx(model_trained_prepared)


  1. print("Regular floating point model: " )
  2. print_model_size( model_fp)
  3. print("Weights only qunatization: ")
  4. print_model_size( model_quantized_dynamic)
  5. print("Weights/Activations only qunatization: ")
  6. print_model_size(model_quantized_static)
  7. print("Qunatization aware trained: ")
  8. print_model_size(model_quantized_trained)


  1. def get_accuracy(model):
  2. correct = 0
  3. total = 0
  4. with torch.no_grad():
  5. for data in testloader:
  6. images, labels = data
  7. images, labels = images, labels
  8. outputs = model(images)
  9. _, predicted = torch.max(outputs.data, 1)
  10. total += labels.size(0)
  11. correct += (predicted == labels).sum().item()
  12. return 100 * correct / total
  13. fp_model_acc = get_accuracy(model)
  14. dy_model_acc = get_accuracy(model_quantized_dynamic)
  15. static_model_acc = get_accuracy(model_quantized_static)
  16. q_trained_model_acc = get_accuracy(model_quantized_trained)
  17. print("Acc on fp_model:" ,fp_model_acc)
  18. print("Acc weigths only quantization:", dy_model_acc)
  19. print("Acc weigths/activations quantization" ,static_model_acc)
  20. print("Acc on qunatization awere trained model:" ,q_trained_model_acc)









