Reconstructing Shader Inputs with SlangPY and PyTorch

Shaders define visual effects by transforming inputs into rendered outputs. But what if we could reverse that process—recover shader inputs directly from the final image? Or adjust parameters so that the output matches a given reference? With SlangPY and PyTorch, we can.

With SlangPY and PyTorch, shader inputs become trainable parameters. Gradient descent lets you optimize these inputs to closely match a target image—similar to training a neural network.

This means you can, for example:

  • Reconstruct shader parameters to match a photo, concept art, or manually-tuned reference
  • Tune visual effects automatically based on some objective
  • Reverse-engineer material parameters from screenshots or renders
  • Automatically optimize shader parameters for best balance between performance and quality

What is Gradient Descent?

Gradient descent is a method used to find the best values for certain variables by gradually improving them step by step. It works by looking at how a small change affects the result, and then adjusting in the direction that reduces the error.

Gradient Descent in 2D
Source: https://en.wikipedia.org/wiki/Gradient_descent

In our case, we use it to tweak inputs to a shader—like brightness or textures—so that the final image looks as close as possible to a given reference image.

SlangPY bridges Slang shaders and PyTorch, enabling automatic differentiation. With this setup, shader inputs become learnable variables, making it possible to reconstruct or fine-tune them using gradient-based optimization.

This article covers two examples: a basic brightness adjustment and the reconstruction of a full 512×512 grayscale mask.


1. Recovering a Brightness Offset

We begin by solving for a 3-component brightness vector to match a shader’s output with a reference image.

Shader (Slang):

[Differentiable]
float3 brightness(float3 amount, float3 pixel)
{
    pixel += amount;
    return pixel;
}

Python Training Loop:

amount = torch.tensor([-0.1, 0.1, 0.0], device='cuda', requires_grad=True)  # Learnable brightness offset on GPU
optimizer = torch.optim.Adam([amount], lr=1e-3)  # Adam optimizer with standard learning rate

for _ in range(1000):
    optimizer.zero_grad()  # Clear previous gradients
    output = module.brightness(amount, source_texture)  # Apply shader to input texture
    loss = F.l1_loss(output, target)  # Measure difference from reference image
    loss.backward()  # Compute gradients w.r.t. brightness
    optimizer.step()  # Update brightness to reduce loss

Result:

Final amount: [-0.33533716 -0.3212491  -0.3279705 ]

After a few hundred iterations, the vector converges and the shader output aligns with the reference image. This demonstrates end-to-end gradient flow through a Slang shader.


2. Reconstructing a Grayscale Mask

In the second example, we reconstruct an entire 512×512 grayscale mask. The mask is used as input to a shader that produces a colorful RGB output with distortions, making the task slightly more challenging and realistic.

Visual Example:

Original grayscale mask, created in PS: we will try to recreate it
RGB output generated by the shader using this mask: we will use it as areference

Objective: Based on the output above find the grayscale mask that, when passed into the shader, produces an output matching the reference image.


3. Training Loop to Recover the Mask

Below is a simplified training loop to reconstruct the grayscale input mask:

H, W, STEPS = 512, 512, 3000
mask = torch.ones((1,1,H,W), device='cuda') * 0.3  # Start with a constant grayscale mask
mask += -torch.randn_like(mask) * 0.05             # Add small noise for variety
mask.requires_grad_(True)                          # Make the mask learnable

opt = torch.optim.AdamW([mask], lr=2e-3, betas=(0.9, 0.999), weight_decay=1e-7, amsgrad=True)  # AdamW optimizer
sched = OneCycleLR(opt, max_lr=1e-3, total_steps=STEPS, pct_start=0.1, div_factor=10, final_div_factor=1e4)  # Learning rate scheduler

for step in range(STEPS):
    opt.zero_grad()  # Reset gradients
    output = module.rainbow(uv_grid, mask.squeeze())  # Shader output using current mask
    loss = F.l1_loss(output, reference)  # Difference from reference image
    loss.backward()  # Compute gradients
    opt.step(); sched.step()  # Update mask and scheduler

    with torch.no_grad():
        mask.clamp_(0, 1)  # Keep values in valid range

Each iteration includes:

  1. Computing shader output.
  2. Measuring loss against the reference.
  3. Backpropagating to compute gradients.
  4. Updating the mask to minimize loss.

Using adaptive optimizers like AdamW with schedulers such as OneCycleLR helps improve convergence.

And here’s the visualization of the training process:

4. Summary

This approach opens doors to inverse rendering, procedural texture fitting, and new ML-powered artistic workflows.

Try it yourself with the full example notebook: View full code and demo

Dodaj komentarz

Twój adres e-mail nie zostanie opublikowany. Wymagane pola są oznaczone *