一、Batch概念

什么是batch,准备了两种解释,看君喜欢哪种?

  1. 对于一个有 2000 个训练样本的数据集。将 2000 个样本分成大小为 500 的 batch,那么完成一个 epoch 需要 4 个 iteration。
  2. 如果把准备训练数据比喻成一块准备打火锅的牛肉,那么epoch就是整块牛肉,batch就是切片后的牛肉片,iteration就是涮一块牛肉片(饿了吗?)。

     

二、Batch用来干什么

不是给人吃,是喂给模型吃。在搭建了“模型-策略-算法”三大步之后,要开始利用数据跑(训练)这个框架,训练出最佳参数。

  1. 理想状态,就是把所有数据都喂给框架,求出最小化损失,再更新参数,重复这个过程,但是就像煮一整块牛肉那样,不知道什么时候才有得吃。----全量数据的梯度下降算法
  2. 另一个极端的状态,就是每次只给模型喂一条数据,立马就熟了,快是够快了,但是一个不小心也会直接化掉,吃都没得吃(可能无法得到局部最优)----随机梯度下降算法(stochastic gradient descent)
  3. 平衡方案,综合考虑又要快,又要有得吃,那么选用切片涮牛肉的方法,把数据切成batch大小的一块,每次(iteration)只吃一块。每次只计算一小部分数据的损失函数,并更改参数。

三、Batch的实现

再次提供两种方法

1. yield→generator

具体的语法知识,请点链接


 
  1. # --------------函数说明-----------------
  2. # sourceData_feature :训练集的feature部分
  3. # sourceData_label :训练集的label部分
  4. # batch_ size : 牛肉片的厚度
  5. # num_epochs : 牛肉翻煮多少次
  6. # shuffle : 是否打乱数据
  7. def batch_iter(sourceData_feature,sourceData_label, batch_ size, num_epochs, shuffle = True):
  8. data_ size = len(sourceData_feature)
  9. num_batches_per_epoch = int( data_ size / batch_ size) # 样本数 /batch块大小,多出来的“尾数”,不要了
  10. for epoch in range(num_epochs):
  11. # Shuffle the data at each epoch
  12. if shuffle:
  13. shuffle_indices = np. random.permutation(np.arange( data_ size))
  14. shuffled_ data_feature = sourceData_feature[shuffle_indices]
  15. shuffled_ data_label = sourceData_label[shuffle_indices]
  16. else:
  17. shuffled_ data_feature = sourceData_feature
  18. shuffled_ data_label = sourceData_label
  19. for batch_num in range(num_batches_per_epoch): # batch_num取值 0到num_batches_per_epoch- 1
  20. start_ index = batch_num * batch_ size
  21. end_ index = min((batch_num + 1) * batch_ size, data_ size)
  22. yield (shuffled_ data_feature[ start_ index: end_ index] , shuffled_ data_label[ start_ index: end_ index])


 
  1. batchSize = 100 # 定义具体的牛肉厚度
  2. Iterations = 0 # 记录迭代的次数
  3. # sess
  4. sess = tf.Session()
  5. sess. run(tf. global_variables_initializer())
  6. # 迭代 必须注意batch_iter是yield→generator,所以 for语句有特别
  7. for (batchInput, batchLabels) in batch_iter(mnist.train.images, mnist.train.labels, batchSize, 30, shuffle = True):
  8. trainingLoss = sess. run([opt,loss], feed_dict = {X: batchInput, y:batchLabels})
  9. if Iterations% 1000 = = 0: # 每迭代一千次,输出一次效果
  10. train_accuracy = sess. run(accuracy, feed_dict ={X:batchInput, y:batchLabels})
  11. print( "step %d, training accuracy %g"%(Iterations,train_accuracy))
  12. Iterations =Iterations + 1

2. slice_input_producer + batch

又涉及到一些背景知识,这篇文章这篇文章。以下是图解slice_input_producer。


 
  1. def get_batch_ data(images, label, batch_ Size):
  2. input_queue = tf.train.slice_ input_producer([images, label], shuffle = True, num_epochs = 20) # 见图解
  3. image_batch, label_batch = tf.train.batch( input_queue, batch_ size =batch_ Size, num_threads = 2,allow_smaller_ final_batch = True)
  4. return image_batch,label_batch
  5. batchSize = 100 # 记录迭代的次数
  6. batchInput, batchLabels = get_batch_ data(mnist.train.images, mnist.train.labels, batchSize)


 
  1. Iterations = 0 # 定义具体的牛肉厚度
  2. # sess
  3. sess = tf.Session()
  4. sess. run(tf. global_variables_initializer())
  5. sess. run(tf.local_variables_initializer())#就是这一行
  6. coord = tf.train.Coordinator()
  7. # 真正将文件放入文件名队列,还需要调用tf.train. start_queue_runners 函数来启动执行文件名队列填充的线程,
  8. # 之后计算单元才可以把数据读出来,否则文件名队列为空的,
  9. threads = tf.train. start_queue_runners(sess,coord)
  10. try:
  11. while not coord.should_ stop():
  12. BatchInput,BatchLabels = sess. run([batchInput, batchLabels])
  13. trainingLoss = sess. run([opt,loss], feed_dict = {X:BatchInput, y:BatchLabels})
  14. if Iterations% 1000 = = 0:
  15. train_accuracy = accuracy.eval(session = sess, feed_dict ={X:BatchInput, y:BatchLabels})
  16. print( "step %d, training accuracy %g"%(Iterations,train_accuracy))
  17. Iterations = Iterations + 1
  18. except tf.errors.OutOfRangeError:
  19. train_accuracy = accuracy.eval(session = sess, feed_dict ={X:BatchInput, y:BatchLabels})
  20. print( "step %d, training accuracy %g"%(Iterations,train_accuracy))
  21. print( 'Done training')
  22. finally:
  23. coord.request_ stop()
  24. coord.join(threads)
  25. # sess. close()
四、两种方式的对比

方式1: yield→generator 30个epoch
试验效果,开始前python.exe进程占了402M内存。
试验中,内存基本维持在865M左右
试验后,30个epoch耗时需要49.8s

方式2: slice_input_producer + batch
进行slice_input_producer这步,占用内存由410M提升到了583M
训练的时候,内存占用比较飘忽,有时1G多。
20个epoch耗时需要199s

小结:方式1的效率暂时比方式2快不少。



作者:StarsOcean
链接:https://www.jianshu.com/p/71f31c105879
来源:简书
著作权归作者所有。商业转载请联系作者获得授权,非商业转载请注明出处。

From:https://blog.csdn.net/aha_Yali/article/details/128173662?ops_request_misc=%257B%2522request%255Fid%2522%253A%2522D02CE88E-0D0D-4AC5-8CE8-786B88CB8195%2522%252C%2522scm%2522%253A%252220140713.130102334…%2522%257D&request_id=D02CE88E-0D0D-4AC5-8CE8-786B88CB8195&biz_id=0&utm_medium=distribute.pc_search_result.none-task-blog-2allsobaiduend~default-1-128173662-null-null.142v100pc_search_result_base9&utm_term=%E8%AE%AD%E7%BB%83%E4%B8%AD%E7%9A%84batch&spm=1018.2226.3001.4187

Logo

一站式 AI 云服务平台

更多推荐