TF session in predict()
When doing inferences with TensorFlow using the Realtime API Python Predictor or Batch API Python Predictor, it should be noted that your Python Predictor's __init__() constructor is only called on one thread, whereas its predict() method can run on any of the available threads (which is configured via the threads_per_process field in the API's predictor configuration). If threads_per_process is set to 1 (the default value), then there is no concern, since __init__() and predict() will run on the same thread. However, if threads_per_process is greater than 1, then only one of the inference threads will have executed the __init__() function. This can cause issues with TensorFlow because the default graph is a property of the current thread, so if __init__() initializes the TensorFlow graph, only the thread that executed __init__() will have the default graph set.
The error you may see if the default graph is not set (as a consequence of __init__() and predict() running in separate threads) is:
TypeError: Cannot interpret feed_dict key as Tensor: Tensor Tensor("Placeholder:0", shape=(1, ?), dtype=int32) is not an element of this graph.
To avoid this error, you can set the default graph before running the prediction in the predict() method:
def predict(self, payload):
with self.sess.graph.as_default():
# perform your inference here
Last modified 10mo ago
Copy link