Path: blob/master/site/en-snapshot/xla/custom_call.md
25115 views
XLA Custom Calls
This document describes how to write and use XLA "custom calls". Custom calls let you invoke code written in a programming language like C++ or CUDA from an XLA program.
Warning: Custom calls are a low-level power-user feature. It is easy to break your program in difficult-to-debug (and even difficult-to-notice) ways using custom-calls. You shouldn't use custom calls unless you're prepared to debug XLA yourself when something goes wrong, and you should expect relatively less assistance from XLA developers if you run into trouble.
Warning: The custom-call API/ABI is not currently stable. We don't intend to change it capriciously, but it may change. Some possible future changes are described below.
Custom-call on CPU
You can create an HLO instruction which represents a custom-call via XLA's client API. This is not exposed via TensorFlow as of writing.
For example, the following code uses a custom-call to compute A[i] = B[i % 128]+ C[i]
on the CPU. (Of course you could -- and should! -- do this with regular HLO.)
Notice that the function do_custom_call
needs to know the dimensions of the buffers it operates over. In this example we hardcode the sizes 128 and 2048. If you don't want to do this, you can pass the dimensions in as parameters to the call.
Custom-call on GPU
The GPU custom call framework is somewhat different than that on the CPU. Here is a CUDA example that does the same A[i] = B[i % 128] + C[i]
computation as the CPU code above.
Notice first that the GPU custom call function is still a function executed on the CPU. Our do_custom_call
CPU function is responsible for enqueueing work on the GPU. Here it launches a CUDA kernel, but it could also do something else, like call cublas.
buffers
is an array of pointers which lives on the host, and each element it contains points to device (i.e. GPU) memory. The parameters come first, followed by the output value. This is notably different from the CPU calling convention, which has two params, ins
and out
. The main reason we diverge is to make it possible to handle tuple-shaped inputs/outputs efficiently; see the section below.
As in the CPU example, we've hardcoded the input and output buffer sizes into our custom call. However unlike in the CPU case, passing the buffer sizes in as operands to the custom call would not work well. Usually we need the buffer sizes available to us on the CPU; e.g. when launching a kernel, we need to know the block/grid dimensions to use. But if we were to pass the buffer sizes as operands to our custom call, their values would live in GPU memory. We'd then have to do an expensive synchronous device-to-host memcpy at the start of our operation just to read the sizes.
To let you work around this, we provide the opaque
parameter. You can set this to an arbitrary string of bytes when you create the custom call:
Since xla::Shape
has a protocol buffer representation, you could store this serialized proto inside of opaque
and deserialize it within your GPU custom-call. Note however that although xla::ShapeProto
does not change frequently, it does change. Check the git log to see how it has changed in the past.
Signalling an error.
If your custom call encounters an error, you can signal the error to the XLA runtime (instead of e.g. crashing or returning nonsense in the output buffers) by using the following signature for your function on CPU:
... and on GPU:
You can signal failure by using XlaCustomCallStatusSetFailure
, e.g.:
You can also use XlaCustomCallStatusSetSuccess
to indicate success, but the XlaCustomCallStatus
is in a success state by default, so ignoring it completely will also indicate success.
When using custom call functions with this signature, you must create the corresponding custom-call
op with the appropriate API version set, e.g.:
NOTE: In the future all clients will be required to migrate their custom call functions to the new API version and the old one will be deprecated. For custom calls that can't fail, you can simply add the new XlaCustomCallStatus*
parameter and then ignore it.
On failure, none of the custom call outputs will be used; the XLA runtime will terminate the computation. It is not possible for an HLO computation to recover from the error (e.g. by catching and handling it).
Passing tuples to custom-calls
Consider the following custom-call.
On both CPU and GPU, a tuple is represented in memory as an array of pointers. In C++-pseudocode, parameter 0 above is laid out as follows.
Although the in-memory representation of tuples is the same in CPU and GPU, they are handled differently in the CPU and GPU custom-call calling conventions.
Tuple outputs as temp buffers
Tuple inputs to custom-calls are a convenience, but they aren't strictly necessary. If we didn't support tuple inputs to custom calls, you could always unpack the tuples using get-tuple-element before passing them to the custom call.
On the other hand, tuple outputs do let you do things you couldn't otherwise.
The obvious reason to have tuple outputs is, that's how a custom call (or any other XLA op) returns multiple independent arrays.
But less obviously, a tuple output is also a way to give your custom call temp memory. Yes, an output can represent a temp buffer. Consider, an output buffer has the property that the op can write to it, and it can read from it after it's been written to. That's exactly what you want from a temp buffer.
In the example above, suppose we wanted to use the F32[1024]
as a temp buffer. Then we'd write the HLO just as above, and we'd simply never read tuple index 1 of the custom call's output.
Tuples in CPU custom-calls
In CPU code, we have a function do_custom_call(const void** ins, void* out)
. ins
is an array with just one element, which points to param0
. The subbuffers of param0
are accessible by dereferencing that pointer, and the subbuffers of output_tuple
are accessible by dereferencing out
.
Tuples in GPU custom-calls
In GPU code, we have a function do_custom_call(..., void** buffers, ...)
. In this case buffers
is a host array of six device pointers, one for each leaf buffer in the input/output. To generate the flat list, we iterate over the parameters and output, and for each we do a preorder traversal of its shape. Concretely: