Graph Neural Networks
This tutorial will demonstrate how one can use WLPlan in a Graph Neural Network (GNN) pipeline.
We will assume that you have gone through and ran the code for the introductory tutorial which sets up the environment and constructs the dataset required for this tutorial. The corresponding notebook for this tutorial is available here.
We begin by installing Torch and PyTorch for the GNN components, importing packages, and loading the data.
[ ]:
# You can change to the correct CUDA version for PyTorch if required
# by following instructions from https://pytorch.org/get-started/locally/
%pip install torch --index-url https://download.pytorch.org/whl/cpu
%pip install torch_geometric
import pickle
import time
from dataclasses import dataclass
from typing import Any, List, Optional
import torch
import torch.nn as nn
import torch.nn.functional as F
from sklearn.model_selection import train_test_split
from torch import Tensor
from torch.nn import Linear, ReLU, Sequential, BCEWithLogitsLoss, Module, MSELoss
from torch.nn.modules.module import Module
from torch.optim import Optimizer
from torch_geometric.nn import MessagePassing, global_add_pool
from torch_geometric.data import Data
from torch_geometric.loader import DataLoader
from wlplan.data import DomainDataset
from wlplan.graph_generator import Graph, GraphGenerator, init_graph_generator
# We assume we have generated the dataset from the introductory tutorial and load it
with open("wlplan-blocks.pkl", "rb") as f:
domain, dataset, y = pickle.load(f)
Converting from WLPlan to PyTorch Geometric
The following code is used to help convert WLPlan classes into components suitable for input for PyTorch Geometric (PyG) models. We begin with a custom class for specifying inputs to PyG models.
[ ]:
@dataclass
class PyGGraph:
x: Tensor
edge_index: list[Tensor] # list index corresponds to edge label
def to(self, device: torch.device) -> "PyGGraph":
return PyGGraph(
x=self.x.to(device),
edge_index=[edge_index.to(device) for edge_index in self.edge_index],
)
Next, we have a function that takes as input a WLPlan graph, and a WLPlan graph generator that holds useful information about the underlying graph representation such as the maximum number of node and edge features, and outputs a PyG graph.
[ ]:
def wlplan_graph_to_pyg(graph_generator: GraphGenerator, graph: Graph) -> PyGGraph:
nodes = graph.node_colours
edges = graph.edges
# x is (n, d)
x = torch.zeros(len(nodes), graph_generator.get_n_features())
x[torch.arange(len(nodes)), nodes] = 1
# edge_index is list of (2, e)
edge_index = [[[], []] for _ in range(graph_generator.get_n_relations())]
for u, neighbours in enumerate(edges):
for r, v in neighbours:
edge_index[r][0].append(u)
edge_index[r][1].append(v)
edge_index = [torch.tensor(edges, dtype=torch.long) for edges in edge_index]
return PyGGraph(x=x, edge_index=edge_index)
Lastly, we have a function for generating training and validation loaders for GNN training.
[ ]:
def get_data_loaders(
domain_dataset: DomainDataset,
labels: Any,
graph_generator: GraphGenerator,
batch_size: int = 128,
) -> tuple[DataLoader, DataLoader]:
graphs = graph_generator.to_graphs(domain_dataset)
pyg_dataset = []
for graph, y in zip(graphs, labels):
pyg_graph = wlplan_graph_to_pyg(graph_generator=graph_generator, graph=graph)
# `edge_index` variable must be used for combining edge index lists [(2, e)]
data = Data(x=pyg_graph.x, edge_index=pyg_graph.edge_index, y=y)
pyg_dataset.append(data)
train_set, val_set = train_test_split(pyg_dataset, test_size=0.15, random_state=4550)
print(f"{len(train_set)=}")
print(f"{len(val_set)=}")
train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_set, batch_size=batch_size, shuffle=False)
return train_loader, val_loader
Graph Neural Networks
Next, we write some code to implement a Relational Graph Convolution Network as they support graphs with edge labels, as is the case with our WLPlan graphs. A minor exception is that we use a max
aggregator instead of a mean
aggregator for combining neighbouring node features.
[ ]:
class RGNNLayer(Module):
def __init__(self, n_relations: int, in_feat: int, out_feat: int) -> None:
super(RGNNLayer, self).__init__()
self.convs = torch.nn.ModuleList()
for _ in range(n_relations):
self.convs.append(LinearConv(in_feat, out_feat))
self.root = Linear(in_feat, out_feat, bias=True)
def forward(self, x: Tensor, edge_indices_list: List[Tensor]) -> Tensor:
x_out = self.root(x)
for i, conv in enumerate(self.convs):
x_out += conv(x, edge_indices_list[i])
return x_out
class LinearConv(MessagePassing):
def __init__(self, in_feat: int, out_feat: int, aggr: str = "max") -> None:
super().__init__(aggr=aggr)
self.f = Linear(in_feat, out_feat, bias=False)
def forward(self, x: Tensor, edge_index: Tensor) -> Tensor:
x = self.f(x)
x = self.propagate(edge_index=edge_index, x=x, size=None)
return x
class RGNN(nn.Module):
def __init__(
self,
n_relations: int,
in_feat: int,
out_feat: int,
n_hid: int = 64,
n_layers: int = 4,
) -> None:
super().__init__()
self.n_relations = n_relations
self.in_feat = in_feat
self.out_feat = out_feat
self.n_hid = n_hid
self.n_layers = n_layers
self.emb = torch.nn.Linear(self.in_feat, self.n_hid)
self.layers = torch.nn.ModuleList()
for _ in range(self.n_layers):
layer = RGNNLayer(
n_relations=self.n_relations, in_feat=self.n_hid, out_feat=self.n_hid
)
self.layers.append(layer)
self.mlp = Sequential(
Linear(self.n_hid, self.n_hid), ReLU(), Linear(self.n_hid, self.out_feat)
)
@property
def num_parameters(self) -> int:
return sum(p.numel() for p in self.parameters() if p.requires_grad)
def forward(
self, x: Tensor, edge_indices_list: List[Tensor], batch: Optional[Tensor]
) -> Tensor:
# Embed nodes
x = self.emb(x)
for layer in self.layers:
x = layer(x, edge_indices_list)
x = F.relu(x)
# Readout layer
x = global_add_pool(x, batch)
h = self.mlp(x)
h = h.squeeze(1)
return h
Training Pipeline
Now with the data preprocessing functions implemented, we get to the main part of the tutorial: training a GNN. We begin by calling our preprocessing functions and initialising the model and optimiser.
[ ]:
# Process dataset
graph_generator = init_graph_generator(graph_representation="ilg", domain=domain)
train_loader, val_loader = get_data_loaders(
domain_dataset=dataset, labels=y, graph_generator=graph_generator
)
# Initialise model
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = RGNN(
n_relations=graph_generator.get_n_relations(),
in_feat=graph_generator.get_n_features(),
out_feat=1, # because we are predicting heuristics
)
model = model.to(device)
print(f"{model.num_parameters=}")
# Initialise optimisers
optimiser = torch.optim.Adam(model.parameters(), lr=0.001)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
optimiser, mode="min", factor=0.1, patience=10
)
Now we implement our main training loop by minimising mean squared error with gradient descent, and using a scheduler to decrement our learning rate and stop training early when performance on the validation set plateaus.
[ ]:
criterion = MSELoss()
epochs = 1024
for epoch in range(epochs):
t = time.time()
# Train step
model.train()
train_loss = 0
for data in train_loader:
data = data.to(device)
y_true = data.y.float().to(device)
optimiser.zero_grad(set_to_none=True)
y_pred = model.forward(x=data.x, edge_indices_list=data.edge_index, batch=data.batch)
loss = criterion.forward(y_pred, y_true)
loss.backward()
optimiser.step()
train_loss += loss.detach().cpu().item()
train_loss /= len(train_loader)
# Validation step
model.eval()
val_loss = 0
for data in val_loader:
data = data.to(device)
h_true = data.y.float().to(device)
h_pred = model.forward(x=data.x, edge_indices_list=data.edge_index, batch=data.batch)
loss = criterion.forward(h_pred, h_true)
val_loss += loss.detach().cpu().item()
val_loss /= len(val_loader)
scheduler.step(val_loss)
t = time.time() - t
lr = optimiser.param_groups[0]["lr"]
print(
", ".join(
[
f"{epoch=}",
f"{t=:.8f}",
f"{train_loss=:.8f}",
f"{val_loss=:.8f}",
f"{lr=:.1e}",
]
)
)
if lr < 1e-5:
print(f"Early stopping due to small {lr=:.1e}")
break
Depending on the seed, the model should have converged with a training and validation loss less than 1. This is not bad at all.
However, recall how much code was required to implement, the time to train, and the resulting mean squared error loss of using a classical ML model in the introductory tutorial. It seems that training an ML model is much more efficient and effective than a GNN for learning heuristic functions for PDDL planning.
The paper Return to Tradition: Learning Reliable Heuristics with Classical Machine Learning shows that it is also the case that training a classical ML model with WLPlan features also leads to much more efficient planning performance compared to GNNs!