Point Cloud Classification with Graph Neural Networks Using PyTorch Geometric
When we think about machine learning on 3D data, we usually think about images or voxel grids — discretised, regular structures that fit neatly into convolution pipelines. But a lot of 3D data in the real world doesn’t look like that. It looks like a point cloud: a raw, unordered set of coordinates in space, sometimes accompanied by extra measurements like surface normals or colour.
Protein structures, LiDAR scans, molecular conformations, cryo-EM densities — all of these are fundamentally point clouds. PyTorch Geometric (PyG) makes it straightforward to apply Graph Neural Networks directly to this type of data, without voxelising it or projecting it onto a grid.
This post walks through PyG’s point cloud tutorial from first principles. No prior graph ML experience assumed.
What Is a Point Cloud?
A point cloud is simply a collection of points in 3D space:
P = { (x₁, y₁, z₁), (x₂, y₂, z₂), ..., (xₙ, yₙ, zₙ) }
There is no inherent connectivity — no bonds, no edges, no grid. The points can be in any order, and there can be any number of them. This is what makes point clouds tricky for standard deep learning: CNNs expect a fixed-size, ordered grid, but a point cloud is neither.
Graph Neural Networks sidestep this by constructing a graph from the points — connecting nearby points with edges — and then running message passing on that graph to learn local geometric structure.
The PyG Data Object
Everything in PyTorch Geometric revolves around the Data object — a flexible container for graph-structured data. For point clouds specifically, the key attribute is pos:
from torch_geometric.data import Data
import torch
# A simple point cloud: 5 points in 3D
data = Data(pos=torch.randn(5, 3))
print(data)
# Data(pos=[5, 3])
data.pos holds the 3D coordinates of each point. Other common attributes you’ll encounter:
| Attribute | Shape | Meaning |
|---|---|---|
data.pos |
[N, 3] |
3D coordinates of N points |
data.x |
[N, F] |
Node feature vectors (optional) |
data.edge_index |
[2, E] |
Graph connectivity — pairs of connected point indices |
data.y |
[1] or [N] |
Label(s) — graph-level or point-level |
data.batch |
[N] |
Batch assignment vector (added automatically during batching) |
The Data object is intentionally minimal — you only include what you have.
The Dataset: GeometricShapes
PyG includes GeometricShapes, a synthetic dataset of 3D meshes representing simple geometric shapes (cubes, spheres, pyramids, and more). It’s a convenient sandbox for learning.
from torch_geometric.datasets import GeometricShapes
dataset = GeometricShapes(root='data/GeometricShapes')
print(dataset)
# GeometricShapes(40)
data = dataset[0]
print(data)
# Data(pos=[32, 3], face=[3, 30], y=[1])
Each object is stored as a mesh — vertices (pos) and their triangular connectivity (face). But we want to work with point clouds, not meshes. This is where transforms come in.
Transforms: From Mesh to Point Cloud to Graph
PyG has a rich transforms system (torch_geometric.transforms). Transforms are functions that take a Data object and return a modified one. You can compose multiple transforms together.
Step 1: Sample points from the mesh surface
SamplePoints uniformly samples a fixed number of 3D points from the mesh’s surface, proportional to face area. It discards the mesh connectivity and replaces it with a point cloud.
import torch_geometric.transforms as T
dataset.transform = T.SamplePoints(num=256)
data = dataset[0]
print(data)
# Data(pos=[256, 3], y=[1])
We now have 256 points sampled from the surface of the shape. The face attribute is gone — we only have coordinates.
Note:
SamplePointsis stochastic — you get a different set of points every time you access the same example. This acts as data augmentation during training for free.
Step 2: Build a graph with k-nearest neighbours
A point cloud alone has no edges. To run a GNN, we need to connect the points. The standard approach is k-nearest neighbour (k-NN) graph construction: for each point, connect it to its k geometrically closest neighbours.
dataset.transform = T.Compose([
T.SamplePoints(num=256),
T.KNNGraph(k=6),
])
data = dataset[0]
print(data)
# Data(pos=[256, 3], edge_index=[2, 1536], y=[1])
We now have 1536 edges — exactly 6 per point. edge_index is a [2, E] tensor where edge_index[0] contains source node indices and edge_index[1] contains destination node indices.
Why k-NN and not a fixed radius? k-NN always gives exactly k neighbours per point regardless of local density, which keeps the graph regular. A radius-based graph (radius_graph) is an alternative when you care more about physical scale than computational regularity.
PointNet++: The Architecture
PointNet++ (Qi et al., NeurIPS 2017) is the foundational architecture for learning from point clouds. It works by iteratively:
- Grouping — constructing a local neighbourhood graph (e.g., via k-NN)
- Aggregating — running a GNN layer that collects information from each point’s neighbours
- Downsampling — reducing the number of points (like pooling in a CNN) to capture larger-scale structure
For simplicity, we focus here on the grouping and aggregation steps, which are the core of the method.
The PointNet++ Layer: MessagePassing from Scratch
PyG provides a MessagePassing base class that handles all the bookkeeping of graph message propagation — you just define what message to send, and how to aggregate them.
The PointNet++ update rule for node i at layer ℓ+1 is:
$$h_i^{(\ell+1)} = \max_{j \in \mathcal{N}(i)} \text{MLP}\!\left(h_j^{(\ell)},\ \mathbf{p}_j - \mathbf{p}_i\right)$$In words: for each neighbour j of point i, concatenate j’s features with the relative position of j with respect to i, pass that through a small MLP, then take the maximum over all neighbours. The relative position p_j - p_i is what gives the layer its geometric awareness — the network learns to respond to local shape, not just content.
Here is a full implementation from scratch using MessagePassing:
from torch import Tensor
from torch.nn import Sequential, Linear, ReLU
from torch_geometric.nn import MessagePassing
class PointNetLayer(MessagePassing):
def __init__(self, in_channels: int, out_channels: int):
# "max" aggregation: take the maximum over all incoming messages
super().__init__(aggr='max')
# MLP: maps (neighbour features + relative position) → message
# Input size = in_channels (neighbour features) + 3 (relative xyz)
self.mlp = Sequential(
Linear(in_channels + 3, out_channels),
ReLU(),
Linear(out_channels, out_channels),
)
def forward(self, h: Tensor, pos: Tensor, edge_index: Tensor) -> Tensor:
# Trigger message passing:
# pos=(pos, pos) passes source and destination positions to message()
return self.propagate(edge_index, h=h, pos=pos)
def message(self, h_j: Tensor, pos_j: Tensor, pos_i: Tensor) -> Tensor:
# h_j: features of neighbouring points (the "j" suffix = source nodes)
# pos_j - pos_i: relative position of neighbour w.r.t. current point
# The "_j" and "_i" suffixes are a PyG convention for source/destination
relative_pos = pos_j - pos_i
msg_input = torch.cat([h_j, relative_pos], dim=-1)
return self.mlp(msg_input)
A few things worth understanding here:
The _j and _i suffix convention. In PyG’s MessagePassing, when you name a tensor in message() with a _j suffix, PyG automatically selects the features of the source nodes (neighbours). The _i suffix selects features of the destination node (the central point). This is just syntactic sugar — PyG does the indexing for you.
Why relative position, not absolute? Using pos_j - pos_i instead of absolute coordinates makes the layer translation-invariant — it responds to the shape of the local neighbourhood, not where it is in global space. This is important for generalisation.
Why max aggregation? Max pooling is permutation-invariant (order of neighbours doesn’t matter) and selects the most “active” incoming signal. It empirically works better than mean for point clouds, because the most distinctive geometric feature often comes from a single key neighbour, not the average.
Building the Full Network
Now we assemble the PointNetLayer into a full classifier. PyG handles mini-batching by concatenating all points from all examples in the batch into a single tensor, using a batch vector to track which points belong to which example.
import torch
import torch.nn.functional as F
from torch_geometric.nn import global_max_pool
class PointNet(torch.nn.Module):
def __init__(self, num_classes: int):
super().__init__()
# Two PointNet++ layers: 3D coords → 32 features → 64 features
self.conv1 = PointNetLayer(in_channels=3, out_channels=32)
self.conv2 = PointNetLayer(in_channels=32, out_channels=64)
# Final classifier
self.classifier = torch.nn.Linear(64, num_classes)
def forward(self, pos: Tensor, edge_index: Tensor, batch: Tensor) -> Tensor:
# Initial node features = raw 3D coordinates
h = pos
# Layer 1: each point aggregates from its k-NN neighbours
h = self.conv1(h, pos, edge_index)
h = F.relu(h)
# Layer 2: aggregate again with updated features
h = self.conv2(h, pos, edge_index)
h = F.relu(h)
# Global readout: reduce all N points → one vector per graph
# global_max_pool uses the `batch` vector to know which points belong together
graph_embedding = global_max_pool(h, batch) # shape: [B, 64]
# Classify
return self.classifier(graph_embedding)
Understanding global_max_pool and the batch vector
This is one of PyG’s most important design choices, and it’s worth understanding clearly.
When you load a mini-batch of 10 point clouds, each with 256 points, PyG doesn’t create a list of 10 separate tensors. Instead, it concatenates everything into a single tensor of shape [2560, 3] — all 2560 points together — and creates a batch vector:
batch = [0, 0, ..., 0, 1, 1, ..., 1, ..., 9, 9, ..., 9]
↑ 256 zeros ↑ 256 ones ↑ 256 nines
This batch vector tells global_max_pool which points belong to which example, so it can reduce each group of 256 points into a single 64-dimensional embedding. The output shape is [10, 64] — one vector per shape in the batch.
This approach is computationally efficient because all operations (message passing, MLP forward passes) run on contiguous tensors, not Python lists of variable-size arrays.
Using the Built-in PointNetConv
If you don’t want to implement the layer from scratch, PyG ships a ready-to-use version:
from torch_geometric.nn import PointNetConv
from torch.nn import Sequential, Linear, ReLU, BatchNorm1d
local_nn = Sequential(
Linear(3 + 3, 64), # 3 (neighbour features) + 3 (relative pos)
ReLU(),
Linear(64, 64),
BatchNorm1d(64),
)
conv = PointNetConv(local_nn=local_nn, add_self_loops=True)
PointNetConv implements the same operation but with additional options like a global_nn (applied after aggregation) and support for self-loops.
Putting It All Together: Dataset, Loader, and Training
import torch
import torch_geometric.transforms as T
from torch_geometric.datasets import GeometricShapes
from torch_geometric.loader import DataLoader
from torch_geometric.nn import global_max_pool
# --- Data ---
transform = T.Compose([T.SamplePoints(num=256), T.KNNGraph(k=6)])
train_dataset = GeometricShapes(root='data/GeometricShapes', train=True)
train_dataset.transform = transform
test_dataset = GeometricShapes(root='data/GeometricShapes', train=False)
test_dataset.transform = transform
train_loader = DataLoader(train_dataset, batch_size=10, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=10)
# --- Model, optimizer, loss ---
model = PointNet(num_classes=dataset.num_classes)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
criterion = torch.nn.CrossEntropyLoss()
# --- Training loop ---
def train():
model.train()
total_loss = 0
for data in train_loader:
optimizer.zero_grad()
logits = model(data.pos, data.edge_index, data.batch)
loss = criterion(logits, data.y)
loss.backward()
optimizer.step()
total_loss += float(loss) * data.num_graphs
return total_loss / len(train_loader.dataset)
@torch.no_grad()
def test(loader):
model.eval()
total_correct = 0
for data in loader:
logits = model(data.pos, data.edge_index, data.batch)
pred = logits.argmax(dim=-1)
total_correct += int((pred == data.y).sum())
return total_correct / len(loader.dataset)
for epoch in range(1, 51):
loss = train()
train_acc = test(train_loader)
test_acc = test(test_loader)
print(f'Epoch {epoch:02d} | Loss: {loss:.4f} | Train: {train_acc:.2%} | Test: {test_acc:.2%}')
The training loop is standard PyTorch — the only PyG-specific things are the DataLoader (which handles batching of graphs) and the data.pos, data.edge_index, data.batch attributes you pass to the model.
Going Further: Farthest Point Sampling
The tutorial also introduces Farthest Point Sampling (FPS) — a downsampling strategy used in PointNet++ to progressively reduce the number of points across layers, similar to how a CNN reduces spatial resolution via pooling.
FPS iteratively selects the point furthest from all already-selected points, ensuring good coverage of the shape’s surface:
from torch_cluster import fps
# Sample 25% of the points, keeping those most spread out
sampled_indices = fps(data.pos, ratio=0.25)
downsampled_pos = data.pos[sampled_indices]
After downsampling, you’d rebuild a new k-NN graph on the reduced point set and run another PointNet++ layer. This hierarchical approach — aggregate, downsample, repeat — is what gives PointNet++ its ability to capture multi-scale geometric structure.
Key Concepts to Remember
From mesh to point cloud: SamplePoints does this in one line. The stochastic sampling acts as free data augmentation.
From point cloud to graph: KNNGraph(k=6) connects each point to its 6 nearest neighbours. The graph is re-constructed dynamically every time you access the data.
Message passing in PyG: Subclass MessagePassing, define message(), and use _j/_i suffixes to access neighbour/central node attributes. PyG handles all the scatter-gather indexing under the hood.
Relative position is the key: The term pos_j - pos_i in the message function is what gives the network its geometric awareness. Without it, the network would be blind to the shape of the neighbourhood.
Batching via the batch vector: PyG concatenates all graphs in a batch into one large disconnected graph. global_max_pool(h, batch) then recovers one embedding per graph using the batch assignment vector.
Comparison with DGL
If you read the previous post on DGL, you’ll notice both libraries share the message-passing paradigm. The key practical differences for point cloud work:
| PyG | DGL | |
|---|---|---|
| Graph construction | KNNGraph, radius_graph transforms |
Manual or external |
| Point cloud support | First-class (SamplePoints, PointNetConv) |
General-purpose |
| Message passing | MessagePassing base class, _j/_i suffix convention |
update_all + fn.* built-ins |
| Batching | batch vector, global_*_pool |
dgl.batch, dgl.mean_nodes |
| Built-in PointNet | PointNetConv |
Not built-in |
For point cloud tasks specifically, PyG’s transform pipeline and built-in PointNet support give it an edge. For heterogeneous molecular graphs, DGL’s heterograph API is more expressive.
Next post: building atom-level node features from protein–ligand complexes with RDKit and MDAnalysis, and connecting them to a PyG or DGL training pipeline.