0%

基于卷积神经网络的手写数字识别App

手写数字识别:深度学习领域的 Hello World

按照“数据准备-算法开发-系统部署”的工作流程展开,这里首先给出最终效果:


数据准备

MNIST 数据集是机器学习领域的一个经典数据集,包含 60000 张训练图像和 10000 张测试图像,由美国国家标准与技术研究院(National Institute of Standards and Technology,即 MNIST 中 的 NIST)在20世纪80年代收集得到。

MNIST 网站提供四个下载文件,解压后得到:

1
2
3
4
train-images.idx3-ubyte  % 60000张图片
train-labels.idx1-ubyte % 60000个标签
t10k-images.idx3-ubyte % 10000张图片
t10k-labels.idx1-ubyte % 10000个标签

文本选用t10k中的10000张图片来训练神经网络

路径fullfile(matlabroot,'toolbox','nnet','nndemos','nndatasets','DigitDataset')下,MATLAB也提供了10000张训练图片,不过图片存在随机的旋转角度

提取上面的二进制文件中图片,并进行分类:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
% 需在 t10k 文件目录下运行
str1='t10k-images.idx3-ubyte';% 10000张图片
str2='t10k-labels.idx1-ubyte';% 10000个标签
fid1=fopen(str1);fid2=fopen(str2);
images=fread(fid1);images=images(17:end);
labels=fread(fid2);labels=labels(9:end);
fclose(fid1);fclose(fid2);

workdir=pwd;% 当前工作路径
mkdir('DigitDataset');% 在当前目录下创建存放图片的文件夹
for i=0:9
% 在DigitDataset文件夹下创建10个子文件夹, 用于分类存放图片
mkdir(fullfile('DigitDataset',num2str(i)));
end

num=length(labels);
data=cell(num,2);% [图片数字, 像素矩阵]
tab1=zeros(1,10);
for i=1:num
data{i,1}=labels(i);
data{i,2}=reshape(images(784*i-783:784*i),[28,28])';
index=data{i,1};
tab1(index+1)=tab1(index+1)+1;
str3=fullfile(workdir,num2str(index), ...
sprintf('image%04d.png',tab1(index+1)));

% 按标签存入对应文件夹
% imwrite(uint8(data{i,2}),str3);% 完全保存, uint8: 0~255
% imwrite(data{i,2}/255,str3);% 完全保存, double: 0~1
imwrite(data{i,2},str3);% 二值化简化保存, double: 0~1
end

算法开发

MATLAB的Deep Learning Toolbox提供了一个实现深度神经网络的框架。

卷积神经网络(Convolutional Neural Network)是一种适用于视觉任务的深度神经网络架构。

这里给出一个15层的卷积神经网络:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
% 准备训练数据
imds=imageDatastore([pwd,'\DigitDataset'], ...
'IncludeSubfolders',true,'LabelSource','foldernames');% 按子文件夹分类录入
numTrainFiles=800;% 每类800张图片用于训练, 共8000张
[imdsTrain,imdsValidation]=splitEachLabel(imds,numTrainFiles,'randomize');% 随机分类

% 神经网络连接
layers=[ % 共15层
imageInputLayer([28 28 1]) % 输入层
convolution2dLayer(3,8,'Padding','same') % 卷积层
batchNormalizationLayer % 批量归一化层
reluLayer % 激活函数层
maxPooling2dLayer(2,'Stride',2) % 池化层
convolution2dLayer(3,16,'Padding','same') % 卷积层
batchNormalizationLayer % 批量归一化层
reluLayer % 激活函数层
maxPooling2dLayer(2,'Stride',2) % 池化层
convolution2dLayer(3,32,'Padding','same') % 卷积层
batchNormalizationLayer % 批量归一化层
reluLayer % 激活函数层
fullyConnectedLayer(10) % 全连接层
softmaxLayer % softmax分类器
classificationLayer]; % 分类器

% 训练设置
options=trainingOptions('sgdm', ...
'InitialLearnRate',0.01, ...
'MaxEpochs',10, ...
'Shuffle','every-epoch', ...
'ValidationData',imdsValidation, ...
'ValidationFrequency',30, ...
'Verbose',false, ...
'Plots','training-progress');

% 开始训练
CNN_Script_Classification=trainNetwork(imdsTrain,layers,options);

保存训练完毕的卷积神经网络:

1
save input_CNN.mat CNN_Script_Classification;

卷积神经网络的调用方式:

1
2
3
4
5
load input_CNN.mat CNN_Script_Classification;
z0=imread([pwd,'\DigitDataset\8\image0123.png']);
z1=z0+uint8(rands(28,28)*40);% 加入噪声
figure(1);imshow(z0);figure(2);imshow(z1);% uint8格式
classify(CNN_Script_Classification,z1)

系统部署

将训练好的卷积神经网络部署到 MATLAB App。前面视频中的App主要有三个功能:

  1. 利用鼠标完成手写输入
  2. 对手写输入的图片进行预处理
  3. 对预处理后的图片进行识别

卷积神经网络对应第3点,即完成图片识别,下面给出第1、2点的程序。

利用鼠标完成手写输入

参考博客园中的程序,并按照我的需求进行修改:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
% https://www.cnblogs.com/xfzhang/archive/2010/12/27/1918393.html
function MouseDraw_app(action)

global InitialX InitialY FigHandle
imSize = 200;
if nargin == 0
action = 'start';
end

switch(action)
%%开启图形视窗
case 'start'
index = 0;InputImage = ones(imSize);
save('input.mat','index','InputImage');
FigHandle = figure('WindowButtonDownFcn','MouseDraw_app down;');

set(0,'units','centimeters')
cm_ss=get(0,'screensize');
W=cm_ss(3);H=cm_ss(4);L=10;
set(gcf,'units','normalized','position',...
[(W-L)/2/W,(H-L)/2/H,L/W,L/H]);
set(gca,'XTick',zeros(1,0),'YTick',zeros(1,0));
axis([1 imSize 1 imSize]);% 设定图轴范围
axis on;grid off;box on;
title('\rm{鼠标左键}\bf{输入} \rm{鼠标右键}\bf{确认}',...
'FontSize',14);
dlmwrite('IXT.txt', -10, 'delimiter', '\t', 'precision', 6);
dlmwrite('IYT.txt', -10, 'delimiter', '\t', 'precision', 6);

%%滑鼠按钮被按下时的反应指令
case 'down'
if strcmp(get(FigHandle, 'SelectionType'), 'normal') % 如果是左键
set(FigHandle,'pointer','hand');
CurPiont = get(gca, 'CurrentPoint');
InitialX = CurPiont(1,1);
InitialY = CurPiont(1,2);
dlmwrite('IXT.txt', InitialX, '-append', ...
'delimiter', '\t', 'precision', 6);
dlmwrite('IYT.txt', InitialY, '-append', ...
'delimiter', '\t', 'precision', 6);
% 设定滑鼠移动时的反应指令为「MouseDraw_app move」
set(gcf, 'WindowButtonMotionFcn', 'MouseDraw_app move;');
set(gcf, 'WindowButtonUpFcn', 'MouseDraw_app up;');
elseif strcmp(get(FigHandle, 'SelectionType'), 'alt') % 如果是右键
set(FigHandle, 'Pointer', 'arrow');
set(FigHandle, 'WindowButtonMotionFcn', '');
set(FigHandle, 'WindowButtonUpFcn', '');
close(FigHandle);
ImageX = importdata('IXT.txt');
ImageY = importdata('IYT.txt');
InputImage = ones(imSize);
roundX = round(ImageX);
roundY = round(ImageY);
for k = 1:size(ImageX,1)
if 0<roundX(k) && roundX(k)<imSize && ...
0<roundY(k) && roundY(k)<imSize
InputImage(roundX(k)-1:roundX(k)+2,...
roundY(k)-1:roundY(k)+2) = 0;
end
end
InputImage = imrotate(InputImage,90);% 图像旋转90
index = 1;
save('input.mat','index','InputImage');
end

%%滑鼠移动时的反应指令
case 'move'
CurPiont = get(gca, 'CurrentPoint');
X = CurPiont(1,1);
Y = CurPiont(1,2);
% 当鼠标移动较快时, 不会出现离散点
% 利用y=kx+b直线方程实现
x_gap = 0.1;% 定义x方向增量
y_gap = 0.1;% 定义y方向增量
if X > InitialX
step_x = x_gap;
else
step_x = -x_gap;
end
if Y > InitialY
step_y = y_gap;
else
step_y = -y_gap;
end
% 定义x, y的变化范围和步长
if abs(X-InitialX) < 0.01 % 线平行于y轴, 即斜率不存在时
iy = InitialY:step_y:Y;
ix = X.*ones(1,size(iy,2));
else
ix = InitialX:step_x:X ; % 定义x的变化范围和步长
% 当斜率存在, 即k = (Y-InitialY)/(X-InitialX) ~= 0
iy = (Y-InitialY)/(X-InitialX).*(ix-InitialX)+InitialY;
end
ImageX = [ix, X];
ImageY = cat(2, iy, Y);
line(ImageX,ImageY, 'marker', '.', 'markerSize',28, ...
'LineStyle', '-', 'LineWidth', 4, 'Color', 'Blue');
dlmwrite('IXT.txt', ImageX, '-append', ...
'delimiter', '\t', 'precision', 6);
dlmwrite('IYT.txt', ImageY, '-append', ...
'delimiter', '\t', 'precision', 6);
InitialX = X;% 记住当前点坐标
InitialY = Y;% 记住当前点坐标

%%滑鼠按钮被释放时的反应指令
case 'up'
set(gcf, 'WindowButtonMotionFcn', '');% 清除滑鼠移动时的反应指令
set(gcf, 'WindowButtonUpFcn', '');% 清除滑鼠按钮被释放时的反应指令
end

end

对手写输入的图片进行预处理

程序MouseDraw_app将得到手写数字保存到矩阵InputImage中,再将该矩阵保存到input.mat文件中。卷积神经网络CNN_Script_Classification的输入为28x28的矩阵,需要将InputImage作进一步地处理。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
% 锁定图片中数字区域, 并降低分辨率
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
% INPUT
% InputImage : - 矩阵 200 x 200, 手写输入
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
% OUTPUT
% image200 : - 矩阵 200 x 200, 锁定后的数字, 用于APP中预处理展示
% image28 : - 矩阵 28 x 28 ,
% 将image200中数字加粗, 并降低分辨率, 作为神经网络输入
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
function [image28,image200]=high2low(InputImage)

num=28;% 目标分辨率: num1 * num1
[num2,~]=size(InputImage);
num3=2;% 加粗偏移量

stats=regionprops('table',1-InputImage,'BoundingBox');% 数字边界
bbox=floor(stats.BoundingBox);

r=bbox(2);dr=bbox(4);ddr=floor(dr/3);% 行
c=bbox(1);dc=bbox(3);ddc=floor(dc/2);% 列

temp0=ones(dr+2*ddr,dc+2*ddc);
temp0(ddr+1:dr+ddr,ddc+1:dc+ddc)=InputImage(r+1:r+dr,c+1:c+dc);

image200=imresize(temp0,[200,200]);

temp1=ones(num2);
temp1(num3+1:end,num3+1:end)=image200(1:end-num3,1:end-num3);

temp2=ones(num2);
temp2(1:end-num3,1:end-num3)=image200(num3+1:end,num3+1:end);

temp3=ones(num2);
temp3(num3+1:end,1:end-num3)=image200(1:end-num3,num3+1:end);

temp4=ones(num2);
temp4(1:end-num3,num3+1:end)=image200(num3+1:end,1:end-num3);

a=min(temp1,temp2);b=min(temp3,temp4);temp=min(a,b);% 偏移加粗

image28=imresize(temp,[num,num]);
image28(image28<0.65)=0;
image28(image28>0.80)=1;

end

在App Designer中,根据:

  • 手写输入函数MouseDraw_app
  • 图片预处理函数high2low
  • 卷积神经网络CNN_Script_Classification

完成App设计,并部署到MATLAB中。