JobPlus知识库 IT 工业智能4.0 文章
tensoflow之手写数字识别项目

首先,使用lenet-5框架:

[python] 

  1. #!/usr/bin/env python3  
  2. # -*- coding: utf-8 -*-  
  3. """ 
  4. Created on Thu May 24 10:12:33 2018 
  5.  
  6.  
  7. @author: john 
  8. """  
  9. ########Training a DNN Using Plain TensorFlow  
  10.   
  11.   
  12. ####Construction Phase  
  13.   
  14.   
  15. ##define parameter  
  16. import tensorflow as tf  
  17. import numpy as np  
  18. import matplotlib.pyplot as plt  
  19. import pandas as pd  
  20. import  tensorflow.contrib.slim as slim  
  21.   
  22.   
  23. #load the data  
  24. train=pd.read_csv('train.csv')  
  25. y=train['label'].values  
  26. del train['label']        
  27. data=train.values  
  28. data=np.reshape(data,(data.shape[0],28,28,1))  
  29.   
  30.   
  31. from sklearn.cross_validation import train_test_split  
  32. X_train,X_test,y_train,y_test=train_test_split(data,y,test_size=0.2)  
  33.       
  34. ########construct phase  
  35. X=tf.placeholder(tf.float32,shape=(None,28,28,1))  
  36. y=tf.placeholder(tf.int32,shape=(None))  
  37.       
  38. ##net struct  
  39. conv1=slim.conv2d(X,6,[5,5],stride=1,padding='SAME',activation_fn=tf.nn.tanh,scope='conv1')  
  40. avg_pool1=slim.avg_pool2d(conv1,[2,2],[2,2],padding='SAME')      
  41. conv2=slim.conv2d(avg_pool1,16,[5,5],stride=1,padding='SAME',activation_fn=tf.nn.tanh,scope='conv2')      
  42. avg_pool2=slim.avg_pool2d(conv2,[2,2],[2,2],padding='SAME')      
  43. conv3=slim.conv2d(avg_pool2,120,[5,5],stride=1,padding='SAME',activation_fn=tf.nn.tanh,scope='conv3')  
  44. flatten=slim.flatten(conv3)  
  45. f1=slim.fully_connected(flatten,84,activation_fn=tf.nn.tanh,scope='f1')  
  46. logits=slim.fully_connected(f1,10,activation_fn=None,scope='f2')  
  47.   
  48.   
  49. ##loss  
  50. xentropy=tf.nn.sparse_softmax_cross_entropy_with_logits(labels=y,logits=logits)  
  51. loss=tf.reduce_mean(xentropy,name='loss')  
  52.   
  53.   
  54. ##train  
  55. learning_rate=0.01  
  56. optimizer=tf.train.GradientDescentOptimizer(learning_rate)  
  57. training_op=optimizer.minimize(loss)  
  58.   
  59.   
  60. ##eval  
  61. correct=tf.nn.in_top_k(logits,y,1)  
  62. accuracy=tf.reduce_mean(tf.cast(correct,tf.float32))  
  63.   
  64.   
  65. init=tf.global_variables_initializer()  
  66. saver=tf.train.Saver()  
  67.   
  68.   
  69. ########execution phase  
  70. n_input=28*28  
  71. n_epochs=10  
  72. batch_size=125  
  73.   
  74.   
  75. def fetch_batch(epoch,batch_index,batch_size):  
  76.     np.random.seed(epoch*n_input+batch_index)  
  77.     indicies=np.random.randint(X_train.shape[0],size=batch_size)  
  78.     X_batch=X_train[indicies]  
  79.     y_batch=y_train[indicies]  
  80.     return X_batch,y_batch  
  81. with tf.Session() as sess:  
  82.     init.run()  
  83.     for epoch in range(n_epochs):  
  84.         for batch_index in range(X_train.shape[0]//batch_size):  
  85.             X_batch,y_batch=fetch_batch(epoch,batch_index,batch_size)  
  86.             sess.run(training_op,feed_dict={X:X_batch,y:y_batch})  
  87.         acc_train=accuracy.eval(feed_dict={X:X_train,y:y_train})  
  88.         acc_test=accuracy.eval(feed_dict={X:X_test,y:y_test})  
  89.           
  90.         print(epoch,' train accuracy:',acc_train,  
  91.               ' test accuracy:',acc_test)  
  92.     save_path=saver.save(sess,'./my_model.ckpt')  

部分结果如下所示:

[python] 

  1. 0  train accuracy: 0.916042  test accuracy: 0.915357  
  2. 1  train accuracy: 0.940149  test accuracy: 0.937024  
  3. 2  train accuracy: 0.951369  test accuracy: 0.946786  
  4. 3  train accuracy: 0.956756  test accuracy: 0.954167  
  5. 4  train accuracy: 0.96131  test accuracy: 0.958809  
  6. 5  train accuracy: 0.965952  test accuracy: 0.965  
  7. 6  train accuracy: 0.967351  test accuracy: 0.964405  
  8. 7  train accuracy: 0.968869  test accuracy: 0.967262  
  9. 8  train accuracy: 0.972321  test accuracy: 0.969643  


如果觉得我的文章对您有用,请随意打赏。您的支持将鼓励我继续创作!

¥ 打赏支持
28人赞 举报
分享到
用户评价(0)

暂无评价,你也可以发布评价哦:)

扫码APP

扫描使用APP

扫码使用

扫描使用小程序