博客
关于我
强烈建议你试试无所不能的chatGPT,快点击我
knn算法之数字识别
阅读量:4101 次
发布时间:2019-05-25

本文共 5740 字,大约阅读时间需要 19 分钟。

目录

 


数字识别所用文件来源:

(一)训练集文件列表

           32行32列

(二)拿文件,读原数据 

这一步的目标是:读取每一个txt,得到原数据

(1)读取小提示

 输出:

 默认读出的是科学计数法的数字,我们要得到原来的数据,则在加载文件时加上dtype为str

结果(原txt文件里每行数据时隔以为列表):

 (三)数据处理

这一步的目标是:每一个txt的数据(即上图的一维数组)处理为一行1024,即32x32个数字,并且转为int方便knn计算

(1)将上图变成二维数组再展平就是一行了

问题:1、这里一串字符串要变成二维数组的一行,怎么转?

          解法:拿到这个红框,用list(map(int,红框))   ----》[0,0,0,0,0,0,0,0,0]

          2、二维数组的第二行以及往下多行怎么赋值?

          解法:用上图列表的下标作为二维数组的行坐标

实践:

(四)将上一步的一行1024个0,1赋值给最终的数组(knn计算的数组)

 问题:一个txt文件就是最终数组的一行,这个行的下标怎么定呢?

解决:第一个for就应该取个枚举,所以修改第一个for的代码(还要在第一个for之前定义一个最终数组)

加上存储到最终数组的代码后,目前的完整代码如下:(不再拿两个文件,而是拿全部)

import osimport pandas as pdimport numpy as nppath = 'C:\\Users\\Administrator\\Desktop\\数字识别\\trainingDigits'file_list = os.listdir(path)#列出trainingDigits文件夹下的所有文件final_arr = np.zeros((len(file_list),1025))#1025的原因是最后一列最为标签for file_index,file_name in enumerate(file_list):#拿全部文件    tag = file_name[0] #取文件名第一个数字作为标签,0~9的数字    txt = np.loadtxt(path+'\\'+file_name,dtype=str) #读取文件    # print(txt)    arr = np.zeros((32,32)) #每个txt赋值到这里,之后展平为1024个数    for index,num_str in enumerate(txt):        arr[index] = np.array(list(map(int,num_str)),dtype=np.int8)    # print(arr)    arr_flatten = arr.flatten()#展平为一行1024个1,0    # print(arr_flatten.size)    #赋值到最终数组,file_index代表了第几个txt也就是赋值都最终数组的第几行    final_arr[file_index,:-1] = arr_flatten#最后一列没赋值,保留做tag    # 最后一列进行赋值,代表此行数据就是0或1,也就是训练集结果    final_arr[file_index, -1] =tag

(五)将最终数组进行保存为训练集的训练结果(分类器)

(六)训练集完成(分类器)

(训练集制作的完整代码)

import osimport pandas as pdimport numpy as nppath = 'C:\\Users\\Administrator\\Desktop\\数字识别\\trainingDigits'file_list = os.listdir(path)#列出trainingDigits文件夹下的所有文件final_arr = np.zeros((len(file_list),1025))#1025的原因是最后一列最为标签for file_index,file_name in enumerate(file_list[:2]):#拿两个文件实验    tag = file_name[0] #取文件名第一个数字作为标签,0~9的数字    txt = np.loadtxt(path+'\\'+file_name,dtype=str) #读取文件    # print(txt)    arr = np.zeros((32,32)) #每个txt赋值到这里,之后展平为1024个数    for index,num_str in enumerate(txt):        arr[index] = np.array(list(map(int,num_str)),dtype=np.int8)    # print(arr)    arr_flatten = arr.flatten()#二维数组展平为一行    # print(arr_flatten.size)    #赋值到最终数组,file_index代表了第几个txt也就是赋值都最终数组的第几行    final_arr[file_index,:-1] = arr_flatten#最后一列没赋值,保留做tag    # 最后一列进行赋值,代表此行数据就是0或1,也就是训练集结果    final_arr[file_index, -1] =tag    name = path.split('\\')[-1]    print(name)    np.savetxt(f'{name}.csv',final_arr,fmt='%d')#保存格式为十进制数字

(七)测试集制作

     32x32

测试集制作与训练集制作流程一样,训练集产生testDigits.csv文件,第1025列也是tag,

现在我们要做的就是:拿testDigits.csv的前1024列对训练集进行测试,测试结果与testDigits.csv的第1025列比对,判断准确率

既然制作流程一致,那就封装个函数:

import osimport pandas as pdimport numpy as npdef data_trans(path):    file_list = os.listdir(path)#列出trainingDigits文件夹下的所有文件    if path == 'C:\\Users\\Administrator\\Desktop\\数字识别\\trainingDigits':        file_list.pop() #trainingDigit最后一个文件有问题就删了    final_arr = np.zeros((len(file_list),1025))#1025的原因是最后一列最为标签    for file_index,file_name in enumerate(file_list):#拿两个文件实验        tag = file_name[0] #取文件名第一个数字作为标签,0~9的数字        txt = np.loadtxt(path+'\\'+file_name,dtype=str) #读取文件        # print(txt)        arr = np.zeros((32,32)) #每个txt赋值到这里,之后展平为1024个数        for index,num_str in enumerate(txt):            arr[index] = np.array(list(map(int,num_str)),dtype=np.int8)        # print(arr)        arr_flatten = arr.flatten()        # print(arr_flatten.size)        #赋值到最终数组,file_index代表了第几个txt也就是赋值都最终数组的第几行        final_arr[file_index,:-1] = arr_flatten#最后一列没赋值,保留做tag        # 最后一列进行赋值,代表此行数据就是0或1,也就是训练集结果        final_arr[file_index, -1] =tag    name = path.split('\\')[-1]    print(name)    np.savetxt(f'{name}.csv',final_arr,fmt='%d')#保存格式为十进制数字if __name__ == '__main__':    path1 = 'C:\\Users\\Administrator\\Desktop\\数字识别\\trainingDigits'    path2 = 'C:\\Users\\Administrator\\Desktop\\数字识别\\testDigits'    data_trans(path1)    data_trans(path2)

以上代码生成两个文件(其实他俩一样,只是一个用来测试,一个用来被测试罢了):

(八)进行测试(knn算法)

import numpy as npimport pandas as pdimport matplotlib.pyplot as pltdef knn(final_arr,test_arr,k):    """    在test_arr中取出每一行用分类器final_arr进行分类给出标签    计算出的标签与test_arr里的标签对比来计算准确率    :param final_arr: 训练出来的分类器    :param test_arr: 测试集    :k: knn中的k,排序后取样本的数量    :return: 准确率    """    test_arr_rows = test_arr.shape[0] #测试集有多少行    correct_num = 0 # 分类器得到正确tag的个数,用于计算准确率    for i in range(test_arr_rows):        row_to_test = test_arr[i,:-1] #待测试的行        #计算出与分类器每一行的距离(用到数组的广播性质)        distance = np.sqrt(((row_to_test-final_arr[:,:-1])**2).sum(axis=1))        #根据knn算法将距离排序后去前K个做判断标准,        # K个样本中取出标签的众数作为判断结果        #argsort函数返回的是数组值从小到大的索引值        #这个索引值也是final_arr的对应行索引值        distance_sorted_index = distance.argsort()[:k]        #在final_arr中拿到离测试点最近的前k行的tag        k_tags_arr = final_arr[distance_sorted_index,-1]        #取k个tag中的众数,众数函数是在DataFrame中,所以转一下        predict_tag = pd.DataFrame(k_tags_arr).mode()        #mode()返回的是DataFrame,所以取值要predict_tag[0][0]        print('预测值:',predict_tag[0][0])        print('真实值:',test_arr[i,-1])        if predict_tag[0][0] == test_arr[i,-1]:            correct_num += 1    #因为这个准确度要反映到绘图上,所以不能是字符串,还是在绘图时转成%形式吧    # precision = '%.4f%%'%(correct_num/test_arr.shape[0])    precision =correct_num/test_arr.shape[0]    # print('准确度为:',precision)    return precisionif __name__ == '__main__':    final_arr = np.loadtxt('testDigits.csv',dtype=int)#[0 0 0 ... 0 0 0]    test_arr = np.loadtxt('trainingDigits.csv',dtype=int)#不加dtype返回的是浮点数[0. 0. 0. ... 0. 0. 0.]    #测试不同k的准确度    x = list(range(5,15))    y = [] #y用来接收不同k返回的准确度,用于绘图    for k in range(5,15):        precision = knn(final_arr,test_arr,k)        y.append(precision)    print(x)#[5]    print(y)#[0.967408173823073]    #绘图    plt.figure()    plt.plot(x,y,marker='*',markersize=12,ha='center')    for a,b in zip(x,y):        plt.text(a,b,'%.3f%%'%(b*100))    plt.xlabel('k')    plt.ylabel('precision')    plt.show()

测试结果图:可以看出k取5,7时准确率很高

转载地址:http://rdwsi.baihongyu.com/

你可能感兴趣的文章
JS中各种数组遍历方式的性能对比
查看>>
Mysql复制表以及复制数据库
查看>>
进程管理(一)
查看>>
linux 内核—进程的地址空间(1)
查看>>
存储器管理(二)
查看>>
开局一张图,学一学项目管理神器Maven!
查看>>
Android中的Binder(二)
查看>>
Framework之View的工作原理(一)
查看>>
Web应用架构
查看>>
设计模式之策略模式
查看>>
深究Java中的RMI底层原理
查看>>
用idea创建一个maven web项目
查看>>
Kafka
查看>>
9.1 为我们的角色划分权限
查看>>
维吉尼亚之加解密及破解
查看>>
DES加解密
查看>>
TCP/IP协议三次握手与四次握手流程解析
查看>>
PHP 扩展开发 : 编写一个hello world !
查看>>
inet_ntoa、 inet_aton、inet_addr
查看>>
用模板写单链表
查看>>