Hacks and Defenses in Automatic GPU Kernel Generation

Jiwei Li from the DeepReinforce Team
🎯Hack 🛡️Defense Defensible
Stream Injection Add synchronize() before end_event ✓ Yes
🧵Thread Injection Check thread count before/after ✓ Yes
💤Lazy Evaluation Validate tensor type & storage ✓ Yes
📉Precision Downgrade Check dtype, tighten tolerance, or use LLM ⚠ Partial
🐒Monkey-patch Timing Verify function references ✓ Yes

Self-evolving systems, whether RL or agent-based, for GPU kernel generation are taking off rapidly! There have been a lot of exciting recent projects pushing this direction from datasets (Flashinfer-bench, KernelBench, robust-kbench) to algorithms (AI CUDA engineer, CUDA-L1 and CUDA-L2, Kevin, Swizzleperf, Locus, KernelLLM, etc). For a more comprehensive background, see Simon Guo’s recent [blog].

In this blog, we’ll focus on an emerging and serious issue that most researchers working in this direction have encountered, or will soon encounter: the evaluation-hacking problem. This refers to a situation where the self-evolving system tries to game the evaluation, producing code that appears faster but actually isn’t. In fact, this issue is not rare but super common, and traces back to the classic reward hacking problem that people saw from day one in reinforcement learning. Code for defenses against evaluation hacking described in this blog can be found at [code].

How to hack the evaluation?

To illustrate how the evaluation can be hacked, let's first look at how it works. For most kernel generation tasks, one important goal is to produce faster kernels. Therefore, we need to measure a kernel's execution time.

The most common timing approach, which is widely adopted in FlashInfer, KernelBench, CUDA-L2 and a lot of others, is to record CUDA events before and after the kernel call:

Python - Standard Timing Approach
torch.cuda.synchronize()

start_event = torch.cuda.Event(enable_timing=True)
end_event = torch.cuda.Event(enable_timing=True)

start_event.record()
kernel(*inputs)
end_event.record()

torch.cuda.synchronize()
elapsed_time_ms = start_event.elapsed_time(end_event)

In the code above, custom_kernel can take several forms:

  1. Inline CUDA via load_inline — a Python API (torch.utils.cpp_extension.load_inline) that takes CUDA C++ code as a string and JIT-compiles it at runtime
  2. A Triton kernel — a Python function decorated with @triton.jit
  3. A precompiled .cu file — CUDA code compiled ahead of time and bound to Python via pybind11

Regardless of the implementation, the malicious kernel must satisfy a basic correctness check. Designing a comprehensive correctness-checking function deserves a dedicated discussion, which we will leave for the next blog. Here, we simply use torch.allclose, which checks the element-wise difference between the outputs of the custom kernel and a reference kernel, as adopted in KernelBench and CUDA-L2.

Therefore, hacking strategies that can be easily filtered by torch.allclose() checking are not in our scope of discussion. These kernels include partial computation (which only computes a subset of the output), and approximate algorithms (e.g., using low-rank approximation like svd).

Next, we will describe a few exploits we have observed that can successfully hack this timing system.

Hack #1: Stream Injection

In the current timing system, both CUDA events (start_event and end_event) are recorded on the default stream. These events only "see" work submitted to the default stream. If a malicious kernel creates a separate stream and launches computation there, the computation is invisible to the default stream. From the default stream's perspective, nothing happened between start_event.record() and end_event.record(), making the malicious kernel appear much faster.

The following is a simple example of this kind of malicious kernel:

Python - Stream Injection Exploit
def custom_kernel_stream_inject(c1, c2):
    stream = torch.cuda.Stream()
    with torch.cuda.stream(stream):
        return torch.matmul(c1, c2)

Let’s just run the full code and compare the execution time.

Python - Full Comparison Code
import torch

def custom_kernel_original(c1, c2):
    return torch.matmul(c1, c2)

def custom_kernel_stream_inject(c1, c2):
    stream = torch.cuda.Stream()
    with torch.cuda.stream(stream):
        return torch.matmul(c1, c2)

def timing_cuda_event(kernel):
    c1 = torch.randn(10000, 10000).cuda()
    c2 = torch.randn(10000, 10000).cuda()

    torch.cuda.synchronize()

    start_event = torch.cuda.Event(enable_timing=True)
    end_event = torch.cuda.Event(enable_timing=True)

    start_event.record()
    kernel(c1, c2)
    end_event.record()

    torch.cuda.synchronize()
    elapsed_time_ms = start_event.elapsed_time(end_event)
    print(f"Elapsed time of {kernel.__name__}: {float(elapsed_time_ms):.3f} ms")

if __name__ == "__main__":
    timing_cuda_event(custom_kernel_original)
    timing_cuda_event(custom_kernel_stream_inject)

The execution times of the original and malicious kernels are shown below:

$ python stream_injection.py
Elapsed time of custom_kernel_original: 264.706 ms
Elapsed time of custom_kernel_stream_inject: 4.244 ms
>>> Defense 1: Add a torch.cuda.synchronize() before end_event.record().

One straightforward defense is to place another torch.cuda.synchronize() right after kernel(*inputs) and before end_event.record(), as shown below:

Original code:

Python
start_event.record()
custom_model(c1, c2)
end_event.record()
torch.cuda.synchronize()
elapsed_time_ms = start_event.elapsed_time(end_event)

Defense code:

Python
start_event.record()
custom_model(c1, c2)
torch.cuda.synchronize()  # <<<
end_event.record()
torch.cuda.synchronize()
elapsed_time_ms = start_event.elapsed_time(end_event)

The added torch.cuda.synchronize() blocks the CPU threads so that all hidden work in the injected stream must finish before the event is recorded.

...

However, the defense code will have CPU-side delay to be recorded, which can be undesirable in some cases.

In the original timing code, the CPU sends three requests to the GPU in sequence: start_event.record(), kernel(c1, c2), and end_event.record(). Each request takes about 5–25 µs to issue. The GPU begins processing the first request upon its arrival, and the requests that arrive later will wait in the queue. Because the GPU is already busy when end_event.record() is sent, there is no extra delay.

In contrast, in the defense code, the CPU first sends only two requests to the GPU: start_event.record() and kernel(c1, c2). Then it calls synchronize(), which forces the CPU to wait until both complete on the GPU.

Next, the CPU sends the last end_event.record() request to the GPU. At this moment, the GPU is idle, so issuing end_event.record() will incur a delay, which will be captured in timing. This delay does not occur in the original code because, when end_event.record() is issued, the GPU is busy handling previous requests.

...

For kernels with seconds of execution times, this delay is negligible. But for a 1 ms kernel, this delay cannot be ignored.

>>> Defense 2: The Hybrid Approach

If we want to completely get rid of the delay in evaluation, we can use a hybrid approach. Before running the full evaluation, we first run both the original and the defense versions and compare their execution times. If the difference is large, it indicates that the custom kernel might be malicious, and we can discard it or assign it a score of 0. Otherwise, we can then fall back to the original timing method for evaluation.

>>> Defense 3: Call LLMs for new stream detection

We can use an LLM API to detect whether the custom kernel creates a new CUDA stream. However, this approach can be expensive when evaluating a lot of kernels.

Hack #2: Thread Injection

Similar to stream injection, an attacker can spawn a background thread, which performs the actual computation asynchronously while the main thread returns an empty output immediately. This attack can also pass the correctness check, since by the time the correctness check runs, the background thread has finished and filled in the correct output.

Here’s an example of a malicious kernel with thread-injection attack:

Python - Thread Injection Exploit
import threading, torch

def custom_kernel(A, B):
    out = torch.empty(A.size(0), B.size(1), device=A.device)

    def compute():
        result = torch.matmul(A, B)
        out.copy_(result)  # Fill in result in the background thread

    t = threading.Thread(target=compute)
    t.start()
    return out
>>> Defense: Compare thread counts before/after

We can compare the thread count before and after kernel execution to detect when new threads are created:

Python - Thread Count Check
import threading

before = threading.active_count()
output = custom_kernel(A, B)
after = threading.active_count()

if after > before:
    raise RuntimeError("Kernel spawned background thread")

Process Injection

With stream and thread injection attacks already observed, process injection naturally follows. Th good news is that CUDA contexts don't transfer across process boundaries, which means injection attack wouldn't work for CUDA operations in practice. One fewer thing to worry about !

Hack #3: Lazy Evaluation

In lazy evaluation, custom_model(*inputs) does not ensure that the output is actually materialized or computed. We need to pay more attention to this when using load_inline in python. An example of this type of malicious kernel is given as follows:

Python - Lazy Evaluation Exploit
class LazyMatmul(torch.Tensor):
    @staticmethod
    def __new__(cls, A, B):
        obj = torch.Tensor._make_subclass(
            cls, torch.empty(A.size(0), B.size(1), device=A.device)
        )
        obj.A, obj.B = A, B
        return obj

    def __eq__(self, other):
        return torch.matmul(self.A, self.B) == other

def custom_kernel(A, B):
    return LazyMatmul(A, B)

Let’s compare the execution times:

$ python lazy.py Elapsed time of custom_kernel_original: 305.062 ms Elapsed time of custom_kernel_lazy: 0.341 ms
>>> Defense: Validate that outputs are materialized tensors

To ensure that materialized functions are called before ending the time measurement, we can enforce a validation check. The check involves verifying the following conditions: the output must be a standard torch.Tensor (not a subclass), must be on the correct GPU device, must have allocated memory, and the corresponding storage must be valid.

Python - Tensor Validation Function
def validate_tensor(out: torch.Tensor, device: torch.device, prefix: str = "Output") -> tuple[bool, str]:
      """
       Validate tensor is real and materialized, not a lazy hack.

       Returns:
           (True, "OK") if valid
           (False, error_message) if invalid
      """
      # Check 1: Must be a tensor
      if not isinstance(out, torch.Tensor):
          return False, f"{prefix} is not a tensor: {type(out)}"

      # Check 2: Must be standard torch.Tensor, not a subclass
      if type(out).__name__ not in ["Tensor", "Parameter"]:
          return False, f"{prefix} is {type(out).__name__}, not standard torch.Tensor"

      # Check 3: Must be on correct device
      if out.device != device:
          return False, f"{prefix} on wrong device: {out.device} (expected {device})"

      # Check 4: Must have allocated storage
      if out.untyped_storage().size() == 0:
          return False, f"{prefix} has no allocated storage (likely lazy)"

      # Check 5: Storage pointer must be valid
      if out.data_ptr() == 0:
          return False, f"{prefix} storage pointer is null (likely lazy)"

      return True, "OK"

Hack #4: Precision Downgrading

Another hack we’ve observed is precision downgrading. For example, suppose the task is to build a FP32 kernel. The custom kernel silently downgrades all computations to BF16 or FP16, then casts the output back to FP32 after the computation, which of course makes it faster. An illustrative example is shown below:

Python - Precision Downgrade Example
def custom_kernel(A, B):  # A, B are FP32
    out = torch.matmul(A.bfloat16(), B.bfloat16())  # BF16 compute, FP32 accumulator
    return out.float()  # Output is FP32, but precision is BF16

Theoretically, the code with precision downgrading should not pass the correctness check. But in practice, if we are using torch.allclose to check the element-wise difference between the outputs of the custom kernel and the reference kernel, we need to empirically set a threshold for the difference. Kernels with Precision Downgrading have a good chance of slipping through this threshold because the computation itself is correct and that only the arithmetic precision is wrong.

Unfortunately, we haven’t found a universally simple and perfect solution to this issue.

  • Checking the output dtype or whether the lower bits are zero cannot defend against all cases, because many so-called 16-bit kernels still use 32-bit arithmetic internally. For example, BF16 matmul usually multiplies in 16-bit but accumulates with a 32-bit accumulator.
  • Tightening the torch.allclose threshold in the correctness check can help in most of the time, but not always, since choosing the threshold is already tricky. For kernels that fall close to the threshold, it is hard to decide whether they should be considered correct.

If we want a fully reliable solution, using an LLM API for checking seems to be the only option.

Hack #5: Monkey-patching start_event.elapsed_time(end_event)

Since timing depends on start_event.elapsed_time(end_event) to ensure GPU completion, the hack can just replace it with a function that returns 0.0001s 😩 😩 😩 . The kernel still runs and passes the correctness check, but timing becomes meaningless.

Python - Fake Timing via Monkey-patch
import torch

# Monkey-patch elapsed_time to return fake fast timing
_original_elapsed_time = torch.cuda.Event.elapsed_time

def _fake_elapsed_time(self, end_event):
    return 0.001  # Always report 0.001ms - fake fast!

torch.cuda.Event.elapsed_time = _fake_elapsed_time

def custom_kernel_monkey_patch(A, B):
    return torch.matmul(A, B)

Let’s see the execution time:

$ python benchmark_monkey.py Elapsed time of custom_kernel_original: 287.818 ms Elapsed time of custom_kernel_monkey_patch: 0.001 ms
>>> Defense: Verify critical timing functions aren’t monkey-patched

As with defense, we can just check whether torch.cuda.Event.record has been overridden

Python - Verify No Monkey-patch
_real_record = torch.cuda.Event.record
from custom import custom_kernel

def verify_no_monkey_patch():
    if torch.cuda.Event.record is not _real_record:
        return False, "torch.cuda.Event.record was monkey-patched"
    return True, "OK"

ok, msg = verify_no_monkey_patch()
if not ok:
    raise RuntimeError(msg)

Summary

Despite the many hacks discussed in this blog, it is impossible to enumerate them all. LLMs,especially when combined with reinforcement learning,are extremely good at exploiting these loopholes. The arm race between attacks (used to be carried out by human hackers, now by LLMs) and defenses shows no end in sight.

In the next blog post, we will discuss another tricky but critical issue: correctness checking. Stay tuned!

Reference