| 🎯Hack | 🛡️Defense | ✅Defensible |
|---|---|---|
| ⚡Stream Injection | Add synchronize() before end_event |
✓ Yes |
| 🧵Thread Injection | Check thread count before/after | ✓ Yes |
| 💤Lazy Evaluation | Validate tensor type & storage | ✓ Yes |
| 📉Precision Downgrade | Check dtype, tighten tolerance, or use LLM | ⚠ Partial |
| 🐒Monkey-patch Timing | Verify function references | ✓ Yes |
Self-evolving systems, whether RL or agent-based, for GPU kernel generation are taking off rapidly! There have been a lot of exciting recent projects pushing this direction from datasets (Flashinfer-bench, KernelBench, robust-kbench) to algorithms (AI CUDA engineer, CUDA-L1 and CUDA-L2, Kevin, Swizzleperf, Locus, KernelLLM, etc). For a more comprehensive background, see Simon Guo’s recent [blog].
In this blog, we’ll focus on an emerging and serious issue that most researchers working in this direction have encountered, or will soon encounter: the evaluation-hacking problem. This refers to a situation where the self-evolving system tries to game the evaluation, producing code that appears faster but actually isn’t. In fact, this issue is not rare but super common, and traces back to the classic reward hacking problem that people saw from day one in reinforcement learning. Code for defenses against evaluation hacking described in this blog can be found at [code].
How to hack the evaluation?
To illustrate how the evaluation can be hacked, let's first look at how it works. For most kernel generation tasks, one important goal is to produce faster kernels. Therefore, we need to measure a kernel's execution time.
The most common timing approach, which is widely adopted in FlashInfer, KernelBench, CUDA-L2 and a lot of others, is to record CUDA events before and after the kernel call:
torch.cuda.synchronize()
start_event = torch.cuda.Event(enable_timing=True)
end_event = torch.cuda.Event(enable_timing=True)
start_event.record()
kernel(*inputs)
end_event.record()
torch.cuda.synchronize()
elapsed_time_ms = start_event.elapsed_time(end_event)
In the code above, custom_kernel can take several
forms:
-
Inline CUDA via
load_inline— a Python API (torch.utils.cpp_extension.load_inline) that takes CUDA C++ code as a string and JIT-compiles it at runtime -
A Triton kernel — a Python function decorated with
@triton.jit -
A precompiled
.cufile — CUDA code compiled ahead of time and bound to Python via pybind11
Regardless of the implementation, the malicious kernel must satisfy
a basic correctness check. Designing a comprehensive
correctness-checking function deserves a dedicated discussion, which
we will leave for the next blog. Here, we simply use
torch.allclose, which checks the element-wise
difference between the outputs of the custom kernel and a reference
kernel, as adopted in KernelBench and CUDA-L2.
Therefore, hacking strategies that can be easily filtered by
torch.allclose() checking are not in our scope of
discussion. These kernels include partial computation (which only
computes a subset of the output), and approximate algorithms (e.g.,
using low-rank approximation like svd).
Next, we will describe a few exploits we have observed that can successfully hack this timing system.
Hack #1: Stream Injection
In the current timing system, both CUDA events (start_event
and end_event) are recorded on the default stream.
These events only "see" work submitted to the default stream. If a
malicious kernel creates a separate stream and launches computation
there, the computation is invisible to the default stream. From the
default stream's perspective, nothing happened between
start_event.record() and
end_event.record(), making the malicious kernel appear
much faster.
The following is a simple example of this kind of malicious kernel:
def custom_kernel_stream_inject(c1, c2):
stream = torch.cuda.Stream()
with torch.cuda.stream(stream):
return torch.matmul(c1, c2)
Let’s just run the full code and compare the execution time.
import torch
def custom_kernel_original(c1, c2):
return torch.matmul(c1, c2)
def custom_kernel_stream_inject(c1, c2):
stream = torch.cuda.Stream()
with torch.cuda.stream(stream):
return torch.matmul(c1, c2)
def timing_cuda_event(kernel):
c1 = torch.randn(10000, 10000).cuda()
c2 = torch.randn(10000, 10000).cuda()
torch.cuda.synchronize()
start_event = torch.cuda.Event(enable_timing=True)
end_event = torch.cuda.Event(enable_timing=True)
start_event.record()
kernel(c1, c2)
end_event.record()
torch.cuda.synchronize()
elapsed_time_ms = start_event.elapsed_time(end_event)
print(f"Elapsed time of {kernel.__name__}: {float(elapsed_time_ms):.3f} ms")
if __name__ == "__main__":
timing_cuda_event(custom_kernel_original)
timing_cuda_event(custom_kernel_stream_inject)
The execution times of the original and malicious kernels are shown below:
One straightforward defense is to place another
torch.cuda.synchronize() right after
kernel(*inputs) and before
end_event.record(), as shown below:
Original code:
start_event.record()
custom_model(c1, c2)
end_event.record()
torch.cuda.synchronize()
elapsed_time_ms = start_event.elapsed_time(end_event)
Defense code:
start_event.record()
custom_model(c1, c2)
torch.cuda.synchronize() # <<<
end_event.record()
torch.cuda.synchronize()
elapsed_time_ms = start_event.elapsed_time(end_event)
The added torch.cuda.synchronize() blocks the CPU
threads so that all hidden work in the injected stream must finish
before the event is recorded.
However, the defense code will have CPU-side delay to be recorded, which can be undesirable in some cases.
In the original timing code, the CPU sends three requests to the GPU
in sequence: start_event.record(),
kernel(c1, c2), and end_event.record().
Each request takes about 5–25 µs to issue. The GPU begins processing
the first request upon its arrival, and the requests that arrive
later will wait in the queue. Because the GPU is already busy when
end_event.record() is sent, there is no extra delay.
In contrast, in the defense code, the CPU first sends only two
requests to the GPU: start_event.record() and
kernel(c1, c2). Then it calls
synchronize(), which forces the CPU to wait until both
complete on the GPU.
Next, the CPU sends the last end_event.record() request
to the GPU. At this moment, the GPU is idle, so issuing
end_event.record() will incur a delay, which will be
captured in timing. This delay does not occur in the original code
because, when end_event.record() is issued, the GPU is
busy handling previous requests.
For kernels with seconds of execution times, this delay is negligible. But for a 1 ms kernel, this delay cannot be ignored.
If we want to completely get rid of the delay in evaluation, we can use a hybrid approach. Before running the full evaluation, we first run both the original and the defense versions and compare their execution times. If the difference is large, it indicates that the custom kernel might be malicious, and we can discard it or assign it a score of 0. Otherwise, we can then fall back to the original timing method for evaluation.
We can use an LLM API to detect whether the custom kernel creates a new CUDA stream. However, this approach can be expensive when evaluating a lot of kernels.
Hack #2: Thread Injection
Similar to stream injection, an attacker can spawn a background thread, which performs the actual computation asynchronously while the main thread returns an empty output immediately. This attack can also pass the correctness check, since by the time the correctness check runs, the background thread has finished and filled in the correct output.
Here’s an example of a malicious kernel with thread-injection attack:
import threading, torch
def custom_kernel(A, B):
out = torch.empty(A.size(0), B.size(1), device=A.device)
def compute():
result = torch.matmul(A, B)
out.copy_(result) # Fill in result in the background thread
t = threading.Thread(target=compute)
t.start()
return out
We can compare the thread count before and after kernel execution to detect when new threads are created:
import threading
before = threading.active_count()
output = custom_kernel(A, B)
after = threading.active_count()
if after > before:
raise RuntimeError("Kernel spawned background thread")
Process Injection
Hack #3: Lazy Evaluation
In lazy evaluation, custom_model(*inputs) does not
ensure that the output is actually materialized or computed. We need
to pay more attention to this when using load_inline in
python. An example of this type of malicious kernel is given as
follows:
class LazyMatmul(torch.Tensor):
@staticmethod
def __new__(cls, A, B):
obj = torch.Tensor._make_subclass(
cls, torch.empty(A.size(0), B.size(1), device=A.device)
)
obj.A, obj.B = A, B
return obj
def __eq__(self, other):
return torch.matmul(self.A, self.B) == other
def custom_kernel(A, B):
return LazyMatmul(A, B)
Let’s compare the execution times:
To ensure that materialized functions are called before ending the
time measurement, we can enforce a validation check. The check
involves verifying the following conditions: the output must be a
standard torch.Tensor (not a subclass), must be on the
correct GPU device, must have allocated memory, and the
corresponding storage must be valid.
def validate_tensor(out: torch.Tensor, device: torch.device, prefix: str = "Output") -> tuple[bool, str]:
"""
Validate tensor is real and materialized, not a lazy hack.
Returns:
(True, "OK") if valid
(False, error_message) if invalid
"""
# Check 1: Must be a tensor
if not isinstance(out, torch.Tensor):
return False, f"{prefix} is not a tensor: {type(out)}"
# Check 2: Must be standard torch.Tensor, not a subclass
if type(out).__name__ not in ["Tensor", "Parameter"]:
return False, f"{prefix} is {type(out).__name__}, not standard torch.Tensor"
# Check 3: Must be on correct device
if out.device != device:
return False, f"{prefix} on wrong device: {out.device} (expected {device})"
# Check 4: Must have allocated storage
if out.untyped_storage().size() == 0:
return False, f"{prefix} has no allocated storage (likely lazy)"
# Check 5: Storage pointer must be valid
if out.data_ptr() == 0:
return False, f"{prefix} storage pointer is null (likely lazy)"
return True, "OK"
Hack #4: Precision Downgrading
Another hack we’ve observed is precision downgrading. For example, suppose the task is to build a FP32 kernel. The custom kernel silently downgrades all computations to BF16 or FP16, then casts the output back to FP32 after the computation, which of course makes it faster. An illustrative example is shown below:
def custom_kernel(A, B): # A, B are FP32
out = torch.matmul(A.bfloat16(), B.bfloat16()) # BF16 compute, FP32 accumulator
return out.float() # Output is FP32, but precision is BF16
Theoretically, the code with precision downgrading should not pass
the correctness check. But in practice, if we are using
torch.allclose to check the element-wise difference
between the outputs of the custom kernel and the reference kernel,
we need to empirically set a threshold for the difference. Kernels
with Precision Downgrading have a good chance of slipping through
this threshold because the computation itself is correct and that
only the arithmetic precision is wrong.
Unfortunately, we haven’t found a universally simple and perfect solution to this issue.
- Checking the output dtype or whether the lower bits are zero cannot defend against all cases, because many so-called 16-bit kernels still use 32-bit arithmetic internally. For example, BF16 matmul usually multiplies in 16-bit but accumulates with a 32-bit accumulator.
-
Tightening the
torch.allclosethreshold in the correctness check can help in most of the time, but not always, since choosing the threshold is already tricky. For kernels that fall close to the threshold, it is hard to decide whether they should be considered correct.
If we want a fully reliable solution, using an LLM API for checking seems to be the only option.
Hack #5: Monkey-patching start_event.elapsed_time(end_event)
Since timing depends on
start_event.elapsed_time(end_event) to ensure GPU
completion, the hack can just replace it with a function that
returns 0.0001s 😩 😩 😩 . The kernel still runs and passes the
correctness check, but timing becomes meaningless.
import torch
# Monkey-patch elapsed_time to return fake fast timing
_original_elapsed_time = torch.cuda.Event.elapsed_time
def _fake_elapsed_time(self, end_event):
return 0.001 # Always report 0.001ms - fake fast!
torch.cuda.Event.elapsed_time = _fake_elapsed_time
def custom_kernel_monkey_patch(A, B):
return torch.matmul(A, B)
Let’s see the execution time:
As with defense, we can just check whether
torch.cuda.Event.record has been overridden
_real_record = torch.cuda.Event.record
from custom import custom_kernel
def verify_no_monkey_patch():
if torch.cuda.Event.record is not _real_record:
return False, "torch.cuda.Event.record was monkey-patched"
return True, "OK"
ok, msg = verify_no_monkey_patch()
if not ok:
raise RuntimeError(msg)
Summary
Despite the many hacks discussed in this blog, it is impossible to enumerate them all. LLMs,especially when combined with reinforcement learning,are extremely good at exploiting these loopholes. The arm race between attacks (used to be carried out by human hackers, now by LLMs) and defenses shows no end in sight.
In the next blog post, we will discuss another tricky but critical issue: correctness checking. Stay tuned!
Reference
-
Towards Automated GPU Kernel Generation
Simon Guo & Alex Zhang
[Blog] -
KernelBench: Can LLMs Write Efficient GPU Kernels?
Anne Ouyang*, Simon Guo*, Simran Arora, Alex L. Zhang, William Hu, Christopher Ré, Azalia Mirhoseini
[Github, Paper] -
CUDA-L2: Surpassing cuBLAS Performance for Matrix Multiplication
through Reinforcement Learning
Songqiao Su, Xiaofei Sun, Xiaoya Li, Albert Wang, Jiwei Li, Chris Shum
[Github, Paper] -
CUDA-L1: Improving CUDA Optimization via Contrastive
Reinforcement Learning
Xiaoya Li, Xiaofei Sun, Albert Wang, Jiwei Li, Chris Shum
[Github, Paper] -
SwizzlePerf: Hardware-Aware LLMs for GPU Kernel Performance
Optimization
Arya Tschand, Muhammad Awad, Ryan Swann, Kesavan Ramakrishnan, Jeffrey Ma, Keith Lowery, Ganesh Dasika, Vijay Janapa Reddi
[Paper] -
FlashInferBench: Building the Virtuous Cycle for AI-driven LLM
Systems
FlashInfer Team
[Github, Blog] -
Kevin: Multi-Turn RL for Generating CUDA Kernels
Carlo Baronio, Pietro Marsella, Ben Pan, Simon Guo, Silas Alberti
[Paper] -
FlashInfer: Efficient and Customizable Attention Engine for LLM
Inference Serving
Zihao Ye, Lequn Chen, Ruihang Lai, Wuwei Lin, Yineng Zhang, Stephanie Wang, Tianqi Chen, Baris Kasikci, Vinod Grover, Arvind Krishnamurthy, Luis Ceze
[Github, Paper] -
KernelLLM: Making Kernel Development More Accessible
Zacharias V. Fisches, Sahan Paliskara, Simon Guo, Alex Zhang, Joe Spisak, Chris Cummins, Hugh Leather, Gabriel Synnaeve, Joe Isaacson, Aram Markosyan, Mark Saroufim
[HuggingFace]