关键词

notMNIST 数据集pyTorch分类

简介

notMNIST数据集 是于2011公布的,可以认为是MNIST数据集地一个加强版本。数据集包含了从A到J十个字母,由large与small两个子集组成。其中samll数据集是经过手工清理的,包含19k个图片,误分类率越为0.5%,large数据集是未经过手工清理的,包含500k张图片,误分类率约为6.5%。

作者推荐在large数据集上训练网络,在small数据集上测试网络。可以将large数据集分为5/6和1/6,使用5/6做training,1/6做validation。

在该网站上网友做的正确率较高的再97%到98%,我自己使用resnet最高达到了98.04%。接下来就说一下我做的步骤。

分类

数据预处理

一步要解决的是数据集的加载。原始数据集是一些很小地图片,一个一个地从磁盘中加载无疑会拖慢模型训练的速度。最好的方式就是将所有数据都加载到内存中。因此,可以将数据加载到内存中,并将标准化之后的数据以二进制文件使用pickle保存到磁盘。这样,每次从磁盘中读取数据可以直接读取二进制文件,否则每次读取数据集中地图片都会耗时很久。

import os, cv2, pickle
import numpy as np
rootdir = 'D:/DataSet/notMNIST/notMNIST_large'
classlist = os.listdir(rootdir)
imgLabels = []
imgNames = []
for classes in classlist:
    imgFolder = os.path.join(rootdir, classes)
    imgnames = os.listdir(imgFolder)
    imgLabels.extend([idxName[classes]] * len(imgnames))
    imgNames.extend([os.path.join(imgFolder, img) for img in imgnames])
 
imgs = np.zeros((len(imgLabels), 28, 28), np.float)
idx = 0
print('loading training data......')
for imgname in imgNames:
    try:
        img = cv2.imread(imgname, 0).astype(np.float) / 255.0
        imgs[idx, :, :] = img
        idx += 1
    except AttributeError:
        np.delete(imgs, idx, axis=0)
print('loading training data finished, %d samples' % imgs.shape[0])

train_mean, train_std = np.mean(imgs), np.std(imgs)
print('%.6f, %6f', train_mean, train_std)
imgs = (imgs - train_mean) / train_std
data = {'images': imgs, 'labels': imgLabels}

with open('D:/DataSet/notMNIST/trainset', 'wb') as f:
    pickle.dump(data, f)
print('train set finished')


rootdir = 'D:/DataSet/notMNIST/notMNIST_small'
classlist = os.listdir(rootdir)
imgLabels = []
imgNames = []
for classes in classlist:
    imgFolder = os.path.join(rootdir, classes)
    imgnames = os.listdir(imgFolder)
    imgLabels.extend([idxName[classes]] * len(imgnames))
    imgNames.extend([os.path.join(imgFolder, img) for img in imgnames])

imgs = np.zeros((len(imgLabels), 28, 28), np.float)
idx = 0
print('loading test data......')
for imgname in imgNames:
    try:
        img = cv2.imread(imgname, 0).astype(np.float) / 255.0
        imgs[idx, :, :] = img
        idx += 1
    except AttributeError:
        np.delete(imgs, idx, axis=0)
print('loading test data finished. % d samples' % imgs.shape[0])

train_mean, train_std = np.mean(imgs), np.std(imgs)
imgs = (imgs - train_mean) / train_std
data = {'images': imgs, 'labels': imgLabels}

with open('D:/DataSet/notMNIST/testset', 'wb') as f:
    pickle.dump(data, f)
print('test set finished')

使用try语句地原因是,在读取过程中可能出现一些错误。

本文链接:http://task.lmcjl.com/news/1549.html

展开阅读全文