{"id":15903106,"url":"https://github.com/abeleinin/Metal-Puzzles","last_synced_at":"2025-10-18T06:30:35.016Z","repository":{"id":256658160,"uuid":"855567211","full_name":"abeleinin/Metal-Puzzles","owner":"abeleinin","description":"Solve Puzzles. Learn Metal 🤘","archived":false,"fork":false,"pushed_at":"2024-09-24T05:45:04.000Z","size":4022,"stargazers_count":164,"open_issues_count":0,"forks_count":7,"subscribers_count":2,"default_branch":"main","last_synced_at":"2024-10-06T12:01:59.755Z","etag":null,"topics":["gpu-programming","metal","mlx","puzzles"],"latest_commit_sha":null,"homepage":"","language":"Jupyter Notebook","has_issues":true,"has_wiki":null,"has_pages":null,"mirror_url":null,"source_name":null,"license":"mit","status":null,"scm":"git","pull_requests_enabled":true,"icon_url":"https://github.com/abeleinin.png","metadata":{"files":{"readme":"README.md","changelog":null,"contributing":null,"funding":null,"license":"LICENSE","code_of_conduct":null,"threat_model":null,"audit":null,"citation":null,"codeowners":null,"security":null,"support":null,"governance":null,"roadmap":null,"authors":null,"dei":null,"publiccode":null,"codemeta":null}},"created_at":"2024-09-11T04:47:51.000Z","updated_at":"2024-10-05T05:06:17.000Z","dependencies_parsed_at":"2024-09-18T06:58:28.434Z","dependency_job_id":null,"html_url":"https://github.com/abeleinin/Metal-Puzzles","commit_stats":null,"previous_names":["abeleinin/metal-puzzles"],"tags_count":0,"template":false,"template_full_name":null,"repository_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories/abeleinin%2FMetal-Puzzles","tags_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories/abeleinin%2FMetal-Puzzles/tags","releases_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories/abeleinin%2FMetal-Puzzles/releases","manifests_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories/abeleinin%2FMetal-Puzzles/manifests","owner_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/owners/abeleinin","download_url":"https://codeload.github.com/abeleinin/Metal-Puzzles/tar.gz/refs/heads/main","host":{"name":"GitHub","url":"https://github.com","kind":"github","repositories_count":236907708,"owners_count":19223638,"icon_url":"https://github.com/github.png","version":null,"created_at":"2022-05-30T11:31:42.601Z","updated_at":"2022-07-04T15:15:14.044Z","host_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub","repositories_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories","repository_names_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repository_names","owners_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/owners"}},"keywords":["gpu-programming","metal","mlx","puzzles"],"created_at":"2024-10-06T12:00:59.618Z","updated_at":"2025-10-18T06:30:29.153Z","avatar_url":"https://github.com/abeleinin.png","language":"Jupyter Notebook","funding_links":[],"categories":["Jupyter Notebook","Other"],"sub_categories":[],"readme":"# Metal Puzzles\n\nPort of [srush/GPU-Puzzles](https://github.com/srush/GPU-Puzzles) to [Metal](https://en.wikipedia.org/wiki/Metal_API) using [MLX Custom Kernals](https://ml-explore.github.io/mlx/build/html/dev/custom_metal_kernels.html). Inspired by [@awnihannun](https://x.com/awnihannun/status/1833376670063202536)!\n\n![Metal Puzzles Logo](./imgs/metal_puzzles.png)\n\nGPUs are crucial in machine learning because they can process data on a massively parallel scale. While it's possible to become an expert in machine learning without writing any GPU code, building intuition is challenging when you're only working through layers of abstraction. Additionally, as models grow in complexity, the need for developers to write efficient, high-performance kernels becomes increasingly important to leverage the power of modern hardware.\n\nWhether you're new to GPU programming or have experience with CUDA, the following puzzles provide a straightforward way to learn on an Apple Silicon computer. In the following exercises, you'll use the `mx.fast.metal_kernel()` function from Apple's [mlx](https://github.com/ml-explore/mlx) framework, which allows you to write custom Metal kernels through a Python/C++ API. For verification purposes, I've created a wrapper class around `mx.fast.metal_kernel()` called `MetalKernel`, but the interface remains identical.\n\nIf you're interested in more material, check out the [MLX Custom Metal Kernels Documentation](https://ml-explore.github.io/mlx/build/html/dev/custom_metal_kernels.html) and the [Metal Shading Language specification](https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf).\n\n```sh\npip install -qqq git+https://github.com/danoneata/chalk@srush-patch-1\npip install mlx\n```\n\n```python\nimport mlx.core as mx\nfrom utils import MetalKernel, MetalProblem\n```\n\n## Puzzle 1: Map\n\nImplement a \"kernel\" (GPU function) that adds 10 to each position of the array `a` and stores it in the array `out`.  You have 1 thread per position.\n\n**Note:** The `source` string below is the body of your Metal kernel, the function signature will be automatically generated for you. Below you'll notice the `input_names` and `output_names` parameters. These define the parameters for your Metal kernel.\n\n**Tip:** If you need a tool for debugging your Kernel read the [Metal Debugger](#metal-debugger) section below. Also, you can print out the generated Metal kernel by setting the environment variable `VERBOSE=1`.\n\n```python\ndef map_spec(a: mx.array):\n    return a + 10\n\ndef map_test(a: mx.array):\n    source = \"\"\"\n        uint local_i = thread_position_in_grid.x;\n        // FILL ME IN (roughly 1 line)\n    \"\"\"\n\n    kernel = MetalKernel(\n        name=\"map\",\n        input_names=[\"a\"],\n        output_names=[\"out\"],\n        source=source,\n    )\n\n    return kernel\n\nSIZE = 4\na = mx.arange(SIZE)\noutput_shape = (SIZE,)\n\nproblem = MetalProblem(\n    \"Map\",\n    map_test,\n    [a], \n    output_shape,\n    grid=(SIZE,1,1), \n    spec=map_spec\n)\nproblem.show()\n```\n\n```\n# Map\n \n   Score (Max Per Thread):\n   |  Global Reads | Global Writes |  Shared Reads | Shared Writes |\n   |             0 |             0 |             0 |             0 | \n```\n\n\u003cimg src=\"imgs/metal_puzzles_map.png\" height=\"500\"\u003e\n\n```python\nproblem.check()\n```\n\n```\nFailed Tests.\nYours: array([0, 0, 0, 0], dtype=float32)\nSpec : array([10, 11, 12, 13], dtype=int32)\n```\n\n## Puzzle 2: Zip \n\nImplement a kernel that takes two arrays `a` and `b`, adds each element together, and stores the result in the output array `out`. You have 1 thread per position.\n\n```python\ndef zip_spec(a: mx.array, b: mx.array):\n    return a + b\n\ndef zip_test(a: mx.array, b: mx.array):\n    source = \"\"\"\n        uint local_i = thread_position_in_grid.x;\n        // FILL ME IN (roughly 1 line)\n    \"\"\"\n\n    kernel = MetalKernel(\n        name=\"zip\",\n        input_names=[\"a\", \"b\"],\n        output_names=[\"out\"],\n        source=source,\n    )\n\n    return kernel\n\nSIZE = 4\na = mx.arange(SIZE)\nb = mx.arange(SIZE)\noutput_shapes = (SIZE,)\n\nproblem = MetalProblem(\n    \"Zip\",\n    zip_test,\n    [a, b],\n    output_shapes,\n    grid=(SIZE,1,1),\n    spec=zip_spec\n)\nproblem.show()\n```\n\n```\n# Zip\n \n   Score (Max Per Thread):\n   |  Global Reads | Global Writes |  Shared Reads | Shared Writes |\n   |             0 |             0 |             0 |             0 | \n```\n\n\u003cimg src=\"imgs/metal_puzzles_zip.png\" height=\"500\"\u003e\n\n```python\nproblem.check()\n```\n\n```\nFailed Tests.\nYours: array([0, 0, 0, 0], dtype=float32)\nSpec : array([0, 2, 4, 6], dtype=int32)\n```\n\n## Puzzle 3: Guard\n\nImplement a kernel that adds 10 to each position of `a` and stores it in `out`. You have more threads than positions.\n\n**Warning:** Be careful of out-of-bounds access.\n\n**Note:** You can append `_shape`, `_strides`, or `_ndim` to any input parameter to automatically add that data as a paramter to your kerenls. So, in the following puzzle you could use `a_shape`, `a_strides`, or `a_ndim`.\n\n```python\ndef map_guard_test(a: mx.array):\n    source = \"\"\"\n        uint local_i = thread_position_in_grid.x;\n        // FILL ME IN (roughly 1-3 lines)\n    \"\"\"\n\n    kernel = MetalKernel(\n        name=\"guard\",\n        input_names=[\"a\"],\n        output_names=[\"out\"],\n        source=source,\n    )\n\n    return kernel\n\nSIZE = 4\na = mx.arange(SIZE)\noutput_shape = (SIZE,)\n\nproblem = MetalProblem(\n    \"Guard\",\n    map_guard_test,\n    [a], \n    output_shape,\n    grid=(8,1,1), \n    spec=map_spec\n)\nproblem.show()\n```\n\n```\n# Guard\n \n   Score (Max Per Thread):\n   |  Global Reads | Global Writes |  Shared Reads | Shared Writes |\n   |             0 |             0 |             0 |             0 | \n```\n\n\u003cimg src=\"imgs/metal_puzzles_guard.png\" height=\"500\"\u003e\n\n```python\nproblem.check()\n```\n\n```\nFailed Tests.\nYours: array([0, 0, 0, 0], dtype=float32)\nSpec : array([10, 11, 12, 13], dtype=int32)\n```\n\n## Puzzle 4: Map 2D\n\nImplement a kernel that adds 10 to each position of `a` and stores it in `out`. Input `a` is 2D and square. You have more threads than positions.\n\n**Note:** All memory in Metal is represented as a 1D array, so direct 2D indexing is not supported.\n\n```python\ndef map_2D_test(a: mx.array):\n    source = \"\"\"\n        uint thread_x = thread_position_in_grid.x;\n        uint thread_y = thread_position_in_grid.y;\n        // FILL ME IN (roughly 4 lines)\n    \"\"\"\n\n    kernel = MetalKernel(\n        name=\"map_2D\",\n        input_names=[\"a\"],\n        output_names=[\"out\"],\n        source=source,\n    )\n\n    return kernel\n\nSIZE = 2\na = mx.arange(SIZE * SIZE).reshape((SIZE, SIZE))\noutput_shape = (SIZE,SIZE)\n\nproblem = MetalProblem(\n    \"Map 2D\",\n    map_2D_test,\n    [a], \n    output_shape,\n    grid=(3,3,1), \n    spec=map_spec\n)\nproblem.show()\n```\n\n```\n# Map 2D\n \n   Score (Max Per Thread):\n   |  Global Reads | Global Writes |  Shared Reads | Shared Writes |\n   |             0 |             0 |             0 |             0 | \n```\n\n\u003cimg src=\"imgs/metal_puzzles_map_2D.png\" height=\"500\"\u003e\n\n```python\nproblem.check()\n```\n\n```\nFailed Tests.\nYours: array([[0, 0],\n       [0, 0]], dtype=float32)\nSpec : array([[10, 11],\n       [12, 13]], dtype=int32)\n```\n\n## Puzzle 5: Broadcast\n\nImplement a kernel that adds `a` and `b` and stores it in `out`. Inputs `a` and `b` are arrays. You have more threads than positions.\n\n```python\ndef broadcast_test(a: mx.array, b: mx.array):\n    source = \"\"\"\n        uint thread_x = thread_position_in_grid.x;\n        uint thread_y = thread_position_in_grid.y;\n        // FILL ME IN (roughly 4 lines)\n    \"\"\"\n\n    kernel = MetalKernel(\n        name=\"broadcast\",\n        input_names=[\"a\", \"b\"],\n        output_names=[\"out\"],\n        source=source,\n    )\n\n    return kernel\n\nSIZE = 2\na = mx.arange(SIZE).reshape(SIZE, 1)\nb = mx.arange(SIZE).reshape(1, SIZE)\noutput_shape = (SIZE,SIZE)\n\nproblem = MetalProblem(\n    \"Broadcast\",\n    broadcast_test,\n    [a, b], \n    output_shape,\n    grid=(3,3,1), \n    spec=zip_spec\n)\nproblem.show()\n```\n\n```\n# Broadcast\n \n   Score (Max Per Thread):\n   |  Global Reads | Global Writes |  Shared Reads | Shared Writes |\n   |             0 |             0 |             0 |             0 | \n```\n\n\u003cimg src=\"imgs/metal_puzzles_broadcast.png\" height=\"500\"\u003e\n\n```python\nproblem.check()\n```\n\n```\nFailed Tests.\nYours: array([[0, 0],\n       [0, 0]], dtype=float32)\nSpec : array([[0, 1],\n       [1, 2]], dtype=int32)\n```\n\n## Puzzle 6: Threadgroups\n\nImplement a kernel that adds 10 to each position of `a` and stores it in `out`. You have fewer threads per threadgroup than the size of `a`, but more threads than positions.\n\n**Note:** A threadgroup is simply a group of threads within the thread grid. The number of threads per threadgroup is limited to a defined number, but we can have multiple different threadgroups. The Metal parameter `threadgroup_position_in_grid` tells us what threadgroup we are in.\n\n```python\ndef map_threadgroup_test(a: mx.array):\n    source = \"\"\"\n        uint i = threadgroup_position_in_grid.x * threads_per_threadgroup.x + thread_position_in_threadgroup.x;\n        // FILL ME IN (roughly 1-3 lines)\n    \"\"\"\n\n    kernel = MetalKernel(\n        name=\"threadgroups\",\n        input_names=[\"a\"],\n        output_names=[\"out\"],\n        source=source,\n    )\n\n    return kernel\n\nSIZE = 9\na = mx.arange(SIZE)\noutput_shape = (SIZE,)\n\nproblem = MetalProblem(\n    \"Threadgroups\",\n    map_threadgroup_test,\n    [a], \n    output_shape,\n    grid=(12,1,1), \n    threadgroup=(4,1,1),\n    spec=map_spec\n)\nproblem.show()\n```\n\n```\n# Threadgroups\n\n    Score (Max Per Thread):\n    |  Global Reads | Global Writes |  Shared Reads | Shared Writes |\n    |             0 |             0 |             0 |             0 | \n```\n\n\u003cimg src=\"imgs/metal_puzzles_threadgroups.png\" height=\"500\"\u003e\n\n```python\nproblem.check()\n```\n\n```\nFailed Tests.\nYours: array([0, 0, 0, 0, 0, 0, 0, 0, 0], dtype=float32)\nSpec : array([10, 11, 12, 13, 14, 15, 16, 17, 18], dtype=int32)\n```\n\n## Puzzle 7: Threadgroups 2D\n\nImplement the same kernel in 2D. You have fewer threads per threadgroup than the size of `a` in both directions, but more threads than positions in the grid.\n\n```python\ndef map_threadgroup_2D_test(a: mx.array):\n    source = \"\"\"\n        uint i = threadgroup_position_in_grid.x * threads_per_threadgroup.x + thread_position_in_threadgroup.x;\n        // FILL ME IN (roughly 5 lines)\n    \"\"\"\n\n    kernel = MetalKernel(\n        name=\"threadgroups_2D\",\n        input_names=[\"a\"],\n        output_names=[\"out\"],\n        source=source,\n    )\n\n    return kernel\n\nSIZE = 5\na = mx.ones((SIZE, SIZE))\noutput_shape = (SIZE, SIZE)\n\nproblem = MetalProblem(\n    \"Threadgroups 2D\",\n    map_threadgroup_2D_test,\n    [a], \n    output_shape,\n    grid=(6,6,1), \n    threadgroup=(3,3,1),\n    spec=map_spec\n)\nproblem.show()\n```\n\n```\n# Threadgroups 2D\n\n    Score (Max Per Thread):\n    |  Global Reads | Global Writes |  Shared Reads | Shared Writes |\n    |             0 |             0 |             0 |             0 | \n```\n\n\u003cimg src=\"imgs/metal_puzzles_threadgroup_2D.png\" height=\"500\"\u003e\n\n```python\nproblem.check()\n```\n\n```\nFailed Tests.\nYours: array([[0, 0, 0, 0, 0],\n    [0, 0, 0, 0, 0],\n    [0, 0, 0, 0, 0],\n    [0, 0, 0, 0, 0],\n    [0, 0, 0, 0, 0]], dtype=float32)\nSpec : array([[11, 11, 11, 11, 11],\n    [11, 11, 11, 11, 11],\n    [11, 11, 11, 11, 11],\n    [11, 11, 11, 11, 11],\n    [11, 11, 11, 11, 11]], dtype=float32)\n```\n\n## Puzzle 8: Threadgroup Memory\n\nImplement a kernel that adds 10 to each position of `a` and stores it in `out`. You have fewer threads per threadgroup than the size of `a`.\n\n**Warning**: Each threadgroup can only have a *constant* amount of threadgroup memory that the threads can read and write to. After writing to threadgroup memory, you need to call `threadgroup_barrier(mem_flags::mem_threadgroup)` to ensure that threads are synchronized. In this puzzle we add the `header` variable as a new parameter to the `MetalKernel` object, which simply defines values outside of the kernel body (often used for header imports).\n\nFor more information read section [4.4 Threadgroup Address Space](https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf#page=86) and section [6.9 Synchronization and SIMD-Group Functions](https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf#page=177) in the Metal Shading Language Specification.\n\n(This example does not really need threadgroup memory or synchronization, but it's a demo.)\n\n```python\ndef shared_test(a: mx.array):\n    header = \"\"\"\n        constant uint THREADGROUP_MEM_SIZE = 4;\n    \"\"\"\n\n    source = \"\"\"\n        threadgroup float shared[THREADGROUP_MEM_SIZE];\n        uint i = threadgroup_position_in_grid.x * threads_per_threadgroup.x + thread_position_in_threadgroup.x;\n        uint local_i = thread_position_in_threadgroup.x;\n\n        if (i \u003c a_shape[0]) {\n            shared[local_i] = a[i];\n            threadgroup_barrier(mem_flags::mem_threadgroup);\n        }\n\n        // FILL ME IN (roughly 1-3 lines)\n    \"\"\"\n\n    kernel = MetalKernel(\n        name=\"threadgroup_memory\",\n        input_names=[\"a\"],\n        output_names=[\"out\"],\n        header=header,\n        source=source,\n    )\n\n    return kernel\n\nSIZE = 8\na = mx.ones(SIZE)\noutput_shape = (SIZE,)\n\nproblem = MetalProblem(\n    \"Threadgroup Memory\",\n    shared_test,\n    [a], \n    output_shape,\n    grid=(SIZE,1,1), \n    threadgroup=(4,1,1),\n    spec=map_spec\n)\nproblem.show()\n```\n\n```\n# Threadgroup Memory\n\n    Score (Max Per Thread):\n    |  Global Reads | Global Writes |  Shared Reads | Shared Writes |\n    |             1 |             0 |             0 |             1 | \n```\n\n\u003cimg src=\"imgs/metal_puzzles_threadgroup_memory.png\" height=\"500\"\u003e\n\n```python\nproblem.check()\n```\n\n```\nFailed Tests.\nYours: array([0, 0, 0, 0, 0, 0, 0, 0], dtype=float32)\nSpec : array([11, 11, 11, 11, 11, 11, 11, 11], dtype=float32)\n```\n\n## Puzzle 9: Pooling\n\nImplement a kernel that sums together the last 3 position of `a` and stores it in `out`. You have 1 thread per position. \n\n**Note:** `threadgroup` memory is often faster than sharing data in `device` memory because it is located closer the the GPU's compute units. Be careful of uncessary reads and writes from global parameters (`a` and `out`), since their data is stored in `device` memory. You only need 1 global read and 1 global write per thread.\n\n**Tip:** Remember to be careful about syncing.\n\n```python\ndef pooling_spec(a: mx.array):\n    out = mx.zeros(*a.shape)\n    for i in range(a.shape[0]):\n        out[i] = a[max(i - 2, 0) : i + 1].sum()\n    return out\n\ndef pooling_test(a: mx.array):\n    header = \"\"\"\n        constant uint THREADGROUP_MEM_SIZE = 8;\n    \"\"\"\n\n    source = \"\"\"\n        threadgroup float shared[THREADGROUP_MEM_SIZE];\n        uint i = threadgroup_position_in_grid.x * threads_per_threadgroup.x + thread_position_in_threadgroup.x;\n        uint local_i = thread_position_in_threadgroup.x;\n        // FILL ME IN (roughly 11 lines)\n    \"\"\"\n\n    kernel = MetalKernel(\n        name=\"pooling\",\n        input_names=[\"a\"],\n        output_names=[\"out\"],\n        header=header,\n        source=source,\n    )\n\n    return kernel\n\nSIZE = 8\na = mx.arange(SIZE)\noutput_shape = (SIZE,)\n\nproblem = MetalProblem(\n    \"Pooling\",\n    pooling_test,\n    [a], \n    output_shape,\n    grid=(SIZE,1,1), \n    threadgroup=(SIZE,1,1),\n    spec=pooling_spec\n)\nproblem.show()\n```\n\n```\n# Pooling\n \n   Score (Max Per Thread):\n   |  Global Reads | Global Writes |  Shared Reads | Shared Writes |\n   |             0 |             0 |             0 |             0 | \n```\n\n\u003cimg src=\"imgs/metal_puzzles_pooling.png\" height=\"500\"\u003e\n\n```\nproblem.check()\n```\n\n```\nFailed Tests.\nYours: array([0, 0, 0, 0, 0, 0, 0, 0], dtype=float32)\nSpec : array([0, 1, 3, 6, 9, 12, 15, 18], dtype=float32)\n```\n\n## Puzzle 10: Dot Product\n\nImplement a kernel that computes the [dot product](https://en.wikipedia.org/wiki/Dot_product#Coordinate_definition) of `a` and `b` and stores it in `out`. You have 1 thread per position. You only need 2 global reads and 1 global write per thread.\n\n**Note**: For this problem you don't need to worry about number of reads to the `threadgroup` memory. We will handle that challenge later.\n\n```python\ndef dot_spec(a: mx.array, b: mx.array):\n    return a @ b\n\ndef dot_test(a: mx.array, b: mx.array):\n    header = \"\"\"\n        constant uint THREADGROUP_MEM_SIZE = 8;\n    \"\"\"\n\n    source = \"\"\"\n        threadgroup float shared[THREADGROUP_MEM_SIZE];\n        uint i = threadgroup_position_in_grid.x * threads_per_threadgroup.x + thread_position_in_threadgroup.x;\n        uint local_i = thread_position_in_threadgroup.x;\n        // FILL ME IN (roughly 11 lines)\n    \"\"\"\n\n    kernel = MetalKernel(\n        name=\"dot_product\",\n        input_names=[\"a\", \"b\"],\n        output_names=[\"out\"],\n        header=header,\n        source=source,\n    )\n\n    return kernel\n\nSIZE = 8\na = mx.arange(SIZE, dtype=mx.float32)\nb = mx.arange(SIZE, dtype=mx.float32)\noutput_shape = (1,)\n\nproblem = MetalProblem(\n    \"Dot Product\",\n    dot_test,\n    [a, b], \n    output_shape,\n    grid=(SIZE,1,1), \n    threadgroup=(SIZE,1,1),\n    spec=dot_spec\n)\nproblem.show()\n```\n\n```\n# Dot Product\n \n   Score (Max Per Thread):\n   |  Global Reads | Global Writes |  Shared Reads | Shared Writes |\n   |             0 |             0 |             0 |             0 | \n```\n\n\u003cimg src=\"imgs/metal_puzzles_dot_product.png\" height=\"500\"\u003e\n\n```python\nproblem.check()\n```\n\n## Puzzle 11: 1D Convolution\n\nImplement a kernel that computes a 1D convolution between `a` and `b` and stores it in `out`. You need to handle the general case. You only need 2 global reads and 1 global write per thread.\n\n```python\ndef conv_spec(a: mx.array, b: mx.array):\n    out = mx.zeros(*a.shape)\n    len = b.shape[0]\n    for i in range(a.shape[0]):\n        out[i] = sum([a[i + j] * b[j] for j in range(len) if i + j \u003c a.shape[0]])\n    return out\n\ndef conv_test(a: mx.array, b: mx.array):\n    header = \"\"\"\n        constant uint THREADGROUP_MAX_CONV_SIZE = 12;\n        constant uint MAX_CONV = 4;\n    \"\"\"\n\n    source = \"\"\"\n        uint i = threadgroup_position_in_grid.x * threads_per_threadgroup.x + thread_position_in_threadgroup.x;\n        uint local_i = thread_position_in_threadgroup.x;\n        // FILL ME IN (roughly 24 lines)\n    \"\"\"\n\n    kernel = MetalKernel(\n        name=\"1D_conv\",\n        input_names=[\"a\", \"b\"],\n        output_names=[\"out\"],\n        header=header,\n        source=source,\n    )\n\n    return kernel\n\n# Test 1\nSIZE = 6\nCONV = 3\na = mx.arange(SIZE, dtype=mx.float32)\nb = mx.arange(CONV, dtype=mx.float32)\noutput_shape = (SIZE,)\n\nproblem = MetalProblem(\n    \"1D Conv (Simple)\",\n    conv_test,\n    [a, b], \n    output_shape,\n    grid=(8,1,1), \n    threadgroup=(8,1,1),\n    spec=conv_spec\n)\nproblem.show()\n```\n\n```\n# 1D Conv (Simple)\n \n   Score (Max Per Thread):\n   |  Global Reads | Global Writes |  Shared Reads | Shared Writes |\n   |             0 |             0 |             0 |             0 | \n```\n\n\u003cimg src=\"imgs/metal_puzzles_1D_conv_simple.png\" height=\"500\"\u003e\n\n```python\nproblem.check()\n```\n\n```\nFailed Tests.\nYours: array([0, 0, 0, 0, 0, 0], dtype=float32)\nSpec : array([5, 8, 11, 14, 5, 0], dtype=float32)\n```\n\n```python\n# Test 2\na = mx.arange(15, dtype=mx.float32)\nb = mx.arange(4, dtype=mx.float32)\noutput_shape = (15,)\n\nproblem = MetalProblem(\n    \"1D Conv (Full)\",\n    conv_test,\n    [a, b], \n    output_shape,\n    grid=(16,1,1), \n    threadgroup=(8,1,1),\n    spec=conv_spec\n)\nproblem.show()\n```\n\n```\n# 1D Conv (Full)\n \n   Score (Max Per Thread):\n   |  Global Reads | Global Writes |  Shared Reads | Shared Writes |\n   |             0 |             0 |             0 |             0 | \n```\n\n\u003cimg src=\"imgs/metal_puzzles_1D_conv_full.png\" height=\"500\"\u003e\n\n```python\nproblem.check()\n```\n\n```\nFailed Tests.\nYours: array([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], dtype=float32)\nSpec : array([14, 20, 26, 32, 38, 44, 50, 56, 62, 68, 74, 80, 41, 14, 0], dtype=float32)\n```\n\n## Puzzle 12: Prefix Sum\n\nImplement a kernel that computes a sum over `a` and stores it in `out`. If the size of `a` is greater than the threadgroup size, only store the sum of each threadgroup.\n\nWe will do this using the [parallel prefix sum](https://en.wikipedia.org/wiki/Prefix_sum#Parallel_algorithms) algorithm in `threadgroup` memory. In each step, the algorithm will sum half of the remaining elements together.\n\n```python\nTHREADGROUP_MEM_SIZE = 8\ndef prefix_sum_spec(a: mx.array):\n    out = mx.zeros((a.shape[0] + THREADGROUP_MEM_SIZE - 1) // THREADGROUP_MEM_SIZE)\n    for j, i in enumerate(range(0, a.shape[-1], THREADGROUP_MEM_SIZE)):\n        out[j] = a[i : i + THREADGROUP_MEM_SIZE].sum()\n    return out\n\ndef prefix_sum_test(a: mx.array):\n    header = \"\"\"\n        constant uint THREADGROUP_MEM_SIZE = 8;\n    \"\"\"\n\n    source = \"\"\"\n        threadgroup float cache[THREADGROUP_MEM_SIZE];\n        uint i = threadgroup_position_in_grid.x * threads_per_threadgroup.x + thread_position_in_threadgroup.x;\n        uint local_i = thread_position_in_threadgroup.x;\n        // FILL ME IN (roughly 14 lines)\n    \"\"\"\n\n    kernel = MetalKernel(\n        name=\"prefix_sum\",\n        input_names=[\"a\"],\n        output_names=[\"out\"],\n        header=header,\n        source=source,\n    )\n\n    return kernel\n\n# Test 1\nSIZE = 8\na = mx.arange(SIZE)\noutput_shape = (1,)\n\nproblem = MetalProblem(\n    \"Prefix Sum (Simple)\",\n    prefix_sum_test,\n    [a], \n    output_shape,\n    grid=(8,1,1), \n    threadgroup=(8,1,1),\n    spec=prefix_sum_spec\n)\nproblem.show()\n```\n\n```\n# Prefix Sum (Simple)\n \n   Score (Max Per Thread):\n   |  Global Reads | Global Writes |  Shared Reads | Shared Writes |\n   |             0 |             0 |             0 |             0 | \n```\n\n\u003cimg src=\"imgs/metal_puzzles_prefix_sum_simple.png\" height=\"500\"\u003e\n\n```python\nproblem.check()\n```\n\n```\nFailed Tests.\nYours: array([0], dtype=float32)\nSpec : array([28], dtype=float32)\n```\n\n```python\n# Test 2\nSIZE = 15\na = mx.arange(SIZE)\noutput_shape = (2,)\n\nproblem = MetalProblem(\n    \"Prefix Sum (Full)\",\n    prefix_sum_test,\n    [a], \n    output_shape,\n    grid=(16,1,1), \n    threadgroup=(8,1,1),\n    spec=prefix_sum_spec\n)\nproblem.show()\n```\n\n```\n# Prefix Sum (Full)\n \n   Score (Max Per Thread):\n   |  Global Reads | Global Writes |  Shared Reads | Shared Writes |\n   |             0 |             0 |             0 |             0 | \n```\n\n\u003cimg src=\"imgs/metal_puzzles_prefix_sum_full.png\" height=\"500\"\u003e\n\n```python\nproblem.check()\n```\n\n```\nFailed Tests.\nYours: array([0, 0], dtype=float32)\nSpec : array([28, 77], dtype=float32)\n```\n\n## Puzzle 13: Axis Sum\n\nImplement a kernel that computes the sum over each column in the input array `a` and stores it in `out`.\n\n```python\nTHREADGROUP_MEM_SIZE = 8\ndef axis_sum_spec(a: mx.array):\n    out = mx.zeros((a.shape[0], (a.shape[1] + THREADGROUP_MEM_SIZE - 1) // THREADGROUP_MEM_SIZE))\n    for j, i in enumerate(range(0, a.shape[-1], THREADGROUP_MEM_SIZE)):\n        out[..., j] = a[..., i : i + THREADGROUP_MEM_SIZE].sum(-1)\n    return out\n\ndef axis_sum_test(a: mx.array):\n    header = \"\"\"\n        constant uint THREADGROUP_MEM_SIZE = 8;\n    \"\"\"\n\n    source = \"\"\"\n        threadgroup float cache[THREADGROUP_MEM_SIZE];\n        uint i = threadgroup_position_in_grid.x * threads_per_threadgroup.x + thread_position_in_threadgroup.x;\n        uint local_i = thread_position_in_threadgroup.x;\n        uint batch = threadgroup_position_in_grid.y;\n        // FILL ME IN (roughly 16 lines)\n    \"\"\"\n\n    kernel = MetalKernel(\n        name=\"axis_sum\",\n        input_names=[\"a\"],\n        output_names=[\"out\"],\n        header=header,\n        source=source,\n    )\n\n    return kernel\n\nBATCH = 4\nSIZE = 6\na = mx.arange(BATCH * SIZE).reshape((BATCH, SIZE))\noutput_shape = (BATCH, 1)\n\nproblem = MetalProblem(\n    \"Axis Sum\",\n    axis_sum_test,\n    [a], \n    output_shape,\n    grid=(8,BATCH,1), \n    threadgroup=(8,1,1),\n    spec=axis_sum_spec\n)\nproblem.show()\n```\n\n```\n# Axis Sum\n \n   Score (Max Per Thread):\n   |  Global Reads | Global Writes |  Shared Reads | Shared Writes |\n   |             0 |             0 |             0 |             0 | \n```\n\n\u003cimg src=\"imgs/metal_puzzles_axis_sum.png\" height=\"500\"\u003e\n\n```python\nproblem.check()\n```\n\n```\nFailed Tests.\nYours: array([[0],\n       [0],\n       [0],\n       [0]], dtype=float32)\nSpec : array([[15],\n       [51],\n       [87],\n       [123]], dtype=float32)\n```\n\n## Puzzle 14: Matrix Multiply!\n\nImplement a kernel that multiplies square matrices `a` and `b` and stores the result in `out`.\n\n**Tip**: The most efficient algorithm will copy a block of data into `threadgroup` memory before computing each of the individual row-column dot products. This is straightforward if the matrix fits entirely in `threadgroup` memory (start by implementing that case first). Then, modify your code to compute partial dot products and iteratively move portions of the matrix into `threadgroup` memory. You should be able to handle the hard test in 6 device memory reads.\n\n```python\ndef matmul_spec(a: mx.array, b: mx.array):\n    return a @ b\n\ndef matmul_test(a: mx.array, b: mx.array):\n    header = \"\"\"\n        constant uint THREADGROUP_MEM_SIZE = 3;\n    \"\"\"\n\n    source = \"\"\"\n        threadgroup float a_shared[THREADGROUP_MEM_SIZE][THREADGROUP_MEM_SIZE];\n        threadgroup float b_shared[THREADGROUP_MEM_SIZE][THREADGROUP_MEM_SIZE];\n\n        uint i = threadgroup_position_in_grid.x * threads_per_threadgroup.x + thread_position_in_threadgroup.x;\n        uint j = threadgroup_position_in_grid.y * threads_per_threadgroup.y + thread_position_in_threadgroup.y;\n\n        uint local_i = thread_position_in_threadgroup.x;\n        uint local_j = thread_position_in_threadgroup.y;\n        // FILL ME IN (roughly 19 lines)\n    \"\"\"\n\n    kernel = MetalKernel(\n        name=\"matmul\",\n        input_names=[\"a\", \"b\"],\n        output_names=[\"out\"],\n        header=header,\n        source=source,\n    )\n\n    return kernel\n\n# Test 1\nSIZE = 2\na = mx.arange(SIZE * SIZE, dtype=mx.float32).reshape((SIZE, SIZE))\nb = mx.arange(SIZE * SIZE, dtype=mx.float32).reshape((SIZE, SIZE)).T\noutput_shape = (SIZE, SIZE)\n\nproblem = MetalProblem(\n    \"Matmul (Simple)\",\n    matmul_test,\n    [a, b], \n    output_shape,\n    grid=(3,3,1), \n    threadgroup=(3,3,1),\n    spec=matmul_spec\n)\nproblem.show()\n```\n\n```\n# Matmul (Simple)\n \n   Score (Max Per Thread):\n   |  Global Reads | Global Writes |  Shared Reads | Shared Writes |\n   |             0 |             0 |             0 |             0 | \n```\n\n\u003cimg src=\"imgs/metal_puzzles_matmul_simple.png\" height=\"500\"\u003e\n\n```python\nproblem.check()\n```\n\n```\nFailed Tests.\nYours: array([[0, 0],\n       [0, 0]], dtype=float32)\nSpec : array([[1, 3],\n       [3, 13]], dtype=float32)\n```\n\n```python\n# Test 2\nSIZE = 8\na = mx.arange(SIZE * SIZE, dtype=mx.float32).reshape((SIZE, SIZE))\nb = mx.arange(SIZE * SIZE, dtype=mx.float32).reshape((SIZE, SIZE)).T\noutput_shape = (SIZE, SIZE)\n\nproblem = MetalProblem(\n    \"Matmul (Full)\",\n    matmul_test,\n    [a, b], \n    output_shape,\n    grid=(9,9,1), \n    threadgroup=(3,3,1),\n    spec=matmul_spec\n)\nproblem.show()\n```\n\n```\n# Matmul (Full)\n \n   Score (Max Per Thread):\n   |  Global Reads | Global Writes |  Shared Reads | Shared Writes |\n   |             0 |             0 |             0 |             0 | \n```\n\n\u003cimg src=\"imgs/metal_puzzles_matmul_full.png\" height=\"500\"\u003e\n\n```python\nproblem.check()\n```\n\n```\nFailed Tests.\nYours: array([[0, 0, 0, 0, 0, 0, 0, 0]\n       [0, 0, 0, 0, 0, 0, 0, 0]\n       [0, 0, 0, 0, 0, 0, 0, 0]\n       [0, 0, 0, 0, 0, 0, 0, 0]\n       [0, 0, 0, 0, 0, 0, 0, 0]\n       [0, 0, 0, 0, 0, 0, 0, 0]\n       [0, 0, 0, 0, 0, 0, 0, 0]\n       [0, 0, 0, 0, 0, 0, 0, 0]], dtype=float32)\nSpec : array([[  140,   364,   588,   812,  1036,  1260,  1484,  1708]\n       [  364,  1100,  1836,  2572,  3308,  4044,  4780,  5516]\n       [  588,  1836,  3084,  4332,  5580,  6828,  8076,  9324]\n       [  812,  2572,  4332,  6092,  7852,  9612, 11372, 13132]\n       [ 1036,  3308,  5580,  7852, 10124, 12396, 14668, 16940]\n       [ 1260,  4044,  6828,  9612, 12396, 15180, 17964, 20748]\n       [ 1484,  4780,  8076, 11372, 14668, 17964, 21260, 24556]\n       [ 1708,  5516,  9324, 13132, 16940, 20748, 24556, 28364]], dtype=float32)\n```\n\n## Metal Debugger\n\nA useful resource when writing Metal code is the Metal Debugger in Xcode. You can capture GPU work from any kernel by setting the environment variable `MTL_CAPTURE_ENABLED=1`. This will generate a `.gputrace` file, which you can open in Xcode by running:\n\n```sh\nopen custom_kernel.gputrace\n```\n\nOnce opened you'll be able to profile the GPU trace to view its performance. Here is a basic guide to locate the kernel debugger and view kernel statistics. \n\nFirst select `Group By Pipeline State` on the left sidebar, which will simplify locating the custom kernels `Compute Pipeline`.\n\n![](/imgs/metal_debugger_1.png)\n\nNext, local which `Compute Pipeline` contains to your custom kernel (all generated kernels will be prefixed with `custom_kernel_{name}`).\n\n![](/imgs/metal_debugger_2.png)\n\nIf you click on the kernel name on the left sidebar you'll be shown your kernel code. From this page, you can select the bug icon to begin a step debugger for each GPU thread or view statistics for different parts of your kernel.\n\n![](/imgs/metal_debugger_3.png)\n\nIf you can hover over one of the orange circles, you can view its `Runtime Statistics`.\n\n![](/imgs/metal_debugger_4.png)\n\nMore information about the debugger can be found on the [MLX Metal Debugger](https://ml-explore.github.io/mlx/build/html/dev/metal_debugger.html) documentation or in the [Metal Debugger Apple Developer](https://developer.apple.com/documentation/xcode/metal-debugger) documentation.\n","project_url":"https://awesome.ecosyste.ms/api/v1/projects/github.com%2Fabeleinin%2FMetal-Puzzles","html_url":"https://awesome.ecosyste.ms/projects/github.com%2Fabeleinin%2FMetal-Puzzles","lists_url":"https://awesome.ecosyste.ms/api/v1/projects/github.com%2Fabeleinin%2FMetal-Puzzles/lists"}