{"id":20962232,"url":"https://github.com/lartpang/mssim.pytorch","last_synced_at":"2026-03-09T19:40:57.660Z","repository":{"id":37639563,"uuid":"505836504","full_name":"lartpang/mssim.pytorch","owner":"lartpang","description":"A better pytorch-based implementation for the mean structural similarity. Differentiable simpler SSIM and MS-SSIM.","archived":false,"fork":false,"pushed_at":"2024-12-04T09:10:54.000Z","size":52,"stargazers_count":23,"open_issues_count":0,"forks_count":1,"subscribers_count":1,"default_branch":"main","last_synced_at":"2024-12-28T23:32:37.428Z","etag":null,"topics":["loss-function","loss-functions","ssim","ssim-loss","ssim-metric","ssim-metrics","ssim-pytorch","structure-similarity"],"latest_commit_sha":null,"homepage":"","language":"Python","has_issues":true,"has_wiki":null,"has_pages":null,"mirror_url":null,"source_name":null,"license":"mit","status":null,"scm":"git","pull_requests_enabled":true,"icon_url":"https://github.com/lartpang.png","metadata":{"files":{"readme":"README.md","changelog":null,"contributing":null,"funding":null,"license":"LICENSE","code_of_conduct":null,"threat_model":null,"audit":null,"citation":null,"codeowners":null,"security":null,"support":null}},"created_at":"2022-06-21T12:36:51.000Z","updated_at":"2024-12-23T18:06:02.000Z","dependencies_parsed_at":"2022-07-17T08:46:30.419Z","dependency_job_id":null,"html_url":"https://github.com/lartpang/mssim.pytorch","commit_stats":null,"previous_names":[],"tags_count":2,"template":false,"template_full_name":null,"repository_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories/lartpang%2Fmssim.pytorch","tags_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories/lartpang%2Fmssim.pytorch/tags","releases_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories/lartpang%2Fmssim.pytorch/releases","manifests_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories/lartpang%2Fmssim.pytorch/manifests","owner_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/owners/lartpang","download_url":"https://codeload.github.com/lartpang/mssim.pytorch/tar.gz/refs/heads/main","host":{"name":"GitHub","url":"https://github.com","kind":"github","repositories_count":233536889,"owners_count":18690812,"icon_url":"https://github.com/github.png","version":null,"created_at":"2022-05-30T11:31:42.601Z","updated_at":"2022-07-04T15:15:14.044Z","host_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub","repositories_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories","repository_names_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repository_names","owners_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/owners"}},"keywords":["loss-function","loss-functions","ssim","ssim-loss","ssim-metric","ssim-metrics","ssim-pytorch","structure-similarity"],"created_at":"2024-11-19T02:25:06.631Z","updated_at":"2025-09-19T00:32:29.356Z","avatar_url":"https://github.com/lartpang.png","language":"Python","funding_links":[],"categories":[],"sub_categories":[],"readme":"# mssim.pytorch\n\n$$\n\\begin{align}\nl(\\mathbf{x}, \\mathbf{y}) \u0026 = \\frac{2\\mu_x\\mu_y+C_1}{\\mu_x^2+\\mu_y^2+C_1}, C_1=(K_1L)^2, K_1=0.01, \\\\\nc(\\mathbf{x}, \\mathbf{y}) \u0026 = \\frac{2\\sigma_{x}\\sigma_{y}+C_2}{\\sigma_x^2+\\sigma_y^2+C_2}, C_2=(K_2L)^2, K_2=0.02, \\\\\ns(\\mathbf{x}, \\mathbf{y}) \u0026 = \\frac{\\sigma_{xy}+C_3}{\\sigma_x\\sigma_y+C_3}, C_3=C_2/2,  \\\\\n\\text{SSIM}(\\mathbf{x}, \\mathbf{y}) \u0026 = [l(\\mathbf{x}, \\mathbf{y})]^\\alpha \\cdot [c(\\mathbf{x}, \\mathbf{y})]^\\beta \\cdot [s(\\mathbf{x}, \\mathbf{y})]^\\gamma \\\\\n\u0026 = \\frac{(2\\mu_x\\mu_y+C_1)(2\\sigma_{xy}+C_2)}{(\\mu_x^2+\\mu_y^2+C_1)(\\sigma_x^2+\\sigma_y^2+C_2)}, \\\\\n\u0026 \\alpha=\\beta=\\gamma=1, \\\\\n\\text{MS-SSIM}(\\mathbf{x}, \\mathbf{y}) \u0026 = [l(\\mathbf{x}, \\mathbf{y})]^{\\alpha_{M}} \\cdot \\prod^{M}_{j=1} [c_j(\\mathbf{x}, \\mathbf{y})]^{\\beta_j} \\cdot [s_j(\\mathbf{x}, \\mathbf{y})]^{\\gamma_j}, (M=5) \\\\\n\u0026 \\beta_1=\\gamma_1=0.0448, \\\\\n\u0026 \\beta_2=\\gamma_2=0.2856, \\\\\n\u0026 \\beta_3=\\gamma_3=0.3001, \\\\\n\u0026 \\beta_4=\\gamma_4=0.2363, \\\\\n\u0026 \\alpha_5=\\beta_5=\\gamma_5=0.1333.\n\\end{align}\n$$\n\nA better pytorch-based implementation for the mean structural similarity (MSSIM).\n\nCompared to this widely used implementation: \u003chttps://github.com/Po-Hsun-Su/pytorch-ssim\u003e, I further optimized and refactored the code.\n\nAt the same time, in this implementation, I have dealt with the problem that the calculation with the fp16 mode cannot be consistent with the calculation with the fp32 mode. Typecasting is used here to ensure that the computation is done in fp32 mode. This might also avoid unexpected results when using it as a loss.\n\n\u003e [!note]\n\u003e 2024-12-04: SSIM for 1D, 2D and 3D data, and MS-SSIM calculation for 2D and 3D data are now supported simultaneously.\n\n| Setting         | SSIM1d         | SSIM2d                | SSIM3d                       | MS-SSIM2d             | MS-SSIM3d (**only pooling in the spatial domain**) |\n| --------------- | -------------- | --------------------- | ---------------------------- | --------------------- | -------------------------------------------------- |\n| data_dim        | 1              | 2 (Default)           | 3                            | 2                     | 3                                                  |\n| return_msssim   | `False`        | `False`               | `False`                      | `True`                | `True`                                             |\n| window_size     | int, [int]     | int, [int, int]       | int, [int, int, int]         | int, [int, int]       | int, [int, int, int]                               |\n| padding         | int, [int]     | int, [int, int]       | int, [int, int, int]         | int, [int, int]       | int, [int, int, int]                               |\n| sigma           | float, [float] | float, [float, float] | float, [float, float, float] | float, [float, float] | float, [float, float, float]                       |\n| in_channels     | int            | int                   | int                          | int                   | int                                                |\n| L               | 1, 255         | 1, 255                | 1, 255                       | 1, 255                | 1, 255                                             |\n| keep_batch_dim  | ✅              | ✅                     | ✅                            | ✅                     | ✅                                                  |\n| return_log      | ✅              | ✅                     | ✅                            | ❌                     | ❌                                                  |\n| ensemble_kernel | ✅              | ✅                     | ✅                            | ✅                     | ✅                                                  |\n\n## Structural similarity index\n\n\u003e When comparing images, the mean squared error (MSE)–while simple to implement–is not highly indicative of perceived similarity. Structural similarity aims to address this shortcoming by taking texture into account. More details can be seen at \u003chttps://scikit-image.org/docs/dev/auto_examples/transform/plot_ssim.html?highlight=structure+similarity\u003e\n\n![results](https://user-images.githubusercontent.com/26847524/175031400-92426661-4536-43c7-8f6e-5c470fb9ccb5.png)\n\n```python\nimport matplotlib.pyplot as plt\nimport numpy as np\nimport torch\nimport torch.nn.functional as F\nfrom lartpang_ssim import SSIM\nfrom po_hsun_su_ssim import SSIM as PoHsunSuSSIM\nfrom vainf_ssim import MS_SSIM as VainFMSSSIM\nfrom vainf_ssim import SSIM as VainFSSIM\nfrom skimage import data, img_as_float\n\nimg = img_as_float(data.camera())\nrows, cols = img.shape\n\nnoise = np.ones_like(img) * 0.3 * (img.max() - img.min())\nrng = np.random.default_rng()\nnoise[rng.random(size=noise.shape) \u003e 0.5] *= -1\n\nimg_noise = img + noise\nimg_const = np.zeros_like(img)\n\nimg_tensor = torch.from_numpy(img).unsqueeze(0).unsqueeze(0).float()\nimg_noise_tensor = torch.from_numpy(img_noise).unsqueeze(0).unsqueeze(0).float()\nimg_const_tensor = torch.from_numpy(img_const).unsqueeze(0).unsqueeze(0).float()\n\nfig, axes = plt.subplots(nrows=1, ncols=3, figsize=(15, 7))\nax = axes.ravel()\n\nmse_none = F.mse_loss(img_tensor, img_tensor, reduction=\"mean\")\nmse_noise = F.mse_loss(img_tensor, img_noise_tensor, reduction=\"mean\")\nmse_const = F.mse_loss(img_tensor, img_const_tensor, reduction=\"mean\")\n\n# https://github.com/VainF/pytorch-msssim\nvainf_ssim_none = VainFSSIM(channel=1, data_range=1)(img_tensor, img_tensor)\nvainf_ssim_noise = VainFSSIM(channel=1, data_range=1)(img_tensor, img_noise_tensor)\nvainf_ssim_const = VainFSSIM(channel=1, data_range=1)(img_tensor, img_const_tensor)\nvainf_ms_ssim_none = VainFMSSSIM(channel=1, data_range=1)(img_tensor, img_tensor)\nvainf_ms_ssim_noise = VainFMSSSIM(channel=1, data_range=1)(img_tensor, img_noise_tensor)\nvainf_ms_ssim_const = VainFMSSSIM(channel=1, data_range=1)(img_tensor, img_const_tensor)\n\n# use the settings of https://github.com/VainF/pytorch-msssim\nssim_none_0 = SSIM(return_msssim=False, L=1, padding=0, ensemble_kernel=False)(img_tensor, img_tensor)\nssim_noise_0 = SSIM(return_msssim=False, L=1, padding=0, ensemble_kernel=False)(img_tensor, img_noise_tensor)\nssim_const_0 = SSIM(return_msssim=False, L=1, padding=0, ensemble_kernel=False)(img_tensor, img_const_tensor)\nms_ssim_none_0 = SSIM(return_msssim=True, L=1, padding=0, ensemble_kernel=False)(img_tensor, img_tensor)\nms_ssim_noise_0 = SSIM(return_msssim=True, L=1, padding=0, ensemble_kernel=False)(img_tensor, img_noise_tensor)\nms_ssim_const_0 = SSIM(return_msssim=True, L=1, padding=0, ensemble_kernel=False)(img_tensor, img_const_tensor)\n\n# https://github.com/Po-Hsun-Su/pytorch-ssim\npohsunsu_ssim_none = PoHsunSuSSIM()(img_tensor, img_tensor)\npohsunsu_ssim_noise = PoHsunSuSSIM()(img_tensor, img_noise_tensor)\npohsunsu_ssim_const = PoHsunSuSSIM()(img_tensor, img_const_tensor)\n\n# use the settings of https://github.com/Po-Hsun-Su/pytorch-ssim\nssim_none_1 = SSIM(return_msssim=False, L=1, padding=None, ensemble_kernel=True)(img_tensor, img_tensor)\nssim_noise_1 = SSIM(return_msssim=False, L=1, padding=None, ensemble_kernel=True)(img_tensor, img_noise_tensor)\nssim_const_1 = SSIM(return_msssim=False, L=1, padding=None, ensemble_kernel=True)(img_tensor, img_const_tensor)\n\n\nax[0].imshow(img, cmap=plt.cm.gray, vmin=0, vmax=1)\nax[0].set_xlabel(\n    f\"MSE: {mse_none:.6f}\\n\"\n    f\"SSIM {ssim_none_0:.6f}, MS-SSIM {ms_ssim_none_0:.6f}\\n\"\n    f\"(VainF) SSIM: {vainf_ssim_none:.6f}, MS-SSIM {vainf_ms_ssim_none:.6f}\\n\"\n    f\"SSIM {ssim_none_1:.6f}\\n\"\n    f\"(PoHsunSu) SSIM: {pohsunsu_ssim_none:.6f}\\n\"\n)\nax[0].set_title(\"Original image\")\n\nax[1].imshow(img_noise, cmap=plt.cm.gray, vmin=0, vmax=1)\nax[1].set_xlabel(\n    f\"MSE: {mse_noise:.6f}\\n\"\n    f\"SSIM {ssim_noise_0:.6f}, MS-SSIM {ms_ssim_noise_0:.6f}\\n\"\n    f\"(VainF) SSIM: {vainf_ssim_noise:.6f}, MS-SSIM {vainf_ms_ssim_noise:.6f}\\n\"\n    f\"SSIM {ssim_noise_1:.6f}\\n\"\n    f\"(PoHsunSu) SSIM: {pohsunsu_ssim_noise:.6f}\\n\"\n)\nax[1].set_title(\"Image with noise\")\n\nax[2].imshow(img_const, cmap=plt.cm.gray, vmin=0, vmax=1)\nax[2].set_xlabel(\n    f\"MSE: {mse_const:.6f}\\n\"\n    f\"SSIM {ssim_const_0:.6f}, MS-SSIM {ms_ssim_const_0:.6f}\\n\"\n    f\"(VainF) SSIM: {vainf_ssim_const:.6f}, MS-SSIM {vainf_ms_ssim_const:.6f}\\n\"\n    f\"SSIM {ssim_const_1:.6f}\\n\"\n    f\"(PoHsunSu) SSIM: {pohsunsu_ssim_const:.6f}\\n\"\n)\nax[2].set_title(\"Image plus constant\")\n\n\n[ax[i].set(xticklabels=[], yticklabels=[], xticks=[], yticks=[]) for i in range(len(axes))]\n\nplt.tight_layout()\nplt.savefig(\"results.png\")\n```\n\n## More Examples\n\n```python\n# setting 4: for 4d float tensors with the data range [0, 1] and 1 channel,return the logarithmic form, and keep the batch dim\nssim_caller = SSIM(return_log=True, keep_batch_dim=True).cuda()\n\n# two 4d tensors\nx = torch.randn(3, 1, 100, 100).cuda()\ny = torch.randn(3, 1, 100, 100).cuda()\nssim_score_0 = ssim_caller(x, y)\n# or in the fp16 mode (we have fixed the computation progress into the float32 mode to avoid the unexpected result)\nwith torch.cuda.amp.autocast(enabled=True):\n    ssim_score_1 = ssim_caller(x, y)\nassert torch.allclose(ssim_score_0, ssim_score_1)\nprint(ssim_score_0.shape, ssim_score_1.shape)\n```\n\n## As A Loss\n\nAs you can see from the respective thresholds of the two cases below, it is easier to optimize towards MSSIM=1 than MSSIM=-1.\n\n### Optimize towards MSSIM=1\n\n![prediction](https://user-images.githubusercontent.com/26847524/174930091-9d7f7505-1752-423a-b7c3-d4dbfeb8d336.png)\n\n```python\nimport matplotlib.pyplot as plt\nimport torch\nfrom pytorch_ssim import SSIM\nfrom skimage import data\nfrom torch import optim\n\noriginal_image = data.moon() / 255\ntarget_image = torch.from_numpy(original_image).unsqueeze(0).unsqueeze(0).float().cuda()\npredicted_image = torch.zeros_like(\n    target_image, device=target_image.device, dtype=target_image.dtype, requires_grad=True\n)\ninitial_image = predicted_image.clone()\n\nssim = SSIM().cuda()\ninitial_ssim_value = ssim(predicted_image, target_image)\n\nssim_value = initial_ssim_value\noptimizer = optim.Adam([predicted_image], lr=0.01)\nloss_curves = []\nwhile ssim_value \u003c 0.999:\n    ssim_out = 1 - ssim(predicted_image, target_image)\n    loss_curves.append(ssim_out.item())\n    ssim_value = 1 - ssim_out.item()\n    print(ssim_value)\n    ssim_out.backward()\n    optimizer.step()\n    optimizer.zero_grad()\n\nfig, axes = plt.subplots(nrows=2, ncols=4, figsize=(8, 4))\nax = axes.ravel()\n\nax[0].imshow(original_image, cmap=plt.cm.gray, vmin=0, vmax=1)\nax[0].set_title(\"Original Image\")\n\nax[1].imshow(initial_image.squeeze().detach().cpu().numpy(), cmap=plt.cm.gray, vmin=0, vmax=1)\nax[1].set_xlabel(f\"SSIM: {initial_ssim_value:.5f}\")\nax[1].set_title(\"Initial Image\")\n\nax[2].imshow(predicted_image.squeeze().detach().cpu().numpy(), cmap=plt.cm.gray, vmin=0, vmax=1)\nax[2].set_xlabel(f\"SSIM: {ssim_value:.5f}\")\nax[2].set_title(\"Predicted Image\")\n\nax[3].plot(loss_curves)\nax[3].set_title(\"SSIM Loss Curve\")\n\nax[4].set_title(\"Original Image\")\nax[4].hist(original_image.ravel(), bins=256)\nax[4].ticklabel_format(axis=\"y\", style=\"scientific\", scilimits=(0, 0))\nax[4].set_xlabel(\"Pixel Intensity\")\n\nax[5].set_title(\"Initial Image\")\nax[5].hist(initial_image.squeeze().detach().cpu().numpy().ravel(), bins=256)\nax[5].ticklabel_format(axis=\"y\", style=\"scientific\", scilimits=(0, 0))\nax[5].set_xlabel(\"Pixel Intensity\")\n\nax[6].set_title(\"Predicted Image\")\nax[6].hist(predicted_image.squeeze().detach().cpu().numpy().ravel(), bins=256)\nax[6].ticklabel_format(axis=\"y\", style=\"scientific\", scilimits=(0, 0))\nax[6].set_xlabel(\"Pixel Intensity\")\n\nplt.tight_layout()\nplt.savefig(\"prediction.png\")\n```\n\n### Optimize towards MSSIM=-1\n\n![prediction](https://user-images.githubusercontent.com/26847524/174929574-5332cab2-104f-4aab-a4e5-35e7635a793f.png)\n\n```python\nimport matplotlib.pyplot as plt\nimport torch\nfrom pytorch_ssim import SSIM\nfrom skimage import data\nfrom torch import optim\n\noriginal_image = data.moon() / 255\ntarget_image = torch.from_numpy(original_image).unsqueeze(0).unsqueeze(0).float().cuda()\npredicted_image = torch.zeros_like(\n    target_image, device=target_image.device, dtype=target_image.dtype, requires_grad=True\n)\ninitial_image = predicted_image.clone()\n\nssim = SSIM(L=original_image.max() - original_image.min()).cuda()\ninitial_ssim_value = ssim(predicted_image, target_image)\n\nssim_value = initial_ssim_value\noptimizer = optim.Adam([predicted_image], lr=0.01)\nloss_curves = []\nwhile ssim_value \u003e -0.94:\n    ssim_out = ssim(predicted_image, target_image)\n    loss_curves.append(ssim_out.item())\n    ssim_value = ssim_out.item()\n    print(ssim_value)\n    ssim_out.backward()\n    optimizer.step()\n    optimizer.zero_grad()\n\nfig, axes = plt.subplots(nrows=2, ncols=4, figsize=(8, 4))\nax = axes.ravel()\n\nax[0].imshow(original_image, cmap=plt.cm.gray, vmin=0, vmax=1)\nax[0].set_title(\"Original Image\")\n\nax[1].imshow(initial_image.squeeze().detach().cpu().numpy(), cmap=plt.cm.gray, vmin=0, vmax=1)\nax[1].set_xlabel(f\"SSIM: {initial_ssim_value:.5f}\")\nax[1].set_title(\"Initial Image\")\n\nax[2].imshow(predicted_image.squeeze().detach().cpu().numpy(), cmap=plt.cm.gray, vmin=0, vmax=1)\nax[2].set_xlabel(f\"SSIM: {ssim_value:.5f}\")\nax[2].set_title(\"Predicted Image\")\n\nax[3].plot(loss_curves)\nax[3].set_title(\"SSIM Loss Curve\")\n\nax[4].set_title(\"Original Image\")\nax[4].hist(original_image.ravel(), bins=256)\nax[4].ticklabel_format(axis=\"y\", style=\"scientific\", scilimits=(0, 0))\nax[4].set_xlabel(\"Pixel Intensity\")\n\nax[5].set_title(\"Initial Image\")\nax[5].hist(initial_image.squeeze().detach().cpu().numpy().ravel(), bins=256)\nax[5].ticklabel_format(axis=\"y\", style=\"scientific\", scilimits=(0, 0))\nax[5].set_xlabel(\"Pixel Intensity\")\n\nax[6].set_title(\"Predicted Image\")\nax[6].hist(predicted_image.squeeze().detach().cpu().numpy().ravel(), bins=256)\nax[6].ticklabel_format(axis=\"y\", style=\"scientific\", scilimits=(0, 0))\nax[6].set_xlabel(\"Pixel Intensity\")\n\nplt.tight_layout()\nplt.savefig(\"prediction.png\")\n```\n\n## Reference\n\n* \u003chttps://github.com/Po-Hsun-Su/pytorch-ssim\u003e\n* \u003chttps://github.com/VainF/pytorch-msssim\u003e\n* \u003chttps://scikit-image.org/docs/dev/auto_examples/transform/plot_ssim.html?highlight=structure+similarity\u003e\n* Z. Wang, A. C. Bovik, H. R. Sheikh and E. P. Simoncelli, “Image quality assessment: From error visibility to structural similarity,” IEEE Transactions on Image Processing, vol. 13, no. 4, pp. 600-612, Apr. 2004.\n\n## Cite\n\nIf you find this library useful, please cite our bibtex:\n\n```bibtex\n@online{mssim.pytorch,\n    author=\"lartpang\",\n    title=\"{A better pytorch-based implementation for the mean structural similarity. Differentiable simpler SSIM and MS-SSIM.}\",\n    url=\"https://github.com/lartpang/mssim.pytorch\",\n    note=\"(Jun 21, 2022)\",\n}\n```\n","project_url":"https://awesome.ecosyste.ms/api/v1/projects/github.com%2Flartpang%2Fmssim.pytorch","html_url":"https://awesome.ecosyste.ms/projects/github.com%2Flartpang%2Fmssim.pytorch","lists_url":"https://awesome.ecosyste.ms/api/v1/projects/github.com%2Flartpang%2Fmssim.pytorch/lists"}