Package pytorch_numba_extension_jit
This package is aimed at simplifying the usage of Numba-CUDA kernels within projects using the PyTorch deep learning framework.
By annotating a function written in the style of a
Numba-CUDA kernel
with type hints from this package, jit()
can
generate
PyTorch Custom
Operator
bindings that allow the kernel to be used within a traced (e.g. torch.compile
)
environment.
Furthermore, by setting to_extension=True
, the kernel can also be transformed into
PTX, and C++ code can be generated to invoke the kernel with minimal overhead.
As a toy example, consider the task of creating a copy of a 1D array:
>>> import pytorch_numba_extension_jit as pnex
>>> @pnex.jit(n_threads="a.numel()")
... def copy(
... a: pnex.In(dtype="f32", shape=(None,)),
... result: pnex.Out(dtype="f32", shape="a"),
... ):
... x = cuda.grid(1)
... if x < a.shape[0]:
... result[x] = a[x]
>>> A = torch.arange(5, dtype=torch.float32, device="cuda")
>>> copy(A)
tensor([0., 1., 2., 3., 4.], device='cuda:0')
For more examples of usage, see jit()
and the
examples directory of the project.
Entrypoint
def jit(*,
n_threads: str | tuple[str, str] | tuple[str, str, str],
to_extension: bool = False,
cache_id: str = None,
verbose: bool = False,
threads_per_block: int | tuple[int, int] | tuple[int, int, int] = None,
max_registers: int = None) ‑> Callable[[Callable[..., None]], torch._library.custom_ops.CustomOpDef]-
Compile a Python function in the form of a Numba-CUDA kernel to a PyTorch operator
All parameters must be annotated with one of the argument types exported by this module, and the resulting operator will take
In
/InMut
/Scalar
parameters as arguments, while returningOut
parameters.The keyword-only argument
n_threads
must be specified to indicate with how many threads the resulting kernel should be launched. The dimensionality ofn_threads
indicates the dimensionality of the launched kernel, whilethreads_per_block
controls the size of each block.With
to_extension=True
, this function will also compile the PTX generated by Numba to a PyTorch native C++ extension, thereby reducing the overhead per call. If the resulting compilation times (first several seconds, then cached) are not acceptable, this additional compilation step can be skipped withto_extension=False
.Parameters
n_threads
:str, tuple[str, str], tuple[str, str, str]
-
Expression(s) that evaluate to the total number of threads that the kernel should be launched with. Thread axes are filled in the order X, Y, Z: as such, passing only a single string
n_threads
is equivalent to passing(n_threads, 1, 1)
, with only the X thread-dimension being non-unit.In practice, this number is then divided by
threads_per_block
and rounded up to get the number of blocks for a single kernel invocation (blocks per grid). to_extension
:bool = False
-
Whether the function should be compiled to a PyTorch C++ extension or instead be left as a wrapped Numba-CUDA kernel. The signature of the returned function is identical in both cases, but compiling an extension can take 5+ seconds, while not compiling an extension incurs a small runtime overhead on every call.
For neural networks, it is best to keep
to_extension
as False and use CUDA Graphs viatorch.compile(model, mode="reduce-overhead", fullgraph=True)
to eliminate the wrapper code. If this is not possible (due to highly dynamic code or irregular shapes), then the next best option would be to useto_extension
and minimise call overhead. cache_id
:str
, optional-
The name to save the compiled extension under: clashing
cache_id
s will result in recompilations (clashing functions will evict each-other from the cache), but not miscompilations (the results will be correct).Only used when
to_extension=True
Returns
decorator
:(kernel) -> torch.library.CustomOpDef
-
The resulting decorator will transform a Python function (if properly annotated, and the function is a valid Numba-CUDA kernel) into a
CustomOpDef
, where the signature is such that all parameters annotated withIn
,InMut
orScalar
must be provided as arguments, and allOut
parameters are returned.All parameters must be annotated with one of
In
,InMut
,Out
,Scalar
orUnused
Other Parameters
verbose
:bool = False
- Whether to print additional information about the compilation process. Compilation errors are always printed.
threads_per_block
:int, tuple[int, int], tuple[int, int, int] = None
-
The number of threads within a thread block across the various dimensions.
Depending on the dimensionality of
n_threads
, this defaults to one of:- For 1 dimension: 256
- For 2 dimensions: (16, 16)
- For 3 dimensions: (8, 8, 4)
max_registers
:int
, optional- Specify the maximum number of registers to be used by the kernel, with excess
spilling over to local memory.
Typically, the compiler is quite good at guessing the number of registers it
should use, but limiting this to hit occupancy targets may help in some cases.
This option is only available with
to_extension=False
, due to the structure of the Numba-CUDA API.
Examples
This is an example implementation of the
mymuladd
function from the PyTorch Custom C++ and CUDA Operators documentation, where we take 2D inputs instead of flattening. A variety of methods for specifying dtype and shape are used in this example, but sticking to one convention may be better for readability.>>> import pytorch_numba_extension_jit as pnex >>> # Can be invoked as mymuladd_2d(A, B, C) to return RESULT ... @pnex.jit(n_threads="result.numel()") ... def mymuladd_2d( ... a: pnex.In(torch.float32, (None, None)), ... b: pnex.In("f32", ("a.size(0)", "a.size(1)")), ... c: float, # : pnex.Scalar(float) ... result: pnex.Out("float32", "a"), ... ): ... idx = cuda.grid(1) ... y, x = divmod(idx, result.shape[0]) ... if y < result.shape[0]: ... result[y, x] = a[y, x] * b[y, x] + c
Here, we can see an alternate version that uses multidimensional blocks to achieve the same task, while compiling the result to a C++ operator using
to_extension
. Note that then_threads
argument is given sizes in the X, Y, Z order (consistent with C++ CUDA kernels), and thatnumba.cuda.grid
also returns indices in this order, even if we might later use indices in e.g.y, x
order.>>> @pnex.jit(n_threads=("result.size(1)", "result.size(0)"), to_extension=True) ... def mymuladd_grid( ... a: pnex.In("f32", (None, None)), ... b: pnex.In("f32", ("a.size(0)", "a.size(1)")), ... c: float, ... result: pnex.Out("f32", "a"), ... ): ... # always use this order for names to be consistent with CUDA terminology: ... x, y = cuda.grid(2) ... ... if y < result.shape[0] and x < result.shape[1]: ... result[y, x] = a[y, x] * b[y, x] + c
Notes
This function relies heavily on internals and undocumented behaviour of the Numba-CUDA PTX compiler. However, these internals have not changed in over 3 years, so it is reasonable to assume they will remain similar in future versions as well. Versions 0.9.0 and 0.10.0 of Numba-CUDA have been verified to work as expected.
Additionally, it should be noted that storing the function to be compiled for compilation in a different stack frame may cause issues if some annotations use local variables and the module is using
from __future__ import annotations
. This is because annotations are not considered part of the function proper, so they are not closed over during the construction of a function (no cell is created). Usingjit()
directly with the decorator syntax@pnex.jit(n_threads=...)
has no such problems, or one can selectively disableannotations
for the file where the function to be compiled is defined.See Also
numba.cuda.compile_for_current_device
- used to compile the Python function into PTX: all functions must therefore also be valid
numba.cuda
kernels. numba.cuda.jit
- used instead to allow
to_extension=False
torch.utils.cpp_extension.load_inline
- used to compile the PyTorch C++ extension
Argument types
class In (dtype: torch.dtype | numpy.dtype | str,
shape: tuple[int | str | None, ...] | str)-
A type annotation for immutable input tensor parameters in a
jit()
functionAn input tensor is part of the argument list in the final operator, meaning it must be provided by the caller. This variant is immutable, meaning the kernel must not modify the tensor.
To use this annotation, use the syntax
param: In(dtype, shape)
.Parameters
dtype
:torch.dtype, np.dtype, str
-
The data type of the input tensor.
Some equivalent examples:
torch.float32
,float
,"float32"
or"f32"
shape
:str, tuple
of(int
orstr
orNone)
-
The shape of the input tensor.
If
shape
is a string, it must be the name of a previously defined tensor parameter, and the shape of this parameter must be equal to the shape of the parameter named byshape
.If
shape
is a tuple, every element in the tuple corresponds with one axis in the input tensor. For every such element:int
constrains the axis to be exactly of the given dimension.str
represents an expression that evaluates to an integer, and constrains the axis to be equal to the result of the expression. If the name of a tensor parameter is provided, this is equivalent toparam_name.shape[nth_dim]
wherenth_dim
is the index of the current axis.None
does not constrain the size of the axis.
class InMut (dtype: torch.dtype | numpy.dtype | str,
shape: tuple[int | str | None, ...] | str)-
A type annotation for mutable input tensor parameters in a
jit()
functionAn input tensor is part of the argument list in the final operator, meaning it must be provided by the caller. This variant is mutable, meaning the kernel may modify the tensor.
To use this annotation, use the syntax
param: InMut(dtype, shape)
.For information on the parameters, see
In
. class Out (dtype: torch.dtype | numpy.dtype | str,
shape: tuple[int | str, ...] | str,
init: float = None)-
A type annotation for output tensor parameters in a
jit()
functionAn output tensor is not part of the argument list in the final operator, meaning the caller must not attempt to provide it. Instead, parameters marked as
Out
are created by the wrapper code before being passed to the kernel, and are returned to the caller afterwards as return values from the final operator. Since parameters markedOut
are returned, they can receive a gradient and can work with the PyTorch autograd system.To use this annotation, use the syntax
param: Out(dtype, shape[, init=init])
.Parameters
dtype
:torch.dtype, np.dtype, str
-
The data type of the output tensor.
Some equivalent examples:
torch.float32
,float
,"float32"
or"f32"
shape
:str, tuple
of(int
orstr)
-
The shape of the output tensor.
If
shape
is a string, it must be the name of a previously defined tensor parameter, and this tensor will be constructed to have the same shape as the parameter named byshape
If
shape
is a tuple, every element in the tuple corresponds with one axis in the output tensor. For every such element:int
sets the size to be exactly the provided value.str
represents an expression that evaluates to an integer, and sets the size of the axis to be equal to the result of the expression. If the name of a tensor parameter is provided, this is equivalent toparam_name.shape[nth_dim]
wherenth_dim
is the index of the current axis.
init
:float
orint
, optional-
The initial value used to fill the output tensor with. If not provided, the output tensor will contain uninitialised memory (in the style of
torch.empty
).Example: gradient tensors for the backward pass should be initialised with 0.
class Scalar (dtype: torch.dtype | numpy.dtype | str)
-
A type annotation for scalar input parameters in a
jit()
functionA scalar input is part of the argument list in the final operator, meaning the caller must provide it. It is not returned: for scalar outputs, use
Out(dtype, (1,))
instead.To use this annotation, use the syntax
param: Scalar(dtype)
, or the shorthandparam: dtype
.Parameters
dtype
:torch.dtype, np.dtype, str
-
The data type of the scalar.
Some equivalent examples:
torch.float32
,float
,"float32"
or"f32"
class Unused
-
A type annotation for ignored parameters in a
jit()
functionThis is a utility class for marking certain parameters to be skipped during compilation. An example of this would be a kernel which can optionally return an additional output (such as provenance indices for a maximum operation), allowing this output to be skipped programmatically.
Note that all array accesses of a parameter marked
Unused
must be statically determined to be dead code (e.g.if False
), as compilation will otherwise fail.To use this annotation, use e.g.
param: Out(...) if condition else Unused