{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Lab PyG\n", "\n", "### Andrea Passerini, Antonio Longa \n", "### andrea.passerini@unitn.it, antonio.longa@unitn.it" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Pytorch Geometric (PyG) \n", "\n", "Pytorch Geometric (PyG) is a geometric deep learning extension library for PyTorch. It consists of various methods for deep learning on graphs and other irregular structures. It implements plenty of graph neural networks from the literature and allows to easily prototype new ones.\n", "\n", "Adapted from tutorials and notebooks from https://github.com/rusty1s/pytorch_geometric" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Creating Message Passing Networks\n", "\n", "\n", "![title](img/img1.png)\n", "\n", "\n", "Graph neural networks can be defined in terms of a *neighborhood aggregation* or *message passing* scheme.\n", "With $\\mathbf{x}^{(k-1)}_i \\in \\mathbb{R}^F$ denoting node features of node $i$ in layer $(k-1)$ and $\\mathbf{e}_{j,i} \\in \\mathbb{R}^D$ denoting (optional) edge features from node $j$ to node $i$, message passing graph neural networks can be described as\n", "\n", "\n", "\n", "$$\n", " \\mathbf{x}_i^{(k)} = \\gamma^{(k)} \\left( \\mathbf{x}_i^{(k-1)}, \\square_{j \\in \\mathcal{N}(i)} \\, \\phi^{(k)}\\left(\\mathbf{x}_i^{(k-1)}, \\mathbf{x}_j^{(k-1)},\\mathbf{e}_{j,i}\\right) \\right),\n", "$$\n", "\n", "where $\\square$ denotes a differentiable, permutation invariant function, *e.g.*, sum, mean or max, and $\\gamma$ and $\\phi$ denote differentiable functions such as MLPs (Multi Layer Perceptrons)." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## The \"MessagePassing\" Base Class\n", "\n", "PyTorch Geometric provides the `MessagePassing` base class, which helps in creating such kinds of message passing graph neural networks by automatically taking care of message propagation.\n", "The user only has to define the functions $\\phi$ , *i.e.* `message`, and $\\gamma$ , *i.e.* `update`, as well as the aggregation scheme to use, *i.e.* `aggr=\"add\"`, `aggr=\"mean\"` or `aggr=\"max\"`.\n", "\n", "