未验证 提交 901a7441 编写于 作者: V Varuna Jayasiri 提交者: GitHub

GAT (#67)

上级 af6b99a5
此差异已折叠。
此差异已折叠。
<!DOCTYPE html>
<html>
<head>
<meta http-equiv="content-type" content="text/html;charset=utf-8"/>
<meta name="viewport" content="width=device-width, initial-scale=1.0"/>
<meta name="description" content=""/>
<meta name="twitter:card" content="summary"/>
<meta name="twitter:image:src" content="https://avatars1.githubusercontent.com/u/64068543?s=400&amp;v=4"/>
<meta name="twitter:title" content="Graph Attention Networks (GAT)"/>
<meta name="twitter:description" content=""/>
<meta name="twitter:site" content="@labmlai"/>
<meta name="twitter:creator" content="@labmlai"/>
<meta property="og:url" content="https://nn.labml.ai/graphs/gat/readme.html"/>
<meta property="og:title" content="Graph Attention Networks (GAT)"/>
<meta property="og:image" content="https://avatars1.githubusercontent.com/u/64068543?s=400&amp;v=4"/>
<meta property="og:site_name" content="LabML Neural Networks"/>
<meta property="og:type" content="object"/>
<meta property="og:title" content="Graph Attention Networks (GAT)"/>
<meta property="og:description" content=""/>
<title>Graph Attention Networks (GAT)</title>
<link rel="shortcut icon" href="/icon.png"/>
<link rel="stylesheet" href="../../pylit.css">
<link rel="canonical" href="https://nn.labml.ai/graphs/gat/readme.html"/>
<!-- Global site tag (gtag.js) - Google Analytics -->
<script async src="https://www.googletagmanager.com/gtag/js?id=G-4V3HC8HBLH"></script>
<script>
window.dataLayer = window.dataLayer || [];
function gtag() {
dataLayer.push(arguments);
}
gtag('js', new Date());
gtag('config', 'G-4V3HC8HBLH');
</script>
</head>
<body>
<div id='container'>
<div id="background"></div>
<div class='section'>
<div class='docs'>
<p>
<a class="parent" href="/">home</a>
<a class="parent" href="../index.html">graphs</a>
<a class="parent" href="index.html">gat</a>
</p>
<p>
<a href="https://github.com/lab-ml/labml_nn/tree/master/labml_nn/graphs/gat/readme.md">
<img alt="Github"
src="https://img.shields.io/github/stars/lab-ml/nn?style=social"
style="max-width:100%;"/></a>
<a href="https://twitter.com/labmlai"
rel="nofollow">
<img alt="Twitter"
src="https://img.shields.io/twitter/follow/labmlai?style=social"
style="max-width:100%;"/></a>
</p>
</div>
</div>
<div class='section' id='section-0'>
<div class='docs'>
<div class='section-link'>
<a href='#section-0'>#</a>
</div>
<h1><a href="https://nn.labml.ai/graph/gat/index.html">Graph Attention Networks (GAT)</a></h1>
<p>This is a <a href="https://pytorch.org">PyTorch</a> implementation of the paper
<a href="https://arxiv.org/abs/1710.10903">Graph Attention Networks</a>.</p>
<p>GATs work on graph data.
A graph consists of nodes and edges connecting nodes.
For example, in Cora dataset the nodes are research papers and the edges are citations that
connect the papers.</p>
<p>GAT uses masked self-attention, kind of similar to <a href="https://nn.labml.ai/transformers/mha.html">transformers</a>.
GAT consists of graph attention layers stacked on top of each other.
Each graph attention layer gets node embeddings as inputs and outputs transformed embeddings.
The node embeddings pay attention to the embeddings of other nodes it&rsquo;s connected to.
The details of graph attention layers are included alongside the implementation.</p>
<p>Here is <a href="https://nn.labml.ai/graph/gat/experiment.html">the training code</a> for training
a two-layer GAT on Cora dataset.</p>
<p><a href="https://app.labml.ai/run/d6c636cadf3511eba2f1e707f612f95d"><img alt="View Run" src="https://img.shields.io/badge/labml-experiment-brightgreen" /></a></p>
</div>
<div class='code'>
</div>
</div>
</div>
</div>
<script src="https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.4/MathJax.js?config=TeX-AMS_HTML">
</script>
<!-- MathJax configuration -->
<script type="text/x-mathjax-config">
MathJax.Hub.Config({
tex2jax: {
inlineMath: [ ['$','$'] ],
displayMath: [ ['$$','$$'] ],
processEscapes: true,
processEnvironments: true
},
// Center justify equations in code and markdown cells. Elsewhere
// we use CSS to left justify single line equations in code cells.
displayAlign: 'center',
"HTML-CSS": { fonts: ["TeX"] }
});
</script>
<script>
function handleImages() {
var images = document.querySelectorAll('p>img')
console.log(images);
for (var i = 0; i < images.length; ++i) {
handleImage(images[i])
}
}
function handleImage(img) {
img.parentElement.style.textAlign = 'center'
var modal = document.createElement('div')
modal.id = 'modal'
var modalContent = document.createElement('div')
modal.appendChild(modalContent)
var modalImage = document.createElement('img')
modalContent.appendChild(modalImage)
var span = document.createElement('span')
span.classList.add('close')
span.textContent = 'x'
modal.appendChild(span)
img.onclick = function () {
console.log('clicked')
document.body.appendChild(modal)
modalImage.src = img.src
}
span.onclick = function () {
document.body.removeChild(modal)
}
}
handleImages()
</script>
</body>
</html>
\ No newline at end of file
<!DOCTYPE html>
<html>
<head>
<meta http-equiv="content-type" content="text/html;charset=utf-8"/>
<meta name="viewport" content="width=device-width, initial-scale=1.0"/>
<meta name="description" content="A set of PyTorch implementations/tutorials related to graph neural networks"/>
<meta name="twitter:card" content="summary"/>
<meta name="twitter:image:src" content="https://avatars1.githubusercontent.com/u/64068543?s=400&amp;v=4"/>
<meta name="twitter:title" content="Graph Neural Networks"/>
<meta name="twitter:description" content="A set of PyTorch implementations/tutorials related to graph neural networks"/>
<meta name="twitter:site" content="@labmlai"/>
<meta name="twitter:creator" content="@labmlai"/>
<meta property="og:url" content="https://nn.labml.ai/graphs/index.html"/>
<meta property="og:title" content="Graph Neural Networks"/>
<meta property="og:image" content="https://avatars1.githubusercontent.com/u/64068543?s=400&amp;v=4"/>
<meta property="og:site_name" content="LabML Neural Networks"/>
<meta property="og:type" content="object"/>
<meta property="og:title" content="Graph Neural Networks"/>
<meta property="og:description" content="A set of PyTorch implementations/tutorials related to graph neural networks"/>
<title>Graph Neural Networks</title>
<link rel="shortcut icon" href="/icon.png"/>
<link rel="stylesheet" href="../pylit.css">
<link rel="canonical" href="https://nn.labml.ai/graphs/index.html"/>
<!-- Global site tag (gtag.js) - Google Analytics -->
<script async src="https://www.googletagmanager.com/gtag/js?id=G-4V3HC8HBLH"></script>
<script>
window.dataLayer = window.dataLayer || [];
function gtag() {
dataLayer.push(arguments);
}
gtag('js', new Date());
gtag('config', 'G-4V3HC8HBLH');
</script>
</head>
<body>
<div id='container'>
<div id="background"></div>
<div class='section'>
<div class='docs'>
<p>
<a class="parent" href="/">home</a>
<a class="parent" href="index.html">graphs</a>
</p>
<p>
<a href="https://github.com/lab-ml/labml_nn/tree/master/labml_nn/graphs/__init__.py">
<img alt="Github"
src="https://img.shields.io/github/stars/lab-ml/nn?style=social"
style="max-width:100%;"/></a>
<a href="https://twitter.com/labmlai"
rel="nofollow">
<img alt="Twitter"
src="https://img.shields.io/twitter/follow/labmlai?style=social"
style="max-width:100%;"/></a>
</p>
</div>
</div>
<div class='section' id='section-0'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-0'>#</a>
</div>
<h1>Graph Neural Networks</h1>
<ul>
<li><a href="gat/index.html">Graph Attention Networks (GAT)</a></li>
</ul>
</div>
<div class='code'>
<div class="highlight"><pre></pre></div>
</div>
</div>
</div>
</div>
<script src="https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.4/MathJax.js?config=TeX-AMS_HTML">
</script>
<!-- MathJax configuration -->
<script type="text/x-mathjax-config">
MathJax.Hub.Config({
tex2jax: {
inlineMath: [ ['$','$'] ],
displayMath: [ ['$$','$$'] ],
processEscapes: true,
processEnvironments: true
},
// Center justify equations in code and markdown cells. Elsewhere
// we use CSS to left justify single line equations in code cells.
displayAlign: 'center',
"HTML-CSS": { fonts: ["TeX"] }
});
</script>
<script>
function handleImages() {
var images = document.querySelectorAll('p>img')
console.log(images);
for (var i = 0; i < images.length; ++i) {
handleImage(images[i])
}
}
function handleImage(img) {
img.parentElement.style.textAlign = 'center'
var modal = document.createElement('div')
modal.id = 'modal'
var modalContent = document.createElement('div')
modal.appendChild(modalContent)
var modalImage = document.createElement('img')
modalContent.appendChild(modalImage)
var span = document.createElement('span')
span.classList.add('close')
span.textContent = 'x'
modal.appendChild(span)
img.onclick = function () {
console.log('clicked')
document.body.appendChild(modal)
modalImage.src = img.src
}
span.onclick = function () {
document.body.removeChild(modal)
}
}
handleImages()
</script>
</body>
</html>
\ No newline at end of file
......@@ -110,6 +110,10 @@ implementations.</p>
<li><a href="gan/stylegan/index.html">StyleGAN 2</a></li>
</ul>
<h4><a href="sketch_rnn/index.html">Sketch RNN</a></h4>
<h4>✨ Graph Neural Networks</h4>
<ul>
<li><a href="graphs/gat/index.html">Graph Attention Networks (GAT)</a></li>
</ul>
<h4><a href="cfr/index.html">Counterfactual Regret Minimization (CFR)</a></h4>
<p>Solving games with incomplete information such as poker with CFR.</p>
<ul>
......
......@@ -190,7 +190,7 @@
<url>
<loc>https://nn.labml.ai/normalization/weight_standardization/index.html</loc>
<lastmod>2021-04-28T16:30:00+00:00</lastmod>
<lastmod>2021-07-04T16:30:00+00:00</lastmod>
<priority>1.00</priority>
</url>
......@@ -239,7 +239,7 @@
<url>
<loc>https://nn.labml.ai/normalization/batch_channel_norm/index.html</loc>
<lastmod>2021-04-28T16:30:00+00:00</lastmod>
<lastmod>2021-07-04T16:30:00+00:00</lastmod>
<priority>1.00</priority>
</url>
......@@ -818,6 +818,27 @@
</url>
<url>
<loc>https://nn.labml.ai/graphs/gat/index.html</loc>
<lastmod>2021-07-08T16:30:00+00:00</lastmod>
<priority>1.00</priority>
</url>
<url>
<loc>https://nn.labml.ai/graphs/gat/experiment.html</loc>
<lastmod>2021-07-08T16:30:00+00:00</lastmod>
<priority>1.00</priority>
</url>
<url>
<loc>https://nn.labml.ai/graphs/index.html</loc>
<lastmod>2021-07-07T16:30:00+00:00</lastmod>
<priority>1.00</priority>
</url>
<url>
<loc>https://nn.labml.ai/sketch_rnn/index.html</loc>
<lastmod>2021-03-04T16:30:00+00:00</lastmod>
......
......@@ -50,6 +50,10 @@ implementations.
#### ✨ [Sketch RNN](sketch_rnn/index.html)
#### ✨ Graph Neural Networks
* [Graph Attention Networks (GAT)](graphs/gat/index.html)
#### ✨ [Counterfactual Regret Minimization (CFR)](cfr/index.html)
Solving games with incomplete information such as poker with CFR.
......
"""
---
title: Graph Neural Networks
summary: >
A set of PyTorch implementations/tutorials related to graph neural networks
---
# Graph Neural Networks
* [Graph Attention Networks (GAT)](gat/index.html)
"""
"""
---
title: Graph Attention Networks (GAT)
summary: >
A PyTorch implementation/tutorial of Graph Attention Networks.
---
# Graph Attention Networks (GAT)
This is a [PyTorch](https://pytorch.org) implementation of the paper
[Graph Attention Networks](https://arxiv.org/abs/1710.10903).
GATs work on graph data.
A graph consists of nodes and edges connecting nodes.
For example, in Cora dataset the nodes are research papers and the edges are citations that
connect the papers.
GAT uses masked self-attention, kind of similar to [transformers](../../transformers/mha.html).
GAT consists of graph attention layers stacked on top of each other.
Each graph attention layer gets node embeddings as inputs and outputs transformed embeddings.
The node embeddings pay attention to the embeddings of other nodes it's connected to.
The details of graph attention layers are included alongside the implementation.
Here is [the training code](experiment.html) for training
a two-layer GAT on Cora dataset.
[![View Run](https://img.shields.io/badge/labml-experiment-brightgreen)](https://app.labml.ai/run/d6c636cadf3511eba2f1e707f612f95d)
"""
import torch
from torch import nn
from labml_helpers.module import Module
class GraphAttentionLayer(Module):
"""
## Graph attention layer
This is a single graph attention layer.
A GAT is made up of multiple such layers.
It takes
$$\mathbf{h} = \{ \overrightarrow{h_1}, \overrightarrow{h_2}, \dots, \overrightarrow{h_N} \}$$,
where $\overrightarrow{h_i} \in \mathbb{R}^F$ as input
and outputs
$$\mathbf{h'} = \{ \overrightarrow{h'_1}, \overrightarrow{h'_2}, \dots, \overrightarrow{h'_N} \}$$,
where $\overrightarrow{h'_i} \in \mathbb{R}^{F'}$.
"""
def __init__(self, in_features: int, out_features: int, n_heads: int,
is_concat: bool = True,
dropout: float = 0.6,
leaky_relu_negative_slope: float = 0.2):
"""
* `in_features`, $F$, is the number of input features per node
* `out_features`, $F'$, is the number of output features per node
* `n_heads`, $K$, is the number of attention heads
* `is_concat` whether the multi-head results should be concatenated or averaged
* `dropout` is the dropout probability
* `leaky_relu_negative_slope` is the negative slope for leaky relu activation
"""
super().__init__()
self.is_concat = is_concat
self.n_heads = n_heads
# Calculate the number of dimensions per head
if is_concat:
assert out_features % n_heads == 0
# If we are concatenating the multiple heads
self.n_hidden = out_features // n_heads
else:
# If we are averaging the multiple heads
self.n_hidden = out_features
# Linear layer for initial transformation;
# i.e. to transform the node embeddings before self-attention
self.linear = nn.Linear(in_features, self.n_hidden * n_heads, bias=False)
# Linear layer to compute attention score $e_{ij}$
self.attn = nn.Linear(self.n_hidden * 2, 1, bias=False)
# The activation for attention score $e_{ij}$
self.activation = nn.LeakyReLU(negative_slope=leaky_relu_negative_slope)
# Softmax to compute attention $\alpha_{ij}$
self.softmax = nn.Softmax(dim=1)
# Dropout layer to be applied for attention
self.dropout = nn.Dropout(dropout)
def __call__(self, h: torch.Tensor, adj_mat: torch.Tensor):
"""
* `h`, $\mathbf{h}$ is the input node embeddings of shape `[n_nodes, in_features]`.
* `adj_mat` is the adjacency matrix of shape `[n_nodes, n_nodes, n_heads]`.
We use shape `[n_nodes, n_nodes, 1]` since the adjacency is the same for each head.
Adjacency matrix represent the edges (or connections) among nodes.
`adj_mat[i][j]` is `True` if there is an edge from node `i` to node `j`.
"""
# Number of nodes
n_nodes = h.shape[0]
# The initial transformation,
# $$\overrightarrow{g^k_i} = \mathbf{W}^k \overrightarrow{h_i}$$
# for each head.
# We do single linear transformation and then split it up for each head.
g = self.linear(h).view(n_nodes, self.n_heads, self.n_hidden)
# #### Calculate attention score
#
# We calculate these for each head $k$. *We have omitted $\cdot^k$ for simplicity*.
#
# $$e_{ij} = a(\mathbf{W} \overrightarrow{h_i}, \mathbf{W} \overrightarrow{h_j}) =
# a(\overrightarrow{g_i}, \overrightarrow{g_j})$$
#
# $e_{ij}$ is the attention score (importance) from node $j$ to node $i$.
# We calculate this for each head.
#
# $a$ is the attention mechanism, that calculates the attention score.
# The paper concatenates
# $\overrightarrow{g_i}$, $\overrightarrow{g_j}$
# and does a linear transformation with a weight vector $\mathbf{a} \in \mathbb{R}^{2 F'}$
# followed by a $\text{LeakyReLU}$.
#
# $$e_{ij} = \text{LeakyReLU} \Big(
# \mathbf{a}^\top \Big[
# \overrightarrow{g_i} \Vert \overrightarrow{g_j}
# \Big] \Big)$$
# First we calculate
# $\Big[\overrightarrow{g_i} \Vert \overrightarrow{g_j} \Big]$
# for all pairs of $i, j$.
#
# `g_repeat` gets
# $$\{\overrightarrow{g_1}, \overrightarrow{g_2}, \dots, \overrightarrow{g_N},
# \overrightarrow{g_1}, \overrightarrow{g_2}, \dots, \overrightarrow{g_N}, ...\}$$
# where each node embedding is repeated `n_nodes` times.
g_repeat = g.repeat(n_nodes, 1, 1)
# `g_repeat_interleave` gets
# $$\{\overrightarrow{g_1}, \overrightarrow{g_1}, \dots, \overrightarrow{g_1},
# \overrightarrow{g_2}, \overrightarrow{g_2}, \dots, \overrightarrow{g_2}, ...\}$$
# where each node embedding is repeated `n_nodes` times.
g_repeat_interleave = g.repeat_interleave(n_nodes, dim=0)
# Now we concatenate to get
# $$\{\overrightarrow{g_1} \Vert \overrightarrow{g_1},
# \overrightarrow{g_1}, \Vert \overrightarrow{g_2},
# \dots, \overrightarrow{g_1} \Vert \overrightarrow{g_N},
# \overrightarrow{g_2} \Vert \overrightarrow{g_1},
# \overrightarrow{g_2}, \Vert \overrightarrow{g_2},
# \dots, \overrightarrow{g_2} \Vert \overrightarrow{g_N}, ...\}$$
g_concat = torch.cat([g_repeat_interleave, g_repeat], dim=-1)
# Reshape so that `g_concat[i, j]` is $\overrightarrow{g_i} \Vert \overrightarrow{g_j}$
g_concat = g_concat.view(n_nodes, n_nodes, self.n_heads, 2 * self.n_hidden)
# Calculate
# $$e_{ij} = \text{LeakyReLU} \Big(
# \mathbf{a}^\top \Big[
# \overrightarrow{g_i} \Vert \overrightarrow{g_j}
# \Big] \Big)$$
# `e` is of shape `[n_nodes, n_nodes, n_heads, 1]`
e = self.activation(self.attn(g_concat))
# Remove the last dimension of size `1`
e = e.squeeze(-1)
# The adjacency matrix should have shape
# `[n_nodes, n_nodes, n_heads]` or`[n_nodes, n_nodes, 1]`
assert adj_mat.shape[0] == 1 or adj_mat.shape[0] == n_nodes
assert adj_mat.shape[1] == 1 or adj_mat.shape[1] == n_nodes
assert adj_mat.shape[2] == 1 or adj_mat.shape[2] == self.n_heads
# Mask $e_{ij}$ based on adjacency matrix.
# $e_{ij}$ is set to $- \infty$ if there is no edge from $i$ to $j$.
e = e.masked_fill(adj_mat == 0, float('-inf'))
# We then normalize attention scores (or coefficients)
# $$\alpha_{ij} = \text{softmax}_j(e_{ij}) =
# \frac{\exp(e_{ij})}{\sum_{j \in \mathcal{N}_i} \exp(e_{ij})}$$
#
# where $\mathcal{N}_i$ is the set of nodes connected to $i$.
#
# We do this by setting unconnected $e_{ij}$ to $- \infty$ which
# makes $\exp(e_{ij}) \sim 0$ for unconnected pairs.
a = self.softmax(e)
# Apply dropout regularization
a = self.dropout(a)
# Calculate final output for each head
# $$\overrightarrow{h'^k_i} = \sum_{j \in \mathcal{N}_i} \alpha^k_{ij} \overrightarrow{g^k_j}$$
#
# *Note:* The paper includes the final activation $\sigma$ in $\overrightarrow{h_i}$
# We have omitted this from the Graph Attention Layer implementation
# and use it on the GAT model to match with how other PyTorch modules are defined -
# activation as a separate layer.
attn_res = torch.einsum('ijh,jhf->ihf', a, g)
# Concatenate the heads
if self.is_concat:
# $$\overrightarrow{h'_i} = \Bigg\Vert_{k=1}^{K} \overrightarrow{h'^k_i}$$
return attn_res.reshape(n_nodes, self.n_heads * self.n_hidden)
# Take the mean of the heads
else:
# $$\overrightarrow{h'_i} = \frac{1}{K} \sum_{k=1}^{K} \overrightarrow{h'^k_i}$$
return attn_res.mean(dim=1)
"""
---
title: Train a Graph Attention Network (GAT) on Cora dataset
summary: >
This trains is a Graph Attention Network (GAT) on Cora dataset
---
# Train a Graph Attention Network (GAT) on Cora dataset
[![View Run](https://img.shields.io/badge/labml-experiment-brightgreen)](https://app.labml.ai/run/d6c636cadf3511eba2f1e707f612f95d)
"""
from typing import Dict
import numpy as np
import torch
from torch import nn
from labml import lab, monit, tracker, experiment
from labml.configs import BaseConfigs
from labml.utils import download
from labml_helpers.device import DeviceConfigs
from labml_helpers.module import Module
from labml_nn.graphs.gat import GraphAttentionLayer
from labml_nn.optimizers.configs import OptimizerConfigs
class CoraDataset:
"""
## [Cora Dataset](https://linqs.soe.ucsc.edu/data)
Cora dataset is a dataset of research papers.
For each paper we are given a binary feature vector that indicates the presence of words.
Each paper is classified into one of 7 classes.
The dataset also has the citation network.
The papers are the nodes of the graph and the edges are the citations.
The task is to classify the edges to the 7 classes with feature vectors and
citation network as input.
"""
# Labels for each node
labels: torch.Tensor
# Set of class names and an unique integer index
classes: Dict[str, int]
# Feature vectors for all nodes
features: torch.Tensor
# Adjacency matrix with the edge information.
# `adj_mat[i][j]` is `True` if there is an edge from `i` to `j`.
adj_mat: torch.Tensor
@staticmethod
def _download():
"""
Download the dataset
"""
if not (lab.get_data_path() / 'cora').exists():
download.download_file('https://linqs-data.soe.ucsc.edu/public/lbc/cora.tgz',
lab.get_data_path() / 'cora.tgz')
download.extract_tar(lab.get_data_path() / 'cora.tgz', lab.get_data_path())
def __init__(self, include_edges: bool = True):
"""
Load the dataset
"""
# Whether to include edges.
# This is test how much accuracy is lost if we ignore the citation network.
self.include_edges = include_edges
# Download dataset
self._download()
# Read the paper ids, feature vectors, and labels
with monit.section('Read content file'):
content = np.genfromtxt(str(lab.get_data_path() / 'cora/cora.content'), dtype=np.dtype(str))
# Load the citations, it's a list of pairs of integers.
with monit.section('Read citations file'):
citations = np.genfromtxt(str(lab.get_data_path() / 'cora/cora.cites'), dtype=np.int32)
# Get the feature vectors
features = torch.tensor(np.array(content[:, 1:-1], dtype=np.float32))
# Normalize the feature vectors
self.features = features / features.sum(dim=1, keepdim=True)
# Get the class names and assign an unique integer to each of them
self.classes = {s: i for i, s in enumerate(set(content[:, -1]))}
# Get the labels as those integers
self.labels = torch.tensor([self.classes[i] for i in content[:, -1]], dtype=torch.long)
# Get the paper ids
paper_ids = np.array(content[:, 0], dtype=np.int32)
# Map of paper id to index
ids_to_idx = {id_: i for i, id_ in enumerate(paper_ids)}
# Empty adjacency matrix - an identity matrix
self.adj_mat = torch.eye(len(self.labels), dtype=torch.bool)
# Mark the citations in the adjacency matrix
if self.include_edges:
for e in citations:
# The pair of paper indexes
e1, e2 = ids_to_idx[e[0]], ids_to_idx[e[1]]
# We build a symmetrical graph, where if paper $i$ referenced
# paper $j$ we place an adge from $i$ to $j$ as well as an edge
# from $j$ to $i$.
self.adj_mat[e1][e2] = True
self.adj_mat[e2][e1] = True
class GAT(Module):
"""
## Graph Attention Network (GAT)
This graph attention network has two [graph attention layers](index.html).
"""
def __init__(self, in_features: int, n_hidden: int, n_classes: int, n_heads: int, dropout: float):
"""
* `in_features` is the number of features per node
* `n_hidden` is the number of features in the first graph attention layer
* `n_classes` is the number of classes
* `n_heads` is the number of heads in the graph attention layers
* `dropout` is the dropout probability
"""
super().__init__()
# First graph attention layer where we concatenate the heads
self.layer1 = GraphAttentionLayer(in_features, n_hidden, n_heads, is_concat=True, dropout=dropout)
# Activation function after first graph attention layer
self.activation = nn.ELU()
# Final graph attention layer where we average the heads
self.output = GraphAttentionLayer(n_hidden, n_classes, 1, is_concat=False, dropout=dropout)
# Dropout
self.dropout = nn.Dropout(dropout)
def __call__(self, x: torch.Tensor, adj_mat: torch.Tensor):
"""
* `x` is the features vectors of shape `[n_nodes, in_features]`
* `adj_mat` is the adjacency matrix of the form
`[n_nodes, n_nodes, n_heads]` or `[n_nodes, n_nodes, 1]`
"""
# Apply dropout to the input
x = self.dropout(x)
# First graph attention layer
x = self.layer1(x, adj_mat)
# Activation function
x = self.activation(x)
# Dropout
x = self.dropout(x)
# Output layer (without activation) for logits
return self.output(x, adj_mat)
def accuracy(output: torch.Tensor, labels: torch.Tensor):
"""
A simple function to calculate the accuracy
"""
return output.argmax(dim=-1).eq(labels).sum().item() / len(labels)
class Configs(BaseConfigs):
"""
## Configurations
"""
# Model
model: GAT
# Number of nodes to train on
training_samples: int = 500
# Number of features per node in the input
in_features: int
# Number of features in the first graph attention layer
n_hidden: int = 64
# Number of heads
n_heads: int = 8
# Number of classes for classification
n_classes: int
# Dropout probability
dropout: float = 0.6
# Whether to include the citation network
include_edges: bool = True
# Dataset
dataset: CoraDataset
# Number of training iterations
epochs: int = 1_000
# Loss function
loss_func = nn.CrossEntropyLoss()
# Device to train on
#
# This creates configs for device, so that
# we can change the device by passing a config value
device: torch.device = DeviceConfigs()
# Optimizer
optimizer: torch.optim.Adam
def initialize(self):
"""
Initialize
"""
# Create the dataset
self.dataset = CoraDataset(self.include_edges)
# Get the number of classes
self.n_classes = len(self.dataset.classes)
# Number of features in the input
self.in_features = self.dataset.features.shape[1]
# Create the model
self.model = GAT(self.in_features, self.n_hidden, self.n_classes, self.n_heads, self.dropout)
# Move the model to the device
self.model.to(self.device)
# Configurable optimizer, so that we can set the configurations
# such as learning rate by passing the dictionary later.
optimizer_conf = OptimizerConfigs()
optimizer_conf.parameters = self.model.parameters()
self.optimizer = optimizer_conf
def run(self):
"""
### Training loop
We do full batch training since the dataset is small.
If we were to sample and train we will have to sample a set of
nodes for each training step along with the edges that span
across those selected nodes.
"""
# Move the feature vectors to the device
features = self.dataset.features.to(self.device)
# Move the labels to the device
labels = self.dataset.labels.to(self.device)
# Move the adjacency matrix to the device
edges_adj = self.dataset.adj_mat.to(self.device)
# Add an empty third dimension for the heads
edges_adj = edges_adj.unsqueeze(-1)
# Random indexes
idx_rand = torch.randperm(len(labels))
# Nodes for training
idx_train = idx_rand[:self.training_samples]
# Nodes for validation
idx_valid = idx_rand[self.training_samples:]
# Training loop
for epoch in monit.loop(self.epochs):
# Set the model to training mode
self.model.train()
# Make all the gradients zero
self.optimizer.zero_grad()
# Evaluate the model
output = self.model(features, edges_adj)
# Get the loss for training nodes
loss = self.loss_func(output[idx_train], labels[idx_train])
# Calculate gradients
loss.backward()
# Take optimization step
self.optimizer.step()
# Log the loss
tracker.add('loss.train', loss)
# Log the accuracy
tracker.add('accuracy.train', accuracy(output[idx_train], labels[idx_train]))
# Set mode to evaluation mode for validation
self.model.eval()
# No need to compute gradients
with torch.no_grad():
# Evaluate the model again
output = self.model(features, edges_adj)
# Calculate the loss for validation nodes
loss = self.loss_func(output[idx_valid], labels[idx_valid])
# Log the loss
tracker.add('loss.valid', loss)
# Log the accuracy
tracker.add('accuracy.valid', accuracy(output[idx_valid], labels[idx_valid]))
# Save logs
tracker.save()
def main():
# Create configurations
conf = Configs()
# Create an experiment
experiment.create(name='gat')
# Calculate configurations.
experiment.configs(conf, {
# Adam optimizer
'optimizer.optimizer': 'Adam',
'optimizer.learning_rate': 5e-3,
'optimizer.weight_decay': 5e-4,
})
# Initialize
conf.initialize()
# Start and watch the experiment
with experiment.start():
# Run the training
conf.run()
#
if __name__ == '__main__':
main()
# [Graph Attention Networks (GAT)](https://nn.labml.ai/graph/gat/index.html)
This is a [PyTorch](https://pytorch.org) implementation of the paper
[Graph Attention Networks](https://arxiv.org/abs/1710.10903).
GATs work on graph data.
A graph consists of nodes and edges connecting nodes.
For example, in Cora dataset the nodes are research papers and the edges are citations that
connect the papers.
GAT uses masked self-attention, kind of similar to [transformers](https://nn.labml.ai/transformers/mha.html).
GAT consists of graph attention layers stacked on top of each other.
Each graph attention layer gets node embeddings as inputs and outputs transformed embeddings.
The node embeddings pay attention to the embeddings of other nodes it's connected to.
The details of graph attention layers are included alongside the implementation.
Here is [the training code](https://nn.labml.ai/graph/gat/experiment.html) for training
a two-layer GAT on Cora dataset.
[![View Run](https://img.shields.io/badge/labml-experiment-brightgreen)](https://app.labml.ai/run/d6c636cadf3511eba2f1e707f612f95d)
......@@ -56,6 +56,10 @@ implementations almost weekly.
#### ✨ [Sketch RNN](https://nn.labml.ai/sketch_rnn/index.html)
#### ✨ Graph Neural Networks
* [Graph Attention Networks (GAT)](https://nn.labml.ai/graphs/gat/index.html)
#### ✨ [Counterfactual Regret Minimization (CFR)](https://nn.labml.ai/cfr/index.html)
Solving games with incomplete information such as poker with CFR.
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册