{"id":31825741,"url":"https://github.com/alexshtf/torchcurves","last_synced_at":"2026-04-06T20:04:42.883Z","repository":{"id":316928714,"uuid":"989577719","full_name":"alexshtf/torchcurves","owner":"alexshtf","description":"Parametric differentiable curves with PyTorch for continuous embeddings, shape-restricted models, or KANs","archived":false,"fork":false,"pushed_at":"2026-04-04T14:13:11.000Z","size":4222,"stargazers_count":57,"open_issues_count":1,"forks_count":2,"subscribers_count":0,"default_branch":"master","last_synced_at":"2026-04-04T16:09:57.661Z","etag":null,"topics":["b-splines","curve-fitting","deep-learning","embeddings","interpolation","kan","kolmogorov-arnold-networks","legendre-polynomials","machine-learning","parametric-curves","python","pytorch","spline"],"latest_commit_sha":null,"homepage":"https://torchcurves.readthedocs.io/en/stable/","language":"Python","has_issues":true,"has_wiki":null,"has_pages":null,"mirror_url":null,"source_name":null,"license":"apache-2.0","status":null,"scm":"git","pull_requests_enabled":true,"icon_url":"https://github.com/alexshtf.png","metadata":{"files":{"readme":"README.md","changelog":null,"contributing":null,"funding":null,"license":"LICENSE","code_of_conduct":null,"threat_model":null,"audit":null,"citation":null,"codeowners":null,"security":null,"support":null,"governance":null,"roadmap":null,"authors":null,"dei":null,"publiccode":null,"codemeta":null,"zenodo":null,"notice":null,"maintainers":null,"copyright":null,"agents":null,"dco":null,"cla":null}},"created_at":"2025-05-24T11:40:35.000Z","updated_at":"2026-04-04T13:57:56.000Z","dependencies_parsed_at":"2025-09-27T17:40:24.850Z","dependency_job_id":"5af3bfe1-e60d-4934-bac9-824f417333ac","html_url":"https://github.com/alexshtf/torchcurves","commit_stats":null,"previous_names":["alexshtf/torchcurves"],"tags_count":3,"template":false,"template_full_name":null,"purl":"pkg:github/alexshtf/torchcurves","repository_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories/alexshtf%2Ftorchcurves","tags_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories/alexshtf%2Ftorchcurves/tags","releases_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories/alexshtf%2Ftorchcurves/releases","manifests_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories/alexshtf%2Ftorchcurves/manifests","owner_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/owners/alexshtf","download_url":"https://codeload.github.com/alexshtf/torchcurves/tar.gz/refs/heads/master","sbom_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories/alexshtf%2Ftorchcurves/sbom","scorecard":null,"host":{"name":"GitHub","url":"https://github.com","kind":"github","repositories_count":286080680,"owners_count":31487543,"icon_url":"https://github.com/github.png","version":null,"created_at":"2022-05-30T11:31:42.601Z","updated_at":"2026-04-06T17:22:55.647Z","status":"ssl_error","status_checked_at":"2026-04-06T17:22:54.741Z","response_time":112,"last_error":"SSL_read: unexpected eof while reading","robots_txt_status":"success","robots_txt_updated_at":"2025-07-24T06:49:26.215Z","robots_txt_url":"https://github.com/robots.txt","online":false,"can_crawl_api":true,"host_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub","repositories_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories","repository_names_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repository_names","owners_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/owners"}},"keywords":["b-splines","curve-fitting","deep-learning","embeddings","interpolation","kan","kolmogorov-arnold-networks","legendre-polynomials","machine-learning","parametric-curves","python","pytorch","spline"],"created_at":"2025-10-11T16:24:14.611Z","updated_at":"2026-04-06T20:04:42.877Z","avatar_url":"https://github.com/alexshtf.png","language":"Python","funding_links":[],"categories":["Library"],"sub_categories":["Theorem"],"readme":"\u003cp align=\"center\"\u003e\n\u003cpicture\u003e\n    \u003csource media=\"(prefers-color-scheme: dark)\" srcset=\"https://raw.githubusercontent.com/alexshtf/torchcurves/master/assets/logo_dark.png\"\u003e\n    \u003csource media=\"(prefers-color-scheme: light)\" srcset=\"https://raw.githubusercontent.com/alexshtf/torchcurves/master/assets/logo_light.png\"\u003e\n    \u003cimg width=\"30%\" alt=\"Torchcurves Logo\" src=\"https://raw.githubusercontent.com/alexshtf/torchcurves/master/assets/logo_light.png\"\u003e\n\u003c/picture\u003e\n\u003c/p\u003e\n\n\n\u003cdiv align=\"center\"\u003e\n\n[![torchcurves-backend](https://github.com/alexshtf/torchcurves/actions/workflows/tests.yml/badge.svg)](https://github.com/alexshtf/torchcurves/actions/workflows/tests.yml)\n[![PyPI downloads](https://img.shields.io/pypi/dm/torchcurves)](https://pypi.org/project/torchcurves/)\n[![PyPI](https://img.shields.io/pypi/v/torchcurves)](https://pypi.org/project/torchcurves/)\n![Python version](https://img.shields.io/badge/python-3.9+-important)\n\n\u003c/div\u003e\n\nA PyTorch module for _vectorized_ and _differentiable_ parametric curves with learnable coefficients, such as a B-spline curve with learnable control points, for KANs, continuous embeddings, and shape constraints.\n\n\u003cdiv align=\"center\"\u003e\n    \u003cp\u003e\u003cb\u003eUse cases\u003c/b\u003e\u003c/p\u003e\n    \u003cpicture\u003e\n        \u003csource media=\"(prefers-color-scheme: dark)\" srcset=\"https://raw.githubusercontent.com/alexshtf/torchcurves/master/assets/usecases_dark.png\"\u003e\n        \u003csource media=\"(prefers-color-scheme: light)\" srcset=\"https://raw.githubusercontent.com/alexshtf/torchcurves/master/assets/usecases_light.png\"\u003e\n        \u003cimg width=\"100%\" alt=\"Torchcurves Usecases\" src=\"https://raw.githubusercontent.com/alexshtf/torchcurves/master/assets/usecases_light.png\"\u003e\n    \u003c/picture\u003e\n\u003c/div\u003e\n\nTurns out all the above use cases have one thing in common: they can all be expressed using learnable parametric curves, and that is exactly what this library provides.\n\n## Learn\nA simple \"hello world\" example: evaluate three two-dimensional B-spline curves at four points:\n```python\nimport torch\nimport torchcurves as tc\n\nu = torch.rand(4, 3)        # (B, C)\ncurve = tc.BSplineCurve(\n    num_curves=3,           # C\n    dim=2,                  # D\n)\ny = curve(u)                # (B, C, D)\n\nprint(u.shape, \"-\u003e\", y.shape)            # torch.Size([4, 3]) -\u003e torch.Size([4, 3, 2])\n```\n\nIf the coefficients come from another network instead of living inside the module,\nuse `tc.BSplineBasis` and pass the coefficients explicitly at `forward` time.\n\nFor more information:\n- [Documentation site](https://torchcurves.readthedocs.io/en/latest/).\n- [Example notebooks](https://torchcurves.readthedocs.io/en/latest/example_notebooks.html) for you to try out.\n\n## Features\n\n- **Differentiable**: Custom autograd function ensures gradients flow properly through the curve evaluation.\n- **Vectorized**: Vectorized operations for efficient batch and multi-curve evaluation.\n- **Efficient numerics**: Clenshaw recursion for polynomials, Cox-DeBoor for splines.\n\n## Installation\nWith pip:\n```bash\npip install torchcurves\n```\nWith [uv](https://github.com/astral-sh/uv):\n```bash\nuv add torchcurves\n```\n\n## Use cases\n\nThere are examples in the `doc/source/examples` directory showing how to build models using\nthis library. Here we show some simple code snippets to appreciate the library.\n\n## Use case 1 - continuous embeddings\n\n```python\nimport torchcurves as tc\nfrom torch import nn\nimport torch\n\n\nclass Net(nn.Module):\n    def __init__(self, num_categorical, num_numerical, dim, num_knots=10):\n        super().__init__()\n        self.cat_emb = nn.Embedding(num_categorical, dim)\n        self.num_emb = tc.BSplineCurve(num_numerical, dim, knots_config=num_knots)\n        self.embedding_based_model = MySuperDuperModel()  # placeholder for your encoder model\n\n    def forward(self, x_categorical, x_numerical):\n        embeddings = torch.cat([\n            self.cat_emb(x_categorical),\n            self.num_emb(x_numerical)\n        ], dim=-2)\n        return self.embedding_based_model(embeddings)\n```\n\n`MySuperDuperModel` is a placeholder for your downstream architecture.\n\n## Use case 2 - monotone functions\nWorking on online advertising, and want to model the probability of winning an ad auction given the bid? We know higher bids\nmust result in a higher win probability, so we need a monotone function. Turns out B-splines are monotone if their coefficient vectors are monotone. Want an increasing function? Ensure the spline coefficients are increasing, and the resulting spline will be monotone increasing.\n\nBelow is an example with an auction encoder that encodes the auction into a vector, we then transform it to an increasing vector,\nand use it as the coefficient vector for a B-spline curve.\n```python\nimport torch\nfrom torch import nn\nimport torchcurves as tc\n\n\nclass AuctionWinModel(nn.Module):\n    def __init__(self, num_auction_features, num_bid_coefficients):\n        super().__init__()\n        self.auction_encoder = make_auction_encoder(  # placeholder: an MLP, a transformer, etc.\n            input_features=num_auction_features,\n            output_features=num_bid_coefficients,\n        )\n        self.bid_basis = tc.BSplineBasis(\n            degree=3,\n            knots_config=num_bid_coefficients,\n            input_map=tc.maps.Nonneg.rational(),\n        )\n\n    def forward(self, auction_features, bids):\n        # map auction features to increasing spline coefficients\n        spline_coeffs = self._make_increasing(self.auction_encoder(auction_features))\n\n        # each mini-batch sample is treated as its own curve\n        return self.bid_basis(\n            bids.unsqueeze(0),           # 1 x B (B curves in 1 dimension)\n            spline_coeffs.unsqueeze(-1), # B x C x 1 (B curves with C coefs in 1 dimension)\n        ).squeeze(0).squeeze(-1)\n\n    def _make_increasing(self, x):\n        # transform a mini-batch of vectors to a mini-batch of increasing vectors\n        initial = x[..., :1]\n        increments = nn.functional.softplus(x[..., 1:])\n        concatenated = torch.concat((initial, increments), dim=-1)\n        return torch.cumsum(concatenated, dim=-1)\n```\n\n`make_auction_encoder` is a placeholder for your encoder architecture.\n\nNow we can train the model to predict the probability of winning auctions given auction features and bid:\n```python\nimport torch.nn.functional as F\n\nfor auction_features, bids, win_labels in train_loader:\n    win_logits = model(auction_features, bids)\n    loss = F.binary_cross_entropy_with_logits(  # or any loss we desire\n        win_logits,\n        win_labels\n    )\n\n    optimizer.zero_grad()\n    loss.backward()\n    optimizer.step()\n```\n\n## Use case 3 - Kolmogorov-Arnold networks\n\nA KAN [1] based on the B-spline basis, along the lines of the original paper:\n\n```python\nimport torchcurves as tc\nfrom torch import nn\n\ninput_dim = 2\nintermediate_dim = 5\nnum_control_points = 10\n\nkan = nn.Sequential(\n    # layer 1\n    tc.BSplineCurve(input_dim, intermediate_dim, knots_config=num_control_points),\n    tc.Sum(dim=-2),\n    # layer 2\n    tc.BSplineCurve(intermediate_dim, intermediate_dim, knots_config=num_control_points),\n    tc.Sum(dim=-2),\n    # layer 3\n    tc.BSplineCurve(intermediate_dim, 1, knots_config=num_control_points),\n    tc.Sum(dim=-2),\n)\n```\nYes, we know the original KAN paper used a different curve parametrization,\nB-spline + arcsinh, but the whole point of this repo is showing that KAN\nactivations can be parametrized in arbitrary ways.\n\nFor example, here is a KAN based on Legendre polynomials of degree 5:\n\n```python\nimport torchcurves as tc\nfrom torch import nn\n\ninput_dim = 2\nintermediate_dim = 5\ndegree = 5\n\nkan = nn.Sequential(\n    # layer 1\n    tc.LegendreCurve(input_dim, intermediate_dim, degree=degree),\n    tc.Sum(dim=-2),\n    # layer 2\n    tc.LegendreCurve(intermediate_dim, intermediate_dim, degree=degree),\n    tc.Sum(dim=-2),\n    # layer 3\n    tc.LegendreCurve(intermediate_dim, 1, degree=degree),\n    tc.Sum(dim=-2),\n)\n```\n\nSince KANs are the primary use case for the `tc.Sum()` layer, we can omit the `dim=-2` argument, but it is provided\nhere for clarity.\n\n## Advanced features\n\nThe curves in this library evaluate on compact parameter intervals.\n`input_map` is responsible for mapping raw inputs to that interval.\n\n- `LegendreCurve` always maps to `[-1, 1]`.\n- `BSplineBasis` and `BSplineCurve` map to their effective knot interval.\n- When `BSplineBasis` or `BSplineCurve` receives `knots_config` as an int, use\n  `parameter_range=(a, b)` to choose that interval explicitly.\n\n### Dotted preset strings\n\nUse dotted preset strings for the default built-in input maps:\n\n```python\ntc.BSplineCurve(num_curves, curve_dim, input_map=\"real.rational\")\ntc.BSplineCurve(num_curves, curve_dim, input_map=\"real.arctan\")\ntc.BSplineCurve(num_curves, curve_dim, input_map=\"real.clamp\")\ntc.BSplineBasis(knots_config=num_control_points, parameter_range=(0, 1), input_map=\"nonneg.rational\")\ntc.BSplineBasis(knots_config=num_control_points, parameter_range=(0, 1), input_map=\"nonneg.arctan\")\n```\n\n### Configured map objects\n\nUse `tc.maps` objects when you want a non-default scale:\n\n```python\ntc.BSplineCurve(num_curves, curve_dim, input_map=tc.maps.Real.rational(scale=s))\ntc.BSplineCurve(num_curves, curve_dim, input_map=tc.maps.Real.arctan(scale=s))\ntc.BSplineCurve(num_curves, curve_dim, input_map=tc.maps.Real.clamp(scale=s))\ntc.BSplineBasis(knots_config=num_control_points, parameter_range=(0, 1), input_map=tc.maps.Nonneg.arctan(scale=s))\n```\n\nThe default rational map computes\n\n```math\nx \\to \\frac{x}{\\sqrt{s^2 + x^2}},\n```\n\nand is based on the paper\n\u003eWang, Z.Q. and Guo, B.Y., 2004. Modified Legendre rational spectral method for the whole line. Journal of Computational Mathematics, pp.457-474.\n\nThe arctan map computes\n\n```math\nx \\to \\frac{2}{\\pi} \\arctan(x / s),\n```\n\nThe `nonneg.arctan` map uses the same formula after clamping the input below at `0`,\nso `0` maps to the left boundary and large values approach the right boundary.\n\nThe clamp map clips `x / s` to the designated interval.\n\n### Custom input maps\n\nProvide a callable with signature `f(x, out_min, out_max)`. Example:\n\n```python\nimport torch\n\ndef erf_map(scale: float = 1.0):\n    def input_map(x, out_min: float = -1, out_max: float = 1) -\u003e torch.Tensor:\n        mapped = torch.special.erf(x / scale)\n        return ((mapped + 1) * (out_max - out_min)) / 2 + out_min\n\n    return input_map\n\ntc.BSplineCurve(num_curves, curve_dim, input_map=erf_map(scale=s))\n```\n\n### Gradient checkpointing for Legendre curves\n\nFor large degrees, the backward pass can be memory-intensive. Use\n`checkpoint_segments` to trade compute for memory. Larger values create more\nsegments (lower memory, higher compute). Set to `None` to disable. Checkpointing\nis applied only when gradients are enabled.\n\n```python\n# Functional API\ntc.functional.legendre_curves(x, coeffs, checkpoint_segments=4)\n\n# Module API\ntc.LegendreCurve(num_curves, curve_dim, degree=degree, checkpoint_segments=4)\n```\n\n### Example: B-spline KAN with clamping\n\nA KAN based on a clamped B-spline basis with the default scale of $s=1$:\n\n```python\nimport torchcurves as tc\nfrom torch import nn\n\ninput_dim = 2\nintermediate_dim = 5\nnum_control_points = 10\n\nconfig = dict(knots_config=num_control_points, input_map=\"real.clamp\")\nspline_kan = nn.Sequential(\n    # layer 1\n    tc.BSplineCurve(input_dim, intermediate_dim, **config),\n    tc.Sum(),\n    # layer 2\n    tc.BSplineCurve(intermediate_dim, intermediate_dim, **config),\n    tc.Sum(),\n    # layer 3\n    tc.BSplineCurve(intermediate_dim, 1, **config),\n    tc.Sum(),\n)\n```\n\n### Legendre KAN with clamping\n\n```python\nimport torchcurves as tc\nfrom torch import nn\n\ninput_dim = 2\nintermediate_dim = 5\ndegree = 5\n\nconfig = dict(degree=degree, input_map=\"real.clamp\")\nkan = nn.Sequential(\n    # layer 1\n    tc.LegendreCurve(input_dim, intermediate_dim, **config),\n    tc.Sum(),\n    # layer 2\n    tc.LegendreCurve(intermediate_dim, intermediate_dim, **config),\n    tc.Sum(),\n    # layer 3\n    tc.LegendreCurve(intermediate_dim, 1, **config),\n    tc.Sum(),\n)\n```\n\n\n## Development\n\n## Development Installation\n\nUsing [uv](https://github.com/astral-sh/uv) (recommended):\n\n```bash\n# Clone the repository\ngit clone https://github.com/alexshtf/torchcurves.git\ncd torchcurves\n\n# Create virtual environment and install\nuv venv\nuv sync --all-groups\n```\n\n## Running Tests\n\n```bash\n# Run all tests\nuv run pytest\n\n# Run with coverage\nuv run pytest --cov=torchcurves\n\n# Run specific test file\nuv run pytest tests/test_bspline.py -v\n```\n\n## Performance Benchmarks\n\nThis project includes opt-in performance benchmarks (forward and backward passes) using `pytest-benchmark`.\n\nLocation: `benchmarks/`\n\nRun benchmarks:\n\n```bash\n# Run all benchmarks\nuv run pytest benchmarks -q\n\n# Or select only perf-marked tests if you mix them into tests/\nuv run pytest -m perf -q\n```\n\nCUDA timing notes: We synchronize before/after timed regions for accurate GPU timings.\n\nCompare runs and fail CI on regressions:\n\n```bash\n# Save a baseline\nuv run pytest benchmarks --benchmark-save=legendre_baseline\n\n# Compare current run to baseline (fail if mean slower by 10% or more)\nuv run pytest benchmarks --benchmark-compare --benchmark-compare-fail=mean:10%\n```\n\nExport results:\n\n```bash\nuv run pytest benchmarks --benchmark-json=bench.json\n```\n\n## Building the docs\n\n```bash\n# Prepare API docs\ncd doc\nmake html\n```\n\n## Citation\n\nIf you use this package in your research, please cite:\n\n```bibtex\n@software{torchcurves,\n  author = {Shtoff, Alex},\n  title = {torchcurves: Differentiable Parametric Curves in PyTorch},\n  year = {2025},\n  publisher = {GitHub},\n  url = {https://github.com/alexshtf/torchcurves}\n}\n```\n\n## Related software\n\nSeveral well-maintained PyTorch libraries use splines in practice. They mostly target *interpolation/resampling* or *geometric warping* rather than providing a generic, drop-in learnable parametric curve layer.\n\n### ND interpolation and resampling\n- **[torch-interpol](https://github.com/balbasty/torch-interpol)** (also on **[PyPI](https://pypi.org/project/torch-interpol/)**) implements high-order spline interpolation for **ND tensors** (e.g., 2D/3D images), with TorchScript acceleration and explicit forward/backward implementations. It is primarily designed for resampling under a sampling grid / deformation-field workflows, including dimension-specific interpolation orders and boundary handling (`bound`). *Best suited for resampling tensor data on fixed grids.*\n\n- **[xitorch – `Interp1D`](https://xitorch.readthedocs.io/en/latest/api/xitorch_interpolate/Interp1D.html)** (repo: **[xitorch/xitorch](https://github.com/xitorch/xitorch)**) provides differentiable **1D interpolation** including cubic splines (`method=\"cspline\"`) for non-uniform sample locations with configurable boundary conditions and extrapolation options. This is an interpolation primitive: you provide `(x, y)` samples and query at `xq`. *Designed as a functional primitive for data interpolation.*\n\n### Learnable continuous fields via grids\n- **[torch-cubic-spline-grids](https://github.com/alisterburt/torch-cubic-spline-grids)** (also on **[PyPI](https://pypi.org/project/torch-cubic-spline-grids/)**) provides learnable, continuous parametrisations of **1–4D spaces** using **uniform grids** whose coordinate system spans `[0, 1]` along each dimension. It supports both cubic **B-spline** grids (C2, not interpolating) and cubic **Catmull–Rom** grids (C1, interpolating), which are well suited to learning smooth spatial/temporal fields (e.g., deformation fields). *Targets dense continuous fields rather than curve trajectories.*\n\n### Thin-plate / polyharmonic spline warping\n- **[torch-tps](https://github.com/raphaelreme/torch-tps)** (also on **[PyPI](https://pypi.org/project/torch-tps/)**) implements generalized **polyharmonic spline** interpolation (thin-plate splines in 2D) for learning smooth mappings between Euclidean spaces from control point correspondences, with configurable spline order and regularization. *Specializes in spatial warping and point-set registration.*\n\n- **[Kornia](https://github.com/kornia/kornia)** includes TPS utilities such as `get_tps_transform` and `warp_image_tps` (see **[kornia.geometry.transform docs](https://kornia.readthedocs.io/en/latest/geometry.transform.html)**) as part of a larger differentiable computer vision and geometry toolkit, mainly targeting point/image warping operations. *Focuses on image geometry transforms.*\n\n## References\n\n[1]: Ziming Liu, Yixuan Wang, Sachin Vaidya, Fabian Ruehle, James Halverson, Marin Soljacic, Thomas Y. Hou, Max Tegmark. \"KAN: Kolmogorov–Arnold Networks.\" *ICLR* (2025). \\\n[2]: Juergen Schmidhuber. \"Learning to control fast-weight memories: An alternative to dynamic recurrent networks.\" *Neural Computation*, 4(1), pp.131-139. (1992) \\\n[3]: Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N. Gomez, Łukasz Kaiser, and Illia Polosukhin. \"Attention is all you need.\" *Advances in neural information processing systems* 30 (2017). \\\n[4]: Alex Shtoff, Elie Abboud, Rotem Stram, and Oren Somekh. \"Function Basis Encoding of Numerical Features in Factorization Machines.\" *Transactions on Machine Learning Research*. \\\n[5]: Rügamer, David. \"Scalable Higher-Order Tensor Product Spline Models.\" In *International Conference on Artificial Intelligence and Statistics*, pp. 1-9. PMLR, 2024. \\\n[6]: Steffen Rendle. \"Factorization machines.\" In *2010 IEEE International conference on data mining*, pp. 995-1000. IEEE, 2010.\n","project_url":"https://awesome.ecosyste.ms/api/v1/projects/github.com%2Falexshtf%2Ftorchcurves","html_url":"https://awesome.ecosyste.ms/projects/github.com%2Falexshtf%2Ftorchcurves","lists_url":"https://awesome.ecosyste.ms/api/v1/projects/github.com%2Falexshtf%2Ftorchcurves/lists"}