Graphs processing is an important part of data analysis in many domains. But graphs processing is tricky may be tricky since general purpose distributed computing tools are not suited for graphs processing.

It is not surprising that an important advancement in the area of distributed graphs processing came from Google that has to process one of the biggest graphs: the Webgraph. Engineers in Google wrote a seminal paper where they described a new system for distributed graphs processing they called Pregel.

In this article, I will explain how Pregel works, and demonstrate how to implement algorithms using Pregel using API from Apache Flink.

If you are not familiar with Graph’s API in Apache Flink, you can read about it in my previous article.

How Pregel works

The basic idea of Pregel is that we implement an algorithm that is executed on every vertex of a graph. This algorithm works in iterations and on every iteration it processes incoming messages for a vertex and can update vertex’s value and send messages to other vertices.

Pregel stops algorithm execution when no messages are sent by any vertex during one iteration.

To implement an algorithm using Pregel, we need to implement two functions:

  • Compute function – this function is executed on every vertex once on every iteration. It receives all messages from neighbour vertices and can optionally send messages to other vertices or update vertex value. Messages send by this function will be received on the next iteration.
  • Combiner function – this function receives all messages sent to a particular vertex and combine these messages into a fewer messages. It is used to reduce an amount of data that should be processed on every iteration.

Now let’s take a look at how we can implement algorithms in Pregel using Flink’s API.

Pregel algorithm in Flink

The main class that we need to implement is called ComputeFunction. Here I’ve provided main methods in this class that we will use in this article.

public abstract class ComputeFunction<K, VV, EV, Message> implements Serializable {
	public abstract void compute(Vertex<K, VV> vertex, MessageIterator<Message> messages) throws Exception;

	public final Iterable<Edge<K, EV>> getEdges() {...}

    public final void sendMessageToAllNeighbors(Message m) {...}

    public final void sendMessageTo(K target, Message m) {...}

    public final int getSuperstepNumber() {...}
    
    public final void setNewVertexValue(VV newValue) {...}
}

The only method that we should implement is the compute method. It receives a vertex on which it should operate and an iterator of messages that were sent to this vertex during the previous iteration.

The compute method does not return anything and should only change a state of the graph through other methods in this class. Three most important methods here are sendMessageToAllNeighbors that can be used to send a message to all neighbour vertices, the sendMessageTo that sends a message to a single vertex, and the setNewVertexValue that updates the value of the current vertex. The getEdges method is used to get all out-going edges for the current vertex.

The last method that we will cover here is getSuperstepNumber that simply returns current Pregel iteration’s number. We will use it later to send an initial batch of messages on the first iteration that will start the graph processing.

Gelly also provides the MessageCombiner class that we need to inherit to implement custom combiners. In this class, we only need to implement the combineMessages method that receives all messages that were sent from one vertex to a particular target. To define what messages should be delivered, we need to use the sendCombinedMessage to send combined messages to a target vertex.

public abstract class MessageCombiner<K, Message> implements Serializable {
        ...
	public abstract void combineMessages(MessageIterator<Message> messages) throws Exception;
		
	public final void sendCombinedMessage(Message combinedMessage) {...}
}

Compute functions and message combiners are not supposed to be used by users of our algorithm. Instead in Flink we need to implement the GraphAlgorithm interface that receives a Graph instance, processes it and returns the algorithm’s output.

public interface GraphAlgorithm<K, VV, EV, T> {
	T run(Graph<K, VV, EV> input) throws Exception;
}

This is used because algorithms based on Pregel also need some pre- and post-processing of graphs and this logic is usually put into implementations of the GraphAlgorithm class.

Implement Single Source Shortest Path Algorithm

Now we will apply our knowledge to implement a Single Source Shortest Path algorithm using Pregel. This algorithm will receive a source vertex in a graph and calculate the length of the shortest path to all other vertices in the graph from it.

Algorithm design

Before we start implementing the algorithm, let’s discuss how it is going to work. Let’s say we have a graph with the following intermediate state (every vertex contains the shortest path to it found so far, and every vertex has a length associated with it):

Initial graph state

On every iteration a vertex will receive messages from other vertices with a new shortest path to it:

Processing input messages

A vertex will find the minimum input value on this iteration and if it is lower than the previous shortest path it will send messages to its neighbours containing new shortest paths to these vertices:

Sending output messages

In a nutshell, the algorithm is sending messages when it discovers a new shorter path to one of the vertices. When all shortest paths are calculated, no more messages sent, and the algorithm terminates.

Algorithm implementation

Let’s start with the main part of the algorithm: the implementation of the ComputeFunction. Its implementation pretty much matches the general structure presented in the previous section. On every iteration, the function receives messages containing new minimum distances to it through its neighbours and computes the minimum on this iteration. If this minimum is less than the minimum recorded so far, the function updates the shortest path using the setNewVertexValue and sends new minimum distance to all its neighbours.

class ShortestPathComputeFunction<K> extends ComputeFunction<K, Double, Double, NewMinDistance> {

    private final K sourceVertex;

    public ShortestPathComputeFunction(K sourceVertex) {
        this.sourceVertex = sourceVertex;
    }

    @Override
    public void compute(Vertex<K, Double> vertex, MessageIterator<NewMinDistance> messageIterator) throws Exception {
        // Send initial group of messages from the source vertex
        if (vertex.getId().equals(sourceVertex) && getSuperstepNumber() == 1) {
            sendNewDistanceToAll(0);
        }

        // Calculate new min distance from source node
        double minDistance = minDistance(messageIterator);

        // Send new min distance to neighbour vertices if new min distance is less
        if (minDistance < vertex.getValue()) {
            setNewVertexValue(minDistance);
            sendNewDistanceToAll(minDistance);
        }
    }

    private double minDistance(MessageIterator<NewMinDistance> messageIterator) {
        double minDistance = Double.MAX_VALUE;
        for (NewMinDistance message : messageIterator) {
            minDistance = Math.min(message.getDistance(), minDistance);
        }
        return minDistance;
    }

    private void sendNewDistanceToAll(double newDistance) {
        for (Edge<K, Double> edge : getEdges()) {
            sendMessageTo(edge.getTarget(), new NewMinDistance(edge.getValue() + newDistance));
        }
    }

To transfer distance values between vertices we use a custom NewMinDistance class. It is a simple POJO class with a single field distance.

class NewMinDistance {
    private final double distance;

    public NewMinDistance(double distance) {
        this.distance = distance;
    }

    public double getDistance() {
        return distance;
    }
}

All that is left to do is to implement a combiner that would reduce the number of outgoing messages from every vertex. It can happen that during one iteration the ShortestPathComputeFunction would recompute the shortest path for a single vertex multiple times and this will generate multiple messages to the same neighbour.

To somewhat optimize this we can provide a combiner that would find the shortest distance to every neighbour and combine all output messages to a single one:

class ShortestPathCombiner<K> extends MessageCombiner<K, NewMinDistance> {
    @Override
    public void combineMessages(MessageIterator<NewMinDistance> messageIterator) throws Exception {
        double minDistance = Double.MAX_VALUE;
        for (NewMinDistance message : messageIterator) {
            minDistance = Math.min(message.getDistance(), minDistance);
        }

        sendCombinedMessage(new NewMinDistance(minDistance));
    }
}

This is all we need to do to implement an algorithm using Pregel framework but to make usage of our algorithm more convenient we need to wrap it into an implementation of the GraphAlgorithm interface.

class ShortestPath<K, VV> implements GraphAlgorithm<K, VV, Double, DataSet<Vertex<K, Double>>> {

    private final K sourceVertex;
    private final int maxIterations;

    public ShortestPath(K sourceVertex, int maxIterations) {
        this.sourceVertex = sourceVertex;
        this.maxIterations = maxIterations;
    }

    @Override
    public DataSet<Vertex<K, Double>> run(Graph<K, VV, Double> graph) throws Exception {
        Graph<K, Double, Double> resultGraph = graph.mapVertices(new VertexDoubleMapFunction<>(sourceVertex))
        .runVertexCentricIteration(new ShortestPathComputeFunction(sourceVertex),
                                   new ShortestPathCombiner(),
                                   maxIterations);
        return resultGraph.getVertices();
    }
    ...
}

First we need to set initial paths length for all vertices in the graph which is done using the mapVertices method. As result of the runVertexCentricIteration that will run our algorithm on the input graph we will have a graph where vertex values contain shortest paths from the sourceVertex. All that we need to do after this is to return a dataset of vertices that contains vertices ids and path lengths.

Initialization of the initial state of the graph is pretty straight forward. For the source vertex we set the value to 0, and for any other vertex we set the maximum Double value:

private static class ShortestPathInit<K, VV> implements MapFunction<Vertex<K,VV>, Double> {
    private final K sourceVertex;

    public ShortestPathInit(K sourceVertex) {
        this.sourceVertex = sourceVertex;
    }

    @Override
    public Double map(Vertex<K, VV> vertex) throws Exception {
        if (vertex.getId().equals(sourceVertex)) {
            return 0d;
        }
        return Double.MAX_VALUE;
    }
}

This is it! Now we can apply an algorithm on a Graph instance. To do this, we need to call the run method and pass an instance of our algorithm to it:

Graph<Integer, String, Double> graph = ...

int sourceVertex = 1;
int maxIterations = 10;
graph.run(new ShortestPath<>(sourceVertex, maxIterations)).print();

You can find the full source code of the algorithm from this article in my GitHub repository with other Flink projects examples.

More information

If you want to know more about Apache Flink you can take a look at my Pluralsight course where I cover Apache Flink in more details: Understanding Apache Flink

Here is a short preview of this course:

Posted by Ivan Mushketyk

Principal Software engineer and life-long learner. Creating courses for Pluralsight. Writing for DZone, SitePoint, and SimpleProgrammer.