Getting Started with Graph Neural Networks for Protein–Ligand Complexes Using DGL

Apr 5, 2026 · 8 min read

So you have a set of docked protein–ligand complexes and you want to train a machine learning model on them. Traditional approaches would have you reach for fingerprints or descriptor vectors: flat, fixed-size arrays that compress the 3D interaction into a single number per compound. They work, but they discard a lot of structural information.

Graph Neural Networks (GNNs) offer a richer alternative: represent the molecule as a graph, where atoms are nodes and bonds (or non-covalent contacts) are edges, then let the network learn directly from that topology. The Deep Graph Library (DGL) makes this surprisingly approachable in PyTorch.

This post walks you through the core ideas from the DGL User Guide, translated into the language of computational drug discovery. No prior GNN experience required.


Why Graphs for Molecules?

A molecule has a natural graph structure. Atoms are nodes, bonds are edges. When you extend this to a protein–ligand complex, you can add edges for non-covalent contacts (hydrogen bonds, hydrophobic interactions, π-stacking) connecting ligand atoms to binding-pocket residues.

The key advantage: a GNN can aggregate information from a node’s neighborhood iteratively, so after a few layers, each atom’s representation encodes not just its own identity but the chemical context around it. This is exactly what a human medicinal chemist does mentally when reading a 2D structure.


The DGL Core: DGLGraph

Everything in DGL revolves around the DGLGraph object. Think of it as a PyTorch-native container for graph-structured data.

import dgl
import torch

# Define edges as (source_nodes, destination_nodes)
src = torch.tensor([0, 0, 1, 2])
dst = torch.tensor([1, 2, 2, 3])

g = dgl.graph((src, dst))
print(g)
# Graph(num_nodes=4, num_edges=4, ndata_schemes={}, edata_schemes={})

For molecular graphs, edges are usually undirected (bond A→B is the same as B→A), so use dgl.to_bidirected():

g = dgl.to_bidirected(g)

Attaching Features

Node and edge features are stored as PyTorch tensors directly on the graph:

# Node features: e.g., one-hot atomic number, partial charge, aromaticity...
g.ndata['feat'] = torch.randn(g.num_nodes(), 32)

# Edge features: e.g., bond type, interatomic distance...
g.edata['dist'] = torch.randn(g.num_edges(), 1)

This is the bridge between your cheminformatics featurization pipeline (RDKit, MDAnalysis, etc.) and the GNN. Whatever features you compute, just stick them in ndata or edata.


Heterogeneous Graphs: Protein + Ligand Together

If you want to model protein–ligand contacts explicitly (not just the ligand in isolation), DGL supports heterogeneous graphs: graphs with multiple node and edge types.

g = dgl.heterograph({
    ('ligand_atom', 'bond', 'ligand_atom'): (lig_src, lig_dst),
    ('ligand_atom', 'contacts', 'protein_residue'): (lig_ids, res_ids),
    ('protein_residue', 'contacts', 'ligand_atom'): (res_ids, lig_ids),
})

g.nodes['ligand_atom'].data['feat'] = ligand_atom_features
g.nodes['protein_residue'].data['feat'] = residue_features

This is a more faithful representation of the complex: the network can learn that certain binding-pocket residues modulate the ligand’s representation depending on how they interact.


Message Passing: How GNNs Actually Work

This is the heart of DGL and the heart of every GNN. Message passing works in two steps per layer:

Step 1: Edge-wise (messages): For each edge, compute a “message” from the source node’s features (and optionally the edge’s own features).

Step 2: Node-wise (aggregation + update): Each node collects all incoming messages, aggregates them (sum, mean, or max), and updates its own features.

Mathematically, for node v at layer t+1:

$$h_v^{(t+1)} = \psi\!\left(h_v^{(t)},\ \rho\!\left(\left\{ \phi(h_u^{(t)}, h_v^{(t)}, w_{uv}) : u \in \mathcal{N}(v) \right\}\right)\right)$$

where φ is the message function, ρ is the aggregation (e.g., mean), and ψ is the update function.

In DGL code, the fastest way uses built-in functions:

import dgl.function as fn

# Message: copy source node feature → edge mailbox
# Reduce: average all incoming messages → update node feature
g.update_all(fn.copy_u('h', 'm'), fn.mean('m', 'h'))

This runs as fused CUDA kernels, much faster than writing Python loops.

For custom behavior (e.g., distance-weighted messages):

def weighted_message(edges):
    # edges.src['h']: source node features
    # edges.data['dist']: edge distance
    return {'m': edges.src['h'] * torch.exp(-edges.data['dist'])}

g.update_all(weighted_message, fn.sum('m', 'h'))

Built-in GNN Layers

DGL ships with ready-to-use layers in dgl.nn.pytorch. The most useful for molecular tasks:

Layer When to use
GraphConv Simplest baseline, no edge features
GATConv Learns per-edge attention weights; great for contacts with varying importance
NNConv Edge-feature-conditioned messages; use when bond type or distance matters
HeteroGraphConv Wraps per-relation convolutions for heterogeneous graphs

A simple two-layer GAT:

import dgl.nn.pytorch as dglnn
import torch.nn as nn
import torch.nn.functional as F

class MolGNN(nn.Module):
    def __init__(self, in_dim, hidden_dim):
        super().__init__()
        self.gat1 = dglnn.GATConv(in_dim, hidden_dim, num_heads=4)
        self.gat2 = dglnn.GATConv(hidden_dim * 4, hidden_dim, num_heads=1)

    def forward(self, g, feat):
        h = F.elu(self.gat1(g, feat).flatten(1))  # (N, hidden*4)
        h = self.gat2(g, h).squeeze(1)             # (N, hidden)
        return h

Graph-Level Prediction: Readout

After message passing, each node has a learned embedding. For a graph-level prediction (e.g., is this compound active?), you need to compress the variable-number of node embeddings into a single fixed-size vector. This is called readout.

DGL provides built-in readout functions:

# Mean of all node features → one vector per graph
graph_embedding = dgl.mean_nodes(g, 'h')

# Also available: dgl.sum_nodes, dgl.max_nodes

For a batch of graphs (many compounds at once):

bg = dgl.batch([g1, g2, g3, ...])  # one disconnected super-graph
# ... run message passing on bg ...
embeddings = dgl.mean_nodes(bg, 'h')  # shape (batch_size, hidden_dim)

dgl.batch() is what makes mini-batch training efficient: DGL treats the batch as a single large graph (with no edges between the individual molecules), runs all convolutions in one GPU call, then recovers per-graph embeddings with readout.


The Full Graph Classification Pipeline

Putting it all together: a classifier that takes a molecule graph and outputs a binary label:

import dgl
import dgl.nn.pytorch as dglnn
import torch
import torch.nn as nn
import torch.nn.functional as F

class GraphClassifier(nn.Module):
    def __init__(self, in_dim, hidden_dim, n_classes=2):
        super().__init__()
        self.conv1 = dglnn.GATConv(in_dim, hidden_dim, num_heads=4)
        self.conv2 = dglnn.GATConv(hidden_dim * 4, hidden_dim, num_heads=1)
        self.classify = nn.Linear(hidden_dim, n_classes)

    def forward(self, g, feat):
        # Message passing
        h = F.elu(self.conv1(g, feat).flatten(1))
        h = self.conv2(g, h).squeeze(1)

        # Store updated features
        with g.local_scope():
            g.ndata['h'] = h
            # Readout: mean over all atoms → graph embedding
            graph_feat = dgl.mean_nodes(g, 'h')

        return self.classify(graph_feat)


# Training loop sketch
model = GraphClassifier(in_dim=32, hidden_dim=64)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
loss_fn = nn.CrossEntropyLoss()

for epoch in range(100):
    for batched_graphs, labels in dataloader:
        feats = batched_graphs.ndata['feat']
        logits = model(batched_graphs, feats)
        loss = loss_fn(logits, labels)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

Building a Dataset

DGL has a dataset API that handles caching processed graphs to disk, important when building graphs from PDB/SDF files is slow.

class MolDataset(dgl.data.DGLDataset):
    def __init__(self, raw_dir):
        super().__init__(name='mol_dataset', raw_dir=raw_dir)

    def process(self):
        self.graphs, self.labels = [], []
        for path in self._get_files():
            g, label = self._build_graph(path)   # your featurization here
            self.graphs.append(g)
            self.labels.append(label)

    def __getitem__(self, i):
        return self.graphs[i], self.labels[i]

    def __len__(self):
        return len(self.graphs)

    def has_cache(self):
        return os.path.exists(self.save_path)

    def save(self):
        dgl.save_graphs(self.save_path, self.graphs,
                        {'labels': torch.tensor(self.labels)})

    def load(self):
        self.graphs, label_dict = dgl.load_graphs(self.save_path)
        self.labels = label_dict['labels'].tolist()

Pair this with a DataLoader using dgl.dataloading.GraphDataLoader (or PyTorch’s standard DataLoader with a custom collate_fn that calls dgl.batch()).


Multiple Instances per Compound: the Bag Approach

A common scenario in ensemble docking: you dock one compound against N different receptor conformations and get N poses, each a separate graph. You want to make one prediction per compound, not one per pose.

This is a natural fit for Multiple Instance Learning (MIL). The trick with DGL:

def predict_bag(model, bag_of_graphs):
    # Stack all N pose-graphs into one batched graph
    bg = dgl.batch(bag_of_graphs)
    feats = bg.ndata['feat']

    # Run GNN — poses stay disconnected (no cross-pose edges)
    h = model.encode(bg, feats)          # node embeddings

    with bg.local_scope():
        bg.ndata['h'] = h
        pose_embeds = dgl.mean_nodes(bg, 'h')   # (N, h_dim) — one per pose

    # MIL aggregation over the bag
    bag_embed = pose_embeds.mean(0, keepdim=True)   # mean pooling
    return model.classify(bag_embed)

The key insight: dgl.batch() keeps the poses structurally independent (no edges cross between them), so the GNN respects the fact that each pose is a separate physical structure. The MIL aggregation then decides how to combine them for the final prediction.

For a more expressive aggregation, use attention MIL:

class AttentionMIL(nn.Module):
    def __init__(self, h_dim):
        super().__init__()
        self.attention = nn.Sequential(
            nn.Linear(h_dim, 64),
            nn.Tanh(),
            nn.Linear(64, 1)
        )

    def forward(self, pose_embeds):           # (N, h_dim)
        scores = self.attention(pose_embeds)  # (N, 1)
        weights = torch.softmax(scores, dim=0)
        return (weights * pose_embeds).sum(0) # (h_dim,)

Attention MIL has a nice bonus: weights tells you which poses contributed most to the prediction, interpretable by design.


Which Chapters to Read First

If you’re using this post as a reading guide alongside the DGL User Guide:

Priority Chapter Why
⭐⭐⭐ Ch 1.2–1.5: Graphs, features, heterograph Core data structure
⭐⭐⭐ Ch 2.1–2.2: Message passing API The computation engine
⭐⭐⭐ Ch 5.4: Graph classification Your prediction task
⭐⭐ Ch 3: GNN modules Ready-to-use layers
⭐⭐ Ch 4: Data pipeline Dataset + caching
Ch 6–8: Stochastic / distributed / mixed precision Only if dataset is huge

Summary

DGL wraps the complexity of graph-structured computation behind a clean PyTorch API. For molecular applications, the workflow is:

  1. Build a DGLGraph for each complex: atoms as nodes, bonds/contacts as edges, features from your cheminformatics toolkit
  2. Choose a GNN layer: GATConv is a strong default
  3. Run message passing: each atom aggregates information from its neighbors
  4. Apply readout (dgl.mean_nodes) to get a fixed-size graph embedding
  5. Feed into a standard MLP classifier

The library handles batching, GPU transfer, and efficient sparse operations under the hood. Once you get the featurization right, the training code looks almost identical to a standard PyTorch image classifier.