{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "Graph Neural Networks\n", "=====================" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "This tutorial will demonstrate how one can use WLPlan in a Graph Neural Network (GNN) pipeline.\n", "\n", "We will assume that you have gone through and ran the code for the [introductory tutorial](https://github.com/DillonZChen/wlplan/blob/main/docs/tutorials/1_introduction.ipynb) which sets up the environment and constructs the dataset required for this tutorial. The corresponding notebook for this tutorial is available [here](https://github.com/DillonZChen/wlplan/blob/main/docs/tutorials/3_gnns.ipynb).\n", "\n", "We begin by installing Torch and PyTorch for the GNN components, importing packages, and loading the data." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# You can change to the correct CUDA version for PyTorch if required\n", "# by following instructions from https://pytorch.org/get-started/locally/\n", "%pip install torch --index-url https://download.pytorch.org/whl/cpu\n", "%pip install torch_geometric \n", "\n", "import pickle\n", "import time\n", "from dataclasses import dataclass\n", "from typing import Any, List, Optional\n", "\n", "import torch\n", "import torch.nn as nn\n", "import torch.nn.functional as F\n", "from sklearn.model_selection import train_test_split\n", "from torch import Tensor\n", "from torch.nn import Linear, ReLU, Sequential, BCEWithLogitsLoss, Module, MSELoss\n", "from torch.nn.modules.module import Module\n", "from torch.optim import Optimizer\n", "from torch_geometric.nn import MessagePassing, global_add_pool\n", "from torch_geometric.data import Data\n", "from torch_geometric.loader import DataLoader\n", "from wlplan.data import DomainDataset\n", "from wlplan.graph_generator import Graph, GraphGenerator, init_graph_generator\n", "\n", "# We assume we have generated the dataset from the introductory tutorial and load it\n", "with open(\"wlplan-blocks.pkl\", \"rb\") as f:\n", " domain, dataset, y = pickle.load(f)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Converting from WLPlan to PyTorch Geometric\n", "The following code is used to help convert WLPlan classes into components suitable for input for PyTorch Geometric (PyG) models. We begin with a custom class for specifying inputs to PyG models." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "@dataclass\n", "class PyGGraph:\n", " x: Tensor\n", " edge_index: list[Tensor] # list index corresponds to edge label\n", "\n", " def to(self, device: torch.device) -> \"PyGGraph\":\n", " return PyGGraph(\n", " x=self.x.to(device),\n", " edge_index=[edge_index.to(device) for edge_index in self.edge_index],\n", " )" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Next, we have a function that takes as input a WLPlan graph, and a WLPlan graph generator that holds useful information about the underlying graph representation such as the maximum number of node and edge features, and outputs a PyG graph." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def wlplan_graph_to_pyg(graph_generator: GraphGenerator, graph: Graph) -> PyGGraph:\n", " nodes = graph.node_colours\n", " edges = graph.edges\n", "\n", " # x is (n, d)\n", " x = torch.zeros(len(nodes), graph_generator.get_n_features())\n", " x[torch.arange(len(nodes)), nodes] = 1\n", "\n", " # edge_index is list of (2, e)\n", " edge_index = [[[], []] for _ in range(graph_generator.get_n_relations())]\n", " for u, neighbours in enumerate(edges):\n", " for r, v in neighbours:\n", " edge_index[r][0].append(u)\n", " edge_index[r][1].append(v)\n", " edge_index = [torch.tensor(edges, dtype=torch.long) for edges in edge_index]\n", "\n", " return PyGGraph(x=x, edge_index=edge_index)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Lastly, we have a function for generating training and validation loaders for GNN training." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def get_data_loaders(\n", " domain_dataset: DomainDataset,\n", " labels: Any,\n", " graph_generator: GraphGenerator,\n", " batch_size: int = 128,\n", ") -> tuple[DataLoader, DataLoader]:\n", "\n", " graphs = graph_generator.to_graphs(domain_dataset)\n", "\n", " pyg_dataset = []\n", " for graph, y in zip(graphs, labels):\n", " pyg_graph = wlplan_graph_to_pyg(graph_generator=graph_generator, graph=graph)\n", "\n", " # `edge_index` variable must be used for combining edge index lists [(2, e)]\n", " data = Data(x=pyg_graph.x, edge_index=pyg_graph.edge_index, y=y)\n", " pyg_dataset.append(data)\n", "\n", " train_set, val_set = train_test_split(pyg_dataset, test_size=0.15, random_state=4550)\n", " print(f\"{len(train_set)=}\")\n", " print(f\"{len(val_set)=}\")\n", "\n", " train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True)\n", " val_loader = DataLoader(val_set, batch_size=batch_size, shuffle=False)\n", "\n", " return train_loader, val_loader" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Graph Neural Networks\n", "\n", "Next, we write some code to implement a [Relational Graph Convolution Network](https://arxiv.org/abs/1703.06103) as they support graphs with edge labels, as is the case with our WLPlan graphs. A minor exception is that we use a `max` aggregator instead of a `mean` aggregator for combining neighbouring node features." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "class RGNNLayer(Module):\n", " def __init__(self, n_relations: int, in_feat: int, out_feat: int) -> None:\n", " super(RGNNLayer, self).__init__()\n", " self.convs = torch.nn.ModuleList()\n", " for _ in range(n_relations):\n", " self.convs.append(LinearConv(in_feat, out_feat))\n", " self.root = Linear(in_feat, out_feat, bias=True)\n", "\n", " def forward(self, x: Tensor, edge_indices_list: List[Tensor]) -> Tensor:\n", " x_out = self.root(x)\n", " for i, conv in enumerate(self.convs):\n", " x_out += conv(x, edge_indices_list[i])\n", " return x_out\n", "\n", "\n", "class LinearConv(MessagePassing):\n", " def __init__(self, in_feat: int, out_feat: int, aggr: str = \"max\") -> None:\n", " super().__init__(aggr=aggr)\n", " self.f = Linear(in_feat, out_feat, bias=False)\n", "\n", " def forward(self, x: Tensor, edge_index: Tensor) -> Tensor:\n", " x = self.f(x)\n", " x = self.propagate(edge_index=edge_index, x=x, size=None)\n", " return x\n", "\n", "\n", "class RGNN(nn.Module):\n", " def __init__(\n", " self,\n", " n_relations: int,\n", " in_feat: int,\n", " out_feat: int,\n", " n_hid: int = 64,\n", " n_layers: int = 4,\n", " ) -> None:\n", " super().__init__()\n", " self.n_relations = n_relations\n", " self.in_feat = in_feat\n", " self.out_feat = out_feat\n", " self.n_hid = n_hid\n", " self.n_layers = n_layers\n", "\n", " self.emb = torch.nn.Linear(self.in_feat, self.n_hid)\n", " self.layers = torch.nn.ModuleList()\n", " for _ in range(self.n_layers):\n", " layer = RGNNLayer(\n", " n_relations=self.n_relations, in_feat=self.n_hid, out_feat=self.n_hid\n", " )\n", " self.layers.append(layer)\n", " self.mlp = Sequential(\n", " Linear(self.n_hid, self.n_hid), ReLU(), Linear(self.n_hid, self.out_feat)\n", " )\n", "\n", " @property\n", " def num_parameters(self) -> int:\n", " return sum(p.numel() for p in self.parameters() if p.requires_grad)\n", "\n", " def forward(\n", " self, x: Tensor, edge_indices_list: List[Tensor], batch: Optional[Tensor]\n", " ) -> Tensor:\n", " # Embed nodes\n", " x = self.emb(x)\n", " for layer in self.layers:\n", " x = layer(x, edge_indices_list)\n", " x = F.relu(x)\n", "\n", " # Readout layer\n", " x = global_add_pool(x, batch)\n", " h = self.mlp(x)\n", " h = h.squeeze(1)\n", "\n", " return h" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Training Pipeline\n", "\n", "Now with the data preprocessing functions implemented, we get to the main part of the tutorial: training a GNN. We begin by calling our preprocessing functions and initialising the model and optimiser." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Process dataset\n", "graph_generator = init_graph_generator(graph_representation=\"ilg\", domain=domain)\n", "train_loader, val_loader = get_data_loaders(\n", " domain_dataset=dataset, labels=y, graph_generator=graph_generator\n", ")\n", "\n", "# Initialise model\n", "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n", "model = RGNN(\n", " n_relations=graph_generator.get_n_relations(),\n", " in_feat=graph_generator.get_n_features(),\n", " out_feat=1, # because we are predicting heuristics\n", ")\n", "model = model.to(device)\n", "print(f\"{model.num_parameters=}\")\n", "\n", "# Initialise optimisers\n", "optimiser = torch.optim.Adam(model.parameters(), lr=0.001)\n", "scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(\n", " optimiser, mode=\"min\", factor=0.1, patience=10\n", ")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Now we implement our main training loop by minimising mean squared error with gradient descent, and using a scheduler to decrement our learning rate and stop training early when performance on the validation set plateaus." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "criterion = MSELoss()\n", "epochs = 1024\n", "for epoch in range(epochs):\n", " t = time.time()\n", "\n", " # Train step\n", " model.train()\n", " train_loss = 0\n", " for data in train_loader:\n", " data = data.to(device)\n", " y_true = data.y.float().to(device)\n", " optimiser.zero_grad(set_to_none=True)\n", " y_pred = model.forward(x=data.x, edge_indices_list=data.edge_index, batch=data.batch)\n", " loss = criterion.forward(y_pred, y_true)\n", " loss.backward()\n", " optimiser.step()\n", " train_loss += loss.detach().cpu().item()\n", " train_loss /= len(train_loader)\n", "\n", " # Validation step\n", " model.eval()\n", " val_loss = 0\n", " for data in val_loader:\n", " data = data.to(device)\n", " h_true = data.y.float().to(device)\n", " h_pred = model.forward(x=data.x, edge_indices_list=data.edge_index, batch=data.batch)\n", " loss = criterion.forward(h_pred, h_true)\n", " val_loss += loss.detach().cpu().item()\n", " val_loss /= len(val_loader)\n", " scheduler.step(val_loss)\n", "\n", " t = time.time() - t\n", " lr = optimiser.param_groups[0][\"lr\"]\n", "\n", " print(\n", " \", \".join(\n", " [\n", " f\"{epoch=}\",\n", " f\"{t=:.8f}\",\n", " f\"{train_loss=:.8f}\",\n", " f\"{val_loss=:.8f}\",\n", " f\"{lr=:.1e}\",\n", " ]\n", " )\n", " )\n", "\n", " if lr < 1e-5:\n", " print(f\"Early stopping due to small {lr=:.1e}\")\n", " break" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Depending on the seed, the model should have converged with a training and validation loss less than 1. This is not bad at all. \n", "\n", "However, recall how much code was required to implement, the time to train, and the resulting mean squared error loss of using a classical ML model in the [introductory tutorial](https://github.com/DillonZChen/wlplan/blob/main/docs/tutorials/1_introduction.ipynb). It seems that training an ML model is much more efficient and effective than a GNN for learning heuristic functions for PDDL planning. \n", "\n", "The paper [Return to Tradition: Learning Reliable Heuristics with Classical Machine Learning](https://arxiv.org/abs/2403.16508) shows that it is also the case that training a classical ML model with WLPlan features also leads to much more efficient planning performance compared to GNNs!" ] } ], "metadata": { "kernelspec": { "display_name": ".venv", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.12.3" } }, "nbformat": 4, "nbformat_minor": 2 }