JobPlus知识库 IT 工业智能4.0 文章
KNN实现手写字体的识别

1、KNN思想

KNN就是K最近邻,是一种分类算法,意思是选k个最近的邻居的意思,说的是每个样本都可以用它最接近的k个邻居来代表。在k个样本中,比重最大的那一类即可把目标归为这一类。

优点:不用训练

缺点:该算法在分类时有个主要的不足是,当样本不平衡时,K临近占比概率影响结果、计算量大

2、KNN如何实现手写字体的识别

①数据处理(图片处理为数字文本)

②待测图片与训练集每一张图片的向量做欧氏距离

③排序,选取最优K,使得结果最好

3、数据处理

①01文本转换图像


[python] 

  1. #coding:utf-8  
  2. import os  
  3. from PIL import Image  
  4. ''''' 
  5. ①两个文件夹,一个存储数字文件、一个存储图像文件 
  6. ②读取src文件夹下的文件名,统计个数 
  7. ③src+文件名=每个数字txt的路径,读取其内容 
  8. ④写入图片,保存图片 
  9. 注意: 
  10. ①目标文件夹只能是最后一层不存在,才能创建 
  11. ②python于java/c的区别,路径分割/刚好相反 
  12. ③数字文件为01序列 
  13. ④putpixel((列,行)) 
  14. '''  
  15. def fun(src,dst):  
  16.     #判断源文件夹是否存在,不存在结束  
  17.     if not os.path.exists(src):  
  18.         return  
  19.     # 判断目标文件夹是否存在,不存在创建一个  
  20.     if not os.path.exists(dst):  
  21.         os.mkdir(dst)  
  22.     #读取src文件中的文件名  
  23.     list=os.listdir(src)  
  24.     length=len(list)  
  25.     for i in range(length):  
  26.         #文件路径  
  27.         path=src+"/"+list[i];  
  28.         #读取文件内容  
  29.         read=open(path)  
  30.         #保存路径  
  31.         SavePath=dst+"/"+list[i][:-4]+".png"  
  32.         #写入图片,并保存,图片是32*32  
  33.         image =Image.new("L",(32,32))  
  34.         for j in range(32):  
  35.             line=read.readline()  
  36.             for k in range(32):  
  37.                 bit=int(line[k])  
  38.                 if bit ==1:  
  39.                     bit=255  
  40.                 image.putpixel((k,j),bit)  
  41.         image.save(SavePath)  
  42. srcPath="C:/Users/Administrator/Desktop/src"  
  43. dstPath="C:/Users/Administrator/Desktop/dst"  
  44. fun(srcPath,dstPath)  

②图片转换成01文本


[python]

  1. #coding:utf-8  
  2. import os  
  3. from PIL import Image  
  4. import numpy as np  
  5. ''''' 
  6. Python图像处理库PIL的基本概念介绍”,我们知道PIL中有九种不同模式。 
  7. 分别为1,L,P,RGB,RGBA,CMYK,YCbCr,I,F。 
  8. 模式“1”为二值图像,非黑即白。但是它每个像素用8个bit表示,0表示黑,255表示白 
  9. 模式“L”为灰色图像,它的每个像素用8个bit表示,0表示黑,255表示白,其他数字表示不同的灰度。 
  10. 模式“P”为8位彩色图像,它的每个像素用8个bit表示,其对应的彩色值是按照调色板查询出来的。 
  11. '''  
  12. def fun(src,dst):  
  13.     if not os.path.exists(src):  
  14.         return  
  15.     if not os.path.exists(dst):  
  16.         os.mkdir(dst)  
  17.     list=os.listdir(src)  
  18.     length=len(list)  
  19.     for i in range(length):  
  20.         path=src+'/'+list[i]  
  21.         SavePath=dst+'/'+list[i][:-4]+".txt"  
  22.         read=Image.open(path).convert("1")            
  23.         arr=np.asarray(read)  
  24.         np.savetxt(SavePath,arr,fmt="%d",delimiter='')    #保存格式为整数,没有间隔  
  25.         #np.savetxt(SavePath, arr,fmt="%d")  
  26. src="C:/Users/Administrator/Desktop/src"  
  27. dst="C:/Users/Administrator/Desktop/dst"  
  28. fun(src,dst)  

4、具体实现

数据已经由图片处理为文本,且文本中像素点之间没有间隔。把每张图片处理为一个向量,32*32=1024。计算欧氏距离的时候也有技巧,不用遍历训练集一个一个与待测图片计算,可以利用np.title(),复制待测图片达到和训练集个数,直接矩阵相减。排序也有技巧,为了是每张图片和标签一一对应,排序的时候使用argsort(),统计众数的时候使用了字典,排序时候用sorted().


[python]

  1. #coding:utf-8  
  2. import os  
  3. import numpy as np  
  4. import operator  
  5. ''''' 
  6. ①难点:计算欧氏距离并排序,确定k值,这里k=3最优 
  7. ②图片都处理为数字文本,文本中没有空格 
  8. ③字典排序,排序后变为[(),()]形式 
  9. '''  
  10. #价值数据  
  11. def Load(src):  
  12.     if not os.path.exists(src):  
  13.         return  
  14.     list=os.listdir(src)  
  15.     length=len(list)  
  16.     label=[]  
  17.     train=[]  
  18.     for i in range(length):  
  19.         path=src+"/"+list[i]  
  20.         read=open(path)  
  21.         temp = []  
  22.         for j in range(32):  
  23.             line=read.readline()  
  24.             for k in range(32):  
  25.                 bit=int(line[k])  
  26.                 temp.append(bit)  
  27.         train.append(temp)  
  28.         label.append(int(list[i][0]))  
  29.     train=np.array(train)  
  30.     return train,label  
  31. def Classifier(train,laber,testPath,KK):  
  32.     list=os.listdir(testPath)  
  33.     length=len(list)  
  34.     errorCount=0  
  35.     for i in range(length):  
  36.         #数据处理  
  37.         path=testPath+"/"+list[i]  
  38.         #实际值  
  39.         ok=int(list[i][0])  
  40.         read=open(path)  
  41.         test=[]  
  42.         for j in range(32):  
  43.             line=read.readline()  
  44.             for k in range(32):  
  45.                 bit=int(line[k])  
  46.                 test.append(bit)  
  47.         #计算欧氏距离,不需要遍历,技巧  
  48.         m=train.shape[0]  
  49.         test=np.tile(test,(m,1))  
  50.         sum=train-test    #对应相减  
  51.         sum=sum**2       #平方  
  52.         sum=np.sum(sum,axis=1)  #行求和  
  53.         sum=sum**0.5     #开方  
  54.         # 排序,返回下标  
  55.         sum=np.argsort(sum)  
  56.         #前k个,取最大类  
  57.         ans={}  
  58.         for j in range(KK):  
  59.             lab=label[sum[j]]    #下标对应的标签  
  60.             if lab in ans.keys():  
  61.                 ans[lab]=ans[lab]+1  
  62.             else:  
  63.                 ans[lab] = 1  
  64.         ans=sorted(ans.items(),key=operator.itemgetter(1),reverse=True)  
  65.         print ("实际值=",ok,"预测值=",ans[0][0])  
  66.         if ok != ans[0][0]:  
  67.             errorCount += 1.0  
  68.     print("错误总数:%d" % errorCount)  
  69.     print("错误率:%f" % (errorCount / length))  
  70. trainPath="C:/Users/Administrator/Desktop/src"  
  71. testPath="C:/Users/Administrator/Desktop/dst"  
  72. #训练集处理  
  73. train,label=Load(trainPath)  
  74. #测试集处理  
  75. Classifier(train,label,testPath,3)  

总结:

①文件夹判断是否存在?文件夹创建?文件夹下所有文件名的读取?文本读取?

②图像的创建?图像像素点的填充?图像的保存?图像读取?文本的写入?

③欧氏距离的计算?title()的使用?argsort()的使用?字典的排序sorted()?

④如何确定最优K?遍历K,针对每个k计算错误率。


[python]

  1. def selectK():  
  2.     x = list()  
  3.     y = list()  
  4.     for i in range(1, 5):  
  5.         x.append(int(i))  
  6.         y.append(错误数)  
  7.     plt.plot(x, y)  
  8.     plt.show()  


如果觉得我的文章对您有用,请随意打赏。您的支持将鼓励我继续创作!

¥ 打赏支持
311人赞 举报
分享到
用户评价(0)

暂无评价,你也可以发布评价哦:)

扫码APP

扫描使用APP

扫码使用

扫描使用小程序