{"id":13807165,"url":"https://github.com/ASEM000/kernex","last_synced_at":"2025-05-14T00:31:02.231Z","repository":{"id":44404718,"uuid":"512400616","full_name":"ASEM000/kernex","owner":"ASEM000","description":"Stencil computations in JAX","archived":false,"fork":false,"pushed_at":"2023-10-01T17:58:34.000Z","size":1947,"stargazers_count":70,"open_issues_count":8,"forks_count":3,"subscribers_count":1,"default_branch":"main","last_synced_at":"2025-04-13T10:11:32.528Z","etag":null,"topics":["jax","kernel","stencil"],"latest_commit_sha":null,"homepage":"","language":"Python","has_issues":true,"has_wiki":null,"has_pages":null,"mirror_url":null,"source_name":null,"license":"mit","status":null,"scm":"git","pull_requests_enabled":true,"icon_url":"https://github.com/ASEM000.png","metadata":{"files":{"readme":"README.md","changelog":null,"contributing":null,"funding":null,"license":"LICENSE","code_of_conduct":null,"threat_model":null,"audit":null,"citation":null,"codeowners":null,"security":null,"support":null,"governance":null,"roadmap":null,"authors":null}},"created_at":"2022-07-10T10:01:41.000Z","updated_at":"2025-01-16T06:49:17.000Z","dependencies_parsed_at":"2024-01-07T10:51:53.522Z","dependency_job_id":"718868b8-7311-48cf-bea4-0687b6a93500","html_url":"https://github.com/ASEM000/kernex","commit_stats":{"total_commits":65,"total_committers":2,"mean_commits":32.5,"dds":0.1384615384615384,"last_synced_commit":"1fb2c249073fd17047786dd9326a5e1cbaef2041"},"previous_names":[],"tags_count":11,"template":false,"template_full_name":null,"repository_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories/ASEM000%2Fkernex","tags_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories/ASEM000%2Fkernex/tags","releases_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories/ASEM000%2Fkernex/releases","manifests_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories/ASEM000%2Fkernex/manifests","owner_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/owners/ASEM000","download_url":"https://codeload.github.com/ASEM000/kernex/tar.gz/refs/heads/main","host":{"name":"GitHub","url":"https://github.com","kind":"github","repositories_count":254046241,"owners_count":22005559,"icon_url":"https://github.com/github.png","version":null,"created_at":"2022-05-30T11:31:42.601Z","updated_at":"2022-07-04T15:15:14.044Z","host_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub","repositories_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories","repository_names_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repository_names","owners_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/owners"}},"keywords":["jax","kernel","stencil"],"created_at":"2024-08-04T01:01:21.873Z","updated_at":"2025-05-14T00:30:57.171Z","avatar_url":"https://github.com/ASEM000.png","language":"Python","funding_links":[],"categories":["Libraries"],"sub_categories":["Inactive Libraries","New Libraries"],"readme":"\u003cdiv align = \"center\"\u003e\n\u003cimg  width=400 src=\"assets/kernexlogo.svg\" align=\"center\"\u003e\n\n\u003ch3 align=\"center\"\u003eDifferentiable Stencil computations in JAX \u003c/h2\u003e\n\n[**Installation**](#Installation)\n|[**Description**](#Description)\n|[**Quick example**](#QuickExample)\n|[**More Examples**](#MoreExamples)\n|[**Benchmarking**](#Benchmarking)\n\n![Tests](https://github.com/ASEM000/kernex/actions/workflows/tests.yml/badge.svg)\n![pyver](https://img.shields.io/badge/python-3.8%203.8%203.9%203.11-red)\n![codestyle](https://img.shields.io/badge/codestyle-black-black)\n[![Downloads](https://static.pepy.tech/badge/kernex)](https://pepy.tech/project/kernex)\n[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/14UEqKzIyZsDzQ9IMeanvztXxbbbatTYV?usp=sharing)\n[![codecov](https://codecov.io/gh/ASEM000/kernex/branch/main/graph/badge.svg?token=3KLL24Z94I)](https://codecov.io/gh/ASEM000/kernex)\n[![DOI](https://zenodo.org/badge/512400616.svg)](https://zenodo.org/badge/latestdoi/512400616)\n\n\u003c/div\u003e\n\n## 🛠️ Installation\u003ca id=\"Installation\"\u003e\u003c/a\u003e\n\n```python\npip install kernex\n```\n\n## 📖 Description\u003ca id=\"Description\"\u003e\u003c/a\u003e\n\nKernex extends `jax.vmap`/`jax.lax.map`/`jax.pmap` with `kmap` and `jax.lax.scan` with `kscan` for general stencil computations.\n\n## ⏩ Quick Example \u003ca id=\"QuickExample\"\u003e\n\n\u003cdiv align=\"center\"\u003e\n\u003ctable\u003e\n\u003ctr\u003e\n\u003ctd width=\"50%\" align=\"center\" \u003e kmap \u003c/td\u003e \u003ctd align=\"center\" \u003e kscan \u003c/td\u003e\n\u003c/tr\u003e\n\u003ctr\u003e\n\u003ctd\u003e\n\n```python\nimport kernex as kex\nimport jax.numpy as jnp\n\n@kex.kmap(kernel_size=(3,))\ndef sum_all(x):\n    return jnp.sum(x)\n\nx = jnp.array([1,2,3,4,5])\nprint(sum_all(x))\n# [ 6  9 12]\n```\n\n\u003c/td\u003e\n\u003ctd\u003e\n    \n```python\n\nimport kernex as kex \nimport jax.numpy as jnp\n\n@kex.kscan(kernel_size=(3,))\ndef sum_all(x):\nreturn jnp.sum(x)\n\nx = jnp.array([1,2,3,4,5])\nprint(sum_all(x))\n# [ 6 13 22]\n\n````\n\u003c/td\u003e\n\u003c/tr\u003e\n\u003c/table\u003e\n\n\u003ctable\u003e\n\u003ctr\u003e\n\u003ctd width=\"50%\"\u003e\n`jax.vmap` is used to sum each window content.\n\u003cimg src=\"assets/kmap_sum.png\" width=400px\u003e\n\u003c/td\u003e\n\u003ctd\u003e\n`lax.scan` is used to update the array and the window sum is calculated sequentially.\nthe first three rows represents the three sequential steps used to get the solution in the last row.\n\n\u003cimg align=\"center\" src=\"assets/kscan_sum.png\" width=400px\u003e\n\u003c/td\u003e\n\u003c/tr\u003e\n\u003c/table\u003e\n\u003c/div\u003e\n\n\n\n\n## 🔢 More examples\u003ca id=\"MoreExamples\"\u003e\u003c/a\u003e\n\n\u003cdetails\u003e\n\u003csummary\u003e1️⃣ Convolution operation\u003c/summary\u003e\n\n```python\nimport jax\nimport jax.numpy as jnp\nimport kernex as kex\n\n@jax.jit\n@kex.kmap(\n    kernel_size= (3,3,3),\n    padding = ('valid','same','same'))\ndef kernex_conv2d(x,w):\n    # JAX channel first conv2d with 3x3x3 kernel_size\n    return jnp.sum(x*w)\n```\n\n\u003c/details\u003e\n\n\u003cdetails\u003e\n\u003csummary\u003e2️⃣ Laplacian operation\u003c/summary\u003e\n\n```python\n# see also\n# https://numba.pydata.org/numba-doc/latest/user/stencil.html#basic-usage\nimport jax\nimport jax.numpy as jnp\nimport kernex as kex\n\n@kex.kmap(\n    kernel_size=(3,3),\n    padding= 'valid',\n    relative=True) # `relative`= True enables relative indexing\ndef laplacian(x):\n    return ( 0*x[1,-1]  + 1*x[1,0]   + 0*x[1,1] +\n             1*x[0,-1]  +-4*x[0,0]   + 1*x[0,1] +\n             0*x[-1,-1] + 1*x[-1,0]  + 0*x[-1,1] )\n\nprint(laplacian(jnp.ones([10,10])))\n# [[0., 0., 0., 0., 0., 0., 0., 0.],\n#  [0., 0., 0., 0., 0., 0., 0., 0.],\n#  [0., 0., 0., 0., 0., 0., 0., 0.],\n#  [0., 0., 0., 0., 0., 0., 0., 0.],\n#  [0., 0., 0., 0., 0., 0., 0., 0.],\n#  [0., 0., 0., 0., 0., 0., 0., 0.],\n#  [0., 0., 0., 0., 0., 0., 0., 0.],\n#  [0., 0., 0., 0., 0., 0., 0., 0.]]\n\n```\n\n\u003c/details\u003e\n\n\u003cdetails\u003e\u003csummary\u003e3️⃣ Get Patches of an array\u003c/summary\u003e\n\n```python\n\nimport jax\nimport jax.numpy as jnp\nimport kernex as kex\n\n@kex.kmap(kernel_size=(3,3),relative=True)\ndef identity(x):\n    # similar to numba.stencil\n    # this function returns the top left cell in the padded/unpadded kernel view\n    # or center cell if `relative`=True\n    return x[0,0]\n\n# unlike numba.stencil , vector output is allowed in kernex\n# this function is similar to\n# `jax.lax.conv_general_dilated_patches(x,(3,),(1,),padding='same')`\n@jax.jit\n@kex.kmap(kernel_size=(3,3),padding='same')\ndef get_3x3_patches(x):\n    # returns 5x5x3x3 array\n    return x\n\nmat = jnp.arange(1,26).reshape(5,5)\nprint(mat)\n# [[ 1  2  3  4  5]\n#  [ 6  7  8  9 10]\n#  [11 12 13 14 15]\n#  [16 17 18 19 20]\n#  [21 22 23 24 25]]\n\n\n# get the view at array index = (0,0)\nprint(get_3x3_patches(mat)[0,0])\n# [[0 0 0]\n#  [0 1 2]\n#  [0 6 7]]\n```\n\u003c/details\u003e\n\n\u003cdetails\u003e\n\u003csummary\u003e4️⃣ Linear convection \u003c/summary\u003e\n\n\n\u003cdiv align =\"center\"\u003e\n\u003ctable\u003e\n\u003ctr\u003e\n\u003ctd\u003e Problem setup \u003c/td\u003e \u003ctd\u003e Stencil view  \u003c/td\u003e\n\u003c/tr\u003e\n\u003ctr\u003e\n\u003ctd\u003e\n\n\u003cimg src=\"assets/linear_convection_init.png\" width=\"500px\"\u003e\n\n\u003c/td\u003e\n\u003ctd\u003e\n\n\u003cimg src=\"assets/linear_convection_view.png\" width=\"500px\"\u003e\n\n\u003c/td\u003e\n\u003c/tr\u003e\n\u003c/table\u003e\n\u003c/div\u003e\n\n```python\n\nimport jax\nimport jax.numpy as jnp\nimport kernex as kex\nimport matplotlib.pyplot as plt\n\n# see https://nbviewer.org/github/barbagroup/CFDPython/blob/master/lessons/01_Step_1.ipynb\n\ntmax,xmax = 0.5,2.0\nnt,nx = 151,51\ndt,dx = tmax/(nt-1) , xmax/(nx-1)\nu = jnp.ones([nt,nx])\nc = 0.5\n\n# kscan moves sequentially in row-major order and updates in-place using lax.scan.\n\nF = kernex.kscan(\n        kernel_size = (3,3),\n        padding = ((1,1),(1,1)),\n        # n for time axis , i for spatial axis (optional naming)\n        named_axis={0:'n',1:'i'},  \n        relative=True\n    )\n\n\n# boundary condtion as a function\ndef bc(u):\n    return 1\n\n# initial condtion as a function\ndef ic1(u):\n    return 1\n\ndef ic2(u):\n    return 2\n\ndef linear_convection(u):\n    return ( u['i','n-1'] - (c*dt/dx) * (u['i','n-1'] - u['i-1','n-1']) )\n\n\nF[:,0]  = F[:,-1] = bc # assign 1 for left and right boundary for all t\n\n# square wave initial condition\nF[:,:int((nx-1)/4)+1] = F[:,int((nx-1)/2):] = ic1\nF[0:1, int((nx-1)/4)+1 : int((nx-1)/2)] = ic2\n\n# assign linear convection function for\n# interior spatial location [1:-1]\n# and start from t\u003e0  [1:]\nF[1:,1:-1] = linear_convection\n\nkx_solution = F(jnp.array(u))\n\nplt.figure(figsize=(20,7))\nfor line in kx_solution[::20]:\n    plt.plot(jnp.linspace(0,xmax,nx),line)\n\n```\n\n\u003cimg src=\"assets/linear_convection.svg\"\u003e\n\n\u003c/details\u003e\n\n\u003cdetails\u003e\u003csummary\u003e5️⃣ Gaussian blur\u003c/summary\u003e\n\n```python\n\nimport jax\nimport jax.numpy as jnp\nimport kernex as kex\n\ndef gaussian_blur(image, sigma, kernel_size):\n    x = jnp.linspace(-(kernel_size - 1) / 2.0, (kernel_size- 1) / 2.0, kernel_size)\n    w = jnp.exp(-0.5 * jnp.square(x) * jax.lax.rsqrt(sigma))\n    w = jnp.outer(w, w)\n    w = w / w.sum()\n\n    @kex.kmap(kernel_size=(kernel_size, kernel_size), padding=\"same\")\n    def conv(x):\n        return jnp.sum(x * w)\n\n    return conv(image)\n\n\n```\n\n\u003c/details\u003e\n\n\u003cdetails \u003e \u003csummary\u003e6️⃣ Depthwise convolution \u003c/summary\u003e\n     \n```python\nimport jax\nimport jax.numpy as jnp\nimport kernex as kex\n\n@jax.jit\n@jax.vmap\n@kex.kmap(\n    kernel_size= (3,3),\n    padding = ('same','same'))\ndef kernex_depthwise_conv2d(x,w): \n    return jnp.sum(x*w)\n\nh,w,c = 5,5,2\nk=3\n\nx = jnp.arange(1,h*w*c+1).reshape(c,h,w)\nw = jnp.arange(1,k*k*c+1).reshape(c,k,k)\nprint(kernex_depthwise_conv2d(x,w))\n````\n\n\u003c/details\u003e\n\n\u003cdetails\u003e \u003csummary\u003e7️⃣ Average pooling 2D \u003c/summary\u003e\n\n```python\n@jax.vmap # vectorize over the channel dimension\n@kex.kmap(kernel_size=(3,3), strides=(2,2))\ndef avgpool_2d(x):\n    # define the kernel for the Average pool operation over the spatial dimensions\n    return jnp.mean(x)\n````\n\n\u003c/details\u003e\n\n\u003cdetails\u003e\u003csummary\u003e8️⃣ Runge-Kutta integration\u003c/summary\u003e\n\n```python\n\n# lets solve dydt = y, where y0 = 1 and y(t)=e^t\n# using Runge-Kutta 4th order method\n# f(t,y) = y\nimport jax.numpy as jnp\nimport matplotlib.pyplot as plt\nimport kernex as kex\n\n\nt = jnp.linspace(0, 1, 5)\ny = jnp.zeros(5)\nx = jnp.stack([y, t], axis=0)\ndt = t[1] - t[0]  # 0.1\nf = lambda tn, yn: yn\n\n\ndef ic(x):\n    \"\"\" initial condition y0 = 1 \"\"\"\n    return 1.\n\n\ndef rk4(x):\n    \"\"\" runge kutta 4th order integration step \"\"\"\n    # ┌────┬────┬────┐      ┌──────┬──────┬──────┐\n    # │ y0 │*y1*│ y2 │      │[0,-1]│[0, 0]│[0, 1]│\n    # ├────┼────┼────┤ ==\u003e  ├──────┼──────┼──────┤\n    # │ t0 │ t1 │ t2 │      │[1,-1]│[1, 0]│[1, 1]│\n    # └────┴────┴────┘      └──────┴──────┴──────┘\n    t0 = x[1, -1]\n    y0 = x[0, -1]\n    k1 = dt * f(t0, y0)\n    k2 = dt * f(t0 + dt / 2, y0 + 1 / 2 * k1)\n    k3 = dt * f(t0 + dt / 2, y0 + 1 / 2 * k2)\n    k4 = dt * f(t0 + dt, y0 + k3)\n    yn_1 = y0 + 1 / 6 * (k1 + 2 * k2 + 2 * k3 + k4)\n    return yn_1\n\n\nF = kex.kscan(kernel_size=(2, 3), relative=True, padding=((0, 1)))  # kernel size = 3\n\nF[0:1, 1:] = rk4\nF[0, 0] = ic\n# compile the solver\nsolver = jax.jit(F.__call__)\ny = solver(x)[0, :]\n\nplt.plot(t, y, '-o', label='rk4')\nplt.plot(t, jnp.exp(t), '-o', label='analytical')\nplt.legend()\n\n```\n\n![img](assets/rk4.svg)\n\n\u003c/details\u003e\n","project_url":"https://awesome.ecosyste.ms/api/v1/projects/github.com%2FASEM000%2Fkernex","html_url":"https://awesome.ecosyste.ms/projects/github.com%2FASEM000%2Fkernex","lists_url":"https://awesome.ecosyste.ms/api/v1/projects/github.com%2FASEM000%2Fkernex/lists"}