{"id":15050731,"url":"https://github.com/smrfeld/pytorch-cpp-metal-tutorial","last_synced_at":"2025-04-30T14:43:37.994Z","repository":{"id":210647445,"uuid":"727094173","full_name":"smrfeld/pytorch-cpp-metal-tutorial","owner":"smrfeld","description":"Tutorial for (PyTorch) + (C++) + (Metal shader)","archived":false,"fork":false,"pushed_at":"2023-12-16T00:38:55.000Z","size":17,"stargazers_count":10,"open_issues_count":0,"forks_count":0,"subscribers_count":2,"default_branch":"main","last_synced_at":"2025-03-30T17:07:04.921Z","etag":null,"topics":["apple-silicon","cplusplus","cppextension","metal","mps","objective-c","pybind11","python","pytorch","shader","shaders","tutorial"],"latest_commit_sha":null,"homepage":"","language":"Objective-C++","has_issues":true,"has_wiki":null,"has_pages":null,"mirror_url":null,"source_name":null,"license":"mit","status":null,"scm":"git","pull_requests_enabled":true,"icon_url":"https://github.com/smrfeld.png","metadata":{"files":{"readme":"README.md","changelog":null,"contributing":null,"funding":null,"license":"LICENSE","code_of_conduct":null,"threat_model":null,"audit":null,"citation":null,"codeowners":null,"security":null,"support":null,"governance":null,"roadmap":null,"authors":null,"dei":null,"publiccode":null,"codemeta":null}},"created_at":"2023-12-04T07:06:03.000Z","updated_at":"2025-03-26T15:13:33.000Z","dependencies_parsed_at":"2024-10-12T17:41:00.497Z","dependency_job_id":"6e17828c-18ce-4d07-8dbf-5e7b70065271","html_url":"https://github.com/smrfeld/pytorch-cpp-metal-tutorial","commit_stats":null,"previous_names":["smrfeld/pytorch-cpp-metal-tutorial"],"tags_count":0,"template":false,"template_full_name":null,"repository_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories/smrfeld%2Fpytorch-cpp-metal-tutorial","tags_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories/smrfeld%2Fpytorch-cpp-metal-tutorial/tags","releases_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories/smrfeld%2Fpytorch-cpp-metal-tutorial/releases","manifests_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories/smrfeld%2Fpytorch-cpp-metal-tutorial/manifests","owner_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/owners/smrfeld","download_url":"https://codeload.github.com/smrfeld/pytorch-cpp-metal-tutorial/tar.gz/refs/heads/main","host":{"name":"GitHub","url":"https://github.com","kind":"github","repositories_count":251723147,"owners_count":21633097,"icon_url":"https://github.com/github.png","version":null,"created_at":"2022-05-30T11:31:42.601Z","updated_at":"2022-07-04T15:15:14.044Z","host_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub","repositories_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories","repository_names_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repository_names","owners_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/owners"}},"keywords":["apple-silicon","cplusplus","cppextension","metal","mps","objective-c","pybind11","python","pytorch","shader","shaders","tutorial"],"created_at":"2024-09-24T21:29:08.502Z","updated_at":"2025-04-30T14:43:37.963Z","avatar_url":"https://github.com/smrfeld.png","language":"Objective-C++","readme":"# Tutorial for custom Metal shaders using PyTorch \u0026 C++\n\nThis is a minimal example of a Python package calling a custom PyTorch C++ module that is using **Metal** shader (on Mac).\n\nSee also the associated [Medium](https://medium.com/practical-coding/metal-shaders-with-pytorch-from-end-to-end-c95370b3449b) article: [Metal shaders with PyTorch from end to end](https://medium.com/practical-coding/metal-shaders-with-pytorch-from-end-to-end-c95370b3449b)\n\n## Installing \u0026 running\n\n0. (Optional) Create a conda environment:\n\n    ```bash\n    conda create -n test-pytorch-cpp python=3.11\n    conda activate test-pytorch-cpp\n    ```\n\n1. Install requirements:\n    ```bash\n    pip install -r requirements.txt\n    ```\n\n2. Install package using `setup.py`:\n    ```bash\n    pip install -e .\n    ```\n\n3. Run the test:\n    ```bash\n    python main.py\n    ```\n    Expected result:\n    ```\n    tensor([5., 7., 9.])\n    ```\n\n## Other good examples\n\n* [https://github.com/open-mmlab/mmcv/blob/main/setup.py](https://github.com/open-mmlab/mmcv/blob/main/setup.py)\n\n## About\n\nGoal: We will write from scratch a Python library that compiles a Metal shader using `C++`/`Objective-C` and lets you call the method from Python using `pybind11`.\n\n### Project setup\n\nWe will create from scratch a new Python package called my_extension. This package will expose a method to add two `PyTorch` Tensors together which are on `MPS` device using a custom Metal shader. Create a new directory with the following structure:\n```\nmy_extension/\nmy_extension/__init__.py\nmy_extension/add_tensors.metal\nmy_extension/cpp_extension.mm\nmy_extension/wrapper.py\nsetup.py\n```\n\nHere the package is build out of the `my_extension` folder. The wrapper.py contains the wrapper code that will call the compiled `C++` library. This is defined in the `cpp_extension.mm`, which mixes `C++` and `Objective-C` to call the shader `add_tensors.metal`.\n\n### setup.py file\n\nLet’s take a look at the [setup.py](setup.py) file. The main action happens in this if statement:\n```python\n    if (hasattr(torch.backends, 'mps') and torch.backends.mps.is_available()):\n```\nwhere we check if `mps` is available. If so, we define how to handle `.mm` files which mix `Objective-C` and `C++`:\n```python\n        from distutils.unixccompiler import UnixCCompiler\n        if '.mm' not in UnixCCompiler.src_extensions:\n            UnixCCompiler.src_extensions.append('.mm')\n            UnixCCompiler.language_map['.mm'] = 'objc'\nand add the Metal framework:\n        extra_compile_args = {}\n        extra_compile_args['cxx'] = [\n            '-Wall', \n            '-std=c++17',\n            '-framework', \n            'Metal', \n            '-framework', \n            'Foundation',\n            '-ObjC++'\n            ]\n```\n\nThere are two packages being defined here:\n1. `my_extension` — this is the final Python package that we want to create. It is defined by the setup command in the last line:\n    ```python\n    setup(\n        name='my_extension',\n        version=\"0.0.1\",\n        packages=find_packages(),\n        include_package_data=True,\n        python_requires='\u003e=3.11',\n        ext_modules=get_extensions(),\n        cmdclass={'build_ext': BuildExtension},\n        zip_safe=False\n    )\n    ```\n2. `my_extension_cpp` — this is a `C++` library that will call the metal shader. It is defined through the `ext_modules` argument in the setup method, specifically in this line:\n    ```python\n    ext_ops = CppExtension(\n        name='my_extension_cpp',\n        sources=['my_extension/cpp_extension.mm'],\n        include_dirs=[],\n        extra_objects=[],\n        extra_compile_args=extra_compile_args,\n        library_dirs=[],\n        libraries=[],\n        extra_link_args=[]\n        )\n    ```\n\n### Python Wrapper\n\nWe now have a project structure that creates a `C++` library called `my_extension_cpp` and a Python package called `my_extension`.\n\nNext, let’s look at the Python wrapper `wrapper.py` defined in `my_extension/wrapper.py`:\n```python\nimport torch\nimport my_extension_cpp\n\n# Define a wrapper function\ndef add_tensors(a: torch.Tensor, b: torch.Tensor) -\u003e torch.Tensor:\n\n    # Find the shader file path\n    import pkg_resources\n    shader_file_path = pkg_resources.resource_filename('my_extension', 'add_tensors.metal')\n\n    # Call the C++ function\n    return my_extension_cpp.add_tensors_metal(a, b, shader_file_path)\n```\n\nHere we just expose the methods defined in `my_extension_cpp` to the `Python` interface. This adds one extra layer between the `C++` interface and the `Python` interface, which is often very useful as the usage can be quite different. For example, here we locate the .metal shader file using `Python` and pass it as argument to the `C++` extension function `add_tensors_metal(...)`.\n\nDon’t forget to also expose this in the `__init__.py`:\n```python\nfrom .wrapper import add_tensors\n```\n\n### Metal shader\n\nLet’s take a look at the actual Metal shader we want to use — `add_tensors.metal`:\n```metal\n#include \u003cmetal_stdlib\u003e\nusing namespace metal;\n\n// Define a simple kernel function to add two tensors\nkernel void addTensors(device float *a [[buffer(0)]],\n                       device float *b [[buffer(1)]],\n                       device float *result [[buffer(2)]],\n                       uint id [[thread_position_in_grid]]) {\n    // Perform addition if within tensor bounds\n    result[id] = a[id] + b[id];\n}\n```\n\nWe import the metal framework `#include \u003cmetal_stdlib\u003e`. In the `addTensors` method, we have the `[[thread_position_in_grid]]`. Straight from the [Apple docs](https://developer.apple.com/documentation/metal/compute_passes/creating_threads_and_threadgroups):\n\n\u003e [[thread_position_in_grid]] is an attribute qualifier. Attribute qualifiers, identifiable by their double square-bracket syntax, allow kernel parameters to be bound to resources and built-in variables, in this case the thread’s position in the grid to the kernel function.\n\nIt is the position of the thread in the threadgroup (threads make up thread groups; thread groups make up grids).\n\nThe result is written to the output buffer `float *result`. The device qualifier indicates that the pointer refers to memory on the GPU.\n\n### Calling the Metal shader from C++\n\nFinally, let’s write the [my_extension/cpp_extension.mm](my_extension/cpp_extension.mm) file in `C++` and `Objective-C` that calls the `.metal` shader.\n\nThere’s a lot to unpack here — let’s start at the very bottom:\n\n```cpp\nPYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {\n    m.def(\"add_tensors_metal\", \u0026add_tensors_metal, \"Add two tensors using Metal\");\n}\n```\n\nThis uses `pybind11` to expose the `add_tensors_metal` function to `Python`, so that we can call it in the `wrapper.py`.\n\nIn the actual function, we load the shader file and compile it:\n\n```cpp\n    // Get the default Metal device\n    id\u003cMTLDevice\u003e device = MTLCreateSystemDefaultDevice();\n\n    // Load the Metal shader from the specified path\n    NSError* error = nil;\n    NSString* shaderSource = [\n        NSString stringWithContentsOfFile:[NSString stringWithUTF8String:shaderFilePath.c_str()]\n        encoding:NSUTF8StringEncoding \n        error:\u0026error];\n    if (error) {\n        throw std::runtime_error(\"Failed to load Metal shader: \" + std::string(error.localizedDescription.UTF8String));\n    }\n\n    // Compile the Metal shader source\n    id\u003cMTLLibrary\u003e library = [device newLibraryWithSource:shaderSource options:nil error:\u0026error];\n    if (!library) {\n        throw std::runtime_error(\"Error compiling Metal shader: \" + std::string(error.localizedDescription.UTF8String));\n    }\n```\n\nEnsure that the function exists\n\n```cpp\n    id\u003cMTLFunction\u003e function = [library newFunctionWithName:@\"addTensors\"];\n    if (!function) {\n        throw std::runtime_error(\"Error: Metal function addTensors not found.\");\n    }\n```\n\nConvert the torch Tensors into buffers\n\n```cpp\n    // Create a Metal compute pipeline state\n    id\u003cMTLComputePipelineState\u003e pipelineState = [device newComputePipelineStateWithFunction:function error:nil];\n\n    // Create Metal buffers for the tensors\n    id\u003cMTLBuffer\u003e aBuffer = [device newBufferWithBytes:a.data_ptr() length:(numElements * sizeof(float)) options:MTLResourceStorageModeShared];\n    id\u003cMTLBuffer\u003e bBuffer = [device newBufferWithBytes:b.data_ptr() length:(numElements * sizeof(float)) options:MTLResourceStorageModeShared];\n    id\u003cMTLBuffer\u003e resultBuffer = [device newBufferWithLength:(numElements * sizeof(float)) options:MTLResourceStorageModeShared];\n\n    // Create a command queue\n    id\u003cMTLCommandQueue\u003e commandQueue = [device newCommandQueue];\n\n    // Create a command buffer\n    id\u003cMTLCommandBuffer\u003e commandBuffer = [commandQueue commandBuffer];\n\n    // Create a compute command encoder\n    id\u003cMTLComputeCommandEncoder\u003e encoder = [commandBuffer computeCommandEncoder];\n\n    // Set the compute pipeline state\n    [encoder setComputePipelineState:pipelineState];\n\n    // Set the buffers\n    [encoder setBuffer:aBuffer offset:0 atIndex:0];\n    [encoder setBuffer:bBuffer offset:0 atIndex:1];\n    [encoder setBuffer:resultBuffer offset:0 atIndex:2];\n```\n\nWe set the grid size and thread group size as the max allowed:\n\n```cpp\n    // Dispatch the compute kernel\n    MTLSize gridSize = MTLSizeMake(numElements, 1, 1);\n    NSUInteger threadGroupSize = pipelineState.maxTotalThreadsPerThreadgroup;\n    if (threadGroupSize \u003e numElements) {\n        threadGroupSize = numElements;\n    }\n    MTLSize threadgroupSize = MTLSizeMake(threadGroupSize, 1, 1);\n    [encoder dispatchThreads:gridSize threadsPerThreadgroup:threadgroupSize];\n    [encoder endEncoding];\n```\n\nExecute the shader\n\n```cpp\n    // Commit the command buffer and wait for it to complete\n    [commandBuffer commit];\n    [commandBuffer waitUntilCompleted];\n```\n\nAnd finally copy the result back to a `torch` Tensor:\n\n```cpp\n    // Create an empty tensor on the MPS device to hold the result\n    torch::Tensor result = torch::empty({numElements}, torch::TensorOptions().dtype(torch::kFloat).device(torch::kMPS));\n\n    // Copy the result from the Metal buffer to the MPS tensor\n    id\u003cMTLBuffer\u003e resultBufferMPS = [device newBufferWithBytesNoCopy:result.data_ptr()\n                                                                length:(numElements * sizeof(float))\n                                                            options:MTLResourceStorageModeShared\n                                                        deallocator:nil];\n\n    return result;\n```\n\n### Test run\n\nLet’s create a test file `main.py` to execute the shader:\n\n```python\nimport torch\nimport my_extension\n\na = torch.tensor([1.0, 2.0, 3.0]).to('mps')\nb = torch.tensor([4.0, 5.0, 6.0]).to('mps')\nprint(f\"Input tensor a: {a}\")\nprint(f\"Input tensor b: {b}\")\nprint(f\"Input device: {a.device}\")\n\nresult = my_extension.add_tensors(a, b)\nprint(f\"Addition result: {result}\")\nprint(f\"Output device {result.device}\")\nassert result.device == torch.device('mps:0'), \"Output tensor is (maybe?) not on the MPS device\"\n```\n\nwhich uses input Tensors on the `MPS` device, and should give the following output to verify that the result is on the `MPS` device:\n\n```\nInput tensor a: tensor([1., 2., 3.], device='mps:0')\nInput tensor b: tensor([4., 5., 6.], device='mps:0')\nInput device: mps:0\nAddition result: tensor([4., 5., 6.], device='mps:0')\nOutput device mps:0\n```\n\n### Closing thoughts\n\nThanks for reading! I had a lot of fun learning about metal shaders and hope you did as well. I found this [example](https://github.com/open-mmlab/mmcv/blob/main/setup.py) `setup.py` file pretty useful to look at, as well as the [official Apple docs](https://developer.apple.com/documentation/metal/compute_passes/creating_threads_and_threadgroups) to understand threads, threadgroups and grids.","funding_links":[],"categories":[],"sub_categories":[],"project_url":"https://awesome.ecosyste.ms/api/v1/projects/github.com%2Fsmrfeld%2Fpytorch-cpp-metal-tutorial","html_url":"https://awesome.ecosyste.ms/projects/github.com%2Fsmrfeld%2Fpytorch-cpp-metal-tutorial","lists_url":"https://awesome.ecosyste.ms/api/v1/projects/github.com%2Fsmrfeld%2Fpytorch-cpp-metal-tutorial/lists"}