MNIST数据集已经是一个被”嚼烂”了的数据集,很多教程都会对它”下手”,几乎成为一个“典范”.不过有些人可能对它还不是很了解,下面来介绍一下.
MNIST数据集来自美国国家标准与技术研究所,NationalInstituteofStandardsandTechnology(NIST).训练集(trainingset)由来自250个不同人手写的数字构成,其中50%是高中学生,50%来自人口普查局(theCensusBureau)的工作人员.测试集(testset)也是同样比例的手写数字数据.
不妨新建一个文件夹–mnist,将数据集下载到mnist以后,解压即可:
图片是以字节的形式进行存储,我们需要把它们读取到NumPyarray中,以便训练和测试算法.
importosimportstructimportnumpyasnpdefload_mnist(path,kind='train'):"""LoadMNISTdatafrom`path`"""labels_path=os.path.join(path,'%s-labels-idx1-ubyte'%kind)images_path=os.path.join(path,'%s-images-idx3-ubyte'%kind)withopen(labels_path,'rb')aslbpath:magic,n=struct.unpack('>II',lbpath.read(8))labels=np.fromfile(lbpath,dtype=np.uint8)withopen(images_path,'rb')asimgpath:magic,num,rows,cols=struct.unpack('>IIII',imgpath.read(16))images=np.fromfile(imgpath,dtype=np.uint8).reshape(len(labels),784)returnimages,labelsload_mnist函数返回两个数组,第一个是一个nxm维的NumPyarray(images),这里的n是样本数(行数),m是特征数(列数).训练数据集包含60,000个样本,测试数据集包含10,000样本.在MNIST数据集中的每张图片由28x28个像素点构成,每个像素点用一个灰度值表示.在这里,我们将28x28的像素展开为一个一维的行向量,这些行向量就是图片数组里的行(每行784个值,或者说每行就是代表了一张图片).load_mnist函数返回的第二个数组(labels)包含了相应的目标变量,也就是手写数字的类标签(整数0-9).
第一次见的话,可能会觉得我们读取图片的方式有点奇怪:
magic,n=struct.unpack('>II',lbpath.read(8))labels=np.fromfile(lbpath,dtype=np.uint8)为了理解这两行代码,我们先来看一下MNIST网站上对数据集的介绍:
TRAININGSETLABELFILE(train-labels-idx1-ubyte):[offset][type][value][description]000032bitinteger0x00000801(2049)magicnumber(MSBfirst)000432bitinteger60000numberofitems0008unsignedbytelabel0009unsignedbytelabel........xxxxunsignedbytelabelThelabelsvaluesare0to9.通过使用上面两行代码,我们首先读入magicnumber,它是一个文件协议的描述,也是在我们调用fromfile方法将字节读入NumPyarray之前在文件缓冲中的item数(n).作为参数值传入struct.unpack的>II有两个部分:
通过执行下面的代码,我们将会从刚刚解压MNIST数据集后的mnist目录下加载60,000个训练样本和10,000个测试样本.
为了了解MNIST中的图片看起来到底是个啥,让我们来对它们进行可视化处理.从featurematrix中将784-像素值的向量reshape为之前的28*28的形状,然后通过matplotlib的imshow函数进行绘制:
importmatplotlib.pyplotaspltfig,ax=plt.subplots(nrows=2,ncols=5,sharex=True,sharey=True,)ax=ax.flatten()foriinrange(10):img=X_train[y_train==i][0].reshape(28,28)ax[i].imshow(img,cmap='Greys',interpolation='nearest')ax[0].set_xticks([])ax[0].set_yticks([])plt.tight_layout()plt.show()我们现在应该可以看到一个2*5的图片,里面分别是0-9单个数字的图片.
此外,我们还可以绘制某一数字的多个样本图片,来看一下这些手写样本到底有多不同:
fig,ax=plt.subplots(nrows=5,ncols=5,sharex=True,sharey=True,)ax=ax.flatten()foriinrange(25):img=X_train[y_train==7][i].reshape(28,28)ax[i].imshow(img,cmap='Greys',interpolation='nearest')ax[0].set_xticks([])ax[0].set_yticks([])plt.tight_layout()plt.show()执行上面的代码后,我们应该看到数字7的25个不同形态:
另外,我们也可以选择将MNIST图片数据和标签保存为CSV文件,这样就可以在不支持特殊的字节格式的程序中打开数据集.但是,有一点要说明,CSV的文件格式将会占用更多的磁盘空间,如下所示:
如果我们打算保存这些CSV文件,在将MNIST数据集加载入NumPyarray以后,我们应该执行下列代码:
np.savetxt('train_img.csv',X_train,fmt='%i',delimiter=',')np.savetxt('train_labels.csv',y_train,fmt='%i',delimiter=',')np.savetxt('test_img.csv',X_test,fmt='%i',delimiter=',')np.savetxt('test_labels.csv',y_test,fmt='%i',delimiter=',')一旦将数据集保存为CSV文件,我们也可以用NumPy的genfromtxt函数重新将它们加载入程序中: