Shaoyu Yang
杨少宇
🛠️AI Infra Testing①: PyTorch model introduction, debugging, and compiler correctness

S1: Introduction for components of Deep Learning models
It is widely accepted that a Deep Learning (DL) model is essentially a piece of code (for PyTorch, it is a Python class inherited from torch.nn.Module
). A complete DL model test-case that can be executed by DL frameworks (e.g., PyTorch) consists of two key components (here we only discuss the simplest case because the real-world model is quite large, especially Large Models) :
- A Python class inherited from
torch.nn.Module
- Tensor inputs
Here is a specific model corresponding to these two components.
import torch
# A Python class inherited from `torch.nn.Module`
class Model(torch.nn.Module):
def __init__(self):
super().__init__()
self.conv = torch.nn.Conv1d(in_channels=1, out_channels=3, kernel_size=1)
self.linear = torch.nn.Linear(3, 1)
def forward(self, x):
x = x.unsqueeze(1) # tensor shape: (1,3) -> (1,1,3)
x = self.conv(x) # tensor shape: (1,1,3) -> (1,3,3)
x = x.mean(dim=-1) # tensor shape: (1,3,3) -> (1,3)
return self.linear(x) # tensor shape: (1,3) -> (1,1)
x = torch.randn(1, 3) # Tensor inputs
m = Model() # Model initialization
output = m(x)
print(output)
"""
tensor([[-0.0931]])
"""
OK, now, let me explain the code in more detail.
-
Model class is a Python class inherited from
torch.nn.Module
which can be taken as a so-called DL model (program). In the__init__
method, we need to instantiate some PyTorch APIs (also called operators) with attributes, as they are essentially Python classes. For example, the above codetorch.nn.Conv1d
has some attributes (e.g., in_channels, out_channels, and kernel_size), which should be explicitly instantiated and assigned values. Theforward
method is the core logic for us to implement forward computing, which can be abstracted into a computational graph. The DL computational graph is a directed acyclic graph (DAG). The process in which tensor inputs flow from the start point to the ending point of DAG is the process of DL model inference. - Tensor input(s) is
torch.tensor
type. This is very easy to understand because DL model needs to do is to calculate tensors. PyTorch provide some APIs to create tensors (in this case, we usetorch.rand()
).
After we define these two core components, what we need to do is very easy: First, instantiate the model class (Line 19). Then, pass the tensor input(s) into the instantiated model m
. m
call the forward
method to complete forward computing.
S2: Debugging PyTorch models
Sometimes, we may encounter some errors when running the code. Of course, it may be a potential bug in PyTorch, but the majority of the time, it is a buggy code itself (usually generated by DL fuzzers or typos). For example, look at the code below:
class MatrixModel(torch.nn.Module):
def __init__(self):
super().__init__()
self.mm_layer = torch.nn.Linear(10, 10)
def forward(self, x1, inp):
# Constraint1: Row-column alignment
v1 = torch.mm(x1, inp)
# Constraint2: Same shape or broadcastable
v2 = v1 + inp
return v2
# tensor shape definition does not satisfy constraints
x1 = torch.randn(2, 10)
inp = torch.randn(2, 10)
The unsatisfied constraint relationship causes this RuntimError. Let’s check the code in forward function: torch.mm
require the x1 and inp must be row-column aligned ([a, b] * [b, c]). However, the shapes of x1 and inp are both [2, 10]. In the result, PyTorh throws this RuntimeError. So how should we fix this model? We note that another operator in Line 9 ( v2 = v1 + inp
) which apply an additional constraint: Same shape or broadcastable. This means after the first operation (Line 7), v1 should have the same shape with inp. Consequently, it is very easy for us to solve that we just need to modify the shape of x1 from [2, 10] to [2, 2]. Then the model can be executed successfully.
I give the above example because I want to show that before we determine if a DL model is triggering a bug, we should check if it is an invalid model that is causing a false alert. These invalid models (no matter whether we manually craft or fuzzer generate) are meaningless in AI Infra testing.
S3: Correctness of torch.compile()
torch.compile()
is the most important feature in PyTorch 2.x. The two core components of torch.comile()
are Dynamo and Inductor. Dynamo is used to capture the PyTorch dynamic computational graph and perform some initial optimizations, while Inductor conducts more fine-grained optimizations and lowers the computational graph to the target code. Unfortunately, torch.compile()
is developing rapidly, and many bugs are hidden in it.
If we use torch.compile()
to compile a PyTorch model, we can call this execution model compiler mode. Otherwise, we call it eager mode. As a DL compiler, torch.compile()
theoretically only optimize the model performance (i.e., reduce the model training or inference time). However, this is not always the case.
From the perspective of software correctness, the most important principle that torch.compile()
needs to follow is: Aligning with any eager semantics. Specifically, based on my experience, there are the following two possible misalignments (inconsistencies):
-
Behaviour Inconsistency: one of the modes (i.e., compiler or eager) throws an error (so-called crash), but the other one does not. Let me take PyTorch#143729 below as an example. Let’s check the forward function! First,
torch.frexp
accepts the input x and returns two outputs (Line 10): x_frac and x_exp, respectively. These two tensors have different dtypes (i.e.,int32
andfloat32
, respectively). Then, these two tensors are element-wise multiplied. Eager can process this situation and output correct results. Unfortunately, Inductor can’t output tensors correctly and throwsCppCompileError: C++ compile error
. This is because CPU Inductor can’t deduce the data type for each output correctly due to missing of output index. This PR (PyTorch#143746) fixes this issue, more details can be found in it. As mentioned at the beginning, eager passes the check while inductor throws an error. The behaviours which they exhibit are different, so this is a bug oftorch.compile()
.
import torch
import torch.nn as nn
class Model(nn.Module):
def __init__(self):
super(Model, self).__init__()
def forward(self, x):
x_frac, x_exp = torch.frexp(x) # x_frac: int32, x_exp: float32
x = x_frac * x_exp
return x
x = torch.randn(4, 1) # the first element I set 4 can trigger the error
- Numerical Inconsistency: The numerical difference between the output tensors in eager and compiler exceeds the tolerance threshold. Let me take PyTorch#151198 as an example below. The core components for this script is Line12-Line30 which represents DL model consisting of four operators (i.e.,
torch.nn.ReflectionPad3d
,view()
,torch.chunk
andtorch.nn.PairwiseDistance
) and tensor inputs. Then, we run the model with tensor inputs on eager and inductor, respectively. Additionaly, sugggested by PyTorch bug report requirement, we also run the model withtorch.float64
dtype. Finally, we get three output results and compare them withtorch._dynamo.utils.same
. As you can see, it outputs False. This means the difference between eager outputs and inductor outputs exceeds the tolerance threshold (default setting intorch._dynamo.utils.same
). I reported this bug to PyTorch community and this issue was labeled as high priority and utmost priority (i.e., ubn). Finally, this bug is fixed in PyTorch#152993. If your are interested in this bug, please refer to this PR for more details.
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch._inductor import config
import os
config.fallback_random = True
torch.set_grad_enabled(False)
torch.manual_seed(0)
class Model(torch.nn.Module):
def __init__(self):
super(Model, self).__init__()
self.pad = torch.nn.ReflectionPad3d(1)
self.dist = torch.nn.PairwiseDistance(p=2)
def forward(self, x):
x = self.pad(x)
x = x.view(x.size(0), -1)
x = torch.chunk(x, 2, dim=1)
x = self.dist(x[0], x[1])
return x
model = Model().eval().cuda()
x = torch.randn(2, 3, 4, 4, 4).cuda()
inputs = [x]
def run_test(model, inputs, backend):
if backend != "eager":
model = torch.compile(model, backend=backend)
torch.manual_seed(0)
output = model(*inputs)
return output
output = run_test(model, inputs, 'eager')
c_output = run_test(model, inputs, 'inductor')
fp64 = run_test(model.to(dtype=torch.float64), [x.to(dtype=torch.float64)], 'eager')
print(output)
print(c_output)
print(fp64)
print(torch._dynamo.utils.same(output, c_output, fp64))
print(torch.max(torch.abs(output - c_output)))
tensor([23.1078, 21.4387], device='cuda:0')
tensor([22.9208, 22.6405], device='cuda:0', dtype=torch.float64)
False
tensor(1.2018, device='cuda:0')
The above two bugs were both found by an LLM-enpowered fuzzer (will be open source in the near future) designed by me. Your can repro them in PyTorch 2.6. Both of these inconsistencies reflect the incorrectness of torch.compile()
. If you find one of them, congratulations! You may find a potential bug in torch.compile()
and you can report it to the community. I will introduce how to report bugs in the next blog.