机器学习:模型动态增量训练
介绍
通过前面的内容,相信你已经对于使用 scikit-learn 保存、部署模型非常熟悉。本次挑战中,你会了解到什么是增量训练,以及动态增量模型的部署及调用。
知识点
- 动态模型
- 增量训练
- 实时手写字符识别
静态模型和动态模型
上一个文章中,我们了解到了如何将机器学习模型部署到线上,并完成动态推理。实际上,除了推理过程有动态和静态之分,机器学习模型的训练过程也有动态和静态两类。
- 静态模型采用离线训练方式。一般只训练模型一次,然后长时间使用该模型。
- 动态模型采用在线训练方式。数据会不断进入系统,通过不断地更新系统将这些数据整合到模型中。
前面的文章中,我们都采用了离线训练并保存静态模型的方式。而实际上,当你将一个机器学习模型部署到线上时,你可能会想让该模型动态学习更多新的数据,并持续更新。
上面的过程可以这样理解。离线训练使用大量的本地数据来训练模型,此时如果输入增量数据,模型会在已优化的参数条件下继续学习。这样的好处在于,模型是持续学习的过程,而不是每次都从头再来。
当然,想法是非常好的。但是并不是每一种模型都支持在线(增量)训练,这需要根据模型的自身的特征和所使用机器学习框架来决定。
scikit-learn 中, 支持增量训练 的算法有:
- 分类算法
sklearn.naive_bayes.MultinomialNB
sklearn.naive_bayes.BernoulliNB
sklearn.linear_model.Perceptron
sklearn.linear_model.SGDClassifier
sklearn.linear_model.PassiveAggressiveClassifier
sklearn.neural_network.MLPClassifier
- 回归算法
sklearn.linear_model.SGDRegressor
sklearn.linear_model.PassiveAggressiveRegressor
sklearn.neural_network.MLPRegressor
下面,我们使用人工神经网络来完成模型动态增量训练及部署过程。这里同样选择前面用过的 DIGITS 手写字符数据集。为了需要,我们将手写字符矩阵中大于 0 的值全部替换为 1。
from sklearn.datasets import load_digits
digits = load_digits() # 加载数据集
digits.data.shape, digits.target.shape
然后,将数据集切分为训练集和测试数据集。
from sklearn.model_selection import train_test_split
X_train, X_test, y_train, y_test = train_test_split(
digits.data, digits.target, random_state=1, test_size=0.2)
X_train.shape, X_test.shape, y_train.shape, y_test.shape
接下来,使用 Train 数据训练模型,并使用测试数据评估。在 MLPClassifier 中添加 verbose=1
可以输出每一步迭代的损失值。
from sklearn.metrics import accuracy_score
from sklearn.neural_network import MLPClassifier
model = MLPClassifier(random_state=1, verbose=1, max_iter=50)
model.fit(X_train, y_train) # 训练模型
y_pred = model.predict(X_test) # 测试模型
accuracy_score(y_test, y_pred) # 准确度
可以看的,模型在测试集上得到约等于 98% 的准确度。下面,我们就找到那些被模型错误预测的样本。
n = 0
for i, (pred, test) in enumerate(zip(y_pred, y_test)):
if pred != test:
print('样本索引:', i, '被错误预测为: ', pred, '正确标签为: ', test)
n += 1
print("总计错误预测样本数量:", n)
现在,可以使用 Matplotlib 绘制出被错误预测的样本,看看是不是容易被混淆。
from matplotlib import pyplot as plt
%matplotlib inline
plt.imshow(X_test[108].reshape((8, 8)), cmap=plt.cm.gray_r)
随意挑选几个错误预测样本打印,你会发现的确连人眼都不容易分清楚。
动态增量训练
既然,当然训练的模型存在错误预测结果,那么如果我们让模型来学习这些样本,并人为告诉它正确结果,模型不就完成了增量训练了吗?
scikit-learn 中,增量训练的方法是 model.partial_fit(X, y)
,其使用方法与 model.fit(X, y)
别无二致。
接下来,我们就利用上面已经训练好的模型,对错误预测样本进行增量学习。
import numpy as np
addition_index = []
for i, (pred, test) in enumerate(zip(y_pred, y_test)):
if pred != test:
addition_index.append(i)
addition_X = X_test[addition_index] # 错误预测样本特征
addition_y = y_test[addition_index] # 错误预测样本正确标签
# 增量训练模型
model.partial_fit(addition_X, addition_y)
model
下面,我们重新使用模型来对测试数据进行预测,并重新打印出错误预测的样本。
y_pred = model.predict(X_test) # 测试模型
accuracy_score(y_test, y_pred) # 准确度
# 打印错误预测样本
n = 0
for i, (pred, test) in enumerate(zip(y_pred, y_test)):
if pred != test:
print('样本索引:', i, '被错误预测为: ', pred, '正确标签为: ', test)
n += 1
print("总计错误预测样本数量:", n)
可以看的,错误预测样本的总数减少了。不过,部分样本依旧无法正确预测,且由于增量学习样本的输入,导致模型参数的整体变动,所以也可能发生之前正确预测的样本被错误预测的现象。
当然,如果错误预测样本总数并未减少,就可以多次重复执行上方两个单元格让模型不断学习错误样本,应该能看到更为直观的效果。
接下来,我们完成一个有意思的过程。文章打算构建一个可以部署到线上的手写字符识别系统,使之可以实现对用户绘制的字符进行预测。
这里预先实现一段代码,使你可以在 Jupyter Notebook 环境中手动绘制一个字符。直接运行下面单元格即可。
from IPython.display import HTML
input_form = """
<table>
<td style="border-style: none;">
<div style="border: solid 2px #666; width: 43px; height: 44px;">
<canvas width="40" height="40"></canvas>
</div></td>
<td style="border-style: none;">
<button onclick="clear_value()">重绘</button>
</td>
</table>
"""
javascript = '''
<script type="text/Javascript">
var pixels = [];
for (var i = 0; i < 8*8; i++) pixels[i] = 0;
var click = 0;
var canvas = document.querySelector("canvas");
canvas.addEventListener("mousemove", function(e){
if (e.buttons == 1) {
click = 1;
canvas.getContext("2d").fillStyle = "rgb(0,0,0)";
canvas.getContext("2d").fillRect(e.offsetX, e.offsetY, 5, 5);
x = Math.floor(e.offsetY * 0.2);
y = Math.floor(e.offsetX * 0.2) + 1;
for (var dy = 0; dy < 1; dy++){
for (var dx = 0; dx < 1; dx++){
if ((x + dx < 8) && (y + dy < 8)){
pixels[(y+dy)+(x+dx)*8] = 1;
}
}
}
} else {
if (click == 1) set_value();
click = 0;
}
});
function set_value(){
var result = ""
for (var i = 0; i < 8*8; i++) result += pixels[i] + ","
var kernel = IPython.notebook.kernel;
kernel.execute("image = [" + result + "]");
kernel.execute("f = open('digits.json', 'w')");
kernel.execute("f.write('{\\"inputs\\":%s}' % image)");
kernel.execute("f.close()");
}
function clear_value(){
canvas.getContext("2d").fillStyle = "rgb(255,255,255)";
canvas.getContext("2d").fillRect(0, 0, 40, 40);
for (var i = 0; i < 8*8; i++) pixels[i] = 0;
}
</script>
'''
randint = np.random.randint(0, 9)
print(f"请在下方图框中细心绘制手写字符 {randint}")
HTML(input_form + javascript)
由于输入框较小,你可以通过放大浏览器页面用鼠标进行书写。绘制的字符会自动保存存为 digits.json
文件到当前目录下方。然后,我们读取该文件,并将图像绘制出来。
import json
import numpy as np
with open("digits.json") as f:
inputs = f.readlines()[0]
inputs_array = np.array(json.loads(inputs)['inputs'])
plt.imshow(inputs_array.reshape((8, 8)), cmap=plt.cm.gray_r)
你会发现,由于 DIGITS 数据集的图像分辨率为 $8 \times 8$ 像素,处理之后的图像会与绘制图像稍有区别。于此同时,因为我们上方绘制的图像为 2 值图像,即黑色像素数值存为 1,白色像素存为 0。所以,下面我们需要重新训练 DIGITS 模型,使之匹配。我们将 digits.data
中大于 0 的值全部替换为 1,并使用全部数据用于训练。
# 重新训练神经网络
digits.data[digits.data > 0] = 1
model = MLPClassifier(tol=0.001, max_iter=50, verbose=1)
model.fit(digits.data, digits.target)
下面,就可以用刚刚训练好的模型来预测自行绘制的手写字符了。我们对每次预测结果进行增量训练来改善模型。如果预测正确,增量训练可以将此样本纳入模型中。如果预测错误,增量训练依据可以起到持续改善模型的效果。
inputs_array = np.atleast_2d(inputs_array) # 将其处理成 2 维数组
result = model.predict(inputs_array) # 预测
if result != randint:
print(f"预测错误|预测标签: {result}|真实标签: {randint}")
model.partial_fit(inputs_array, np.atleast_1d(randint))
print("完成增量训练")
else:
print(f"预测正确|预测标签: {result}|真实标签: {randint}")
model.partial_fit(inputs_array, np.atleast_1d(randint))
print("完成增量训练")
由于神经网络可以输出不同标签预测的概率,所以最后看一下网络对输入图像属于类别的评判依据。
# 输出神经网络对各类别的概率值
pred_proba = model.predict_proba(np.atleast_2d(inputs_array))
# 绘制柱形图
plt.xticks(range(10))
plt.bar(range(10), pred_proba[0], align='center')
上方柱形图值越大,即代表网络认为输入图像属于该类别的概率更高。
特别说明的是,你可以尝试多次重复执行上方两个单元格,即反复增量训练自定义手写字符,应该可以看到增量训练使得正确标签的概率越来越高,这就是通过增量训练来优化模型的直观效果。
小结
这篇文章中,我们了解了机器学习模型的静态训练和动态训练过程,特别对动态增量训练进行了学习。增量训练在机器学习工程领域有广泛应用,部署在线上的模型需要持续不断地改善才会越来越好。
实际上,你可以借助于前面模型部署的思路来实现一个线上实时手写字符识别应用。并收集每次识别的结果对模型进行增量训练。当然,这需要你对 Flask 等 Web 框架有熟悉的了解,有兴趣可以 学习此示例。
相关链接
系列文章
- 机器学习:综述及示例
- 机器学习:线性回归实现与应用
- 机器学习:多项式回归实现与应用
- 机器学习:岭回归和 LASSO 回归实现
- 机器学习:回归模型评价与检验
- 机器学习:逻辑回归实现与应用
- 机器学习:K-近邻算法实现与应用
- 机器学习:朴素贝叶斯实现及应用
- 机器学习:分类模型评价方法
- 机器学习:支持向量机实现与应用
- 机器学习:决策树实现与应用
- 机器学习:装袋和提升集成学习方法
- 机器学习:划分聚类方法实现与应用
- 机器学习:层次聚类方法实现与应用
- 机器学习:主成分分析原理及应用
- 机器学习:密度聚类方法实现与应用
- 机器学习:谱聚类及其他聚类方法应用
- 机器学习:自动化机器学习综述
- 机器学习:自动化机器学习实践应用
- 机器学习:模型动态增量训练
- 机器学习:模型推理与部署