Using TensorFlow with Apache Spark

Using TensorFlow with Apache Spark

1. Why Apache Spark with TensorFlow?

Since deep neural networks are the first choice to solve machine learning problems on unstructured data, We are faced with a lot of library/framework such as Caffe, Tensorflow, Keras, BigDL etc. in last years. In this article, we will cover the usage of Tensorflow with Apache Spark.

First of all, we need to understand why we need Spark at this point, so lets start with assuming that you are a developer in a large scale Machine Learning Application and you also have stakeholders which have no clue about python programming but luckily they know how to write SQL Queries.

APIs like Tensorflow have low-level interface and it is a bit difficult to interact with them. Of course you can do it but it will take your precious time. Imagine that, you have to train your model with a huge dataset and you want to do it in distributed environment to make it more reliable by using Hyperparameter Tuning. But I have bad news for you, because TensorFlow itself is not distributed so in that point Spark comes to help us. With the help of Spark we can tune our models as fully parallel and distribute it to different environments to find the best model which has the least error rate.


This will increase our reliability and will decrease time spent to train a good model. And of course this is not just for training phase. You will be able to use all the power of Apache Spark.

Using SQL is just one of them. When you successfully create your model you can import your model to Apache Spark just with a few lines of code. Assume that we trained a model that does a simple classification operation over the images and identifies the image if it has car in it.

SELECT image, is_a_car_model(image) as probability
    FROM image_examples

2. Restoring Tensorflow Graph From A Protobuf File

The stakeholders who know SQL, are happy now and they are now easily able to use the power of deep neural network via just writing a SQL query.

In the rest of this article, We will do a quick demonstration about how we can import pre-defined TensorFlow Graph to Spark

first run spark by using pyspark

import tensorflow as tf
from tensorflow.python.platform import gfile
from sparkdl import readImages

image_df = readImages('/tmp/panda_photos')
graph_path = "/tmp/imagenet/classify_image_graph_def.pb"

with gfile.FastGFile(graph_path, 'rb') as f:
    graph_def = tf.GraphDef()

with tf.Graph().as_default() as graph:

3. Apply TFGraph into Apache Spark’s DataFrame

We imported our graph_def to TensorFlow by reading it from our filesystem. Now it is time to apply it into Apache Spark dataframe.

from sparkdl.transformers import utils

image_arr = utils.imageInputPlaceholder()
resized_images = tf.image.resize_images(image_arr, (299, 299))
# If your graph has not variables in it then you dont need to use following line
frozen_graph = utils.stripAndFreezeGraph(graph.as_graph_def(add_shapes=True), tf.Session(graph=graph), [resized_images])

transformer = TFImageTransformer(inputCol="image", outputCol="transformed_img", graph=frozen_graph,
                                 inputTensor=image_arr, outputTensor=resized_images,

tf_trans_df = transformer.transform(image_df)

Finally we apply our TensorFlow model to our image dataframe on Spark. Now we can use SparkSQL to query this dataframe or do some analyzing on the result.


With help of spark-deep-learning, it is easy to integrate Apache Spark with deep learning libraries such as Tensorflow and Keras. And you can combine the power of Apache Spark with DNN/CNN. I think this will give lots of flexibility to the companies that has large scale applications already to use DNN/CNN in their technology stack.

Since Apache Spark has a very large community, it will be easy to fix the problems that you are faced with. I would recommend you to try it at least if you have large-scale projects.

View all blogposts