import tensorflow as tf import numpy as np from tensorflow.keras.preprocessing import image from tensorflow.keras.applications.mobilenet_v2 import preprocess_input, decode_predictions import os from tensorflow.keras.utils import get_file output_names = ['Logits/Softmax'] input_names = ['input_1'] trt_model_url = "" img_url = "" trt_graph_file = get_file(os.path.basename(trt_model_url), trt_model_url, cache_subdir='models') img_file = get_file(os.path.basename(img_url), img_url, cache_subdir='datasets') def get_frozen_graph(graph_file): """Read Frozen Graph file from disk.""" with tf.gfile.FastGFile(graph_file, "rb") as f: graph_def = tf.GraphDef() graph_def.ParseFromString( return graph_def trt_graph = get_frozen_graph(trt_graph_file) # Create session and load graph tf_config = tf.ConfigProto() tf_config.gpu_options.allow_growth = True tf_sess = tf.Session(config=tf_config) tf.import_graph_def(trt_graph, name='') # Get graph input size for node in trt_graph.node: if 'input_' in size = node.attr['shape'].shape image_size = [size.dim[i].size for i in range(1, 4)] break print("image_size: {}".format(image_size)) # input and output tensor names. input_tensor_name = input_names[0] + ":0" output_tensor_name = output_names[0] + ":0" print("input_tensor_name: {}\noutput_tensor_name: {}".format( input_tensor_name, output_tensor_name)) output_tensor = tf_sess.graph.get_tensor_by_name(output_tensor_name) # Optional image to test model prediction. img = image.load_img(img_file, target_size=image_size[:2]) x = image.img_to_array(img) x = np.expand_dims(x, axis=0) x = preprocess_input(x) feed_dict = { input_tensor_name: x } preds =, feed_dict) # decode the results into a list of tuples (class, description, probability) # (one such list for each sample in the batch) print('Predicted:', decode_predictions(preds, top=3)[0]) import time times = [] for i in range(20): start_time = time.time() one_prediction =, feed_dict) delta = (time.time() - start_time) times.append(delta) mean_delta = np.array(times).mean() fps = 1 / mean_delta print('average(sec):{:.2f},fps:{:.2f}'.format(mean_delta, fps))