JAX-Driven Heat Diffusion: Sparse and Vectorized PDE Simulation

JAX
Scipy
NumPy
Author

Emil Blaignan

Published

February 24, 2025

Introduction

Heat diffusion is a fundamental process in physics and engineering, described by the heat equation, a partial differential equation (PDE) modeling thermal energy transfer. In this project, we simulate 2D heat diffusion using multiple computational approaches:

  • Matrix-vector multiplication (explicit finite difference method).
  • Sparse matrix representation in JAX (efficient storage and computation).
  • NumPy vectorized operations (direct array manipulation with np.roll).
  • JAX-optimized just-in-time (JIT) compilation (high-speed execution).

The objective is to compare performance and efficiency across these methods while maintaining correct boundary conditions (to simulate heat diffusion properly) and stable numerical updates. Sparse representations and JIT-optimized computations significantly reduce runtime complexity compared to basic implementations. The results will be visualized using heatmaps at regular intervals to analyze diffusion patterns. Additionally, results will be accompanied by runtimes to compare code efficiency.

The Math of Heat Diffusion Simulations

Before we dive into the code that simulates heat diffusion, we must understand the math behind it.

The Continuous 2D Heat Equation

The heat diffusion equation describes how heat spreads over time in a two-dimensional space. The general form is:

\[\frac{\partial u}{\partial t} = \frac{\partial^2 u}{\partial x^2} + \frac{\partial^2 u}{\partial y^2}\]

where:

  • \(u(x, y, t)\) represents the temperature at position \((x, y)\) at time \(t\).
  • The right-hand side represents the Laplacian operator, which models how heat spreads from a point to its neighbors.

Discretization Using Finite Differences

We approximate the second derivatives using the finite difference method. Using a uniform grid with spacing \(\Delta x = \Delta y\) we approximate:

\[\frac{\partial^2 u}{\partial x^2} \approx \frac{u_{i+1,j} - 2u_{i,j} + u_{i-1,j}}{\Delta x^2}\]

\[\frac{\partial^2 u}{\partial y^2} \approx \frac{u_{i,j+1} - 2u_{i,j} + u_{i,j-1}}{\Delta y^2}\]

Substituting these into the heat equation:

\[\frac{u_{i,j}^{k+1} - u_{i,j}^{k}}{\Delta t} = \frac{u_{i+1,j}^{k} - 2u_{i,j}^{k} + u_{i-1,j}^{k}}{\Delta x^2} + \frac{u_{i,j+1}^{k} - 2u_{i,j}^{k} + u_{i,j-1}^{k}}{\Delta y^2} \]

Rearranging for \(u_{i,j}^{k+1}\):

\[u_{i,j}^{k+1} = u_{i,j}^{k} + \epsilon \left( u_{i+1,j}^{k} + u_{i-1,j}^{k} + u_{i,j+1}^{k} + u_{i,j-1}^{k} - 4u_{i,j}^{k} \right)\]

where \(\epsilon = \frac{ \Delta t}{\Delta x^2}\) is the stability parameter.

Matrix Representation of the Finite Difference Method

We represent the system in matrix form:

\[ u_{k+1} = u_k + \epsilon A u_k \]

where:

  • \(u_k\) is the flattened temperature grid at time \(k\).
  • \(A\) is the finite difference matrix that applies the heat diffusion stencil.

The finite difference matrix \(A\) for a grid of size \(N \times N\) is:

\[ A = \begin{bmatrix} T & I & 0 & 0 & \dots & 0 \\ I & T & I & 0 & \dots & 0 \\ 0 & I & T & I & \dots & 0 \\ \vdots & \vdots & \vdots & \ddots & \vdots & \vdots \\ 0 & 0 & 0 & \dots & T & I \\ 0 & 0 & 0 & \dots & I & T \end{bmatrix} \]

where: - \(T\) is an \(N \times N\) tridiagonal matrix representing left-right interactions:

\[ T = \begin{bmatrix} -4 & 1 & 0 & \dots & 0 \\ 1 & -4 & 1 & \dots & 0 \\ 0 & 1 & -4 & \dots & 0 \\ \vdots & \vdots & \vdots & \ddots & \vdots \\ 0 & 0 & 0 & \dots & -4 \end{bmatrix} \]

  • \(I\) is an \(N \times N\) identity matrix that accounts for top/bottom interactions.

This matrix enforces:

  • Self-weight (-4) at each grid point.
  • Neighboring weights (+1) for adjacent grid points.

Initializing Our Heat Diffusion Simulation

For our simulations we will use a grid size N = 101 and stability parameter \(\epsilon = 0.2\).

N = 101
epsilon = 0.2

We initialize the simulation with a single heat source at the center.

We’ll use the following initial condition:

import numpy as np
from matplotlib import pyplot as plt
# construct initial condition: 1 unit of heat at midpoint. 
u0 = np.zeros((N, N))
u0[int(N/2), int(N/2)] = 1.0
plt.imshow(u0)

Heat Diffusion With Matrix-Vector Multiplication

In order to begin our 2D heat simulations we’ll create a file heat_equation.py with the following packages:

import numpy as np
import jax.numpy as jnp
import jax
from scipy.sparse import diags
from jax.experimental.sparse import BCOO

Using the math above we use matrix-vector multiplication to simulate heat diffusion. The grid is flattened into a 1D vector to apply the finite difference matrix efficiently:

\[ u_{k+1} = u_k + \epsilon A u_k \]

where:

  • \(A\) is the finite difference matrix of size \(N^2 \times N^2\).
  • \(U^k\) is the flattened temperature grid at time \(k\).
  • \(\epsilon\) is the stability constant.

This formula is represented by the following function:

from heat_equation import advance_time_matvecmul
import inspect
print(inspect.getsource(advance_time_matvecmul))
def advance_time_matvecmul(A, u, epsilon):
    """Advances the simulation by one timestep, via matrix-vector multiplication
    Args:
        A: The 2d finite difference matrix, N^2 x N^2. 
        u: N x N grid state at timestep k.
        epsilon: stability constant.

    Returns:
        N x N Grid state at timestep k+1.
    """
    N = u.shape[0]
    u = u + epsilon * (A @ u.flatten()).reshape((N, N))
    return u

In order for the function above to run we also need to write a function get_A() that constructs the matrix A:

from heat_equation import get_A
print(inspect.getsource(get_A))
def get_A(N):
    """Constructs the finite difference matrix A for the heat equation.
    Args:
        N: Grid size

    Returns:
        A: Finite difference matrix (N^2 x N^2).
    """
    n = N * N
    diagonals = [
        -4 * np.ones(n),  # Main diagonal
        np.ones(n - 1),    # Right neighbor
        np.ones(n - 1),    # Left neighbor
        np.ones(n - N),    # Upper neighbor
        np.ones(n - N)     # Lower neighbor
    ]
    # Apply boundary conditions (preventing wrap-around)
    diagonals[1][(N-1)::N] = 0
    diagonals[2][(N-1)::N] = 0  

    # Construct the finite difference matrix
    A = (
        np.diag(diagonals[0]) +
        np.diag(diagonals[1], 1) +
        np.diag(diagonals[2], -1) +
        np.diag(diagonals[3], N) +
        np.diag(diagonals[4], -N)
    )
    return A

To test our functions and create a heat simulation, we’ll write a function that allows us to specify our type of function that updates the heat distribution from heat_equations.py and the version of matrix A we’ll use in each test.

The following function wraps our testing and visualization for all test cases:

import matplotlib.pyplot as plt
import time

def heat_diffusion_test(func, 
                        u_init, 
                        epsilon, 
                        iterations=2700, 
                        snapshot_interval=300, 
                        matrix_A=None):
    """
    Runs and visualizes the heat diffusion simulation.

    Args:
        - func (function): The function that updates the heat distribution.
        - u_init (np.ndarray or jnp.ndarray): Initial heat distribution (N x N grid).
        - epsilon: Stability constant.
        - iterations: Total number of iterations to run.
        - snapshot_interval: Interval at which to capture snapshots.
        - matrix_A: matrix A.
    
    Returns:
        - Heat simulation visualizations
        - Total execution time of the simulation.
    """
    # Copy the initial condition
    u = np.copy(u_init) if isinstance(u_init, np.ndarray) else jnp.array(u_init)

    # Store snapshots for visualization
    snapshots = []

    # Precompile JAX function if applicable
    if "jax" in str(type(func)):
        u = func(u, epsilon)  # JAX functions require precompilation
        u.block_until_ready()

    # Run the simulation and time it
    start_time = time.time()
    for i in range(1, iterations + 1):
        if matrix_A is not None:  # For matrix-vector multiplication
            u = func(A, u, epsilon)
        else:
            u = func(u, epsilon)

        if i % snapshot_interval == 0:
            snapshots.append(np.array(u))  # Convert JAX arrays to NumPy if needed

    # Ensure JAX computations complete
    if isinstance(u, jnp.ndarray):
        u.block_until_ready()
    
    end_time = time.time()
    elapsed_time = end_time - start_time

    # Plot results in a 3x3 grid
    fig, axes = plt.subplots(3, 3, figsize=(12, 12))
    for idx, snapshot in enumerate(snapshots):
        ax = axes[idx // 3, idx % 3]
        ax.imshow(snapshot, cmap="viridis", interpolation="nearest")
        ax.set_title(f"Iteration {(idx + 1) * snapshot_interval}")

    plt.tight_layout()
    plt.show()

    return elapsed_time

We can now test our first heat simulation using get_A() and heat_diffusion_test() with advance_time_matvecmul() passed as the function in the test arguments:

import heat_equation

A = get_A(N)
time = heat_diffusion_test(advance_time_matvecmul, u0, epsilon, matrix_A=A)
print(f"Total execution time: {time:.2f} seconds")

Total execution time: 35.84 seconds

This method, while mathematically correct, is computationally expensive (as seen by the runtime of approximately 36 seconds). In the next sections, we optimize it using sparse matrices and vectorized operations.

Optimizing Computation With Sparse Matrices

The finite difference matrix \(A\) is mostly zeros, leading to wasted computations. Using JAX sparse matrices, we store only nonzero values, reducing memory usage and improving efficiency.

We’ll write a function get_sparse_A() that computes the sparse matrix using the batched coordinate BCOO format to only use \(O(N^2)\) space for the matrix, and only take \(O(N^2)\) time for each update:

from heat_equation import get_sparse_A
print(inspect.getsource(get_sparse_A))
def get_sparse_A(N):
    """Constructs the sparse matrix A using JAX sparse format."""
    A = get_A(N) # Get the dense matrix A
    A_sp_matrix = BCOO.fromdense(jnp.array(A)) # Convert dense matrix to JAX sparse format (BCOO)
    return A_sp_matrix

Here, A_sp_matrix is the sparse matrix of A which should make the compute more efficient. We can test this by rerunning the heat simulation using the get_sparse_A() matrix and the same heat_diffusion_test() function, comparing the runtimes:

import heat_equation

A = get_sparse_A(N)
time = heat_diffusion_test(advance_time_matvecmul, u0, epsilon, matrix_A=A)
print(f"Total execution time: {time:.2f} seconds")

Total execution time: 1.86 seconds

Runtime for the heat simulation using the sparse matrix of A is more than 10x faster (19.27x) at approximately 2 seconds. Still, we can make the heat simulation more efficient by replacing advance_time_matvecmul(), which utilizes matrix-vector multiplications with vectorized array operations like np.roll().

Direct NumPy Vectorization

The matrix-vector multiplication approach is useful, particularly in other PDE problems, such as Poisson equations, where the matrix equation must be solved. However, for the heat equation, it is not necessary in terms of computation. Instead, we can use NumPy’s np.roll() to compute updates directly.

We’ll do this by writing the function advance_time_numpy() in heat_equation.py:

from heat_equation import advance_time_numpy
print(inspect.getsource(advance_time_numpy))
def advance_time_numpy(u, epsilon):
    """Advances the simulation using NumPy's np.roll for boundary handling."""
    u_padded = np.pad(u, pad_width=1, mode='constant', constant_values=0)  # Pad u with zeros
    u_next = u + epsilon * (  # Update u using finite differences
        np.roll(u_padded, 1, axis=0)[1:-1, 1:-1] +  # Shift up
        np.roll(u_padded, -1, axis=0)[1:-1, 1:-1] +  # Shift down
        np.roll(u_padded, 1, axis=1)[1:-1, 1:-1] +  # Shift left
        np.roll(u_padded, -1, axis=1)[1:-1, 1:-1] -  # Shift right
        4 * u  # Subtract central value
    )
    return u_next  # Return updated interior values

The function above works by padding the grid with zeros and computes the transformation using np.roll(). We can understand the function with the visual below:

Next, we can rerun the simulation using the advance_time_numpy() function as an argument in the visualize_heat_diffusion() function and compare runtimes:

import heat_equation

time = heat_diffusion_test(advance_time_numpy, u0, epsilon)
print(f"Total execution time: {time:.2f} seconds")

Total execution time: 0.21 seconds

Runtime for the heat simulation using the direct NumPy vectorization is more than 100x faster (170.67x) than the base matrix-vector simulation we did at 0.21 seconds. While already fast, we can achieve a quicker runtime by applying the jit function to essentially the same function advance_time_numpy(), using the JAX package.

JAX JIT Compilation for Maximum Speed

Using JAX’s @jax.jit, we can compile the function into optimized machine code, further reducing execution time.

We’ll do this by writing the function advance_time_jax() in heat_equation.py with @jax.jit above the function:

from heat_equation import advance_time_jax
print(inspect.getsource(advance_time_jax))
def advance_time_jax(u, epsilon):
    """Advances the simulation using JAX and just-in-time compilation."""
    u_padded = np.pad(u, pad_width=1, mode='constant', constant_values=0)  # Pad u with zeros
    u_next = u + epsilon * (  # Update u using finite differences
        jnp.roll(u_padded, 1, axis=0)[1:-1, 1:-1] +  # Shift up
        jnp.roll(u_padded, -1, axis=0)[1:-1, 1:-1] +  # Shift down
        jnp.roll(u_padded, 1, axis=1)[1:-1, 1:-1] +  # Shift left
        jnp.roll(u_padded, -1, axis=1)[1:-1, 1:-1] -  # Shift right
        4 * u  # Subtract central value
    )
    return u_next  # Return updated interior values

The only difference between the direct NumPy vectorization function is that we use @jax.jit above our function and replace the NumPy calls np with the JAX NumPy calls jnp.

We’ll run the heat diffusion test one last time using the JAX optimized function and compare execution speed:

import heat_equation

time = heat_diffusion_test(advance_time_jax, u0, epsilon)
print(f"Total execution time: {time:.2f} seconds")

Total execution time: 0.06 seconds

Runtime for the heat simulation using the JAX jit operation is more than 2x faster (3.5x) than the direct NumPy vectorization simulation we did at 0.06 seconds. This is nearly 600x faster than the base matrix-vector simulation.

Conclusion: Applications to Deep Learning

Although this blog focuses on simulating heat diffusion, the computational techniques used are directly applicable to deep learning and neural networks. The methods explored, matrix operations, sparse representations, vectorized computations, and JIT compilation, are fundamental to training and optimizing deep learning models.