1. 前言
在使用Tensorflow框架进行模型训练的时候,我们通常需要将训练好的模型保存为.pb格式的文件,方便后续的模型使用和部署。但是在将pb格式的模型转换为tflite格式的模型时,经常会发现模型精度下降的情况。本文将详细介绍在tensorflow pb转tflite过程中可能遇到的问题,以及解决方案。
2. tensorflow pb to tflite 过程
在使用tensorflow进行模型训练后,我们通常会将训练好的模型保存为.pb格式的文件,示例代码如下:
import tensorflow as tf
#定义网络模型
def my_net(...):
...
#读取训练好的模型
input_tensor = tf.placeholder(tf.float32, shape=[None, 224, 224, 3])
output_tensor = my_net(input_tensor)
saver = tf.train.Saver()
with tf.Session() as sess:
saver.restore(sess, 'model.ckpt')
#保存模型为.pb文件
output_graph_def = tf.graph_util.convert_variables_to_constants(sess, sess.graph_def, ['output_tensor'])
with tf.gfile.FastGFile('frozen_model.pb', mode='wb') as f:
f.write(output_graph_def.SerializeToString())
上述代码中,我们先定义了一个网络模型,然后读取之前训练好的模型,并保存为.pb文件。下面是将.pb文件转换为tflite文件的示例代码:
import tensorflow as tf
#将.pb文件转换为tflite文件
converter = tf.contrib.lite.TocoConverter.from_frozen_graph('frozen_model.pb', ['input_tensor'], ['output_tensor'])
converter.post_training_quantize = True #开启量化
tflite_model = converter.convert()
open('quantize_model.tflite', 'wb').write(tflite_model)
3. 量化对模型精度的影响
3.1 量化介绍
在将.pb模型转换为tflite模型时,我们经常会使用量化来减小模型体积以及提高模型推理速度。量化是指将浮点数转换为定点数或者整型数,以减小模型的存储空间和计算量。
3.2 量化对模型精度的影响
量化是一种降低模型精度的方式,通常会对模型的预测精度产生一定的影响。在进行量化操作时,需要注意以下几点:
量化的精度:量化的精度越高,模型的预测精度越好,但是模型的体积和计算量也越大。
量化的范围:量化的范围越大,模型的预测精度越好,但是模型的体积和计算量也越大。
量化的方法:不同的量化方法会对模型的预测精度产生不同的影响,需要根据具体的场景选择合适的量化方法。
4. 解决方案
4.1 增加temperature参数
在进行量化操作时,我们可以通过设置temperature参数来减小量化对模型精度的影响。temperature参数通常取值为0.5或者0.6,可以在一定程度上减小量化造成的精度下降。
import tensorflow as tf
#将.pb文件转换为tflite文件
converter = tf.contrib.lite.TocoConverter.from_frozen_graph('frozen_model.pb', ['input_tensor'], ['output_tensor'])
converter.post_training_quantize = True #开启量化
converter.inference_input_type = tf.int8 #设置输入数据类型为整型
converter.inference_output_type = tf.int8 #设置输出数据类型为整型
converter.quantized_input_stats = {'input_tensor': (0., 1.)} #设置输入数据的范围
converter.default_ranges_stats = (-6., 6.) #设置默认的量化范围
converter.inference_type = tf.float32 #设置推理的数据类型为浮点型
converter.quantization_steps = 256 #设置量化的精度为256
converter.target_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8] #设置量化方法
converter.inference_type = tf.uint8 #设置量化后的数据类型为整型(uint8)
converter.inference_input_type = tf.uint8 #设置输入数据类型为整型
converter.inference_output_type = tf.uint8 #设置输出数据类型为整型
converter.experimental_new_converter = True #开启新的实验性量化过程
converter.change_concat_input_ranges = True #更改拼接输入的范围
converter.allow_custom_ops = True #允许使用自定义操作
converter.temperature = 0.6 # 设置转换的温度
tflite_model = converter.convert()
open('quantize_model.tflite', 'wb').write(tflite_model)
4.2 选择合适的量化方法
在进行量化操作时,选择合适的量化方法也可以减小量化对模型精度的影响。Tensorflow提供了多种量化方法,包括对权重进行量化、对激活函数进行量化和对权重和激活函数同时进行量化等。需要根据具体的场景选择合适的量化方法。
import tensorflow as tf
#将.pb文件转换为tflite文件
converter = tf.lite.TFLiteConverter.from_frozen_graph('frozen_model.pb', ['input_tensor'], ['output_tensor'])
converter.target_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8] #选择量化方法为TFLITE_BUILTINS_INT8
tflite_quant_model = converter.convert()
open("quantize_model.tflite", "wb").write(tflite_quant_model)
5. 总结
在实际项目中,将模型从pb格式转换为tflite格式是一项常见的任务。在进行转换时,我们要注意量化的精度、量化的范围和量化的方法等问题,以避免量化对模型精度造成不必要的影响。如果在进行量化时出现了精度下降等问题,可以通过设置temperature参数和选择合适的量化方法来解决。