<?xml version="1.0" encoding="utf-8" standalone="yes" ?>
<rss version="2.0" xmlns:atom="http://www.w3.org/2005/Atom">
  <channel>
    <title>Deep Learning | Yassir Boulaamane</title>
    <link>https://yboulaamane.github.io/tags/deep-learning/</link>
      <atom:link href="https://yboulaamane.github.io/tags/deep-learning/index.xml" rel="self" type="application/rss+xml" />
    <description>Deep Learning</description>
    <generator>Hugo Blox Builder (https://hugoblox.com)</generator><language>en-us</language><lastBuildDate>Sun, 05 Apr 2026 00:00:00 +0000</lastBuildDate>
    <image>
      <url>https://yboulaamane.github.io/media/icon_hu_4d696a8ace2a642b.png</url>
      <title>Deep Learning</title>
      <link>https://yboulaamane.github.io/tags/deep-learning/</link>
    </image>
    
    <item>
      <title>Getting Started with Graph Neural Networks for Protein–Ligand Complexes Using DGL</title>
      <link>https://yboulaamane.github.io/blog/getting-started-with-graph-neural-networks-for-proteinligand-complexes-using-dgl/</link>
      <pubDate>Sun, 05 Apr 2026 00:00:00 +0000</pubDate>
      <guid>https://yboulaamane.github.io/blog/getting-started-with-graph-neural-networks-for-proteinligand-complexes-using-dgl/</guid>
      <description>&lt;p&gt;So you have a set of docked protein–ligand complexes and you want to train a machine learning model on them. Traditional approaches would have you reach for fingerprints or descriptor vectors: flat, fixed-size arrays that compress the 3D interaction into a single number per compound. They work, but they discard a lot of structural information.&lt;/p&gt;
&lt;p&gt;Graph Neural Networks (GNNs) offer a richer alternative: represent the molecule &lt;em&gt;as a graph&lt;/em&gt;, where atoms are nodes and bonds (or non-covalent contacts) are edges, then let the network learn directly from that topology. The 
 (DGL) makes this surprisingly approachable in PyTorch.&lt;/p&gt;
&lt;p&gt;This post walks you through the core ideas from the DGL User Guide, translated into the language of computational drug discovery. No prior GNN experience required.&lt;/p&gt;
&lt;hr&gt;
&lt;h2 id=&#34;why-graphs-for-molecules&#34;&gt;Why Graphs for Molecules?&lt;/h2&gt;
&lt;p&gt;A molecule has a natural graph structure. Atoms are nodes, bonds are edges. When you extend this to a protein–ligand complex, you can add edges for non-covalent contacts (hydrogen bonds, hydrophobic interactions, π-stacking) connecting ligand atoms to binding-pocket residues.&lt;/p&gt;
&lt;p&gt;The key advantage: a GNN can aggregate information from a node&amp;rsquo;s &lt;em&gt;neighborhood&lt;/em&gt; iteratively, so after a few layers, each atom&amp;rsquo;s representation encodes not just its own identity but the chemical context around it. This is exactly what a human medicinal chemist does mentally when reading a 2D structure.&lt;/p&gt;
&lt;hr&gt;
&lt;h2 id=&#34;the-dgl-core-dglgraph&#34;&gt;The DGL Core: &lt;code&gt;DGLGraph&lt;/code&gt;&lt;/h2&gt;
&lt;p&gt;Everything in DGL revolves around the &lt;code&gt;DGLGraph&lt;/code&gt; object. Think of it as a PyTorch-native container for graph-structured data.&lt;/p&gt;
&lt;div class=&#34;highlight&#34;&gt;&lt;pre tabindex=&#34;0&#34; class=&#34;chroma&#34;&gt;&lt;code class=&#34;language-python&#34; data-lang=&#34;python&#34;&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;&lt;span class=&#34;kn&#34;&gt;import&lt;/span&gt; &lt;span class=&#34;nn&#34;&gt;dgl&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;&lt;span class=&#34;kn&#34;&gt;import&lt;/span&gt; &lt;span class=&#34;nn&#34;&gt;torch&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;&lt;span class=&#34;c1&#34;&gt;# Define edges as (source_nodes, destination_nodes)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;&lt;span class=&#34;n&#34;&gt;src&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;torch&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;tensor&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;([&lt;/span&gt;&lt;span class=&#34;mi&#34;&gt;0&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;mi&#34;&gt;0&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;mi&#34;&gt;1&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;mi&#34;&gt;2&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;])&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;&lt;span class=&#34;n&#34;&gt;dst&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;torch&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;tensor&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;([&lt;/span&gt;&lt;span class=&#34;mi&#34;&gt;1&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;mi&#34;&gt;2&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;mi&#34;&gt;2&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;mi&#34;&gt;3&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;])&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;&lt;span class=&#34;n&#34;&gt;g&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;dgl&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;graph&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;((&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;src&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;dst&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;))&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;&lt;span class=&#34;nb&#34;&gt;print&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;g&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;&lt;span class=&#34;c1&#34;&gt;# Graph(num_nodes=4, num_edges=4, ndata_schemes={}, edata_schemes={})&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;&lt;p&gt;For molecular graphs, edges are usually undirected (bond A→B is the same as B→A), so use &lt;code&gt;dgl.to_bidirected()&lt;/code&gt;:&lt;/p&gt;
&lt;div class=&#34;highlight&#34;&gt;&lt;pre tabindex=&#34;0&#34; class=&#34;chroma&#34;&gt;&lt;code class=&#34;language-python&#34; data-lang=&#34;python&#34;&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;&lt;span class=&#34;n&#34;&gt;g&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;dgl&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;to_bidirected&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;g&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;&lt;h3 id=&#34;attaching-features&#34;&gt;Attaching Features&lt;/h3&gt;
&lt;p&gt;Node and edge features are stored as PyTorch tensors directly on the graph:&lt;/p&gt;
&lt;div class=&#34;highlight&#34;&gt;&lt;pre tabindex=&#34;0&#34; class=&#34;chroma&#34;&gt;&lt;code class=&#34;language-python&#34; data-lang=&#34;python&#34;&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;&lt;span class=&#34;c1&#34;&gt;# Node features: e.g., one-hot atomic number, partial charge, aromaticity...&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;&lt;span class=&#34;n&#34;&gt;g&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;ndata&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;[&lt;/span&gt;&lt;span class=&#34;s1&#34;&gt;&amp;#39;feat&amp;#39;&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;]&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;torch&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;randn&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;g&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;num_nodes&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(),&lt;/span&gt; &lt;span class=&#34;mi&#34;&gt;32&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;&lt;span class=&#34;c1&#34;&gt;# Edge features: e.g., bond type, interatomic distance...&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;&lt;span class=&#34;n&#34;&gt;g&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;edata&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;[&lt;/span&gt;&lt;span class=&#34;s1&#34;&gt;&amp;#39;dist&amp;#39;&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;]&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;torch&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;randn&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;g&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;num_edges&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(),&lt;/span&gt; &lt;span class=&#34;mi&#34;&gt;1&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;&lt;p&gt;This is the bridge between your cheminformatics featurization pipeline (RDKit, MDAnalysis, etc.) and the GNN. Whatever features you compute, just stick them in &lt;code&gt;ndata&lt;/code&gt; or &lt;code&gt;edata&lt;/code&gt;.&lt;/p&gt;
&lt;hr&gt;
&lt;h2 id=&#34;heterogeneous-graphs-protein--ligand-together&#34;&gt;Heterogeneous Graphs: Protein + Ligand Together&lt;/h2&gt;
&lt;p&gt;If you want to model protein–ligand contacts explicitly (not just the ligand in isolation), DGL supports &lt;strong&gt;heterogeneous graphs&lt;/strong&gt;: graphs with multiple node and edge types.&lt;/p&gt;
&lt;div class=&#34;highlight&#34;&gt;&lt;pre tabindex=&#34;0&#34; class=&#34;chroma&#34;&gt;&lt;code class=&#34;language-python&#34; data-lang=&#34;python&#34;&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;&lt;span class=&#34;n&#34;&gt;g&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;dgl&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;heterograph&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;({&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;s1&#34;&gt;&amp;#39;ligand_atom&amp;#39;&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;s1&#34;&gt;&amp;#39;bond&amp;#39;&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;s1&#34;&gt;&amp;#39;ligand_atom&amp;#39;&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;):&lt;/span&gt; &lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;lig_src&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;lig_dst&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;),&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;s1&#34;&gt;&amp;#39;ligand_atom&amp;#39;&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;s1&#34;&gt;&amp;#39;contacts&amp;#39;&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;s1&#34;&gt;&amp;#39;protein_residue&amp;#39;&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;):&lt;/span&gt; &lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;lig_ids&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;res_ids&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;),&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;s1&#34;&gt;&amp;#39;protein_residue&amp;#39;&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;s1&#34;&gt;&amp;#39;contacts&amp;#39;&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;s1&#34;&gt;&amp;#39;ligand_atom&amp;#39;&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;):&lt;/span&gt; &lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;res_ids&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;lig_ids&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;),&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;&lt;span class=&#34;p&#34;&gt;})&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;&lt;span class=&#34;n&#34;&gt;g&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;nodes&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;[&lt;/span&gt;&lt;span class=&#34;s1&#34;&gt;&amp;#39;ligand_atom&amp;#39;&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;]&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;data&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;[&lt;/span&gt;&lt;span class=&#34;s1&#34;&gt;&amp;#39;feat&amp;#39;&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;]&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;ligand_atom_features&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;&lt;span class=&#34;n&#34;&gt;g&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;nodes&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;[&lt;/span&gt;&lt;span class=&#34;s1&#34;&gt;&amp;#39;protein_residue&amp;#39;&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;]&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;data&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;[&lt;/span&gt;&lt;span class=&#34;s1&#34;&gt;&amp;#39;feat&amp;#39;&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;]&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;residue_features&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;&lt;p&gt;This is a more faithful representation of the complex: the network can learn that certain binding-pocket residues modulate the ligand&amp;rsquo;s representation depending on how they interact.&lt;/p&gt;
&lt;hr&gt;
&lt;h2 id=&#34;message-passing-how-gnns-actually-work&#34;&gt;Message Passing: How GNNs Actually Work&lt;/h2&gt;
&lt;p&gt;This is the heart of DGL and the heart of every GNN. Message passing works in two steps per layer:&lt;/p&gt;
&lt;p&gt;&lt;strong&gt;Step 1: Edge-wise (messages):&lt;/strong&gt; For each edge, compute a &amp;ldquo;message&amp;rdquo; from the source node&amp;rsquo;s features (and optionally the edge&amp;rsquo;s own features).&lt;/p&gt;
&lt;p&gt;&lt;strong&gt;Step 2: Node-wise (aggregation + update):&lt;/strong&gt; Each node collects all incoming messages, aggregates them (sum, mean, or max), and updates its own features.&lt;/p&gt;
&lt;p&gt;Mathematically, for node &lt;em&gt;v&lt;/em&gt; at layer &lt;em&gt;t+1&lt;/em&gt;:&lt;/p&gt;
$$h_v^{(t+1)} = \psi\!\left(h_v^{(t)},\ \rho\!\left(\left\{ \phi(h_u^{(t)}, h_v^{(t)}, w_{uv}) : u \in \mathcal{N}(v) \right\}\right)\right)$$&lt;p&gt;where φ is the message function, ρ is the aggregation (e.g., mean), and ψ is the update function.&lt;/p&gt;
&lt;p&gt;In DGL code, the fastest way uses &lt;strong&gt;built-in functions&lt;/strong&gt;:&lt;/p&gt;
&lt;div class=&#34;highlight&#34;&gt;&lt;pre tabindex=&#34;0&#34; class=&#34;chroma&#34;&gt;&lt;code class=&#34;language-python&#34; data-lang=&#34;python&#34;&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;&lt;span class=&#34;kn&#34;&gt;import&lt;/span&gt; &lt;span class=&#34;nn&#34;&gt;dgl.function&lt;/span&gt; &lt;span class=&#34;k&#34;&gt;as&lt;/span&gt; &lt;span class=&#34;nn&#34;&gt;fn&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;&lt;span class=&#34;c1&#34;&gt;# Message: copy source node feature → edge mailbox&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;&lt;span class=&#34;c1&#34;&gt;# Reduce: average all incoming messages → update node feature&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;&lt;span class=&#34;n&#34;&gt;g&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;update_all&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;fn&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;copy_u&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;s1&#34;&gt;&amp;#39;h&amp;#39;&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;s1&#34;&gt;&amp;#39;m&amp;#39;&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;),&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;fn&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;mean&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;s1&#34;&gt;&amp;#39;m&amp;#39;&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;s1&#34;&gt;&amp;#39;h&amp;#39;&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;))&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;&lt;p&gt;This runs as fused CUDA kernels, much faster than writing Python loops.&lt;/p&gt;
&lt;p&gt;For custom behavior (e.g., distance-weighted messages):&lt;/p&gt;
&lt;div class=&#34;highlight&#34;&gt;&lt;pre tabindex=&#34;0&#34; class=&#34;chroma&#34;&gt;&lt;code class=&#34;language-python&#34; data-lang=&#34;python&#34;&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;&lt;span class=&#34;k&#34;&gt;def&lt;/span&gt; &lt;span class=&#34;nf&#34;&gt;weighted_message&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;edges&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;c1&#34;&gt;# edges.src[&amp;#39;h&amp;#39;]: source node features&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;c1&#34;&gt;# edges.data[&amp;#39;dist&amp;#39;]: edge distance&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;k&#34;&gt;return&lt;/span&gt; &lt;span class=&#34;p&#34;&gt;{&lt;/span&gt;&lt;span class=&#34;s1&#34;&gt;&amp;#39;m&amp;#39;&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;:&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;edges&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;src&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;[&lt;/span&gt;&lt;span class=&#34;s1&#34;&gt;&amp;#39;h&amp;#39;&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;]&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;*&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;torch&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;exp&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;-&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;edges&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;data&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;[&lt;/span&gt;&lt;span class=&#34;s1&#34;&gt;&amp;#39;dist&amp;#39;&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;])}&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;&lt;span class=&#34;n&#34;&gt;g&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;update_all&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;weighted_message&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;fn&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;sum&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;s1&#34;&gt;&amp;#39;m&amp;#39;&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;s1&#34;&gt;&amp;#39;h&amp;#39;&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;))&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;&lt;hr&gt;
&lt;h2 id=&#34;built-in-gnn-layers&#34;&gt;Built-in GNN Layers&lt;/h2&gt;
&lt;p&gt;DGL ships with ready-to-use layers in &lt;code&gt;dgl.nn.pytorch&lt;/code&gt;. The most useful for molecular tasks:&lt;/p&gt;
&lt;table&gt;
  &lt;thead&gt;
      &lt;tr&gt;
          &lt;th&gt;Layer&lt;/th&gt;
          &lt;th&gt;When to use&lt;/th&gt;
      &lt;/tr&gt;
  &lt;/thead&gt;
  &lt;tbody&gt;
      &lt;tr&gt;
          &lt;td&gt;&lt;code&gt;GraphConv&lt;/code&gt;&lt;/td&gt;
          &lt;td&gt;Simplest baseline, no edge features&lt;/td&gt;
      &lt;/tr&gt;
      &lt;tr&gt;
          &lt;td&gt;&lt;code&gt;GATConv&lt;/code&gt;&lt;/td&gt;
          &lt;td&gt;Learns per-edge attention weights; great for contacts with varying importance&lt;/td&gt;
      &lt;/tr&gt;
      &lt;tr&gt;
          &lt;td&gt;&lt;code&gt;NNConv&lt;/code&gt;&lt;/td&gt;
          &lt;td&gt;Edge-feature-conditioned messages; use when bond type or distance matters&lt;/td&gt;
      &lt;/tr&gt;
      &lt;tr&gt;
          &lt;td&gt;&lt;code&gt;HeteroGraphConv&lt;/code&gt;&lt;/td&gt;
          &lt;td&gt;Wraps per-relation convolutions for heterogeneous graphs&lt;/td&gt;
      &lt;/tr&gt;
  &lt;/tbody&gt;
&lt;/table&gt;
&lt;p&gt;A simple two-layer GAT:&lt;/p&gt;
&lt;div class=&#34;highlight&#34;&gt;&lt;pre tabindex=&#34;0&#34; class=&#34;chroma&#34;&gt;&lt;code class=&#34;language-python&#34; data-lang=&#34;python&#34;&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;&lt;span class=&#34;kn&#34;&gt;import&lt;/span&gt; &lt;span class=&#34;nn&#34;&gt;dgl.nn.pytorch&lt;/span&gt; &lt;span class=&#34;k&#34;&gt;as&lt;/span&gt; &lt;span class=&#34;nn&#34;&gt;dglnn&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;&lt;span class=&#34;kn&#34;&gt;import&lt;/span&gt; &lt;span class=&#34;nn&#34;&gt;torch.nn&lt;/span&gt; &lt;span class=&#34;k&#34;&gt;as&lt;/span&gt; &lt;span class=&#34;nn&#34;&gt;nn&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;&lt;span class=&#34;kn&#34;&gt;import&lt;/span&gt; &lt;span class=&#34;nn&#34;&gt;torch.nn.functional&lt;/span&gt; &lt;span class=&#34;k&#34;&gt;as&lt;/span&gt; &lt;span class=&#34;nn&#34;&gt;F&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;&lt;span class=&#34;k&#34;&gt;class&lt;/span&gt; &lt;span class=&#34;nc&#34;&gt;MolGNN&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;nn&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;Module&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;k&#34;&gt;def&lt;/span&gt; &lt;span class=&#34;fm&#34;&gt;__init__&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;bp&#34;&gt;self&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;in_dim&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;hidden_dim&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;        &lt;span class=&#34;nb&#34;&gt;super&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;()&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;fm&#34;&gt;__init__&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;()&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;        &lt;span class=&#34;bp&#34;&gt;self&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;gat1&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;dglnn&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;GATConv&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;in_dim&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;hidden_dim&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;num_heads&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;=&lt;/span&gt;&lt;span class=&#34;mi&#34;&gt;4&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;        &lt;span class=&#34;bp&#34;&gt;self&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;gat2&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;dglnn&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;GATConv&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;hidden_dim&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;*&lt;/span&gt; &lt;span class=&#34;mi&#34;&gt;4&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;hidden_dim&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;num_heads&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;=&lt;/span&gt;&lt;span class=&#34;mi&#34;&gt;1&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;k&#34;&gt;def&lt;/span&gt; &lt;span class=&#34;nf&#34;&gt;forward&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;bp&#34;&gt;self&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;g&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;feat&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;        &lt;span class=&#34;n&#34;&gt;h&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;F&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;elu&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;bp&#34;&gt;self&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;gat1&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;g&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;feat&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;flatten&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;mi&#34;&gt;1&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;))&lt;/span&gt;  &lt;span class=&#34;c1&#34;&gt;# (N, hidden*4)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;        &lt;span class=&#34;n&#34;&gt;h&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;bp&#34;&gt;self&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;gat2&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;g&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;h&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;squeeze&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;mi&#34;&gt;1&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt;             &lt;span class=&#34;c1&#34;&gt;# (N, hidden)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;        &lt;span class=&#34;k&#34;&gt;return&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;h&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;&lt;hr&gt;
&lt;h2 id=&#34;graph-level-prediction-readout&#34;&gt;Graph-Level Prediction: Readout&lt;/h2&gt;
&lt;p&gt;After message passing, each node has a learned embedding. For a graph-level prediction (e.g., is this compound active?), you need to compress the variable-number of node embeddings into a single fixed-size vector. This is called &lt;strong&gt;readout&lt;/strong&gt;.&lt;/p&gt;
&lt;p&gt;DGL provides built-in readout functions:&lt;/p&gt;
&lt;div class=&#34;highlight&#34;&gt;&lt;pre tabindex=&#34;0&#34; class=&#34;chroma&#34;&gt;&lt;code class=&#34;language-python&#34; data-lang=&#34;python&#34;&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;&lt;span class=&#34;c1&#34;&gt;# Mean of all node features → one vector per graph&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;&lt;span class=&#34;n&#34;&gt;graph_embedding&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;dgl&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;mean_nodes&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;g&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;s1&#34;&gt;&amp;#39;h&amp;#39;&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;&lt;span class=&#34;c1&#34;&gt;# Also available: dgl.sum_nodes, dgl.max_nodes&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;&lt;p&gt;For a batch of graphs (many compounds at once):&lt;/p&gt;
&lt;div class=&#34;highlight&#34;&gt;&lt;pre tabindex=&#34;0&#34; class=&#34;chroma&#34;&gt;&lt;code class=&#34;language-python&#34; data-lang=&#34;python&#34;&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;&lt;span class=&#34;n&#34;&gt;bg&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;dgl&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;batch&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;([&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;g1&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;g2&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;g3&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;...&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;])&lt;/span&gt;  &lt;span class=&#34;c1&#34;&gt;# one disconnected super-graph&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;&lt;span class=&#34;c1&#34;&gt;# ... run message passing on bg ...&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;&lt;span class=&#34;n&#34;&gt;embeddings&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;dgl&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;mean_nodes&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;bg&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;s1&#34;&gt;&amp;#39;h&amp;#39;&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt;  &lt;span class=&#34;c1&#34;&gt;# shape (batch_size, hidden_dim)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;&lt;p&gt;&lt;code&gt;dgl.batch()&lt;/code&gt; is what makes mini-batch training efficient: DGL treats the batch as a single large graph (with no edges between the individual molecules), runs all convolutions in one GPU call, then recovers per-graph embeddings with readout.&lt;/p&gt;
&lt;hr&gt;
&lt;h2 id=&#34;the-full-graph-classification-pipeline&#34;&gt;The Full Graph Classification Pipeline&lt;/h2&gt;
&lt;p&gt;Putting it all together: a classifier that takes a molecule graph and outputs a binary label:&lt;/p&gt;
&lt;div class=&#34;highlight&#34;&gt;&lt;pre tabindex=&#34;0&#34; class=&#34;chroma&#34;&gt;&lt;code class=&#34;language-python&#34; data-lang=&#34;python&#34;&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;&lt;span class=&#34;kn&#34;&gt;import&lt;/span&gt; &lt;span class=&#34;nn&#34;&gt;dgl&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;&lt;span class=&#34;kn&#34;&gt;import&lt;/span&gt; &lt;span class=&#34;nn&#34;&gt;dgl.nn.pytorch&lt;/span&gt; &lt;span class=&#34;k&#34;&gt;as&lt;/span&gt; &lt;span class=&#34;nn&#34;&gt;dglnn&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;&lt;span class=&#34;kn&#34;&gt;import&lt;/span&gt; &lt;span class=&#34;nn&#34;&gt;torch&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;&lt;span class=&#34;kn&#34;&gt;import&lt;/span&gt; &lt;span class=&#34;nn&#34;&gt;torch.nn&lt;/span&gt; &lt;span class=&#34;k&#34;&gt;as&lt;/span&gt; &lt;span class=&#34;nn&#34;&gt;nn&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;&lt;span class=&#34;kn&#34;&gt;import&lt;/span&gt; &lt;span class=&#34;nn&#34;&gt;torch.nn.functional&lt;/span&gt; &lt;span class=&#34;k&#34;&gt;as&lt;/span&gt; &lt;span class=&#34;nn&#34;&gt;F&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;&lt;span class=&#34;k&#34;&gt;class&lt;/span&gt; &lt;span class=&#34;nc&#34;&gt;GraphClassifier&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;nn&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;Module&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;k&#34;&gt;def&lt;/span&gt; &lt;span class=&#34;fm&#34;&gt;__init__&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;bp&#34;&gt;self&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;in_dim&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;hidden_dim&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;n_classes&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;=&lt;/span&gt;&lt;span class=&#34;mi&#34;&gt;2&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;        &lt;span class=&#34;nb&#34;&gt;super&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;()&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;fm&#34;&gt;__init__&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;()&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;        &lt;span class=&#34;bp&#34;&gt;self&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;conv1&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;dglnn&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;GATConv&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;in_dim&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;hidden_dim&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;num_heads&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;=&lt;/span&gt;&lt;span class=&#34;mi&#34;&gt;4&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;        &lt;span class=&#34;bp&#34;&gt;self&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;conv2&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;dglnn&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;GATConv&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;hidden_dim&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;*&lt;/span&gt; &lt;span class=&#34;mi&#34;&gt;4&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;hidden_dim&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;num_heads&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;=&lt;/span&gt;&lt;span class=&#34;mi&#34;&gt;1&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;        &lt;span class=&#34;bp&#34;&gt;self&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;classify&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;nn&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;Linear&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;hidden_dim&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;n_classes&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;k&#34;&gt;def&lt;/span&gt; &lt;span class=&#34;nf&#34;&gt;forward&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;bp&#34;&gt;self&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;g&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;feat&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;        &lt;span class=&#34;c1&#34;&gt;# Message passing&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;        &lt;span class=&#34;n&#34;&gt;h&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;F&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;elu&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;bp&#34;&gt;self&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;conv1&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;g&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;feat&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;flatten&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;mi&#34;&gt;1&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;))&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;        &lt;span class=&#34;n&#34;&gt;h&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;bp&#34;&gt;self&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;conv2&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;g&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;h&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;squeeze&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;mi&#34;&gt;1&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;        &lt;span class=&#34;c1&#34;&gt;# Store updated features&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;        &lt;span class=&#34;k&#34;&gt;with&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;g&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;local_scope&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;():&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;            &lt;span class=&#34;n&#34;&gt;g&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;ndata&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;[&lt;/span&gt;&lt;span class=&#34;s1&#34;&gt;&amp;#39;h&amp;#39;&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;]&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;h&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;            &lt;span class=&#34;c1&#34;&gt;# Readout: mean over all atoms → graph embedding&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;            &lt;span class=&#34;n&#34;&gt;graph_feat&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;dgl&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;mean_nodes&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;g&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;s1&#34;&gt;&amp;#39;h&amp;#39;&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;        &lt;span class=&#34;k&#34;&gt;return&lt;/span&gt; &lt;span class=&#34;bp&#34;&gt;self&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;classify&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;graph_feat&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;&lt;span class=&#34;c1&#34;&gt;# Training loop sketch&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;&lt;span class=&#34;n&#34;&gt;model&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;GraphClassifier&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;in_dim&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;=&lt;/span&gt;&lt;span class=&#34;mi&#34;&gt;32&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;hidden_dim&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;=&lt;/span&gt;&lt;span class=&#34;mi&#34;&gt;64&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;&lt;span class=&#34;n&#34;&gt;optimizer&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;torch&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;optim&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;Adam&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;model&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;parameters&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(),&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;lr&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;=&lt;/span&gt;&lt;span class=&#34;mf&#34;&gt;1e-3&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;&lt;span class=&#34;n&#34;&gt;loss_fn&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;nn&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;CrossEntropyLoss&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;()&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;&lt;span class=&#34;k&#34;&gt;for&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;epoch&lt;/span&gt; &lt;span class=&#34;ow&#34;&gt;in&lt;/span&gt; &lt;span class=&#34;nb&#34;&gt;range&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;mi&#34;&gt;100&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;k&#34;&gt;for&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;batched_graphs&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;labels&lt;/span&gt; &lt;span class=&#34;ow&#34;&gt;in&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;dataloader&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;:&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;        &lt;span class=&#34;n&#34;&gt;feats&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;batched_graphs&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;ndata&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;[&lt;/span&gt;&lt;span class=&#34;s1&#34;&gt;&amp;#39;feat&amp;#39;&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;]&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;        &lt;span class=&#34;n&#34;&gt;logits&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;model&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;batched_graphs&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;feats&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;        &lt;span class=&#34;n&#34;&gt;loss&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;loss_fn&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;logits&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;labels&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;        &lt;span class=&#34;n&#34;&gt;optimizer&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;zero_grad&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;()&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;        &lt;span class=&#34;n&#34;&gt;loss&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;backward&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;()&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;        &lt;span class=&#34;n&#34;&gt;optimizer&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;step&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;()&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;&lt;hr&gt;
&lt;h2 id=&#34;building-a-dataset&#34;&gt;Building a Dataset&lt;/h2&gt;
&lt;p&gt;DGL has a dataset API that handles caching processed graphs to disk, important when building graphs from PDB/SDF files is slow.&lt;/p&gt;
&lt;div class=&#34;highlight&#34;&gt;&lt;pre tabindex=&#34;0&#34; class=&#34;chroma&#34;&gt;&lt;code class=&#34;language-python&#34; data-lang=&#34;python&#34;&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;&lt;span class=&#34;k&#34;&gt;class&lt;/span&gt; &lt;span class=&#34;nc&#34;&gt;MolDataset&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;dgl&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;data&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;DGLDataset&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;k&#34;&gt;def&lt;/span&gt; &lt;span class=&#34;fm&#34;&gt;__init__&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;bp&#34;&gt;self&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;raw_dir&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;        &lt;span class=&#34;nb&#34;&gt;super&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;()&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;fm&#34;&gt;__init__&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;name&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;=&lt;/span&gt;&lt;span class=&#34;s1&#34;&gt;&amp;#39;mol_dataset&amp;#39;&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;raw_dir&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;=&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;raw_dir&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;k&#34;&gt;def&lt;/span&gt; &lt;span class=&#34;nf&#34;&gt;process&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;bp&#34;&gt;self&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;        &lt;span class=&#34;bp&#34;&gt;self&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;graphs&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;bp&#34;&gt;self&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;labels&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;p&#34;&gt;[],&lt;/span&gt; &lt;span class=&#34;p&#34;&gt;[]&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;        &lt;span class=&#34;k&#34;&gt;for&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;path&lt;/span&gt; &lt;span class=&#34;ow&#34;&gt;in&lt;/span&gt; &lt;span class=&#34;bp&#34;&gt;self&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;_get_files&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;():&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;            &lt;span class=&#34;n&#34;&gt;g&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;label&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;bp&#34;&gt;self&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;_build_graph&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;path&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt;   &lt;span class=&#34;c1&#34;&gt;# your featurization here&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;            &lt;span class=&#34;bp&#34;&gt;self&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;graphs&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;append&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;g&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;            &lt;span class=&#34;bp&#34;&gt;self&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;labels&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;append&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;label&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;k&#34;&gt;def&lt;/span&gt; &lt;span class=&#34;fm&#34;&gt;__getitem__&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;bp&#34;&gt;self&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;i&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;        &lt;span class=&#34;k&#34;&gt;return&lt;/span&gt; &lt;span class=&#34;bp&#34;&gt;self&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;graphs&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;[&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;i&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;],&lt;/span&gt; &lt;span class=&#34;bp&#34;&gt;self&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;labels&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;[&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;i&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;]&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;k&#34;&gt;def&lt;/span&gt; &lt;span class=&#34;fm&#34;&gt;__len__&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;bp&#34;&gt;self&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;        &lt;span class=&#34;k&#34;&gt;return&lt;/span&gt; &lt;span class=&#34;nb&#34;&gt;len&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;bp&#34;&gt;self&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;graphs&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;k&#34;&gt;def&lt;/span&gt; &lt;span class=&#34;nf&#34;&gt;has_cache&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;bp&#34;&gt;self&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;        &lt;span class=&#34;k&#34;&gt;return&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;os&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;path&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;exists&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;bp&#34;&gt;self&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;save_path&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;k&#34;&gt;def&lt;/span&gt; &lt;span class=&#34;nf&#34;&gt;save&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;bp&#34;&gt;self&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;        &lt;span class=&#34;n&#34;&gt;dgl&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;save_graphs&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;bp&#34;&gt;self&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;save_path&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;bp&#34;&gt;self&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;graphs&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;                        &lt;span class=&#34;p&#34;&gt;{&lt;/span&gt;&lt;span class=&#34;s1&#34;&gt;&amp;#39;labels&amp;#39;&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;:&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;torch&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;tensor&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;bp&#34;&gt;self&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;labels&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)})&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;k&#34;&gt;def&lt;/span&gt; &lt;span class=&#34;nf&#34;&gt;load&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;bp&#34;&gt;self&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;        &lt;span class=&#34;bp&#34;&gt;self&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;graphs&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;label_dict&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;dgl&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;load_graphs&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;bp&#34;&gt;self&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;save_path&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;        &lt;span class=&#34;bp&#34;&gt;self&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;labels&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;label_dict&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;[&lt;/span&gt;&lt;span class=&#34;s1&#34;&gt;&amp;#39;labels&amp;#39;&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;]&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;tolist&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;()&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;&lt;p&gt;Pair this with a DataLoader using &lt;code&gt;dgl.dataloading.GraphDataLoader&lt;/code&gt; (or PyTorch&amp;rsquo;s standard DataLoader with a custom &lt;code&gt;collate_fn&lt;/code&gt; that calls &lt;code&gt;dgl.batch()&lt;/code&gt;).&lt;/p&gt;
&lt;hr&gt;
&lt;h2 id=&#34;multiple-instances-per-compound-the-bag-approach&#34;&gt;Multiple Instances per Compound: the Bag Approach&lt;/h2&gt;
&lt;p&gt;A common scenario in ensemble docking: you dock one compound against &lt;em&gt;N&lt;/em&gt; different receptor conformations and get &lt;em&gt;N&lt;/em&gt; poses, each a separate graph. You want to make one prediction per compound, not one per pose.&lt;/p&gt;
&lt;p&gt;This is a natural fit for &lt;strong&gt;Multiple Instance Learning (MIL)&lt;/strong&gt;. The trick with DGL:&lt;/p&gt;
&lt;div class=&#34;highlight&#34;&gt;&lt;pre tabindex=&#34;0&#34; class=&#34;chroma&#34;&gt;&lt;code class=&#34;language-python&#34; data-lang=&#34;python&#34;&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;&lt;span class=&#34;k&#34;&gt;def&lt;/span&gt; &lt;span class=&#34;nf&#34;&gt;predict_bag&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;model&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;bag_of_graphs&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;c1&#34;&gt;# Stack all N pose-graphs into one batched graph&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;n&#34;&gt;bg&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;dgl&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;batch&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;bag_of_graphs&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;n&#34;&gt;feats&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;bg&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;ndata&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;[&lt;/span&gt;&lt;span class=&#34;s1&#34;&gt;&amp;#39;feat&amp;#39;&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;]&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;c1&#34;&gt;# Run GNN — poses stay disconnected (no cross-pose edges)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;n&#34;&gt;h&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;model&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;encode&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;bg&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;feats&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt;          &lt;span class=&#34;c1&#34;&gt;# node embeddings&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;k&#34;&gt;with&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;bg&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;local_scope&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;():&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;        &lt;span class=&#34;n&#34;&gt;bg&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;ndata&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;[&lt;/span&gt;&lt;span class=&#34;s1&#34;&gt;&amp;#39;h&amp;#39;&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;]&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;h&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;        &lt;span class=&#34;n&#34;&gt;pose_embeds&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;dgl&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;mean_nodes&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;bg&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;s1&#34;&gt;&amp;#39;h&amp;#39;&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt;   &lt;span class=&#34;c1&#34;&gt;# (N, h_dim) — one per pose&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;c1&#34;&gt;# MIL aggregation over the bag&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;n&#34;&gt;bag_embed&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;pose_embeds&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;mean&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;mi&#34;&gt;0&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;keepdim&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;=&lt;/span&gt;&lt;span class=&#34;kc&#34;&gt;True&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt;   &lt;span class=&#34;c1&#34;&gt;# mean pooling&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;k&#34;&gt;return&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;model&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;classify&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;bag_embed&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;&lt;p&gt;The key insight: &lt;code&gt;dgl.batch()&lt;/code&gt; keeps the poses structurally independent (no edges cross between them), so the GNN respects the fact that each pose is a separate physical structure. The MIL aggregation then decides how to combine them for the final prediction.&lt;/p&gt;
&lt;p&gt;For a more expressive aggregation, use &lt;strong&gt;attention MIL&lt;/strong&gt;:&lt;/p&gt;
&lt;div class=&#34;highlight&#34;&gt;&lt;pre tabindex=&#34;0&#34; class=&#34;chroma&#34;&gt;&lt;code class=&#34;language-python&#34; data-lang=&#34;python&#34;&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;&lt;span class=&#34;k&#34;&gt;class&lt;/span&gt; &lt;span class=&#34;nc&#34;&gt;AttentionMIL&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;nn&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;Module&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;k&#34;&gt;def&lt;/span&gt; &lt;span class=&#34;fm&#34;&gt;__init__&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;bp&#34;&gt;self&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;h_dim&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;        &lt;span class=&#34;nb&#34;&gt;super&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;()&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;fm&#34;&gt;__init__&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;()&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;        &lt;span class=&#34;bp&#34;&gt;self&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;attention&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;nn&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;Sequential&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;            &lt;span class=&#34;n&#34;&gt;nn&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;Linear&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;h_dim&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;mi&#34;&gt;64&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;),&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;            &lt;span class=&#34;n&#34;&gt;nn&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;Tanh&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(),&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;            &lt;span class=&#34;n&#34;&gt;nn&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;Linear&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;mi&#34;&gt;64&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;mi&#34;&gt;1&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;        &lt;span class=&#34;p&#34;&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;k&#34;&gt;def&lt;/span&gt; &lt;span class=&#34;nf&#34;&gt;forward&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;bp&#34;&gt;self&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;pose_embeds&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;):&lt;/span&gt;           &lt;span class=&#34;c1&#34;&gt;# (N, h_dim)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;        &lt;span class=&#34;n&#34;&gt;scores&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;bp&#34;&gt;self&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;attention&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;pose_embeds&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt;  &lt;span class=&#34;c1&#34;&gt;# (N, 1)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;        &lt;span class=&#34;n&#34;&gt;weights&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;torch&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;softmax&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;scores&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;dim&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;=&lt;/span&gt;&lt;span class=&#34;mi&#34;&gt;0&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;        &lt;span class=&#34;k&#34;&gt;return&lt;/span&gt; &lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;weights&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;*&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;pose_embeds&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;sum&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;mi&#34;&gt;0&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt; &lt;span class=&#34;c1&#34;&gt;# (h_dim,)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;&lt;p&gt;Attention MIL has a nice bonus: &lt;code&gt;weights&lt;/code&gt; tells you which poses contributed most to the prediction, interpretable by design.&lt;/p&gt;
&lt;hr&gt;
&lt;h2 id=&#34;which-chapters-to-read-first&#34;&gt;Which Chapters to Read First&lt;/h2&gt;
&lt;p&gt;If you&amp;rsquo;re using this post as a reading guide alongside the 
:&lt;/p&gt;
&lt;table&gt;
  &lt;thead&gt;
      &lt;tr&gt;
          &lt;th&gt;Priority&lt;/th&gt;
          &lt;th&gt;Chapter&lt;/th&gt;
          &lt;th&gt;Why&lt;/th&gt;
      &lt;/tr&gt;
  &lt;/thead&gt;
  &lt;tbody&gt;
      &lt;tr&gt;
          &lt;td&gt;⭐⭐⭐&lt;/td&gt;
          &lt;td&gt;Ch 1.2–1.5: Graphs, features, heterograph&lt;/td&gt;
          &lt;td&gt;Core data structure&lt;/td&gt;
      &lt;/tr&gt;
      &lt;tr&gt;
          &lt;td&gt;⭐⭐⭐&lt;/td&gt;
          &lt;td&gt;Ch 2.1–2.2: Message passing API&lt;/td&gt;
          &lt;td&gt;The computation engine&lt;/td&gt;
      &lt;/tr&gt;
      &lt;tr&gt;
          &lt;td&gt;⭐⭐⭐&lt;/td&gt;
          &lt;td&gt;Ch 5.4: Graph classification&lt;/td&gt;
          &lt;td&gt;Your prediction task&lt;/td&gt;
      &lt;/tr&gt;
      &lt;tr&gt;
          &lt;td&gt;⭐⭐&lt;/td&gt;
          &lt;td&gt;Ch 3: GNN modules&lt;/td&gt;
          &lt;td&gt;Ready-to-use layers&lt;/td&gt;
      &lt;/tr&gt;
      &lt;tr&gt;
          &lt;td&gt;⭐⭐&lt;/td&gt;
          &lt;td&gt;Ch 4: Data pipeline&lt;/td&gt;
          &lt;td&gt;Dataset + caching&lt;/td&gt;
      &lt;/tr&gt;
      &lt;tr&gt;
          &lt;td&gt;⭐&lt;/td&gt;
          &lt;td&gt;Ch 6–8: Stochastic / distributed / mixed precision&lt;/td&gt;
          &lt;td&gt;Only if dataset is huge&lt;/td&gt;
      &lt;/tr&gt;
  &lt;/tbody&gt;
&lt;/table&gt;
&lt;hr&gt;
&lt;h2 id=&#34;summary&#34;&gt;Summary&lt;/h2&gt;
&lt;p&gt;DGL wraps the complexity of graph-structured computation behind a clean PyTorch API. For molecular applications, the workflow is:&lt;/p&gt;
&lt;ol&gt;
&lt;li&gt;Build a &lt;code&gt;DGLGraph&lt;/code&gt; for each complex: atoms as nodes, bonds/contacts as edges, features from your cheminformatics toolkit&lt;/li&gt;
&lt;li&gt;Choose a GNN layer: &lt;code&gt;GATConv&lt;/code&gt; is a strong default&lt;/li&gt;
&lt;li&gt;Run message passing: each atom aggregates information from its neighbors&lt;/li&gt;
&lt;li&gt;Apply readout (&lt;code&gt;dgl.mean_nodes&lt;/code&gt;) to get a fixed-size graph embedding&lt;/li&gt;
&lt;li&gt;Feed into a standard MLP classifier&lt;/li&gt;
&lt;/ol&gt;
&lt;p&gt;The library handles batching, GPU transfer, and efficient sparse operations under the hood. Once you get the featurization right, the training code looks almost identical to a standard PyTorch image classifier.&lt;/p&gt;
&lt;hr&gt;
</description>
    </item>
    
    <item>
      <title>Point Cloud Classification with Graph Neural Networks Using PyTorch Geometric</title>
      <link>https://yboulaamane.github.io/blog/point-cloud-classification-with-graph-neural-networks-using-pytorch-geometric/</link>
      <pubDate>Sun, 05 Apr 2026 00:00:00 +0000</pubDate>
      <guid>https://yboulaamane.github.io/blog/point-cloud-classification-with-graph-neural-networks-using-pytorch-geometric/</guid>
      <description>&lt;p&gt;When we think about machine learning on 3D data, we usually think about images or voxel grids — discretised, regular structures that fit neatly into convolution pipelines. But a lot of 3D data in the real world doesn&amp;rsquo;t look like that. It looks like a &lt;em&gt;point cloud&lt;/em&gt;: a raw, unordered set of coordinates in space, sometimes accompanied by extra measurements like surface normals or colour.&lt;/p&gt;
&lt;p&gt;Protein structures, LiDAR scans, molecular conformations, cryo-EM densities — all of these are fundamentally point clouds. 
 (PyG) makes it straightforward to apply Graph Neural Networks directly to this type of data, without voxelising it or projecting it onto a grid.&lt;/p&gt;
&lt;p&gt;This post walks through PyG&amp;rsquo;s point cloud tutorial from first principles. No prior graph ML experience assumed.&lt;/p&gt;
&lt;hr&gt;
&lt;h2 id=&#34;what-is-a-point-cloud&#34;&gt;What Is a Point Cloud?&lt;/h2&gt;
&lt;p&gt;A point cloud is simply a collection of points in 3D space:&lt;/p&gt;
&lt;div class=&#34;highlight&#34;&gt;&lt;pre tabindex=&#34;0&#34; class=&#34;chroma&#34;&gt;&lt;code class=&#34;language-fallback&#34; data-lang=&#34;fallback&#34;&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;P = { (x₁, y₁, z₁), (x₂, y₂, z₂), ..., (xₙ, yₙ, zₙ) }
&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;&lt;p&gt;There is no inherent connectivity — no bonds, no edges, no grid. The points can be in any order, and there can be any number of them. This is what makes point clouds tricky for standard deep learning: CNNs expect a fixed-size, ordered grid, but a point cloud is neither.&lt;/p&gt;
&lt;p&gt;Graph Neural Networks sidestep this by &lt;em&gt;constructing&lt;/em&gt; a graph from the points — connecting nearby points with edges — and then running message passing on that graph to learn local geometric structure.&lt;/p&gt;
&lt;hr&gt;
&lt;h2 id=&#34;the-pyg-data-object&#34;&gt;The PyG &lt;code&gt;Data&lt;/code&gt; Object&lt;/h2&gt;
&lt;p&gt;Everything in PyTorch Geometric revolves around the &lt;code&gt;Data&lt;/code&gt; object — a flexible container for graph-structured data. For point clouds specifically, the key attribute is &lt;code&gt;pos&lt;/code&gt;:&lt;/p&gt;
&lt;div class=&#34;highlight&#34;&gt;&lt;pre tabindex=&#34;0&#34; class=&#34;chroma&#34;&gt;&lt;code class=&#34;language-python&#34; data-lang=&#34;python&#34;&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;&lt;span class=&#34;kn&#34;&gt;from&lt;/span&gt; &lt;span class=&#34;nn&#34;&gt;torch_geometric.data&lt;/span&gt; &lt;span class=&#34;kn&#34;&gt;import&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;Data&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;&lt;span class=&#34;kn&#34;&gt;import&lt;/span&gt; &lt;span class=&#34;nn&#34;&gt;torch&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;&lt;span class=&#34;c1&#34;&gt;# A simple point cloud: 5 points in 3D&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;&lt;span class=&#34;n&#34;&gt;data&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;Data&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;pos&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;=&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;torch&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;randn&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;mi&#34;&gt;5&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;mi&#34;&gt;3&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;))&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;&lt;span class=&#34;nb&#34;&gt;print&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;data&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;&lt;span class=&#34;c1&#34;&gt;# Data(pos=[5, 3])&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;&lt;p&gt;&lt;code&gt;data.pos&lt;/code&gt; holds the 3D coordinates of each point. Other common attributes you&amp;rsquo;ll encounter:&lt;/p&gt;
&lt;table&gt;
  &lt;thead&gt;
      &lt;tr&gt;
          &lt;th&gt;Attribute&lt;/th&gt;
          &lt;th&gt;Shape&lt;/th&gt;
          &lt;th&gt;Meaning&lt;/th&gt;
      &lt;/tr&gt;
  &lt;/thead&gt;
  &lt;tbody&gt;
      &lt;tr&gt;
          &lt;td&gt;&lt;code&gt;data.pos&lt;/code&gt;&lt;/td&gt;
          &lt;td&gt;&lt;code&gt;[N, 3]&lt;/code&gt;&lt;/td&gt;
          &lt;td&gt;3D coordinates of N points&lt;/td&gt;
      &lt;/tr&gt;
      &lt;tr&gt;
          &lt;td&gt;&lt;code&gt;data.x&lt;/code&gt;&lt;/td&gt;
          &lt;td&gt;&lt;code&gt;[N, F]&lt;/code&gt;&lt;/td&gt;
          &lt;td&gt;Node feature vectors (optional)&lt;/td&gt;
      &lt;/tr&gt;
      &lt;tr&gt;
          &lt;td&gt;&lt;code&gt;data.edge_index&lt;/code&gt;&lt;/td&gt;
          &lt;td&gt;&lt;code&gt;[2, E]&lt;/code&gt;&lt;/td&gt;
          &lt;td&gt;Graph connectivity — pairs of connected point indices&lt;/td&gt;
      &lt;/tr&gt;
      &lt;tr&gt;
          &lt;td&gt;&lt;code&gt;data.y&lt;/code&gt;&lt;/td&gt;
          &lt;td&gt;&lt;code&gt;[1]&lt;/code&gt; or &lt;code&gt;[N]&lt;/code&gt;&lt;/td&gt;
          &lt;td&gt;Label(s) — graph-level or point-level&lt;/td&gt;
      &lt;/tr&gt;
      &lt;tr&gt;
          &lt;td&gt;&lt;code&gt;data.batch&lt;/code&gt;&lt;/td&gt;
          &lt;td&gt;&lt;code&gt;[N]&lt;/code&gt;&lt;/td&gt;
          &lt;td&gt;Batch assignment vector (added automatically during batching)&lt;/td&gt;
      &lt;/tr&gt;
  &lt;/tbody&gt;
&lt;/table&gt;
&lt;p&gt;The &lt;code&gt;Data&lt;/code&gt; object is intentionally minimal — you only include what you have.&lt;/p&gt;
&lt;hr&gt;
&lt;h2 id=&#34;the-dataset-geometricshapes&#34;&gt;The Dataset: GeometricShapes&lt;/h2&gt;
&lt;p&gt;PyG includes &lt;code&gt;GeometricShapes&lt;/code&gt;, a synthetic dataset of 3D meshes representing simple geometric shapes (cubes, spheres, pyramids, and more). It&amp;rsquo;s a convenient sandbox for learning.&lt;/p&gt;
&lt;div class=&#34;highlight&#34;&gt;&lt;pre tabindex=&#34;0&#34; class=&#34;chroma&#34;&gt;&lt;code class=&#34;language-python&#34; data-lang=&#34;python&#34;&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;&lt;span class=&#34;kn&#34;&gt;from&lt;/span&gt; &lt;span class=&#34;nn&#34;&gt;torch_geometric.datasets&lt;/span&gt; &lt;span class=&#34;kn&#34;&gt;import&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;GeometricShapes&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;&lt;span class=&#34;n&#34;&gt;dataset&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;GeometricShapes&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;root&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;=&lt;/span&gt;&lt;span class=&#34;s1&#34;&gt;&amp;#39;data/GeometricShapes&amp;#39;&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;&lt;span class=&#34;nb&#34;&gt;print&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;dataset&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;&lt;span class=&#34;c1&#34;&gt;# GeometricShapes(40)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;&lt;span class=&#34;n&#34;&gt;data&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;dataset&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;[&lt;/span&gt;&lt;span class=&#34;mi&#34;&gt;0&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;]&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;&lt;span class=&#34;nb&#34;&gt;print&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;data&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;&lt;span class=&#34;c1&#34;&gt;# Data(pos=[32, 3], face=[3, 30], y=[1])&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;&lt;p&gt;Each object is stored as a &lt;strong&gt;mesh&lt;/strong&gt; — vertices (&lt;code&gt;pos&lt;/code&gt;) and their triangular connectivity (&lt;code&gt;face&lt;/code&gt;). But we want to work with point clouds, not meshes. This is where transforms come in.&lt;/p&gt;
&lt;hr&gt;
&lt;h2 id=&#34;transforms-from-mesh-to-point-cloud-to-graph&#34;&gt;Transforms: From Mesh to Point Cloud to Graph&lt;/h2&gt;
&lt;p&gt;PyG has a rich transforms system (&lt;code&gt;torch_geometric.transforms&lt;/code&gt;). Transforms are functions that take a &lt;code&gt;Data&lt;/code&gt; object and return a modified one. You can compose multiple transforms together.&lt;/p&gt;
&lt;h3 id=&#34;step-1-sample-points-from-the-mesh-surface&#34;&gt;Step 1: Sample points from the mesh surface&lt;/h3&gt;
&lt;p&gt;&lt;code&gt;SamplePoints&lt;/code&gt; uniformly samples a fixed number of 3D points from the mesh&amp;rsquo;s surface, proportional to face area. It discards the mesh connectivity and replaces it with a point cloud.&lt;/p&gt;
&lt;div class=&#34;highlight&#34;&gt;&lt;pre tabindex=&#34;0&#34; class=&#34;chroma&#34;&gt;&lt;code class=&#34;language-python&#34; data-lang=&#34;python&#34;&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;&lt;span class=&#34;kn&#34;&gt;import&lt;/span&gt; &lt;span class=&#34;nn&#34;&gt;torch_geometric.transforms&lt;/span&gt; &lt;span class=&#34;k&#34;&gt;as&lt;/span&gt; &lt;span class=&#34;nn&#34;&gt;T&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;&lt;span class=&#34;n&#34;&gt;dataset&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;transform&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;T&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;SamplePoints&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;num&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;=&lt;/span&gt;&lt;span class=&#34;mi&#34;&gt;256&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;&lt;span class=&#34;n&#34;&gt;data&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;dataset&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;[&lt;/span&gt;&lt;span class=&#34;mi&#34;&gt;0&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;]&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;&lt;span class=&#34;nb&#34;&gt;print&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;data&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;&lt;span class=&#34;c1&#34;&gt;# Data(pos=[256, 3], y=[1])&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;&lt;p&gt;We now have 256 points sampled from the surface of the shape. The &lt;code&gt;face&lt;/code&gt; attribute is gone — we only have coordinates.&lt;/p&gt;
&lt;blockquote&gt;
&lt;p&gt;&lt;strong&gt;Note:&lt;/strong&gt; &lt;code&gt;SamplePoints&lt;/code&gt; is stochastic — you get a different set of points every time you access the same example. This acts as data augmentation during training for free.&lt;/p&gt;&lt;/blockquote&gt;
&lt;h3 id=&#34;step-2-build-a-graph-with-k-nearest-neighbours&#34;&gt;Step 2: Build a graph with k-nearest neighbours&lt;/h3&gt;
&lt;p&gt;A point cloud alone has no edges. To run a GNN, we need to connect the points. The standard approach is &lt;strong&gt;k-nearest neighbour (k-NN) graph construction&lt;/strong&gt;: for each point, connect it to its &lt;em&gt;k&lt;/em&gt; geometrically closest neighbours.&lt;/p&gt;
&lt;div class=&#34;highlight&#34;&gt;&lt;pre tabindex=&#34;0&#34; class=&#34;chroma&#34;&gt;&lt;code class=&#34;language-python&#34; data-lang=&#34;python&#34;&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;&lt;span class=&#34;n&#34;&gt;dataset&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;transform&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;T&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;Compose&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;([&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;n&#34;&gt;T&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;SamplePoints&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;num&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;=&lt;/span&gt;&lt;span class=&#34;mi&#34;&gt;256&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;),&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;n&#34;&gt;T&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;KNNGraph&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;k&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;=&lt;/span&gt;&lt;span class=&#34;mi&#34;&gt;6&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;),&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;&lt;span class=&#34;p&#34;&gt;])&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;&lt;span class=&#34;n&#34;&gt;data&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;dataset&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;[&lt;/span&gt;&lt;span class=&#34;mi&#34;&gt;0&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;]&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;&lt;span class=&#34;nb&#34;&gt;print&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;data&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;&lt;span class=&#34;c1&#34;&gt;# Data(pos=[256, 3], edge_index=[2, 1536], y=[1])&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;&lt;p&gt;We now have 1536 edges — exactly 6 per point. &lt;code&gt;edge_index&lt;/code&gt; is a &lt;code&gt;[2, E]&lt;/code&gt; tensor where &lt;code&gt;edge_index[0]&lt;/code&gt; contains source node indices and &lt;code&gt;edge_index[1]&lt;/code&gt; contains destination node indices.&lt;/p&gt;
&lt;p&gt;&lt;strong&gt;Why k-NN and not a fixed radius?&lt;/strong&gt; k-NN always gives exactly &lt;em&gt;k&lt;/em&gt; neighbours per point regardless of local density, which keeps the graph regular. A radius-based graph (&lt;code&gt;radius_graph&lt;/code&gt;) is an alternative when you care more about physical scale than computational regularity.&lt;/p&gt;
&lt;hr&gt;
&lt;h2 id=&#34;pointnet-the-architecture&#34;&gt;PointNet++: The Architecture&lt;/h2&gt;
&lt;p&gt;
 (Qi et al., NeurIPS 2017) is the foundational architecture for learning from point clouds. It works by iteratively:&lt;/p&gt;
&lt;ol&gt;
&lt;li&gt;&lt;strong&gt;Grouping&lt;/strong&gt; — constructing a local neighbourhood graph (e.g., via k-NN)&lt;/li&gt;
&lt;li&gt;&lt;strong&gt;Aggregating&lt;/strong&gt; — running a GNN layer that collects information from each point&amp;rsquo;s neighbours&lt;/li&gt;
&lt;li&gt;&lt;strong&gt;Downsampling&lt;/strong&gt; — reducing the number of points (like pooling in a CNN) to capture larger-scale structure&lt;/li&gt;
&lt;/ol&gt;
&lt;p&gt;For simplicity, we focus here on the grouping and aggregation steps, which are the core of the method.&lt;/p&gt;
&lt;h3 id=&#34;the-pointnet-layer-messagepassing-from-scratch&#34;&gt;The PointNet++ Layer: MessagePassing from Scratch&lt;/h3&gt;
&lt;p&gt;PyG provides a &lt;code&gt;MessagePassing&lt;/code&gt; base class that handles all the bookkeeping of graph message propagation — you just define &lt;em&gt;what message&lt;/em&gt; to send, and &lt;em&gt;how&lt;/em&gt; to aggregate them.&lt;/p&gt;
&lt;p&gt;The PointNet++ update rule for node &lt;em&gt;i&lt;/em&gt; at layer ℓ+1 is:&lt;/p&gt;
$$h_i^{(\ell+1)} = \max_{j \in \mathcal{N}(i)} \text{MLP}\!\left(h_j^{(\ell)},\ \mathbf{p}_j - \mathbf{p}_i\right)$$&lt;p&gt;In words: for each neighbour &lt;em&gt;j&lt;/em&gt; of point &lt;em&gt;i&lt;/em&gt;, concatenate &lt;em&gt;j&lt;/em&gt;&amp;rsquo;s features with the &lt;em&gt;relative position&lt;/em&gt; of &lt;em&gt;j&lt;/em&gt; with respect to &lt;em&gt;i&lt;/em&gt;, pass that through a small MLP, then take the maximum over all neighbours. The relative position &lt;code&gt;p_j - p_i&lt;/code&gt; is what gives the layer its geometric awareness — the network learns to respond to local shape, not just content.&lt;/p&gt;
&lt;p&gt;Here is a full implementation from scratch using &lt;code&gt;MessagePassing&lt;/code&gt;:&lt;/p&gt;
&lt;div class=&#34;highlight&#34;&gt;&lt;pre tabindex=&#34;0&#34; class=&#34;chroma&#34;&gt;&lt;code class=&#34;language-python&#34; data-lang=&#34;python&#34;&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;&lt;span class=&#34;kn&#34;&gt;from&lt;/span&gt; &lt;span class=&#34;nn&#34;&gt;torch&lt;/span&gt; &lt;span class=&#34;kn&#34;&gt;import&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;Tensor&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;&lt;span class=&#34;kn&#34;&gt;from&lt;/span&gt; &lt;span class=&#34;nn&#34;&gt;torch.nn&lt;/span&gt; &lt;span class=&#34;kn&#34;&gt;import&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;Sequential&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;Linear&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;ReLU&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;&lt;span class=&#34;kn&#34;&gt;from&lt;/span&gt; &lt;span class=&#34;nn&#34;&gt;torch_geometric.nn&lt;/span&gt; &lt;span class=&#34;kn&#34;&gt;import&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;MessagePassing&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;&lt;span class=&#34;k&#34;&gt;class&lt;/span&gt; &lt;span class=&#34;nc&#34;&gt;PointNetLayer&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;MessagePassing&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;k&#34;&gt;def&lt;/span&gt; &lt;span class=&#34;fm&#34;&gt;__init__&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;bp&#34;&gt;self&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;in_channels&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;:&lt;/span&gt; &lt;span class=&#34;nb&#34;&gt;int&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;out_channels&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;:&lt;/span&gt; &lt;span class=&#34;nb&#34;&gt;int&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;        &lt;span class=&#34;c1&#34;&gt;# &amp;#34;max&amp;#34; aggregation: take the maximum over all incoming messages&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;        &lt;span class=&#34;nb&#34;&gt;super&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;()&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;fm&#34;&gt;__init__&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;aggr&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;=&lt;/span&gt;&lt;span class=&#34;s1&#34;&gt;&amp;#39;max&amp;#39;&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;        &lt;span class=&#34;c1&#34;&gt;# MLP: maps (neighbour features + relative position) → message&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;        &lt;span class=&#34;c1&#34;&gt;# Input size = in_channels (neighbour features) + 3 (relative xyz)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;        &lt;span class=&#34;bp&#34;&gt;self&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;mlp&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;Sequential&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;            &lt;span class=&#34;n&#34;&gt;Linear&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;in_channels&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;+&lt;/span&gt; &lt;span class=&#34;mi&#34;&gt;3&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;out_channels&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;),&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;            &lt;span class=&#34;n&#34;&gt;ReLU&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(),&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;            &lt;span class=&#34;n&#34;&gt;Linear&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;out_channels&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;out_channels&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;),&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;        &lt;span class=&#34;p&#34;&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;k&#34;&gt;def&lt;/span&gt; &lt;span class=&#34;nf&#34;&gt;forward&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;bp&#34;&gt;self&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;h&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;:&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;Tensor&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;pos&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;:&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;Tensor&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;edge_index&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;:&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;Tensor&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;-&amp;gt;&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;Tensor&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;:&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;        &lt;span class=&#34;c1&#34;&gt;# Trigger message passing:&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;        &lt;span class=&#34;c1&#34;&gt;# pos=(pos, pos) passes source and destination positions to message()&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;        &lt;span class=&#34;k&#34;&gt;return&lt;/span&gt; &lt;span class=&#34;bp&#34;&gt;self&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;propagate&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;edge_index&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;h&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;=&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;h&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;pos&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;=&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;pos&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;k&#34;&gt;def&lt;/span&gt; &lt;span class=&#34;nf&#34;&gt;message&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;bp&#34;&gt;self&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;h_j&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;:&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;Tensor&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;pos_j&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;:&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;Tensor&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;pos_i&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;:&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;Tensor&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;-&amp;gt;&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;Tensor&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;:&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;        &lt;span class=&#34;c1&#34;&gt;# h_j: features of neighbouring points (the &amp;#34;j&amp;#34; suffix = source nodes)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;        &lt;span class=&#34;c1&#34;&gt;# pos_j - pos_i: relative position of neighbour w.r.t. current point&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;        &lt;span class=&#34;c1&#34;&gt;# The &amp;#34;_j&amp;#34; and &amp;#34;_i&amp;#34; suffixes are a PyG convention for source/destination&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;        &lt;span class=&#34;n&#34;&gt;relative_pos&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;pos_j&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;-&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;pos_i&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;        &lt;span class=&#34;n&#34;&gt;msg_input&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;torch&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;cat&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;([&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;h_j&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;relative_pos&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;],&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;dim&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;=-&lt;/span&gt;&lt;span class=&#34;mi&#34;&gt;1&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;        &lt;span class=&#34;k&#34;&gt;return&lt;/span&gt; &lt;span class=&#34;bp&#34;&gt;self&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;mlp&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;msg_input&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;&lt;p&gt;A few things worth understanding here:&lt;/p&gt;
&lt;p&gt;&lt;strong&gt;The &lt;code&gt;_j&lt;/code&gt; and &lt;code&gt;_i&lt;/code&gt; suffix convention.&lt;/strong&gt; In PyG&amp;rsquo;s &lt;code&gt;MessagePassing&lt;/code&gt;, when you name a tensor in &lt;code&gt;message()&lt;/code&gt; with a &lt;code&gt;_j&lt;/code&gt; suffix, PyG automatically selects the features of the source nodes (neighbours). The &lt;code&gt;_i&lt;/code&gt; suffix selects features of the destination node (the central point). This is just syntactic sugar — PyG does the indexing for you.&lt;/p&gt;
&lt;p&gt;&lt;strong&gt;Why relative position, not absolute?&lt;/strong&gt; Using &lt;code&gt;pos_j - pos_i&lt;/code&gt; instead of absolute coordinates makes the layer &lt;em&gt;translation-invariant&lt;/em&gt; — it responds to the shape of the local neighbourhood, not where it is in global space. This is important for generalisation.&lt;/p&gt;
&lt;p&gt;&lt;strong&gt;Why max aggregation?&lt;/strong&gt; Max pooling is permutation-invariant (order of neighbours doesn&amp;rsquo;t matter) and selects the most &amp;ldquo;active&amp;rdquo; incoming signal. It empirically works better than mean for point clouds, because the most distinctive geometric feature often comes from a single key neighbour, not the average.&lt;/p&gt;
&lt;hr&gt;
&lt;h2 id=&#34;building-the-full-network&#34;&gt;Building the Full Network&lt;/h2&gt;
&lt;p&gt;Now we assemble the &lt;code&gt;PointNetLayer&lt;/code&gt; into a full classifier. PyG handles mini-batching by concatenating all points from all examples in the batch into a single tensor, using a &lt;code&gt;batch&lt;/code&gt; vector to track which points belong to which example.&lt;/p&gt;
&lt;div class=&#34;highlight&#34;&gt;&lt;pre tabindex=&#34;0&#34; class=&#34;chroma&#34;&gt;&lt;code class=&#34;language-python&#34; data-lang=&#34;python&#34;&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;&lt;span class=&#34;kn&#34;&gt;import&lt;/span&gt; &lt;span class=&#34;nn&#34;&gt;torch&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;&lt;span class=&#34;kn&#34;&gt;import&lt;/span&gt; &lt;span class=&#34;nn&#34;&gt;torch.nn.functional&lt;/span&gt; &lt;span class=&#34;k&#34;&gt;as&lt;/span&gt; &lt;span class=&#34;nn&#34;&gt;F&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;&lt;span class=&#34;kn&#34;&gt;from&lt;/span&gt; &lt;span class=&#34;nn&#34;&gt;torch_geometric.nn&lt;/span&gt; &lt;span class=&#34;kn&#34;&gt;import&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;global_max_pool&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;&lt;span class=&#34;k&#34;&gt;class&lt;/span&gt; &lt;span class=&#34;nc&#34;&gt;PointNet&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;torch&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;nn&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;Module&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;k&#34;&gt;def&lt;/span&gt; &lt;span class=&#34;fm&#34;&gt;__init__&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;bp&#34;&gt;self&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;num_classes&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;:&lt;/span&gt; &lt;span class=&#34;nb&#34;&gt;int&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;        &lt;span class=&#34;nb&#34;&gt;super&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;()&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;fm&#34;&gt;__init__&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;()&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;        &lt;span class=&#34;c1&#34;&gt;# Two PointNet++ layers: 3D coords → 32 features → 64 features&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;        &lt;span class=&#34;bp&#34;&gt;self&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;conv1&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;PointNetLayer&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;in_channels&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;=&lt;/span&gt;&lt;span class=&#34;mi&#34;&gt;3&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;out_channels&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;=&lt;/span&gt;&lt;span class=&#34;mi&#34;&gt;32&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;        &lt;span class=&#34;bp&#34;&gt;self&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;conv2&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;PointNetLayer&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;in_channels&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;=&lt;/span&gt;&lt;span class=&#34;mi&#34;&gt;32&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;out_channels&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;=&lt;/span&gt;&lt;span class=&#34;mi&#34;&gt;64&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;        &lt;span class=&#34;c1&#34;&gt;# Final classifier&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;        &lt;span class=&#34;bp&#34;&gt;self&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;classifier&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;torch&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;nn&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;Linear&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;mi&#34;&gt;64&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;num_classes&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;k&#34;&gt;def&lt;/span&gt; &lt;span class=&#34;nf&#34;&gt;forward&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;bp&#34;&gt;self&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;pos&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;:&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;Tensor&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;edge_index&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;:&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;Tensor&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;batch&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;:&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;Tensor&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;-&amp;gt;&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;Tensor&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;:&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;        &lt;span class=&#34;c1&#34;&gt;# Initial node features = raw 3D coordinates&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;        &lt;span class=&#34;n&#34;&gt;h&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;pos&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;        &lt;span class=&#34;c1&#34;&gt;# Layer 1: each point aggregates from its k-NN neighbours&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;        &lt;span class=&#34;n&#34;&gt;h&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;bp&#34;&gt;self&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;conv1&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;h&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;pos&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;edge_index&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;        &lt;span class=&#34;n&#34;&gt;h&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;F&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;relu&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;h&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;        &lt;span class=&#34;c1&#34;&gt;# Layer 2: aggregate again with updated features&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;        &lt;span class=&#34;n&#34;&gt;h&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;bp&#34;&gt;self&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;conv2&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;h&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;pos&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;edge_index&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;        &lt;span class=&#34;n&#34;&gt;h&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;F&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;relu&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;h&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;        &lt;span class=&#34;c1&#34;&gt;# Global readout: reduce all N points → one vector per graph&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;        &lt;span class=&#34;c1&#34;&gt;# global_max_pool uses the `batch` vector to know which points belong together&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;        &lt;span class=&#34;n&#34;&gt;graph_embedding&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;global_max_pool&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;h&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;batch&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt;  &lt;span class=&#34;c1&#34;&gt;# shape: [B, 64]&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;        &lt;span class=&#34;c1&#34;&gt;# Classify&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;        &lt;span class=&#34;k&#34;&gt;return&lt;/span&gt; &lt;span class=&#34;bp&#34;&gt;self&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;classifier&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;graph_embedding&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;&lt;h3 id=&#34;understanding-global_max_pool-and-the-batch-vector&#34;&gt;Understanding &lt;code&gt;global_max_pool&lt;/code&gt; and the &lt;code&gt;batch&lt;/code&gt; vector&lt;/h3&gt;
&lt;p&gt;This is one of PyG&amp;rsquo;s most important design choices, and it&amp;rsquo;s worth understanding clearly.&lt;/p&gt;
&lt;p&gt;When you load a mini-batch of 10 point clouds, each with 256 points, PyG doesn&amp;rsquo;t create a list of 10 separate tensors. Instead, it concatenates everything into a single tensor of shape &lt;code&gt;[2560, 3]&lt;/code&gt; — all 2560 points together — and creates a &lt;code&gt;batch&lt;/code&gt; vector:&lt;/p&gt;
&lt;div class=&#34;highlight&#34;&gt;&lt;pre tabindex=&#34;0&#34; class=&#34;chroma&#34;&gt;&lt;code class=&#34;language-fallback&#34; data-lang=&#34;fallback&#34;&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;batch = [0, 0, ..., 0, 1, 1, ..., 1, ..., 9, 9, ..., 9]
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;          ↑ 256 zeros   ↑ 256 ones         ↑ 256 nines
&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;&lt;p&gt;This &lt;code&gt;batch&lt;/code&gt; vector tells &lt;code&gt;global_max_pool&lt;/code&gt; which points belong to which example, so it can reduce each group of 256 points into a single 64-dimensional embedding. The output shape is &lt;code&gt;[10, 64]&lt;/code&gt; — one vector per shape in the batch.&lt;/p&gt;
&lt;p&gt;This approach is computationally efficient because all operations (message passing, MLP forward passes) run on contiguous tensors, not Python lists of variable-size arrays.&lt;/p&gt;
&lt;hr&gt;
&lt;h2 id=&#34;using-the-built-in-pointnetconv&#34;&gt;Using the Built-in &lt;code&gt;PointNetConv&lt;/code&gt;&lt;/h2&gt;
&lt;p&gt;If you don&amp;rsquo;t want to implement the layer from scratch, PyG ships a ready-to-use version:&lt;/p&gt;
&lt;div class=&#34;highlight&#34;&gt;&lt;pre tabindex=&#34;0&#34; class=&#34;chroma&#34;&gt;&lt;code class=&#34;language-python&#34; data-lang=&#34;python&#34;&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;&lt;span class=&#34;kn&#34;&gt;from&lt;/span&gt; &lt;span class=&#34;nn&#34;&gt;torch_geometric.nn&lt;/span&gt; &lt;span class=&#34;kn&#34;&gt;import&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;PointNetConv&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;&lt;span class=&#34;kn&#34;&gt;from&lt;/span&gt; &lt;span class=&#34;nn&#34;&gt;torch.nn&lt;/span&gt; &lt;span class=&#34;kn&#34;&gt;import&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;Sequential&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;Linear&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;ReLU&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;BatchNorm1d&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;&lt;span class=&#34;n&#34;&gt;local_nn&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;Sequential&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;n&#34;&gt;Linear&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;mi&#34;&gt;3&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;+&lt;/span&gt; &lt;span class=&#34;mi&#34;&gt;3&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;mi&#34;&gt;64&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;),&lt;/span&gt;    &lt;span class=&#34;c1&#34;&gt;# 3 (neighbour features) + 3 (relative pos)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;n&#34;&gt;ReLU&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(),&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;n&#34;&gt;Linear&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;mi&#34;&gt;64&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;mi&#34;&gt;64&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;),&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;n&#34;&gt;BatchNorm1d&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;mi&#34;&gt;64&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;),&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;&lt;span class=&#34;n&#34;&gt;conv&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;PointNetConv&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;local_nn&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;=&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;local_nn&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;add_self_loops&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;=&lt;/span&gt;&lt;span class=&#34;kc&#34;&gt;True&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;&lt;p&gt;&lt;code&gt;PointNetConv&lt;/code&gt; implements the same operation but with additional options like a &lt;code&gt;global_nn&lt;/code&gt; (applied after aggregation) and support for self-loops.&lt;/p&gt;
&lt;hr&gt;
&lt;h2 id=&#34;putting-it-all-together-dataset-loader-and-training&#34;&gt;Putting It All Together: Dataset, Loader, and Training&lt;/h2&gt;
&lt;div class=&#34;highlight&#34;&gt;&lt;pre tabindex=&#34;0&#34; class=&#34;chroma&#34;&gt;&lt;code class=&#34;language-python&#34; data-lang=&#34;python&#34;&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;&lt;span class=&#34;kn&#34;&gt;import&lt;/span&gt; &lt;span class=&#34;nn&#34;&gt;torch&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;&lt;span class=&#34;kn&#34;&gt;import&lt;/span&gt; &lt;span class=&#34;nn&#34;&gt;torch_geometric.transforms&lt;/span&gt; &lt;span class=&#34;k&#34;&gt;as&lt;/span&gt; &lt;span class=&#34;nn&#34;&gt;T&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;&lt;span class=&#34;kn&#34;&gt;from&lt;/span&gt; &lt;span class=&#34;nn&#34;&gt;torch_geometric.datasets&lt;/span&gt; &lt;span class=&#34;kn&#34;&gt;import&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;GeometricShapes&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;&lt;span class=&#34;kn&#34;&gt;from&lt;/span&gt; &lt;span class=&#34;nn&#34;&gt;torch_geometric.loader&lt;/span&gt; &lt;span class=&#34;kn&#34;&gt;import&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;DataLoader&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;&lt;span class=&#34;kn&#34;&gt;from&lt;/span&gt; &lt;span class=&#34;nn&#34;&gt;torch_geometric.nn&lt;/span&gt; &lt;span class=&#34;kn&#34;&gt;import&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;global_max_pool&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;&lt;span class=&#34;c1&#34;&gt;# --- Data ---&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;&lt;span class=&#34;n&#34;&gt;transform&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;T&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;Compose&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;([&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;T&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;SamplePoints&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;num&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;=&lt;/span&gt;&lt;span class=&#34;mi&#34;&gt;256&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;),&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;T&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;KNNGraph&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;k&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;=&lt;/span&gt;&lt;span class=&#34;mi&#34;&gt;6&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)])&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;&lt;span class=&#34;n&#34;&gt;train_dataset&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;GeometricShapes&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;root&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;=&lt;/span&gt;&lt;span class=&#34;s1&#34;&gt;&amp;#39;data/GeometricShapes&amp;#39;&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;train&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;=&lt;/span&gt;&lt;span class=&#34;kc&#34;&gt;True&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;&lt;span class=&#34;n&#34;&gt;train_dataset&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;transform&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;transform&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;&lt;span class=&#34;n&#34;&gt;test_dataset&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;GeometricShapes&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;root&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;=&lt;/span&gt;&lt;span class=&#34;s1&#34;&gt;&amp;#39;data/GeometricShapes&amp;#39;&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;train&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;=&lt;/span&gt;&lt;span class=&#34;kc&#34;&gt;False&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;&lt;span class=&#34;n&#34;&gt;test_dataset&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;transform&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;transform&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;&lt;span class=&#34;n&#34;&gt;train_loader&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;DataLoader&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;train_dataset&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;batch_size&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;=&lt;/span&gt;&lt;span class=&#34;mi&#34;&gt;10&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;shuffle&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;=&lt;/span&gt;&lt;span class=&#34;kc&#34;&gt;True&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;&lt;span class=&#34;n&#34;&gt;test_loader&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;DataLoader&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;test_dataset&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;batch_size&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;=&lt;/span&gt;&lt;span class=&#34;mi&#34;&gt;10&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;&lt;span class=&#34;c1&#34;&gt;# --- Model, optimizer, loss ---&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;&lt;span class=&#34;n&#34;&gt;model&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;PointNet&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;num_classes&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;=&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;dataset&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;num_classes&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;&lt;span class=&#34;n&#34;&gt;optimizer&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;torch&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;optim&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;Adam&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;model&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;parameters&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(),&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;lr&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;=&lt;/span&gt;&lt;span class=&#34;mf&#34;&gt;0.01&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;&lt;span class=&#34;n&#34;&gt;criterion&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;torch&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;nn&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;CrossEntropyLoss&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;()&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;&lt;span class=&#34;c1&#34;&gt;# --- Training loop ---&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;&lt;span class=&#34;k&#34;&gt;def&lt;/span&gt; &lt;span class=&#34;nf&#34;&gt;train&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;():&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;n&#34;&gt;model&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;train&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;()&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;n&#34;&gt;total_loss&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;mi&#34;&gt;0&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;k&#34;&gt;for&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;data&lt;/span&gt; &lt;span class=&#34;ow&#34;&gt;in&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;train_loader&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;:&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;        &lt;span class=&#34;n&#34;&gt;optimizer&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;zero_grad&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;()&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;        &lt;span class=&#34;n&#34;&gt;logits&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;model&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;data&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;pos&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;data&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;edge_index&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;data&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;batch&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;        &lt;span class=&#34;n&#34;&gt;loss&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;criterion&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;logits&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;data&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;y&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;        &lt;span class=&#34;n&#34;&gt;loss&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;backward&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;()&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;        &lt;span class=&#34;n&#34;&gt;optimizer&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;step&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;()&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;        &lt;span class=&#34;n&#34;&gt;total_loss&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;+=&lt;/span&gt; &lt;span class=&#34;nb&#34;&gt;float&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;loss&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;*&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;data&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;num_graphs&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;k&#34;&gt;return&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;total_loss&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;/&lt;/span&gt; &lt;span class=&#34;nb&#34;&gt;len&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;train_loader&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;dataset&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;&lt;span class=&#34;nd&#34;&gt;@torch.no_grad&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;()&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;&lt;span class=&#34;k&#34;&gt;def&lt;/span&gt; &lt;span class=&#34;nf&#34;&gt;test&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;loader&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;n&#34;&gt;model&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;eval&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;()&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;n&#34;&gt;total_correct&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;mi&#34;&gt;0&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;k&#34;&gt;for&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;data&lt;/span&gt; &lt;span class=&#34;ow&#34;&gt;in&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;loader&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;:&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;        &lt;span class=&#34;n&#34;&gt;logits&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;model&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;data&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;pos&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;data&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;edge_index&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;data&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;batch&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;        &lt;span class=&#34;n&#34;&gt;pred&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;logits&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;argmax&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;dim&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;=-&lt;/span&gt;&lt;span class=&#34;mi&#34;&gt;1&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;        &lt;span class=&#34;n&#34;&gt;total_correct&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;+=&lt;/span&gt; &lt;span class=&#34;nb&#34;&gt;int&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;((&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;pred&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;==&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;data&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;y&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;sum&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;())&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;k&#34;&gt;return&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;total_correct&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;/&lt;/span&gt; &lt;span class=&#34;nb&#34;&gt;len&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;loader&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;dataset&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;&lt;span class=&#34;k&#34;&gt;for&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;epoch&lt;/span&gt; &lt;span class=&#34;ow&#34;&gt;in&lt;/span&gt; &lt;span class=&#34;nb&#34;&gt;range&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;mi&#34;&gt;1&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;mi&#34;&gt;51&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;):&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;n&#34;&gt;loss&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;train&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;()&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;n&#34;&gt;train_acc&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;test&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;train_loader&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;n&#34;&gt;test_acc&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;test&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;test_loader&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;    &lt;span class=&#34;nb&#34;&gt;print&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;sa&#34;&gt;f&lt;/span&gt;&lt;span class=&#34;s1&#34;&gt;&amp;#39;Epoch &lt;/span&gt;&lt;span class=&#34;si&#34;&gt;{&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;epoch&lt;/span&gt;&lt;span class=&#34;si&#34;&gt;:&lt;/span&gt;&lt;span class=&#34;s1&#34;&gt;02d&lt;/span&gt;&lt;span class=&#34;si&#34;&gt;}&lt;/span&gt;&lt;span class=&#34;s1&#34;&gt; | Loss: &lt;/span&gt;&lt;span class=&#34;si&#34;&gt;{&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;loss&lt;/span&gt;&lt;span class=&#34;si&#34;&gt;:&lt;/span&gt;&lt;span class=&#34;s1&#34;&gt;.4f&lt;/span&gt;&lt;span class=&#34;si&#34;&gt;}&lt;/span&gt;&lt;span class=&#34;s1&#34;&gt; | Train: &lt;/span&gt;&lt;span class=&#34;si&#34;&gt;{&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;train_acc&lt;/span&gt;&lt;span class=&#34;si&#34;&gt;:&lt;/span&gt;&lt;span class=&#34;s1&#34;&gt;.2%&lt;/span&gt;&lt;span class=&#34;si&#34;&gt;}&lt;/span&gt;&lt;span class=&#34;s1&#34;&gt; | Test: &lt;/span&gt;&lt;span class=&#34;si&#34;&gt;{&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;test_acc&lt;/span&gt;&lt;span class=&#34;si&#34;&gt;:&lt;/span&gt;&lt;span class=&#34;s1&#34;&gt;.2%&lt;/span&gt;&lt;span class=&#34;si&#34;&gt;}&lt;/span&gt;&lt;span class=&#34;s1&#34;&gt;&amp;#39;&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;&lt;p&gt;The training loop is standard PyTorch — the only PyG-specific things are the DataLoader (which handles batching of graphs) and the &lt;code&gt;data.pos&lt;/code&gt;, &lt;code&gt;data.edge_index&lt;/code&gt;, &lt;code&gt;data.batch&lt;/code&gt; attributes you pass to the model.&lt;/p&gt;
&lt;hr&gt;
&lt;h2 id=&#34;going-further-farthest-point-sampling&#34;&gt;Going Further: Farthest Point Sampling&lt;/h2&gt;
&lt;p&gt;The tutorial also introduces &lt;strong&gt;Farthest Point Sampling (FPS)&lt;/strong&gt; — a downsampling strategy used in PointNet++ to progressively reduce the number of points across layers, similar to how a CNN reduces spatial resolution via pooling.&lt;/p&gt;
&lt;p&gt;FPS iteratively selects the point furthest from all already-selected points, ensuring good coverage of the shape&amp;rsquo;s surface:&lt;/p&gt;
&lt;div class=&#34;highlight&#34;&gt;&lt;pre tabindex=&#34;0&#34; class=&#34;chroma&#34;&gt;&lt;code class=&#34;language-python&#34; data-lang=&#34;python&#34;&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;&lt;span class=&#34;kn&#34;&gt;from&lt;/span&gt; &lt;span class=&#34;nn&#34;&gt;torch_cluster&lt;/span&gt; &lt;span class=&#34;kn&#34;&gt;import&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;fps&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;&lt;span class=&#34;c1&#34;&gt;# Sample 25% of the points, keeping those most spread out&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;&lt;span class=&#34;n&#34;&gt;sampled_indices&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;fps&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;(&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;data&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;pos&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;,&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;ratio&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;=&lt;/span&gt;&lt;span class=&#34;mf&#34;&gt;0.25&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;)&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;span class=&#34;line&#34;&gt;&lt;span class=&#34;cl&#34;&gt;&lt;span class=&#34;n&#34;&gt;downsampled_pos&lt;/span&gt; &lt;span class=&#34;o&#34;&gt;=&lt;/span&gt; &lt;span class=&#34;n&#34;&gt;data&lt;/span&gt;&lt;span class=&#34;o&#34;&gt;.&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;pos&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;[&lt;/span&gt;&lt;span class=&#34;n&#34;&gt;sampled_indices&lt;/span&gt;&lt;span class=&#34;p&#34;&gt;]&lt;/span&gt;
&lt;/span&gt;&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;&lt;p&gt;After downsampling, you&amp;rsquo;d rebuild a new k-NN graph on the reduced point set and run another PointNet++ layer. This hierarchical approach — aggregate, downsample, repeat — is what gives PointNet++ its ability to capture multi-scale geometric structure.&lt;/p&gt;
&lt;hr&gt;
&lt;h2 id=&#34;key-concepts-to-remember&#34;&gt;Key Concepts to Remember&lt;/h2&gt;
&lt;p&gt;&lt;strong&gt;From mesh to point cloud:&lt;/strong&gt; &lt;code&gt;SamplePoints&lt;/code&gt; does this in one line. The stochastic sampling acts as free data augmentation.&lt;/p&gt;
&lt;p&gt;&lt;strong&gt;From point cloud to graph:&lt;/strong&gt; &lt;code&gt;KNNGraph(k=6)&lt;/code&gt; connects each point to its 6 nearest neighbours. The graph is re-constructed dynamically every time you access the data.&lt;/p&gt;
&lt;p&gt;&lt;strong&gt;Message passing in PyG:&lt;/strong&gt; Subclass &lt;code&gt;MessagePassing&lt;/code&gt;, define &lt;code&gt;message()&lt;/code&gt;, and use &lt;code&gt;_j&lt;/code&gt;/&lt;code&gt;_i&lt;/code&gt; suffixes to access neighbour/central node attributes. PyG handles all the scatter-gather indexing under the hood.&lt;/p&gt;
&lt;p&gt;&lt;strong&gt;Relative position is the key:&lt;/strong&gt; The term &lt;code&gt;pos_j - pos_i&lt;/code&gt; in the message function is what gives the network its geometric awareness. Without it, the network would be blind to the shape of the neighbourhood.&lt;/p&gt;
&lt;p&gt;&lt;strong&gt;Batching via the &lt;code&gt;batch&lt;/code&gt; vector:&lt;/strong&gt; PyG concatenates all graphs in a batch into one large disconnected graph. &lt;code&gt;global_max_pool(h, batch)&lt;/code&gt; then recovers one embedding per graph using the batch assignment vector.&lt;/p&gt;
&lt;hr&gt;
&lt;h2 id=&#34;comparison-with-dgl&#34;&gt;Comparison with DGL&lt;/h2&gt;
&lt;p&gt;If you read the 
, you&amp;rsquo;ll notice both libraries share the message-passing paradigm. The key practical differences for point cloud work:&lt;/p&gt;
&lt;table&gt;
  &lt;thead&gt;
      &lt;tr&gt;
          &lt;th&gt;&lt;/th&gt;
          &lt;th&gt;PyG&lt;/th&gt;
          &lt;th&gt;DGL&lt;/th&gt;
      &lt;/tr&gt;
  &lt;/thead&gt;
  &lt;tbody&gt;
      &lt;tr&gt;
          &lt;td&gt;Graph construction&lt;/td&gt;
          &lt;td&gt;&lt;code&gt;KNNGraph&lt;/code&gt;, &lt;code&gt;radius_graph&lt;/code&gt; transforms&lt;/td&gt;
          &lt;td&gt;Manual or external&lt;/td&gt;
      &lt;/tr&gt;
      &lt;tr&gt;
          &lt;td&gt;Point cloud support&lt;/td&gt;
          &lt;td&gt;First-class (&lt;code&gt;SamplePoints&lt;/code&gt;, &lt;code&gt;PointNetConv&lt;/code&gt;)&lt;/td&gt;
          &lt;td&gt;General-purpose&lt;/td&gt;
      &lt;/tr&gt;
      &lt;tr&gt;
          &lt;td&gt;Message passing&lt;/td&gt;
          &lt;td&gt;&lt;code&gt;MessagePassing&lt;/code&gt; base class, &lt;code&gt;_j/_i&lt;/code&gt; suffix convention&lt;/td&gt;
          &lt;td&gt;&lt;code&gt;update_all&lt;/code&gt; + &lt;code&gt;fn.*&lt;/code&gt; built-ins&lt;/td&gt;
      &lt;/tr&gt;
      &lt;tr&gt;
          &lt;td&gt;Batching&lt;/td&gt;
          &lt;td&gt;&lt;code&gt;batch&lt;/code&gt; vector, &lt;code&gt;global_*_pool&lt;/code&gt;&lt;/td&gt;
          &lt;td&gt;&lt;code&gt;dgl.batch&lt;/code&gt;, &lt;code&gt;dgl.mean_nodes&lt;/code&gt;&lt;/td&gt;
      &lt;/tr&gt;
      &lt;tr&gt;
          &lt;td&gt;Built-in PointNet&lt;/td&gt;
          &lt;td&gt;&lt;code&gt;PointNetConv&lt;/code&gt;&lt;/td&gt;
          &lt;td&gt;Not built-in&lt;/td&gt;
      &lt;/tr&gt;
  &lt;/tbody&gt;
&lt;/table&gt;
&lt;p&gt;For point cloud tasks specifically, PyG&amp;rsquo;s transform pipeline and built-in PointNet support give it an edge. For heterogeneous molecular graphs, DGL&amp;rsquo;s heterograph API is more expressive.&lt;/p&gt;
&lt;hr&gt;
&lt;p&gt;&lt;em&gt;Next post: building atom-level node features from protein–ligand complexes with RDKit and MDAnalysis, and connecting them to a PyG or DGL training pipeline.&lt;/em&gt;&lt;/p&gt;
</description>
    </item>
    
  </channel>
</rss>
