机器学习:scikit-learn 实现手写数字识别

1.1 数据集简介

  • 来源:http://archive.ics.uci.edu/ml/datasets/Optical+Recognition+of+Handwritten+Digits
  • 类别:0-9 共10个数字
  • 样本数:1797
  • 特征数:64
  • 特征含义:8x8像素,每个像素由0到16之间的整数表示
import numpy as np
from sklearn import datasets
digits = datasets.load_digits()
# 输出数据集的样本数与特征数
print digits.data.shape
# 输出所有目标类别
print np.unique(digits.target)
# 输出数据集
print digits.data
(1797, 64)
[0 1 2 3 4 5 6 7 8 9]
[[  0.   0.   5. ...,   0.   0.   0.]
 [  0.   0.   0. ...,  10.   0.   0.]
 [  0.   0.   0. ...,  16.   9.   0.]
 ...,
 [  0.   0.   1. ...,   6.   0.   0.]
 [  0.   0.   2. ...,  12.   0.   0.]
 [  0.   0.  10. ...,  12.   1.   0.]]

1.2 数据集可视化

import matplotlib.pyplot as plt
# 导入字体管理器,用于提供中文支持
import matplotlib.font_manager as fm
font_set= fm.FontProperties(fname='C:/Windows/Fonts/msyh.ttc', size=14)

# 将图像和目标标签合并到一个列表中
images_and_labels = list(zip(digits.images, digits.target))

# 打印数据集的前8个图像
plt.figure(figsize=(8, 6))
for index, (image, label) in enumerate(images_and_labels[:8]):
    plt.subplot(2, 4, index + 1)
    plt.axis('off')
    plt.imshow(image, cmap=plt.cm.gray_r,interpolation='nearest')
    plt.title(u'训练样本:' + str(label), fontproperties=font_set)

plt.show()

png

# 样本图片效果
plt.figure(figsize=(6, 6))
plt.imshow(digits.images[0], cmap=plt.cm.gray_r, interpolation='nearest')
plt.show()

png

1.3 用 PCA 降维

由于该数据集有 64 个特征值,也就是说有 64 个维度,因此没办法直观地看到数据的分布及其之间的关系。但是,实际起作用的维度可能比特征值的个数要少得多,我们可以通过主成分分析来降低数据集的维度,从而观察样本点之间的关系。

主成分分析(PCA):找到两个变量的线性组合,尽可能保留大部分的信息,这个新的变量(主成分)就可以替代原来的变量。也就是说,PCA就是通过线性变换来产生新的变量,并最大化保留了数据的差异。

from sklearn.decomposition import *

# 创建一个 PCA 模型
pca = PCA(n_components=2)

# 将数据应用到模型上
reduced_data_pca = pca.fit_transform(digits.data)

# 查看维度
print reduced_data_pca.shape
(1797, 2)

1.4 绘制散点图

colors = ['black', 'blue', 'purple', 'yellow', 'white', 'red', 'lime', 'cyan', 'orange', 'gray']
plt.figure(figsize=(8, 6))
for i in range(len(colors)):
    x = reduced_data_pca[:, 0][digits.target == i]
    y = reduced_data_pca[:, 1][digits.target == i]
    plt.scatter(x, y, c=colors[i])
plt.legend(digits.target_names, bbox_to_anchor=(1.05, 1), loc=2, borderaxespad=0.)
plt.xlabel(u'第一个主成分', fontproperties=font_set)
plt.ylabel(u'第二个主成分', fontproperties=font_set)
plt.title(u"PCA 散点图", fontproperties=font_set)
plt.show()

png

2.1 数据归一化

from sklearn.preprocessing import scale

data = scale(digits.data)

print data
[[ 0.         -0.33501649 -0.04308102 ..., -1.14664746 -0.5056698
  -0.19600752]
 [ 0.         -0.33501649 -1.09493684 ...,  0.54856067 -0.5056698
  -0.19600752]
 [ 0.         -0.33501649 -1.09493684 ...,  1.56568555  1.6951369
  -0.19600752]
 ...,
 [ 0.         -0.33501649 -0.88456568 ..., -0.12952258 -0.5056698
  -0.19600752]
 [ 0.         -0.33501649 -0.67419451 ...,  0.8876023  -0.5056698
  -0.19600752]
 [ 0.         -0.33501649  1.00877481 ...,  0.8876023  -0.26113572
  -0.19600752]]

2.2 拆分数据集

将数据集拆分成训练集和测试集

from sklearn.cross_validation import train_test_split

X_train, X_test, y_train, y_test, images_train, images_test = train_test_split(data, digits.target, digits.images, test_size=0.25, random_state=42)

print "训练集", X_train.shape
print "测试集", X_test.shape
训练集 (1347, 64)
测试集 (450, 64)

2.3 使用 SVM 分类器

from sklearn import svm

# 创建 SVC 模型
svc_model = svm.SVC(gamma=0.001, C=100, kernel='linear')

# 将训练集应用到 SVC 模型上
svc_model.fit(X_train, y_train)

# 评估模型的预测效果
print svc_model.score(X_test, y_test)
0.97777777777777775

2.4 优化参数

svc_model = svm.SVC(gamma=0.001, C=10, kernel='rbf')

svc_model.fit(X_train, y_train)

print svc_model.score(X_test, y_test)
0.98222222222222222

3.1 预测结果

import matplotlib.pyplot as plt

# 使用创建的 SVC 模型对测试集进行预测
predicted = svc_model.predict(X_test)

# 将测试集的图像与预测的标签合并到一个列表中
images_and_predictions = list(zip(images_test, predicted))

# 打印前 4 个预测的图像和结果
plt.figure(figsize=(8, 2))
for index, (image, prediction) in enumerate(images_and_predictions[:4]):
    plt.subplot(1, 4, index + 1)
    plt.axis('off')
    plt.imshow(image, cmap=plt.cm.gray_r, interpolation='nearest')
    plt.title(u'预测结果: ' + str(prediction), fontproperties=font_set)

plt.show()

png

3.2 分析结果的准确性

X = np.arange(len(y_test))
# 生成比较列表,如果预测的结果正确,则对应位置为0,错误则为1
comp = [0 if y1 == y2 else 1 for y1, y2 in zip(y_test, predicted)]
plt.figure(figsize=(8, 6))
# 图像发生波动的地方,说明预测的结果有误
plt.plot(X, comp)
plt.ylim(-1, 2)
plt.yticks([])
plt.show()

print "测试集数量:", len(y_test)
print "错误识别数:", sum(comp)
print "识别准确率:", 1 - float(sum(comp)) / len(y_test)

png

测试集数量: 450
错误识别数: 8
识别准确率: 0.982222222222

3.3 错误识别样本分析

# 收集错误识别的样本下标
wrong_index = []
for i, value in enumerate(comp):
    if value: wrong_index.append(i)

# 输出错误识别的样本图像
plt.figure(figsize=(8, 6))
for plot_index, image_index in enumerate(wrong_index):
    image = images_test[image_index]
    plt.subplot(2, 4, plot_index + 1)
    plt.axis('off')
    plt.imshow(image, cmap=plt.cm.gray_r,interpolation='nearest')
    # 图像说明,8->9 表示正确值为8,被错误地识别成了9
    info = "{right}->{wrong}".format(right=y_test[image_index], wrong=predicted[image_index])
    plt.title(info, fontsize=16)

plt.show()

png

参考文章:Python Machine Learning: Scikit-Learn Tutorial (Article)

相关文章

发表评论

电子邮件地址不会被公开。 必填项已用*标注