最近看了一些关于卷积网络(Convolution Neural Network)的内容,我想用mnist数据集重复一下网上的教程,好让自己能更好地理解并使用卷积网络。比较了一下TensorflowPytorch,个人比较喜欢Pytorch,所以以后基本就用它了。Pytorch里是包含获取mnist数据集的api的,使用这api我们可以很简单就能准备好数据。但是出于学习Pytorch的目的,我决定自己准备数据。

之前我在另一篇博文用了广义线性模型Softmax来识别这些手写数字,错误率最低达到了7.77%。使用图像识别领域上常用有效的卷积网络,我们会得到更低的错误率。这点会在本文后面看到。这篇博文不会用太多的文字解释,基本上都是代码。

准备数据

import sys
py = 'Python ' + '.'.join(map(str, sys.version_info[:3]))
print('Jupyter notebook with kernel: {}'.format(py))

import gzip
import time
from urllib.request import urlopen
import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
from torch.autograd import Variable
from torch.utils.data import Dataset, DataLoader
print('Using Pytorch {}'.format(torch.__version__))

def get_data(images_url, labels_url):
    images = gzip.decompress(urlopen(images_url).read())
    labels = gzip.decompress(urlopen(labels_url).read())
    img_num = int.from_bytes(images[4:8], byteorder='big')
    lbl_num = int.from_bytes(labels[4:8], byteorder='big')
    assert(img_num == lbl_num)
    row = int.from_bytes(images[8:12], byteorder='big')
    col = int.from_bytes(images[12:16], byteorder='big')
    img_size = row * col
    x, y = [], []
    for i in range(img_num):
        img_offset = 16 + img_size * i
        lbl_offset = 8 + i
        img = torch.Tensor(list(images[img_offset:img_offset+img_size])).float()
        img = img.view(1, row, col)
        lbl = int(labels[lbl_offset])
        x.append(img)
        y.append(lbl)
    return x, y
Jupyter notebook with kernel: Python 3.6.5
Using Pytorch 0.3.1

我们会从LeCun的网站上下载数据。函数get_data以图像及对应标签的url作为输入,获取数据,并返回(x, y)两个列表。第一个列表x中的元素为(1, 28, 28)的torch.Tensor,第二个列表y是相对应的标签,类型为int。在训练的过程中,我们要用到torch.utils.data.DataLoader,为此,我们要自定义Pytorch数据集。

class MnistData(Dataset):
    def __init__(self, images, labels):
        self.images = images
        self.labels = labels

    def __getitem__(self, idx):
        return self.images[idx], self.labels[idx]

    def __len__(self):
        return len(self.images)

Pytorch的数据集必须要继承torch.utils.data.Dataset,然后要实现Dataset.__getitem__以及Dataset.__len__。我们可以在这里找到一些资料。接下来我们要准备数据了。

batch_size = 100

# train data
images_url = "http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz"
labels_url = "http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz"
images, labels = get_data(images_url, labels_url)
train_data = MnistData(images, labels)
train_loader = DataLoader(dataset=train_data,
                         batch_size=batch_size,
                         shuffle=True)
print("There are {} training samples".format(len(images)))

# test data
images_url = "http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz"
labels_url = "http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz"
images, labels = get_data(images_url, labels_url)
test_data = MnistData(images, labels)
test_loader = DataLoader(dataset=test_data,
                         batch_size=batch_size,
                         shuffle=False)
print("There are {} testing samples".format(len(images)))
There are 60000 training samples
There are 10000 testing samples

建立模型

class CNN(nn.Module):
    def __init__(self):
        super(CNN, self).__init__()
        self.conv1 = nn.Conv2d(in_channels=1,
                               out_channels=16,
                               kernel_size=5,
                               padding=2)
        self.conv2 = nn.Conv2d(in_channels=16,
                              out_channels=32,
                              kernel_size=5,
                              padding=2)
        self.pool = nn.MaxPool2d(kernel_size=2)
        self.activation = nn.ReLU()
        self.fc = nn.Linear(32*49, 10)

    def forward(self, x):
        x = self.activation(self.pool(self.conv1(x)))
        x = self.activation(self.pool(self.conv2(x)))
        x = x.view(x.size(0), -1)
        x = self.fc(x)
        return x

cnn = CNN()

训练模型

# hyper parameters
num_epochs = 5
learning_rate = 0.001

# loss criterion and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(cnn.parameters(), lr=learning_rate)

# train the model
for epoch in range(num_epochs):
    for i, (images, labels) in enumerate(train_loader):
        images = Variable(images)
        labels = Variable(labels)

        outputs = cnn(images)
        loss = criterion(outputs, labels)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if (i+1) % 200 == 0:
            print('epoch {}/{}, batch {}/{}, loss: {:.4f}'.format(
                epoch+1, num_epochs, i+1, len(train_loader), loss.data[0]))
epoch 1/5, batch 200/600, loss: 0.1690
epoch 1/5, batch 400/600, loss: 0.0361
epoch 1/5, batch 600/600, loss: 0.1510
epoch 2/5, batch 200/600, loss: 0.0792
epoch 2/5, batch 400/600, loss: 0.0621
epoch 2/5, batch 600/600, loss: 0.0403
epoch 3/5, batch 200/600, loss: 0.0572
epoch 3/5, batch 400/600, loss: 0.0068
epoch 3/5, batch 600/600, loss: 0.0078
epoch 4/5, batch 200/600, loss: 0.0147
epoch 4/5, batch 400/600, loss: 0.0400
epoch 4/5, batch 600/600, loss: 0.0359
epoch 5/5, batch 200/600, loss: 0.0009
epoch 5/5, batch 400/600, loss: 0.0543
epoch 5/5, batch 600/600, loss: 0.0381

检测模型

cnn.eval()
wrong = 0
total = 0
for images, labels in test_loader:
    images = Variable(images)
    outputs = cnn(images)
    _, predicted = torch.max(outputs.data, 1)
    total += labels.size(0)
    wrong += (predicted != labels).sum()

print('Error rate: {:.2f}%'.format(100 * wrong / total))
Error rate: 1.54%

保存模型

训练好模型之后我们可以将整个模型保存下来:

torch.save(cnn, 'mnist_full_cnn.pkl')
/usr/local/lib/python3.6/site-packages/torch/serialization.py:159: UserWarning: Couldn't retrieve source code for container of type CNN. It won't be checked for correctness upon loading.
  "type " + obj.__name__ + ". It won't be checked "

以上代码不仅仅保存训练得到的权重(weights),还能保存整个模型。别人只需要执行以下代码即可使用这个训练好的模型:

model = torch.load('mnist_full_cnn.pkl')
image, label = test_data[0]
plt.imshow(image.numpy().reshape((28, 28)), cmap='gray_r')
plt.axis('off')
plt.show()

output = model(Variable(image.view(1, 1, 28, 28)))
pred = torch.max(output.data, 1)[1][0]
print('Prediction of the above handwriting is: {}'.format(pred))
Prediction of the above handwriting is: 7

当然我们也可以只保存权重,不保存整个模型。以下代码展示如何保存权重和加载权重。

# save weights only
torch.save(cnn.state_dict(), 'mnist_cnn_weights.pkl')

# load weights to model
cnn2 = CNN()
cnn2.load_state_dict(torch.load('mnist_cnn_weights.pkl'))

# use the model as before
output = cnn2(Variable(image.view(1, 1, 28, 28)))
pred = torch.max(output.data, 1)[1][0]
print('Prediction of the above handwriting is: {}'.format(pred))
Prediction of the above handwriting is: 7