{"id":17191078,"url":"https://github.com/dfm/gpu-limbdark","last_synced_at":"2025-10-28T09:40:54.985Z","repository":{"id":66097613,"uuid":"399162160","full_name":"dfm/gpu-limbdark","owner":"dfm","description":null,"archived":false,"fork":false,"pushed_at":"2021-08-23T16:06:06.000Z","size":118,"stargazers_count":3,"open_issues_count":0,"forks_count":0,"subscribers_count":2,"default_branch":"main","last_synced_at":"2025-01-30T05:41:27.186Z","etag":null,"topics":[],"latest_commit_sha":null,"homepage":null,"language":"Python","has_issues":true,"has_wiki":null,"has_pages":null,"mirror_url":null,"source_name":null,"license":"mit","status":null,"scm":"git","pull_requests_enabled":true,"icon_url":"https://github.com/dfm.png","metadata":{"files":{"readme":"README.md","changelog":null,"contributing":null,"funding":null,"license":"LICENSE","code_of_conduct":null,"threat_model":null,"audit":null,"citation":null,"codeowners":null,"security":null,"support":null,"governance":null,"roadmap":null,"authors":null,"dei":null,"publiccode":null,"codemeta":null}},"created_at":"2021-08-23T15:51:29.000Z","updated_at":"2022-09-07T20:46:26.000Z","dependencies_parsed_at":"2023-04-15T04:19:02.129Z","dependency_job_id":null,"html_url":"https://github.com/dfm/gpu-limbdark","commit_stats":null,"previous_names":[],"tags_count":0,"template":false,"template_full_name":null,"repository_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories/dfm%2Fgpu-limbdark","tags_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories/dfm%2Fgpu-limbdark/tags","releases_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories/dfm%2Fgpu-limbdark/releases","manifests_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories/dfm%2Fgpu-limbdark/manifests","owner_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/owners/dfm","download_url":"https://codeload.github.com/dfm/gpu-limbdark/tar.gz/refs/heads/main","host":{"name":"GitHub","url":"https://github.com","kind":"github","repositories_count":245407815,"owners_count":20610238,"icon_url":"https://github.com/github.png","version":null,"created_at":"2022-05-30T11:31:42.601Z","updated_at":"2022-07-04T15:15:14.044Z","host_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub","repositories_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories","repository_names_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repository_names","owners_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/owners"}},"keywords":[],"created_at":"2024-10-15T01:24:49.925Z","updated_at":"2025-10-28T09:40:49.966Z","avatar_url":"https://github.com/dfm.png","language":"Python","funding_links":[],"categories":[],"sub_categories":[],"readme":"# Extending JAX with custom C++ and CUDA code\n\n[![Tests](https://github.com/dfm/extending-jax/workflows/Tests/badge.svg)](https://github.com/dfm/extending-jax/actions?query=workflow%3ATests)\n\nThis repository is meant as a tutorial demonstrating the infrastructure required\nto provide custom ops in JAX when you have an existing implementation in C++\nand, optionally, CUDA. I originally wanted to write this as a blog post, but\nthere's enough boilerplate code that I ended up deciding that it made more sense\nto just share it as a repo with the tutorial in the README, so here we are!\n\nThe motivation for this is that in my work I want to use libraries like JAX to\nfit models to data in astrophysics. In these models, there is often at least one\npart of the model specification that is physically motivated and while there are\ngenerally existing implementations of these model elements, it is often\ninefficient or impractical to re-implement these as a high-level JAX function.\nInstead, I want to expose a well-tested and optimized implementation in C\ndirectly to JAX. In my work, this often includes things like iterative\nalgorithms or special functions that are not well suited to implementation using\nJAX directly.\n\nSo, as part of updating my [exoplanet](https://docs.exoplanet.codes) library to\ninterface with JAX, I had to learn what infrastructure was required to support\nthis use case, and since I couldn't find a tutorial that covered all the pieces\nthat I needed in one place, I wanted to put this together. Pretty much\neverything that I'll talk about is covered in more detail somewhere else (even\nif that somewhere is just a comment in some source code), but hopefully this\nsummary can point you in the right direction if you have a use case like this.\n\n**A warning**: I'm writing this in January 2021 and much of what I'm talking\nabout is based on essentially undocumented APIs that are likely to change.\nFurthermore, I'm not affiliated with the JAX project and I'm far from an expert\nso I'm sure there are wrong things that I say. I'll try to update this if I\nnotice things changing or if I learn of issues, but no promises! So, MIT license\nand all that: use at your own risk.\n\n## Related reading\n\nAs I mentioned previously, this tutorial is built on a lot of existing\nliterature and I won't reproduce all the details of those documents here, so I\nwanted to start by listing the key resources that I found useful:\n\n1. The [How primitives work][jax-primitives] tutorial in the JAX documentation\n   includes almost all the details about how to expose a custom op to JAX and\n   spending some quality time with that tutorial is not wasted time. The only\n   thing missing from that document is a description of how to use the XLA\n   CustomCall interface.\n\n2. Which brings us to the [XLA custom calls][xla-custom] documentation. This\n   page is pretty telegraphic, but it includes a description of the interface\n   that your custom call functions need to support. In particular, this is where\n   the differences in interface between the CPU and GPU are described, including\n   things like the \"opaque\" parameter and how multiple outputs are handled.\n\n3. I originally learned how to write the pybind11 interface for an XLA custom\n   call from the [danieljtait/jax_xla_adventures][xla-adventures] repository by\n   Dan Tait on GitHub. Again, this doesn't include very many details, but that's\n   really a benefit here because it really distills the infrastructure to a\n   place where I could understand what was going on.\n\n4. Finally, much of what I know about this topic, I learned from spelunking in\n   the [jaxlib source code][jaxlib] on GitHub. That code is pretty readable and\n   includes good comments most of the time so that's a good place to look if you\n   get stuck since folks there might have already faced the issue.\n\n## What is an \"op\"\n\nIn frameworks like JAX (or Theano, or TensorFlow, or PyTorch, to name a few),\nmodels are defined as a collection of operations or \"ops\" that can be chained,\nfused, or differentiated in clever ways. For our purposes, an op defines a\nfunction that knows:\n\n1. how the input and output parameter shapes and types are related,\n2. how to compute the output from a set of inputs, and\n3. how to propagate derivatives using the chain rule.\n\nThere are a lot of choices about where you draw the lines around a single op and\nthere will be tradeoffs in terms of performance, generality, ease of use, and\nother factors when making these decisions. In my experience, it is often best to\ndefine the minimal scope ops and then allow your framework of choice to combine\nit efficiently with the rest of your model, but there will always be counter\nexamples.\n\n## Our example application: solving Kepler's equation\n\nIn this section I'll describe the application presented in this project. Feel\nfree to skip this if you just want to get to the technical details.\n\nThis project exposes a single jit-able and differentiable JAX operation to solve\n[Kepler's equation][keplers-equation], a tool that is used for computing\ngravitational orbits in astronomy. This is basically the \"hello world\" example\nthat I use whenever learning about something like this. For example, I have\npreviously written [about how to expose such an op when using Stan][stan-cpp].\nThe implementation used in that post and the one used here are not meant to be\nthe most robust or efficient, but it is relatively simple and it exposes some of\nthe interesting issues that one might face when writing custom JAX ops. If\nyou're interested in the mathematical details, take a look at [my blog\npost][stan-cpp], but the key point for now is that this operation involves\nsolving a transcendental equation, and in this tutorial we'll use a simple\niterative method that you'll find in the [kepler.h][kepler-h] header file. Then,\nthe derivatives of this operation can be evaluated using implicit\ndifferentiation. Unlike in the previously mentioned blog post, our operation\nwill actually return the sine and cosine of the eccentric anomaly, since that's\nwhat most high performance versions of this function would return and because\nthe way XLA handles ops with multiple outputs is a little funky.\n\n## The cost/benefit analysis\n\nOne important question to answer first is: \"should I actually write a custom JAX\nextension?\" If you're here, you've probably already thought about that, but I\nwanted to emphasize a few points to consider.\n\n1. **Performance**: The main reason why you might want to implement a custom op\n   for JAX is performance. JAX's JIT compiler can get great performance in a\n   broad range of applications, but for some of the problems I work on,\n   finely-tuned C++ can be much faster. In my experience, iterative algorithms,\n   other special functions, or code with complicated logic are all examples of\n   places where a custom op might greatly improve performance. I'm not always\n   good at doing this, but it's probably worth benchmarking performance of a\n   version of your code implemented directly in high-level JAX against your\n   custom op.\n\n2. **Autodiff**: One thing that is important to realize is that the extension\n   that we write won't magically know how to propagate derivatives. Instead,\n   we'll be required to provide a JAX interface for applying the chain rule to\n   out op. In other words, if you're setting out to wrap that huge Fortran\n   library that has been passed down through the generations, the payoff might\n   not be as great as you hoped unless (a) the code already provides operations\n   for propagating derivatives (in which case you JAX op probably won't support\n   second and higher order differentiation), or (b) you can easily compute the\n   differentiation rules using the algorithm that you already have (which is the\n   case we have for our example application here). In my work, I try (sometimes\n   unsuccessfully) to identify the minimum number and size of ops that I can get\n   away with and then implement most of my models directly in JAX. In our demo\n   application, for example, I could have chosen to make an XLA op generating a\n   full radial velocity model, instead of just solving Kepler's equation, and\n   that might (or might not) give better performance. But, the differentiation\n   rules are _much_ simpler the way it is implemented.\n\n## Summary of the relevant files\n\nThe files in this repo come in three categories:\n\n1. In the root directory, there are the standard packaging files like a\n   `setup.py` and `pyproject.toml`. Most of this setup is pretty standard, but\n   I'll highlight some of the unique elements in the packaging section below.\n   For example, we'll use a slightly strange combination of PEP-517/518 and\n   CMake to build the extensions. This isn't strictly necessary, but it's the\n   easiest packaging setup that I've been able to put together.\n\n2. Next, the `src/kepler_jax` directory is a Python module with the definition\n   of our JAX primitive roughly following the JAX [How primitives\n   work][jax-primitives] tutorial.\n\n3. Finally, the C++ and CUDA code implementing our XLA op live in the `lib`\n   directory. The `pybind11_kernel_helpers.h` and `kernel_helpers.h` headers are\n   boilerplate necessary for building in the interface. The rest of the files\n   include the code specific for this implementation, but I'll describe this in\n   more detail below.\n\n## Defining an XLA custom call on the CPU\n\nThe algorithm for our example problem is is implemented in the `lib/kepler.h`\nheader and I won't go into details about the algorithm here, but the main point\nis that this could be an implementation built on any external library that you\ncan call from C++ and, if you want to support GPU usage, CUDA. That header file\nincludes a single function `compute_eccentric_anomaly` with the following\nsignature:\n\n```c++\ntemplate \u003ctypename T\u003e\nvoid compute_eccentric_anomaly(\n   const T\u0026 mean_anom, const T\u0026 ecc, T* sin_ecc_anom, T* cos_ecc_anom\n);\n```\n\nThis is the function that we want to expose to JAX.\n\nAs described in the [XLA documentation][xla-custom], the signature for a CPU XLA\ncustom call in C++ is:\n\n```c++\nvoid custom_call(void* out, const void** in);\n```\n\nwhere, as you might expect, the elements of `in` point to the input values. So,\nin our case, the inputs are an integer giving the dimension of the problem\n`size`, an array with the mean anomalies `mean_anomaly`, and an array of\neccentricities `ecc`. Therefore, we might parse the input as follows:\n\n```c++\n#include \u003ccstdint\u003e  // int64_t\n\ntemplate \u003ctypename T\u003e\nvoid cpu_kepler(void *out, const void **in) {\n  const std::int64_t size = *reinterpret_cast\u003cconst std::int64_t *\u003e(in[0]);\n  const T *mean_anom = reinterpret_cast\u003cconst T *\u003e(in[1]);\n  const T *ecc = reinterpret_cast\u003cconst T *\u003e(in[2]);\n}\n```\n\nHere we have used a template so that we can support both single and double\nprecision version of the op.\n\nThe output parameter is somewhat more complicated. If your op only has one\noutput, you would access it using\n\n```c++\nT *result = reinterpret_cast\u003cT *\u003e(out);\n```\n\nbut when you have multiple outputs, things get a little hairy. In our example,\nwe have two outputs, the sine `sin_ecc_anom` and cosine `cos_ecc_anom` of the\neccentric anomaly. Therefore, our `out` parameter -- even though it looks like a\n`void*` -- is actually a `void**`! Therefore, we will access the output as\nfollows:\n\n```c++\ntemplate \u003ctypename T\u003e\nvoid cpu_kepler(void *out_tuple, const void **in) {\n  // ...\n  void **out = reinterpret_cast\u003cvoid **\u003e(out_tuple);\n  T *sin_ecc_anom = reinterpret_cast\u003cT *\u003e(out[0]);\n  T *cos_ecc_anom = reinterpret_cast\u003cT *\u003e(out[1]);\n}\n```\n\nThen finally, we actually apply the op and the full implementation, which you\ncan find in `lib/cpu_ops.cc` is:\n\n```c++\n// lib/cpu_ops.cc\n#include \u003ccstdint\u003e\n\ntemplate \u003ctypename T\u003e\nvoid cpu_kepler(void *out_tuple, const void **in) {\n  const std::int64_t size = *reinterpret_cast\u003cconst std::int64_t *\u003e(in[0]);\n  const T *mean_anom = reinterpret_cast\u003cconst T *\u003e(in[1]);\n  const T *ecc = reinterpret_cast\u003cconst T *\u003e(in[2]);\n\n  void **out = reinterpret_cast\u003cvoid **\u003e(out_tuple);\n  T *sin_ecc_anom = reinterpret_cast\u003cT *\u003e(out[0]);\n  T *cos_ecc_anom = reinterpret_cast\u003cT *\u003e(out[1]);\n\n  for (std::int64_t n = 0; n \u003c size; ++n) {\n    compute_eccentric_anomaly(mean_anom[n], ecc[n], sin_ecc_anom + n, cos_ecc_anom + n);\n  }\n}\n```\n\nand that's it!\n\n## Building \u0026 packaging for the CPU\n\nNow that we have an implementation of our XLA custom call target, we need to\nexpose it to JAX. This is done by compiling a CPython module that wraps this\nfunction as a [`PyCapsule`][capsule] type. This can be done using pybind11,\nCython, SWIG, or the Python C API directly, but for this example we'll use\npybind11 since that's what I'm most familiar with. The [LAPACK ops in\njaxlib][jaxlib-lapack] are implemented using Cython if you'd like to see an\nexample of how to do that.\n\nAnother choice that I've made is to use [CMake](https://cmake.org) to build the\nextensions. It would be totally possible (and perhaps preferable if you only\nsupport CPU usage) to stick to just using setuptools directly, but setuptools\ndoesn't seem to have great support for compiling CUDA extensions so that's why I\nsettled on CMake. In the end, it's not too painful since CMake can be included\nas a build dependency in `pyproject.toml` so users won't have to install it\nseparately. Another build option would be to use [bazel](https://bazel.build) to\ncompile the code, like the JAX project, but I don't have any experience with it\nso I decided to stick with what I know. _The key point is that we're just\ncompiling a regular old Python module so you can use whatever infrastructure\nyou're familiar with!_\n\nWith these choices out of the way, the boilerplate code required to define the\ninterface is, using the `cpu_kepler` function defined in the previous section as\nfollows:\n\n```c++\n// lib/cpu_ops.cc\n#include \u003cpybind11/pybind11.h\u003e\n\n// If you're looking for it, this function is actually implemented in\n// lib/pybind11_kernel_helpers.h\ntemplate \u003ctypename T\u003e\npybind11::capsule EncapsulateFunction(T* fn) {\n  return pybind11::capsule((void*)fn, \"xla._CUSTOM_CALL_TARGET\");\n}\n\npybind11::dict Registrations() {\n  pybind11::dict dict;\n  dict[\"cpu_kepler_f32\"] = EncapsulateFunction(cpu_kepler\u003cfloat\u003e);\n  dict[\"cpu_kepler_f64\"] = EncapsulateFunction(cpu_kepler\u003cdouble\u003e);\n  return dict;\n}\n\nPYBIND11_MODULE(cpu_ops, m) { m.def(\"registrations\", \u0026Registrations); }\n```\n\nIn this case, we're exporting a separate function for both single and double\nprecision. Another option would be to pass the data type to the function and\nperform the dispatch logic directly in C++, but I find it cleaner to do it like\nthis.\n\nWith that out of the way, the actual build routine is defined in the following\nfiles:\n\n- In `./pyproject.toml`, we specify that `pybind11` and `cmake` are required\n  build dependencies and that we'll use `setuptools.build_meta` as the build\n  backend.\n\n- `setup.py` is a pretty typical setup file with a custom class for building the\n  extensions that executes CMake for the actual compilation step. This does\n  include some extra configuration arguments for CMake to make sure that it uses\n  the correct Python libraries and installs the compiled objects to the right\n  place. It might be possible to use something like [scikit-build][scikit-build]\n  to replace this step, but I struggled to get it working.\n\n- Finally, `CMakeLists.txt` defines the build process for CMake using\n  [pybind11's support for CMake builds][pybind11-cmake]. This will also,\n  optionally, build the GPU ops as discussed below.\n\nWith these files in place, we can now compile our XLA custom call ops using\n\n```bash\npip install .\n```\n\nThe final thing that I wanted to reiterate in this section is that\n`kepler_jax.cpu_ops` is just a regular old CPython extension module, so anything\nthat you already know about packaging C extensions or any other resources that\nyou can find on that topic can be applied. This wasn't obvious when I first\nstarted learning about this so I definitely went down some rabbit holes that\nhopefully you can avoid.\n\n## Exposing this op as a JAX primitive\n\nThe main components that are required to now call our custom op from JAX are\nwell covered by the [How primitives work][jax-primitives] tutorial, so I won't\nreproduce all of that here. Instead I'll summarize the key points and then\nprovide the missing part. If you haven't already, you should definitely read\nthat tutorial before getting started on this part.\n\nIn summary, we will define a `jax.core.Primitive` object with an \"abstract\nevaluation\" rule (see `src/kepler_jax/kepler_jax.py` for all the details)\nfollowing the primitives tutorial. Then, we'll add a \"translation rule\" and a\n\"JVP rule\". We're lucky in this case, and we don't need to add a \"transpose\nrule\". JAX can actually work that out automatically, since our primitive is not\nitself used in the calculation of the output tangents. This won't always be\ntrue, and the [How primitives work][jax-primitives] tutorial includes an example\nof what to do in that case.\n\nBefore defining these rules, we need to register the custom call target with\nJAX. To do that, we import our compiled `cpu_ops` extension module from above\nand use the `registrations` dictionary that we defined:\n\n```python\nfrom jax.lib import xla_client\nfrom kepler_jax import cpu_ops\n\nfor _name, _value in cpu_ops.registrations().items():\n    xla_client.register_cpu_custom_call_target(_name, _value)\n```\n\nThen, the **translation rule** is defined roughly as follows (the one you'll\nfind in the source code is a little more complicated since it supports both CPU\nand GPU translation):\n\n```python\n# src/kepler_jax/kepler_jax.py\nimport numpy as np\n\ndef _kepler_cpu_translation(c, mean_anom, ecc):\n    # The inputs have \"shapes\" that provide both the shape and the dtype\n    mean_anom_shape = c.get_shape(mean_anom)\n    ecc_shape = c.get_shape(ecc)\n\n    # Extract the dtype and shape\n    dtype = mean_anom_shape.element_type()\n    dims = mean_anom_shape.dimensions()\n    assert ecc_shape.element_type() == dtype\n    assert ecc_shape.dimensions() == dims\n\n    # The total size of the input is the product across dimensions\n    size = np.prod(dims).astype(np.int64)\n\n    # The inputs and outputs all have the same shape so let's predefine this\n    # specification\n    shape = xla_client.Shape.array_shape(\n        np.dtype(dtype), dims, tuple(range(len(dims) - 1, -1, -1))\n    )\n\n    # We dispatch a different call depending on the dtype\n    if dtype == np.float32:\n        op_name = b\"cpu_kepler_f32\"\n    elif dtype == np.float64:\n        op_name = b\"cpu_kepler_f64\"\n    else:\n        raise NotImplementedError(f\"Unsupported dtype {dtype}\")\n\n    # On the CPU, we pass the size of the data as a the first input\n    # argument\n    return xla_client.ops.CustomCallWithLayout(\n        c,\n        op_name,\n        # The inputs:\n        operands=(xla_client.ops.ConstantLiteral(c, size), mean_anom, ecc),\n        # The input shapes:\n        operand_shapes_with_layout=(\n              xla_client.Shape.array_shape(np.dtype(np.int64), (), ()),\n              shape,\n              shape,\n        ),\n        # The output shapes:\n        shape_with_layout=xla_client.Shape.tuple_shape((shape, shape)),\n    )\n\nxla.backend_specific_translations[\"cpu\"][_kepler_prim] = _kepler_cpu_translation\n```\n\nThere appears to be a lot going on here, but most of it is just typechecking.\nThe main meat of it is the `CustomCallWithLayout` function which, as far as I\ncan tell, isn't documented anywhere. Here's a summary of its arguments, as best\nas I can tell:\n\n- The first argument is the XLA builder that you were passed when your\n  translation rule was called.\n\n- The second argument is the name (as `bytes`!) that you gave your `PyCapsule`\n  in the `registrations` dictionary in `lib/cpu_ops.cc`. You can check what\n  names your capsules had by looking at `cpu_ops.registrations().keys()`.\n\n- Then, the following arguments give the input arguments, and the \"shapes\" of\n  the input and output arrays. In this context, a \"shape\" is specified by a data\n  type, a tuple defining the size of each dimension (what I would normally call\n  the shape), and a tuple defining the dimension order. In this case, we're\n  requiring that all of our inputs and outputs are of the same \"shape\".\n\nIt's worth remembering that we're expecting the first argument to our function\nto be the size of the arrays, and you'll see that that is included as a\n`ConstantLiteral` parameter (explicitly cast to `int64`).\n\nI'm not going to talk about the **JVP rule** here since it's quite problem\nspecific, but I've tried to comment the code reasonably thoroughly so check out\nthe code in `src/kepler_jax/kepler_jax.py` if you're interested, and open an\nissue if anything isn't clear.\n\n## Defining an XLA custom call on the GPU\n\nThe custom call on the GPU isn't terribly different from the CPU version above,\nbut the syntax is somewhat different and there's a heck of a lot more\nboilerplate required. Since we need to compile and link CUDA code, there are\nalso a few more packaging steps, but we'll get to that in the next section. The\ndescription in this section is a little all over the place, but the key files to\nlook at to get more info are (a) `lib/gpu_ops.cc` for the dispatch functions\ncalled from Python, and (b) `lib/kernels.cc.cu` for the CUDA code implementing\nthe kernel.\n\nThe signature for the GPU custom call is:\n\n```c++\n// lib/kernels.cc.cu\ntemplate \u003ctypename T\u003e\nvoid gpu_kepler(\n  cudaStream_t stream, void **buffers, const char *opaque, std::size_t opaque_len\n);\n```\n\nThe first parameter is a CUDA stream, which I won't talk about at all because I\ndon't really know very much about GPU programming and we don't really need to\nworry about it for now. Then you'll notice that the inputs and outputs are all\nprovided as a single `void**` buffer. These will be ordered such that our access\ncode from above is replaced by:\n\n```c++\n// lib/kernels.cc.cu\ntemplate \u003ctypename T\u003e\nvoid gpu_kepler(\n  cudaStream_t stream, void **buffers, const char *opaque, std::size_t opaque_len\n) {\n  const T *mean_anom = reinterpret_cast\u003cconst T *\u003e(buffers[0]);\n  const T *ecc = reinterpret_cast\u003cconst T *\u003e(buffers[1]);\n  T *sin_ecc_anom = reinterpret_cast\u003cT *\u003e(buffers[2]);\n  T *cos_ecc_anom = reinterpret_cast\u003cT *\u003e(buffers[3]);\n}\n```\n\nwhere you might notice that the `size` parameter is no longer one of the inputs.\nInstead the array size is passed using the `opaque` parameter since its value is\nrequired on the CPU and within the GPU kernel (see the [XLA custom\ncalls][xla-custom] documentation for more details). To use this `opaque`\nparameter, we will define a type to hold `size`:\n\n```c++\n// lib/kernels.h\nstruct KeplerDescriptor {\n  std::int64_t size;\n};\n```\n\nand then the following boilerplate to serialize it:\n\n```c++\n// lib/kernel_helpers.h\n#include \u003cstring\u003e\n\n// Note that bit_cast is only available in recent C++ standards so you might need\n// to provide a shim like the one in lib/kernel_helpers.h\ntemplate \u003ctypename T\u003e\nstd::string PackDescriptorAsString(const T\u0026 descriptor) {\n  return std::string(bit_cast\u003cconst char*\u003e(\u0026descriptor), sizeof(T));\n}\n\n// lib/pybind11_kernel_helpers.h\n#include \u003cpybind11/pybind11.h\u003e\n\ntemplate \u003ctypename T\u003e\npybind11::bytes PackDescriptor(const T\u0026 descriptor) {\n  return pybind11::bytes(PackDescriptorAsString(descriptor));\n}\n```\n\nThis serialization procedure should then be exposed in the Python module using:\n\n```c++\n// lib/gpu_ops.cc\n#include \u003cpybind11/pybind11.h\u003e\n\nPYBIND11_MODULE(gpu_ops, m) {\n  // ...\n  m.def(\"build_kepler_descriptor\",\n        [](std::int64_t size) {\n          return PackDescriptor(KeplerDescriptor{size});\n        });\n}\n```\n\nThen, to deserialize this descriptor, we can use the following procedure:\n\n```c++\n// lib/kernel_helpers.h\ntemplate \u003ctypename T\u003e\nconst T* UnpackDescriptor(const char* opaque, std::size_t opaque_len) {\n  if (opaque_len != sizeof(T)) {\n    throw std::runtime_error(\"Invalid opaque object size\");\n  }\n  return bit_cast\u003cconst T*\u003e(opaque);\n}\n\n// lib/kernels.cc.cu\ntemplate \u003ctypename T\u003e\nvoid gpu_kepler(\n  cudaStream_t stream, void **buffers, const char *opaque, std::size_t opaque_len\n) {\n  // ...\n  const KeplerDescriptor \u0026d = *UnpackDescriptor\u003cKeplerDescriptor\u003e(opaque, opaque_len);\n  const std::int64_t size = d.size;\n}\n```\n\nOnce we have these parameters, the full procedure for launching the CUDA kernel\nis:\n\n```c++\n// lib/kernels.cc.cu\ntemplate \u003ctypename T\u003e\nvoid gpu_kepler(\n  cudaStream_t stream, void **buffers, const char *opaque, std::size_t opaque_len\n) {\n  const T *mean_anom = reinterpret_cast\u003cconst T *\u003e(buffers[0]);\n  const T *ecc = reinterpret_cast\u003cconst T *\u003e(buffers[1]);\n  T *sin_ecc_anom = reinterpret_cast\u003cT *\u003e(buffers[2]);\n  T *cos_ecc_anom = reinterpret_cast\u003cT *\u003e(buffers[3]);\n  const KeplerDescriptor \u0026d = *UnpackDescriptor\u003cKeplerDescriptor\u003e(opaque, opaque_len);\n  const std::int64_t size = d.size;\n\n  // Select block sizes, etc., no promises that these numbers are the right choices\n  const int block_dim = 128;\n  const int grid_dim = std::min\u003cint\u003e(1024, (size + block_dim - 1) / block_dim);\n\n  // Launch the kernel\n  kepler_kernel\u003cT\u003e\n      \u003c\u003c\u003cgrid_dim, block_dim, 0, stream\u003e\u003e\u003e(size, mean_anom, ecc, sin_ecc_anom, cos_ecc_anom);\n\n  cudaError_t error = cudaGetLastError();\n  if (error != cudaSuccess) {\n    throw std::runtime_error(cudaGetErrorString(error));\n  }\n}\n```\n\nFinally, the kernel itself is relatively simple:\n\n```c++\n// lib/kernels.cc.cu\ntemplate \u003ctypename T\u003e\n__global__ void kepler_kernel(\n  std::int64_t size, const T *mean_anom, const T *ecc, T *sin_ecc_anom, T *cos_ecc_anom\n) {\n  for (std::int64_t idx = blockIdx.x * blockDim.x + threadIdx.x; idx \u003c size;\n       idx += blockDim.x * gridDim.x) {\n    compute_eccentric_anomaly\u003cT\u003e(mean_anom[idx], ecc[idx], sin_ecc_anom + idx, cos_ecc_anom + idx);\n  }\n}\n```\n\n## Building \u0026 packaging for the GPU\n\nSince we're already using CMake to build our project, it's not too hard to add\nsupport for CUDA. I've chosen to enable GPU builds by the environment variable\n`KEPLER_JAX_CUDA=yes` that you'll see in both `setup.py` and `CMakeLists.txt`.\nOther than conditionally adding an `Extension` in `setup.py`, everything else on\nthe Python side is the same. In `CMakeLists.txt`, we also add a conditional:\n\n```cmake\nif (KEPLER_JAX_CUDA)\n  enable_language(CUDA)\n  # ...\nelse()\n  message(STATUS \"Building without CUDA\")\nendif()\n```\n\nThen, to expose this to JAX, we need to update the translation rule from above as follows:\n\n```python\n# src/kepler_jax/kepler_jax.py\nimport numpy as np\nfrom jax.lib import xla_client\nfrom kepler_jax import gpu_ops\n\nfor _name, _value in gpu_ops.registrations().items():\n    xla_client.register_custom_call_target(_name, _value, platform=\"gpu\")\n\ndef _kepler_gpu_translation(c, mean_anom, ecc):\n    # Most of this function is the same as the CPU version above...\n\n    # The name of the op is now prefaced with 'gpu' (our choice, see lib/gpu_ops.cc,\n    # not a requirement)\n    if dtype == np.float32:\n        op_name = b\"gpu_kepler_f32\"\n    elif dtype == np.float64:\n        op_name = b\"gpu_kepler_f64\"\n    else:\n        raise NotImplementedError(f\"Unsupported dtype {dtype}\")\n\n    # We need to serialize the array size using a descriptor\n    opaque = gpu_ops.build_kepler_descriptor(size)\n\n    # The syntax is *almost* the same as the CPU version, but we need to pass the\n    # size using 'opaque' rather than as an input\n    return xla_client.ops.CustomCallWithLayout(\n        c,\n        op_name,\n        operands=(mean_anom, ecc),\n        operand_shapes_with_layout=(shape, shape),\n        shape_with_layout=xla_client.Shape.tuple_shape((shape, shape)),\n        opaque=opaque,\n    )\n\nxla.backend_specific_translations[\"gpu\"][_kepler_prim] = _kepler_gpu_translation\n```\n\nOtherwise, everything else from our CPU implementation doesn't need to change.\n\n## Testing\n\nAs usual, you should always test your code and this repo includes some unit\ntests in the `tests` directory for inspiration. You can also see an example of\nhow to run these tests using the GitHub Actions CI service and the workflow in\n`.github/workflows/tests.yml`. I don't know of any public CI servers that\nprovide GPU support, but I do include a test to confirm that the GPU ops can be\ncompiled. You can see the infrastructure for that test in the `.github/action`\ndirectory.\n\n## See this in action\n\nTo demo the use of this custom op, I put together a notebook, based on [an\nexample from the exoplanet docs][exoplanet-tutorial]. You can see this notebook\nin the `demo.ipynb` file in the root of this repository or open it on Google\nColab:\n\n[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/dfm/extending-jax/blob/main/demo.ipynb)\n\n## References\n\n[jax-primitives]: https://jax.readthedocs.io/en/latest/notebooks/How_JAX_primitives_work.html \"How primitives work\"\n[xla-custom]: https://www.tensorflow.org/xla/custom_call \"XLA custom calls\"\n[xla-adventures]: https://github.com/danieljtait/jax_xla_adventures \"JAX XLA adventures\"\n[jaxlib]: https://github.com/google/jax/tree/master/jaxlib \"jaxlib source code\"\n[keplers-equation]: https://en.wikipedia.org/wiki/Kepler%27s_equation \"Kepler's equation\"\n[stan-cpp]: https://dfm.io/posts/stan-c++/ \"Using external C++ functions with PyStan \u0026 radial velocity exoplanets\"\n[kepler-h]: https://github.com/dfm/extending-jax/blob/main/lib/kepler.h\n[capsule]: https://docs.python.org/3/c-api/capsule.html \"Capsules\"\n[jaxlib-lapack]: https://github.com/google/jax/blob/master/jaxlib/lapack.pyx \"jax/lapack.pyx\"\n[scikit-build]: https://scikit-build.readthedocs.io/ \"scikit-build\"\n[pybind11-cmake]: https://pybind11.readthedocs.io/en/stable/compiling.html#building-with-cmake \"Building with CMake\"\n[exoplanet-tutorial]: https://docs.exoplanet.codes/en/stable/tutorials/intro-to-pymc3/#A-more-realistic-example:-radial-velocity-exoplanets \"A more realistic example: radial velocity exoplanets\"\n","project_url":"https://awesome.ecosyste.ms/api/v1/projects/github.com%2Fdfm%2Fgpu-limbdark","html_url":"https://awesome.ecosyste.ms/projects/github.com%2Fdfm%2Fgpu-limbdark","lists_url":"https://awesome.ecosyste.ms/api/v1/projects/github.com%2Fdfm%2Fgpu-limbdark/lists"}