tensorflow pb to tflite 精度下降详解

1. 前言

在使用Tensorflow框架进行模型训练的时候,我们通常需要将训练好的模型保存为.pb格式的文件,方便后续的模型使用和部署。但是在将pb格式的模型转换为tflite格式的模型时,经常会发现模型精度下降的情况。本文将详细介绍在tensorflow pb转tflite过程中可能遇到的问题,以及解决方案。

2. tensorflow pb to tflite 过程

在使用tensorflow进行模型训练后,我们通常会将训练好的模型保存为.pb格式的文件,示例代码如下:

import tensorflow as tf

#定义网络模型

def my_net(...):

...

#读取训练好的模型

input_tensor = tf.placeholder(tf.float32, shape=[None, 224, 224, 3])

output_tensor = my_net(input_tensor)

saver = tf.train.Saver()

with tf.Session() as sess:

saver.restore(sess, 'model.ckpt')

#保存模型为.pb文件

output_graph_def = tf.graph_util.convert_variables_to_constants(sess, sess.graph_def, ['output_tensor'])

with tf.gfile.FastGFile('frozen_model.pb', mode='wb') as f:

f.write(output_graph_def.SerializeToString())

上述代码中,我们先定义了一个网络模型,然后读取之前训练好的模型,并保存为.pb文件。下面是将.pb文件转换为tflite文件的示例代码:

import tensorflow as tf

#将.pb文件转换为tflite文件

converter = tf.contrib.lite.TocoConverter.from_frozen_graph('frozen_model.pb', ['input_tensor'], ['output_tensor'])

converter.post_training_quantize = True #开启量化

tflite_model = converter.convert()

open('quantize_model.tflite', 'wb').write(tflite_model)

3. 量化对模型精度的影响

3.1 量化介绍

在将.pb模型转换为tflite模型时,我们经常会使用量化来减小模型体积以及提高模型推理速度。量化是指将浮点数转换为定点数或者整型数,以减小模型的存储空间和计算量。

3.2 量化对模型精度的影响

量化是一种降低模型精度的方式,通常会对模型的预测精度产生一定的影响。在进行量化操作时,需要注意以下几点:

量化的精度:量化的精度越高,模型的预测精度越好,但是模型的体积和计算量也越大。

量化的范围:量化的范围越大,模型的预测精度越好,但是模型的体积和计算量也越大。

量化的方法:不同的量化方法会对模型的预测精度产生不同的影响,需要根据具体的场景选择合适的量化方法。

4. 解决方案

4.1 增加temperature参数

在进行量化操作时,我们可以通过设置temperature参数来减小量化对模型精度的影响。temperature参数通常取值为0.5或者0.6,可以在一定程度上减小量化造成的精度下降。

import tensorflow as tf

#将.pb文件转换为tflite文件

converter = tf.contrib.lite.TocoConverter.from_frozen_graph('frozen_model.pb', ['input_tensor'], ['output_tensor'])

converter.post_training_quantize = True #开启量化

converter.inference_input_type = tf.int8 #设置输入数据类型为整型

converter.inference_output_type = tf.int8 #设置输出数据类型为整型

converter.quantized_input_stats = {'input_tensor': (0., 1.)} #设置输入数据的范围

converter.default_ranges_stats = (-6., 6.) #设置默认的量化范围

converter.inference_type = tf.float32 #设置推理的数据类型为浮点型

converter.quantization_steps = 256 #设置量化的精度为256

converter.target_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8] #设置量化方法

converter.inference_type = tf.uint8 #设置量化后的数据类型为整型(uint8)

converter.inference_input_type = tf.uint8 #设置输入数据类型为整型

converter.inference_output_type = tf.uint8 #设置输出数据类型为整型

converter.experimental_new_converter = True #开启新的实验性量化过程

converter.change_concat_input_ranges = True #更改拼接输入的范围

converter.allow_custom_ops = True #允许使用自定义操作

converter.temperature = 0.6 # 设置转换的温度

tflite_model = converter.convert()

open('quantize_model.tflite', 'wb').write(tflite_model)

4.2 选择合适的量化方法

在进行量化操作时,选择合适的量化方法也可以减小量化对模型精度的影响。Tensorflow提供了多种量化方法,包括对权重进行量化、对激活函数进行量化和对权重和激活函数同时进行量化等。需要根据具体的场景选择合适的量化方法。

import tensorflow as tf

#将.pb文件转换为tflite文件

converter = tf.lite.TFLiteConverter.from_frozen_graph('frozen_model.pb', ['input_tensor'], ['output_tensor'])

converter.target_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8] #选择量化方法为TFLITE_BUILTINS_INT8

tflite_quant_model = converter.convert()

open("quantize_model.tflite", "wb").write(tflite_quant_model)

5. 总结

在实际项目中,将模型从pb格式转换为tflite格式是一项常见的任务。在进行转换时,我们要注意量化的精度、量化的范围和量化的方法等问题,以避免量化对模型精度造成不必要的影响。如果在进行量化时出现了精度下降等问题,可以通过设置temperature参数和选择合适的量化方法来解决。

免责声明:本文来自互联网,本站所有信息(包括但不限于文字、视频、音频、数据及图表),不保证该信息的准确性、真实性、完整性、有效性、及时性、原创性等,版权归属于原作者,如无意侵犯媒体或个人知识产权,请来电或致函告之,本站将在第一时间处理。猿码集站发布此文目的在于促进信息交流,此文观点与本站立场无关,不承担任何责任。

后端开发标签