跳转至

第13章 模型的推理与部署⚓︎

13.0 手工测试训练效果⚓︎

在辛辛苦苦经过短则数分钟,长则几天几个月训练之后,我们得到了一系列的能够满足我们需求的权重矩阵。将输入数据按一定的顺序和与权重矩阵进行运算,就可以得到对应的输出。这个过程就是推理的过程。

如果没有保存这些权重矩阵,那么每次使用之前,需要重新训练,这在计算资源有限的情况下是难以接受的。比方说,每次打开一个网站等待需要等待一天,因为这个时候后台在训练一个很大的模型,需要一天的时间结束训练,然后花10秒钟给出我们想要的结果。这样子的应用是可以接受的嘛?

另一个选择就是将训练好的权重矩阵保存下来,需要使用的时候重新加载权重矩阵。花费10秒钟加载,再花10秒钟把结果输出给用户!只需要20秒,这个应用运行结束了!

13.0.1 训练结果的保存与加载⚓︎

在上一章的三层神经网络例子中,我们在训练完成之后,有一行代码如下:

self.SaveResult()

它的功能就是把训练好的网络各层的权重和偏移参数保存到文件中,这样我们以后可以在使用这个网络推理时,方便地把这些参数重新加载到网络中,而不需要重新训练。前提是训练和推理的网络结构要一模一样。

再次加载训练结果代码如下:

def LoadNet():
    n_input = 784
    n_hidden1 = 64
    n_hidden2 = 16
    n_output = 10
    eta = 0.2
    eps = 0.01
    batch_size = 128
    max_epoch = 40

    hp = HyperParameters3(
        n_input, n_hidden1, n_hidden2, n_output, 
        eta, max_epoch, batch_size, eps, 
        NetType.MultipleClassifier, 
        InitialMethod.Xavier)
    net = NeuralNet3(hp, "MNIST_64_16")
    net.LoadResult()

其中,net.LoadResult()完成了关键动作,其它代码都是复现训练时的网络结构。推理时,只需要调用net.inference()方法即可:

def Inference(img_array):
    output = net.inference(img_array)
    n = np.argmax(output)
    print("------recognize result is: {0} -----", n)

13.0.2 搭建应用⚓︎

这一节用Python搭建一个交互式界面来实现能手写输入的程序,从而验证MNIST的训练结果。需要用到pillow图像工具包和matplotlib绘图工具,可以使用如下命令安装:

pip install pillow matplotlib

搭建交互式应用的实现步骤有:

  1. 重现神经网络结构,加载权重和偏移数据
  2. 显示界面,让用户可以用鼠标或者手指(触摸屏)写一个数字
  3. 写好一个数字后,收集数据,转换数据,并触发推理过程
  4. 得到推理结果
  5. 清除界面,做下一次测试

加载权重数据的过程已在上一节说过了,下一步是用matplotlib提供的功能显示一个绘图面板,如图13-1所示。

图13-1 用matplotlib显示的绘图面板

这是一个方形的空面板,坐标为[0,1],但实际上它的初始尺寸是640\times 480,在使用之前,我们先把它拉成一个正方形(宽高大致近似即可),因为在训练时,我们使用的训练集是28\times 28的正方形,所以要求推理用的数据也是正方形的。

然后需要在这个面板上注册事件,以响应鼠标和键盘输入:

# 加载权重和偏移数据
net = LoadNet()
# 注册事件
fig, ax = plt.subplots()
# 键盘事件
fig.canvas.mpl_connect('key_press_event', on_key_press)
# 鼠标释放
fig.canvas.mpl_connect('button_release_event', on_mouse_release)
# 鼠标按下
fig.canvas.mpl_connect('button_press_event', on_mouse_press)
# 鼠标移动
fig.canvas.mpl_connect('motion_notify_event', on_mouse_move)
# 设置固定的绘图尺寸
plt.axis([0,1,0,1])
plt.show()

事件响应逻辑:

  • 在鼠标按下事件中,启动绘图功能
  • 在鼠标移动事件中,检查如果绘图功能开启,就在面板上显示鼠标轨迹
  • 在鼠标释放事件中,关闭绘图功能
  • 在键盘事件中,如果收到回车键,就触发推理过程;如果收到回退键,就清空画板

13.0.3 交互过程⚓︎

在面板上写个"2",不要担心,那些毛刺并不会影响识别结果,但是注意要写大一些,充满画板,如图13-2所示。

图13-2 在面板上写数字

此时会看到图13-2中,左侧控制台窗口中会显示一些辅助信息,如鼠标按下和释放的坐标。写完后,按回车键,会触发数据处理和推理过程。数据处理过程如下:

  1. 应用程序会先把绘图区域保存为一个文件
  2. 然后再把此文件读入内存,转换成灰度图
  3. 缩放尺寸到28\times 28(和训练数据一致)
  4. 用255减去所有像素值,得到黑底色白前景色的数据(和训练数据一致)
  5. 归一化到[0,1](和训练数据一致)
  6. 变成1\times 784的数组,调用前向计算方法
  7. 得到Output后,做一个argmax,取到最终结果

图13-3 模型推理的结果

最终结果如图13-3所示,在图中可以看到,A3的结果如下,经过softmax计算后,“2”的概率为0.74,argmax方法会把0.74在向量中的位置2返回:

[[1.30227675e-05]
 [2.37117962e-01]
 [7.40282026e-01]
 [5.33239953e-04]
 [1.10064252e-06]
 [1.42939242e-04]
 [1.00293210e-03]
 [1.07402351e-03]
 [1.98273955e-02]
 [5.35748168e-06]]

------recognize result is: {0} ----- 2

图13-3的画板中显示的图片是经过一些列数据处理后的图片(忽略它的彩色,那是matplotlib的装饰色),可以看到和原始图片的差别还是比较大的,尤其是经过缩小处理后,像素点的损失信息很多,这就要求我们在绘图时,要使用足够宽的笔迹,比如在此例中,我们使用了30像素的宽度。

代码位置

ch13, Level1