https://github.com/deepmodeling/jax-fem
Differentiable Finite Element Method with JAX
https://github.com/deepmodeling/jax-fem
differentiable-programming finite-element-methods jax topology-optimization
Last synced: about 1 month ago
JSON representation
Differentiable Finite Element Method with JAX
- Host: GitHub
- URL: https://github.com/deepmodeling/jax-fem
- Owner: deepmodeling
- License: gpl-3.0
- Created: 2023-10-02T15:20:54.000Z (over 1 year ago)
- Default Branch: main
- Last Pushed: 2025-03-18T07:49:13.000Z (3 months ago)
- Last Synced: 2025-04-07T04:09:07.440Z (2 months ago)
- Topics: differentiable-programming, finite-element-methods, jax, topology-optimization
- Language: Python
- Homepage:
- Size: 79.1 MB
- Stars: 363
- Watchers: 12
- Forks: 58
- Open Issues: 31
-
Metadata Files:
- Readme: README.md
- License: LICENSE
Awesome Lists containing this project
- deepmodeling-projects - GitHub - 54% open · ⏱️ 31.05.2025): (Others)
- awesome-jax - jax-fem - Differentiable Finite Element Method with JAX. <img src="https://img.shields.io/github/stars/deepmodeling/jax-fem?style=social" align="center"> (Libraries)
- awesome-jax - jax-fem - Differentiable Finite Element Method with JAX. <img src="https://img.shields.io/github/stars/deepmodeling/jax-fem?style=social" align="center"> (Libraries)
README
A GPU-accelerated differentiable finite element analysis package based on [JAX](https://github.com/google/jax). Used to be part of the suite of open-source python packages for Additive Manufacturing (AM) research, [JAX-AM](https://github.com/tianjuxue/jax-am).
## Finite Element Method (FEM)


FEM is a powerful tool, where we support the following features
- 2D quadrilateral/triangle elements
- 3D hexahedron/tetrahedron elements
- First and second order elements
- Dirichlet/Neumann/Robin boundary conditions
- Linear and nonlinear analysis including
- Heat equation
- Linear elasticity
- Hyperelasticity
- Plasticity (macro and crystal plasticity)
- Differentiable simulation for solving inverse/design problems __without__ deriving sensitivities by hand, e.g.,
- Topology optimization
- Optimal thermal control
- Integration with PETSc for solver choices**Updates** (Dec 11, 2023):
- We now support multi-physics problems in the sense that multiple variables can be solved monolithically. For example, consider running `python -m applications.stokes.example`
- Weak form is now defined through volume integral and surface integral. We can now treat body force, "mass kernel" and "Laplace kernel" in a unified way through volume integral, and treat "Neumann B.C." and "Robin B.C." in a unified way through surface integral.
![]()
Thermal profile in direct energy deposition.
![]()
Linear static analysis of a bracket.
![]()
![]()
Crystal plasticity: grain structure (left) and stress-xx (right).
![]()
![]()
Stokes flow: velocity (left) and pressure(right).
![]()
Topology optimization with differentiable simulation.## Installation
Create a conda environment from the given [`environment.yml`](https://github.com/deepmodeling/jax-fem/blob/main/environment.yml) file and activate it:
```bash
conda env create -f environment.yml
conda activate jax-fem-env
```Install JAX
- See jax installation [instructions](https://github.com/jax-ml/jax?tab=readme-ov-file#installation). Depending on your hardware, you may install the CPU or GPU version of JAX. Both will work, while GPU version usually gives better performance.Then there are two options to continue:
### Option 1
Clone the repository:
```bash
git clone https://github.com/deepmodeling/jax-fem.git
cd jax-fem
```and install the package locally:
```bash
pip install -e .
```**Quick tests**: You can check `demos/` for a variety of FEM cases. For example, run
```bash
python -m demos.hyperelasticity.example
```for hyperelasticity.
Also,
```bash
python -m tests.benchmarks
```will execute a set of test cases.
### Option 2
Install the package from the [PyPI release](https://pypi.org/project/jax-fem/) directly:
```bash
pip install jax-fem
```**Quick tests**: You can create an `example.py` file and run it:
```bash
python example.py
``````python
import jax
import jax.numpy as np
import osfrom jax_fem.problem import Problem
from jax_fem.solver import solver
from jax_fem.utils import save_sol
from jax_fem.generate_mesh import get_meshio_cell_type, Mesh, rectangle_meshclass Poisson(Problem):
def get_tensor_map(self):
return lambda x: xdef get_mass_map(self):
def mass_map(u, x):
val = -np.array([10*np.exp(-(np.power(x[0] - 0.5, 2) + np.power(x[1] - 0.5, 2)) / 0.02)])
return val
return mass_mapele_type = 'QUAD4'
cell_type = get_meshio_cell_type(ele_type)
Lx, Ly = 1., 1.
meshio_mesh = rectangle_mesh(Nx=32, Ny=32, domain_x=Lx, domain_y=Ly)
mesh = Mesh(meshio_mesh.points, meshio_mesh.cells_dict[cell_type])def left(point):
return np.isclose(point[0], 0., atol=1e-5)def right(point):
return np.isclose(point[0], Lx, atol=1e-5)def bottom(point):
return np.isclose(point[1], 0., atol=1e-5)def top(point):
return np.isclose(point[1], Ly, atol=1e-5)def dirichlet_val(point):
return 0.location_fns = [left, right, bottom, top]
value_fns = [dirichlet_val]*4
vecs = [0]*4
dirichlet_bc_info = [location_fns, vecs, value_fns]problem = Poisson(mesh=mesh, vec=1, dim=2, ele_type=ele_type, dirichlet_bc_info=dirichlet_bc_info)
sol = solver(problem)data_dir = os.path.join(os.path.dirname(__file__), 'data')
vtk_path = os.path.join(data_dir, f'vtk/u.vtu')
save_sol(problem.fes[0], sol[0], vtk_path)
```By running the code above and use [Paraview](https://www.paraview.org/) for visualization, you should see the following solution.
![]()
Solution to the Poisson's equation due to a source term.## Tutorial
| Example | Highlight |
| ------------------------------------------------------------ | ------------------------------------------------------------ |
| [poisson](https://github.com/deepmodeling/jax-fem/tree/main/demos/poisson) | $${\color{green}Basics:}$$ Poisson's equation in a unit square domain with Dirichlet and Neumann boundary conditions, as well as a source term. |
| [linear_elasticity](https://github.com/deepmodeling/jax-fem/tree/main/demos/linear_elasticity) | $${\color{green}Basics:}$$ Bending of a linear elastic beam due to Dirichlet and Neumann boundary conditions. Second order tetrahedral element (TET10) is used. |
| [hyperelasticity](https://github.com/deepmodeling/jax-fem/tree/main/demos/hyperelasticity) | $${\color{blue}Nonlinear \space Constitutive \space Law:}$$ Deformation of a hyperelastic cube due to Dirichlet boundary conditions. |
| [plasticity](https://github.com/deepmodeling/jax-fem/tree/main/demos/plasticity) | $${\color{blue}Nonlinear \space Constitutive \space Law:}$$ Perfect J2-plasticity model is implemented for small deformation theory. |
| [phase_field_fracture](https://github.com/deepmodeling/jax-fem/tree/main/demos/phase_field_fracture) | $${\color{orange}Multi-physics \space Coupling:}$$ Phase field fracture model is implemented. Staggered scheme is used for two-way coupling of displacement field and damage field. Miehe's model of spectral decomposition is implemented for a 3D case. |
| [thermal_mechanical](https://github.com/deepmodeling/jax-fem/tree/main/demos/thermal_mechanical) | $${\color{orange}Multi-physics \space Coupling:}$$ Thermal-mechanical modeling of metal additive manufacturing process. One-way coupling is implemented (temperature affects displacement). |
| [thermal_mechanical_full](https://github.com/deepmodeling/jax-fem/tree/main/demos/thermal_mechanical_full) | $${\color{orange}Multi-physics \space Coupling:}$$ Thermal-mechanical modeling of 2D plate. Two-way coupling (temperature and displacement) is implemented with a monolithic scheme. |
| [wave](https://github.com/deepmodeling/jax-fem/tree/main/demos/wave) | $${\color{lightblue}Time \space Dependent \space Problem:}$$ The scalar wave equation is solved with backward difference scheme. |
| [topology_optimization](https://github.com/deepmodeling/jax-fem/tree/main/demos/topology_optimization) | $${\color{red}Inverse \space Problem:}$$ SIMP topology optimization for a 2D beam. Note that sensitivity analysis is done by the program, rather than manual derivation. |
| [inverse](https://github.com/deepmodeling/jax-fem/tree/main/demos/inverse) | $${\color{red}Inverse \space Problem:}$$ Sanity check of how automatic differentiation works. |
| [plasticity_gradient](https://github.com/deepmodeling/jax-fem/tree/main/applications/plasticity_gradient) | $${\color{red}Inverse \space Problem:}$$ Automatic sensitivity analysis involving history variables such as plasticity. |## License
This project is licensed under the GNU General Public License v3 - see the [LICENSE](https://www.gnu.org/licenses/) for details.
## Citations
If you found this library useful in academic or industry work, we appreciate your support if you consider 1) starring the project on Github, and 2) citing relevant papers:
```bibtex
@article{xue2023jax,
title={JAX-FEM: A differentiable GPU-accelerated 3D finite element solver for automatic inverse design and mechanistic data science},
author={Xue, Tianju and Liao, Shuheng and Gan, Zhengtao and Park, Chanwook and Xie, Xiaoyu and Liu, Wing Kam and Cao, Jian},
journal={Computer Physics Communications},
pages={108802},
year={2023},
publisher={Elsevier}
}
```