Graph neural networks have revolutionized the performance of neural networks on graph data. Companies such as Pinterest[1], Google[2], and Uber[3] have implemented graph neural network algorithms to dramatically improve the performance of large-scale data-driven tasks.

Introduction to Graphs

A graph is any dataset that contains nodes and edges. Nodes are entities (e.g. people, organizations), and edges represent the connections between nodes. For example, a social network is a graph in which people in the network are considered nodes. Edges exist when two people are connected in some way (e.g. friends, sharing one’s posts).

In retail applications, customers and products can be viewed as nodes. An edge shows the relationship between the customer and a purchased product. A graph can be used to represent spending habits for each customer. Additionally, nodes can also have features or attributes. People have attributes such as age and height, and products have attributes such as price and size. Pinterest has used graph neural networks in this fashion to improve the performance of its recommendation system by 150%[1].

The advent of Graph Neural Networks

Until the development of graph neural networks, deep learning methods could not be applied to edges to extract knowledge and make predictions, and instead could only operate based on the node features. Applying deep learning to graph data allows us to approach tasks link prediction, community detection, and generating recommendations.

Deep learning can also be applied to node classification, or predicting the label of an unlabelled node. This takes place in a semi-supervised setting, where the labels of some nodes are known, but others are unknown. A survey of deep learning node classification methods shows a history of advances in state-of-the-art performance while illustrating the range of use cases and applications.

Methods discussed in this blog post were evaluated on the benchmark CoRA dataset. CoRA consists of journal publications in deep learning. Each publication is a node, and edges exist between nodes if they cite or are cited by another paper in the dataset. The dataset is comprised of 2708 publications with 5429 edges.

DeepWalk

DeepWalk[4], released in 2014, was the first significant deep learning-based method to approach node classification. DeepWalk’s approach was similar to that taken in natural language processing (NLP) to derive embeddings. An embedding is a vector-representation of an object such a word in NLP or a node in a graph. To create its embeddings, DeepWalk takes truncated random walks from graph data to learn latent representations of nodes. On the CoRA dataset, DeepWalk achieved 67.2% accuracy on the benchmark node classification experiment[5]. At the time, this was 10% better than competing methods with 60% less training data.

To demonstrate how the embeddings generated by DeepWalk contain information about the graph structure, below is a figure that shows how DeepWalk works on Zachary’s Karate Network. The Karate Network is a small network of connections of members of a karate club, where edges are made between two members if they interact outside of karate. The node coloring is based on the different sub-communities within the club. The left image is the input graph of the social network, and the right is the two-dimensional output generated by the DeepWalk algorithm operating on the graph data.

 

 

Graph Convolutional Networks

In 2016, Thomas N. Kipf and Max Welling introduced graph convolutional networks (GCNs)[6], which improved the state-of-the-art CoRA benchmark to 81.5%. A GCN is a network that consists of stacked linear layers with an activation function. Kipf and Welling introduced a new propagation function that operates layer-wise and works directly on graph data.

The number of linear layers in a GCN determines the size of the target node neighborhood to consider when making the classification prediction. For example, one hidden layer would imply that the graph network only examines immediate neighbors when making a classification decision.

The input to a graph convolutional network is an adjacency matrix, which is the representation of the graph itself. It also takes the feature vectors of each node as input. This can be as simple as a one-hot encoding of each node’s attributes, while more complex versions can be used to represent the complex features of a node.

Graph-BERT

GCN’s were the leading architecture for years, and many variations of them were subsequently released. Then, in January 2020 Graph-BERT[7] removed the dependency on links and reformatted the way that graph networks are usually represented. This is important for scalability, while also showing improved accuracy and efficiency relative to other types of graph neural networks. We are currently exploring how Graph-BERT will impact the use cases we have been addressing with graph neural networks.

Conclusion

Graph neural networks are an evolving field in the study of neural networks. Their ability to use graph data has made difficult problems such as node classification more tractable.