{"id":43021681,"url":"https://github.com/usuyama/pytorch-unet","last_synced_at":"2026-01-31T06:37:14.040Z","repository":{"id":55836986,"uuid":"138813086","full_name":"usuyama/pytorch-unet","owner":"usuyama","description":"Simple PyTorch implementations of U-Net/FullyConvNet (FCN) for image segmentation","archived":false,"fork":false,"pushed_at":"2020-08-21T20:31:23.000Z","size":365,"stargazers_count":775,"open_issues_count":8,"forks_count":228,"subscribers_count":10,"default_branch":"master","last_synced_at":"2023-11-07T13:15:59.042Z","etag":null,"topics":["fully-convolutional-networks","image-segmentation","semantic-segmentation","unet"],"latest_commit_sha":null,"homepage":"https://colab.research.google.com/github/usuyama/pytorch-unet/blob/master/pytorch_unet_resnet18_colab.ipynb","language":"Jupyter Notebook","has_issues":true,"has_wiki":null,"has_pages":null,"mirror_url":null,"source_name":null,"license":"mit","status":null,"scm":"git","pull_requests_enabled":true,"icon_url":"https://github.com/usuyama.png","metadata":{"files":{"readme":"README.md","changelog":null,"contributing":null,"funding":null,"license":"LICENSE","code_of_conduct":null,"threat_model":null,"audit":null,"citation":null,"codeowners":null,"security":null,"support":null}},"created_at":"2018-06-27T01:24:14.000Z","updated_at":"2023-11-07T09:29:07.000Z","dependencies_parsed_at":"2022-08-15T07:40:35.298Z","dependency_job_id":null,"html_url":"https://github.com/usuyama/pytorch-unet","commit_stats":null,"previous_names":[],"tags_count":0,"template":null,"template_full_name":null,"purl":"pkg:github/usuyama/pytorch-unet","repository_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories/usuyama%2Fpytorch-unet","tags_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories/usuyama%2Fpytorch-unet/tags","releases_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories/usuyama%2Fpytorch-unet/releases","manifests_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories/usuyama%2Fpytorch-unet/manifests","owner_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/owners/usuyama","download_url":"https://codeload.github.com/usuyama/pytorch-unet/tar.gz/refs/heads/master","sbom_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories/usuyama%2Fpytorch-unet/sbom","scorecard":null,"host":{"name":"GitHub","url":"https://github.com","kind":"github","repositories_count":286080680,"owners_count":28931363,"icon_url":"https://github.com/github.png","version":null,"created_at":"2022-05-30T11:31:42.601Z","updated_at":"2026-01-31T04:05:25.756Z","status":"ssl_error","status_checked_at":"2026-01-31T04:02:35.005Z","response_time":128,"last_error":"SSL_read: unexpected eof while reading","robots_txt_status":"success","robots_txt_updated_at":"2025-07-24T06:49:26.215Z","robots_txt_url":"https://github.com/robots.txt","online":false,"can_crawl_api":true,"host_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub","repositories_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories","repository_names_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repository_names","owners_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/owners"}},"keywords":["fully-convolutional-networks","image-segmentation","semantic-segmentation","unet"],"created_at":"2026-01-31T06:37:13.403Z","updated_at":"2026-01-31T06:37:14.032Z","avatar_url":"https://github.com/usuyama.png","language":"Jupyter Notebook","funding_links":[],"categories":[],"sub_categories":[],"readme":"\n# UNet/FCN PyTorch [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/usuyama/pytorch-unet/blob/master/pytorch_unet_resnet18_colab.ipynb) \n\nThis repository contains simple PyTorch implementations of U-Net and FCN, which are deep learning segmentation methods proposed by Ronneberger et al. and Long et al.\n\n- [U-Net: Convolutional Networks for Biomedical Image Segmentation](https://lmb.informatik.uni-freiburg.de/people/ronneber/u-net/)\n- [Fully Convolutional Networks for Semantic Segmentation](https://people.eecs.berkeley.edu/~jonlong/long_shelhamer_fcn.pdf)\n\n# Synthetic images/masks for training\n\nFirst clone the repository and cd into the project directory.\n\n```python\nimport matplotlib.pyplot as plt\nimport numpy as np\nimport helper\nimport simulation\n\n# Generate some random images\ninput_images, target_masks = simulation.generate_random_data(192, 192, count=3)\n\nfor x in [input_images, target_masks]:\n    print(x.shape)\n    print(x.min(), x.max())\n\n# Change channel-order and make 3 channels for matplot\ninput_images_rgb = [x.astype(np.uint8) for x in input_images]\n\n# Map each channel (i.e. class) to each color\ntarget_masks_rgb = [helper.masks_to_colorimg(x) for x in target_masks]\n\n# Left: Input image (black and white), Right: Target mask (6ch)\nhelper.plot_side_by_side([input_images_rgb, target_masks_rgb])\n```\n\n## Left: Input image (black and white), Right: Target mask (6ch)\n![png](https://raw.githubusercontent.com/usuyama/pytorch-unet/master/images/output_0_1.png)\n\n\n## Prepare Dataset and DataLoader\n```python\nfrom torch.utils.data import Dataset, DataLoader\nfrom torchvision import transforms, datasets, models\n\nclass SimDataset(Dataset):\n    def __init__(self, count, transform=None):\n        self.input_images, self.target_masks = simulation.generate_random_data(192, 192, count=count)\n        self.transform = transform\n\n    def __len__(self):\n        return len(self.input_images)\n\n    def __getitem__(self, idx):\n        image = self.input_images[idx]\n        mask = self.target_masks[idx]\n        if self.transform:\n            image = self.transform(image)\n\n        return [image, mask]\n\n# use the same transformations for train/val in this example\ntrans = transforms.Compose([\n    transforms.ToTensor(),\n    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) # imagenet\n])\n\ntrain_set = SimDataset(2000, transform = trans)\nval_set = SimDataset(200, transform = trans)\n\nimage_datasets = {\n    'train': train_set, 'val': val_set\n}\n\nbatch_size = 25\n\ndataloaders = {\n    'train': DataLoader(train_set, batch_size=batch_size, shuffle=True, num_workers=0),\n    'val': DataLoader(val_set, batch_size=batch_size, shuffle=True, num_workers=0)\n}\n```\n\n## Check the outputs from DataLoader\n```python\nimport torchvision.utils\n\ndef reverse_transform(inp):\n    inp = inp.numpy().transpose((1, 2, 0))\n    mean = np.array([0.485, 0.456, 0.406])\n    std = np.array([0.229, 0.224, 0.225])\n    inp = std * inp + mean\n    inp = np.clip(inp, 0, 1)\n    inp = (inp * 255).astype(np.uint8)\n\n    return inp\n\n# Get a batch of training data\ninputs, masks = next(iter(dataloaders['train']))\n\nprint(inputs.shape, masks.shape)\n\nplt.imshow(reverse_transform(inputs[3]))\n```\n\n    torch.Size([25, 3, 192, 192]) torch.Size([25, 6, 192, 192])\n\n\n![png](https://raw.githubusercontent.com/usuyama/pytorch-unet/master/images/output_2_2.png)\n\n\n\n# Create the UNet module\n\n```python\nimport torch\nimport torch.nn as nn\nfrom torchvision import models\n\ndef convrelu(in_channels, out_channels, kernel, padding):\n    return nn.Sequential(\n        nn.Conv2d(in_channels, out_channels, kernel, padding=padding),\n        nn.ReLU(inplace=True),\n    )\n\n\nclass ResNetUNet(nn.Module):\n    def __init__(self, n_class):\n        super().__init__()\n\n        self.base_model = models.resnet18(pretrained=True)\n        self.base_layers = list(self.base_model.children())\n\n        self.layer0 = nn.Sequential(*self.base_layers[:3]) # size=(N, 64, x.H/2, x.W/2)\n        self.layer0_1x1 = convrelu(64, 64, 1, 0)\n        self.layer1 = nn.Sequential(*self.base_layers[3:5]) # size=(N, 64, x.H/4, x.W/4)\n        self.layer1_1x1 = convrelu(64, 64, 1, 0)\n        self.layer2 = self.base_layers[5]  # size=(N, 128, x.H/8, x.W/8)\n        self.layer2_1x1 = convrelu(128, 128, 1, 0)\n        self.layer3 = self.base_layers[6]  # size=(N, 256, x.H/16, x.W/16)\n        self.layer3_1x1 = convrelu(256, 256, 1, 0)\n        self.layer4 = self.base_layers[7]  # size=(N, 512, x.H/32, x.W/32)\n        self.layer4_1x1 = convrelu(512, 512, 1, 0)\n\n        self.upsample = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)\n\n        self.conv_up3 = convrelu(256 + 512, 512, 3, 1)\n        self.conv_up2 = convrelu(128 + 512, 256, 3, 1)\n        self.conv_up1 = convrelu(64 + 256, 256, 3, 1)\n        self.conv_up0 = convrelu(64 + 256, 128, 3, 1)\n\n        self.conv_original_size0 = convrelu(3, 64, 3, 1)\n        self.conv_original_size1 = convrelu(64, 64, 3, 1)\n        self.conv_original_size2 = convrelu(64 + 128, 64, 3, 1)\n\n        self.conv_last = nn.Conv2d(64, n_class, 1)\n\n    def forward(self, input):\n        x_original = self.conv_original_size0(input)\n        x_original = self.conv_original_size1(x_original)\n\n        layer0 = self.layer0(input)\n        layer1 = self.layer1(layer0)\n        layer2 = self.layer2(layer1)\n        layer3 = self.layer3(layer2)\n        layer4 = self.layer4(layer3)\n\n        layer4 = self.layer4_1x1(layer4)\n        x = self.upsample(layer4)\n        layer3 = self.layer3_1x1(layer3)\n        x = torch.cat([x, layer3], dim=1)\n        x = self.conv_up3(x)\n\n        x = self.upsample(x)\n        layer2 = self.layer2_1x1(layer2)\n        x = torch.cat([x, layer2], dim=1)\n        x = self.conv_up2(x)\n\n        x = self.upsample(x)\n        layer1 = self.layer1_1x1(layer1)\n        x = torch.cat([x, layer1], dim=1)\n        x = self.conv_up1(x)\n\n        x = self.upsample(x)\n        layer0 = self.layer0_1x1(layer0)\n        x = torch.cat([x, layer0], dim=1)\n        x = self.conv_up0(x)\n\n        x = self.upsample(x)\n        x = torch.cat([x, x_original], dim=1)\n        x = self.conv_original_size2(x)\n\n        out = self.conv_last(x)\n\n        return out\n```\n\n## Model summary\n```python\ndevice = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\nmodel = ResNetUNet(n_class=6)\nmodel = model.to(device)\n\n# check keras-like model summary using torchsummary\nfrom torchsummary import summary\nsummary(model, input_size=(3, 224, 224))\n```\n\n    ----------------------------------------------------------------\n            Layer (type)               Output Shape         Param #\n    ================================================================\n                Conv2d-1         [-1, 64, 224, 224]           1,792\n                  ReLU-2         [-1, 64, 224, 224]               0\n                Conv2d-3         [-1, 64, 224, 224]          36,928\n                  ReLU-4         [-1, 64, 224, 224]               0\n                Conv2d-5         [-1, 64, 112, 112]           9,408\n           BatchNorm2d-6         [-1, 64, 112, 112]             128\n                  ReLU-7         [-1, 64, 112, 112]               0\n             MaxPool2d-8           [-1, 64, 56, 56]               0\n                Conv2d-9           [-1, 64, 56, 56]           4,096\n          BatchNorm2d-10           [-1, 64, 56, 56]             128\n                 ReLU-11           [-1, 64, 56, 56]               0\n               Conv2d-12           [-1, 64, 56, 56]          36,864\n          BatchNorm2d-13           [-1, 64, 56, 56]             128\n                 ReLU-14           [-1, 64, 56, 56]               0\n               Conv2d-15          [-1, 256, 56, 56]          16,384\n          BatchNorm2d-16          [-1, 256, 56, 56]             512\n               Conv2d-17          [-1, 256, 56, 56]          16,384\n          BatchNorm2d-18          [-1, 256, 56, 56]             512\n                 ReLU-19          [-1, 256, 56, 56]               0\n           Bottleneck-20          [-1, 256, 56, 56]               0\n               Conv2d-21           [-1, 64, 56, 56]          16,384\n          BatchNorm2d-22           [-1, 64, 56, 56]             128\n                 ReLU-23           [-1, 64, 56, 56]               0\n               Conv2d-24           [-1, 64, 56, 56]          36,864\n          BatchNorm2d-25           [-1, 64, 56, 56]             128\n                 ReLU-26           [-1, 64, 56, 56]               0\n               Conv2d-27          [-1, 256, 56, 56]          16,384\n          BatchNorm2d-28          [-1, 256, 56, 56]             512\n                 ReLU-29          [-1, 256, 56, 56]               0\n           Bottleneck-30          [-1, 256, 56, 56]               0\n               Conv2d-31           [-1, 64, 56, 56]          16,384\n          BatchNorm2d-32           [-1, 64, 56, 56]             128\n                 ReLU-33           [-1, 64, 56, 56]               0\n               Conv2d-34           [-1, 64, 56, 56]          36,864\n          BatchNorm2d-35           [-1, 64, 56, 56]             128\n                 ReLU-36           [-1, 64, 56, 56]               0\n               Conv2d-37          [-1, 256, 56, 56]          16,384\n          BatchNorm2d-38          [-1, 256, 56, 56]             512\n                 ReLU-39          [-1, 256, 56, 56]               0\n           Bottleneck-40          [-1, 256, 56, 56]               0\n               Conv2d-41          [-1, 128, 56, 56]          32,768\n          BatchNorm2d-42          [-1, 128, 56, 56]             256\n                 ReLU-43          [-1, 128, 56, 56]               0\n               Conv2d-44          [-1, 128, 28, 28]         147,456\n          BatchNorm2d-45          [-1, 128, 28, 28]             256\n                 ReLU-46          [-1, 128, 28, 28]               0\n               Conv2d-47          [-1, 512, 28, 28]          65,536\n          BatchNorm2d-48          [-1, 512, 28, 28]           1,024\n               Conv2d-49          [-1, 512, 28, 28]         131,072\n          BatchNorm2d-50          [-1, 512, 28, 28]           1,024\n                 ReLU-51          [-1, 512, 28, 28]               0\n           Bottleneck-52          [-1, 512, 28, 28]               0\n               Conv2d-53          [-1, 128, 28, 28]          65,536\n          BatchNorm2d-54          [-1, 128, 28, 28]             256\n                 ReLU-55          [-1, 128, 28, 28]               0\n               Conv2d-56          [-1, 128, 28, 28]         147,456\n          BatchNorm2d-57          [-1, 128, 28, 28]             256\n                 ReLU-58          [-1, 128, 28, 28]               0\n               Conv2d-59          [-1, 512, 28, 28]          65,536\n          BatchNorm2d-60          [-1, 512, 28, 28]           1,024\n                 ReLU-61          [-1, 512, 28, 28]               0\n           Bottleneck-62          [-1, 512, 28, 28]               0\n               Conv2d-63          [-1, 128, 28, 28]          65,536\n          BatchNorm2d-64          [-1, 128, 28, 28]             256\n                 ReLU-65          [-1, 128, 28, 28]               0\n               Conv2d-66          [-1, 128, 28, 28]         147,456\n          BatchNorm2d-67          [-1, 128, 28, 28]             256\n                 ReLU-68          [-1, 128, 28, 28]               0\n               Conv2d-69          [-1, 512, 28, 28]          65,536\n          BatchNorm2d-70          [-1, 512, 28, 28]           1,024\n                 ReLU-71          [-1, 512, 28, 28]               0\n           Bottleneck-72          [-1, 512, 28, 28]               0\n               Conv2d-73          [-1, 128, 28, 28]          65,536\n          BatchNorm2d-74          [-1, 128, 28, 28]             256\n                 ReLU-75          [-1, 128, 28, 28]               0\n               Conv2d-76          [-1, 128, 28, 28]         147,456\n          BatchNorm2d-77          [-1, 128, 28, 28]             256\n                 ReLU-78          [-1, 128, 28, 28]               0\n               Conv2d-79          [-1, 512, 28, 28]          65,536\n          BatchNorm2d-80          [-1, 512, 28, 28]           1,024\n                 ReLU-81          [-1, 512, 28, 28]               0\n           Bottleneck-82          [-1, 512, 28, 28]               0\n               Conv2d-83          [-1, 256, 28, 28]         131,072\n          BatchNorm2d-84          [-1, 256, 28, 28]             512\n                 ReLU-85          [-1, 256, 28, 28]               0\n               Conv2d-86          [-1, 256, 14, 14]         589,824\n          BatchNorm2d-87          [-1, 256, 14, 14]             512\n                 ReLU-88          [-1, 256, 14, 14]               0\n               Conv2d-89         [-1, 1024, 14, 14]         262,144\n          BatchNorm2d-90         [-1, 1024, 14, 14]           2,048\n               Conv2d-91         [-1, 1024, 14, 14]         524,288\n          BatchNorm2d-92         [-1, 1024, 14, 14]           2,048\n                 ReLU-93         [-1, 1024, 14, 14]               0\n           Bottleneck-94         [-1, 1024, 14, 14]               0\n               Conv2d-95          [-1, 256, 14, 14]         262,144\n          BatchNorm2d-96          [-1, 256, 14, 14]             512\n                 ReLU-97          [-1, 256, 14, 14]               0\n               Conv2d-98          [-1, 256, 14, 14]         589,824\n          BatchNorm2d-99          [-1, 256, 14, 14]             512\n                ReLU-100          [-1, 256, 14, 14]               0\n              Conv2d-101         [-1, 1024, 14, 14]         262,144\n         BatchNorm2d-102         [-1, 1024, 14, 14]           2,048\n                ReLU-103         [-1, 1024, 14, 14]               0\n          Bottleneck-104         [-1, 1024, 14, 14]               0\n              Conv2d-105          [-1, 256, 14, 14]         262,144\n         BatchNorm2d-106          [-1, 256, 14, 14]             512\n                ReLU-107          [-1, 256, 14, 14]               0\n              Conv2d-108          [-1, 256, 14, 14]         589,824\n         BatchNorm2d-109          [-1, 256, 14, 14]             512\n                ReLU-110          [-1, 256, 14, 14]               0\n              Conv2d-111         [-1, 1024, 14, 14]         262,144\n         BatchNorm2d-112         [-1, 1024, 14, 14]           2,048\n                ReLU-113         [-1, 1024, 14, 14]               0\n          Bottleneck-114         [-1, 1024, 14, 14]               0\n              Conv2d-115          [-1, 256, 14, 14]         262,144\n         BatchNorm2d-116          [-1, 256, 14, 14]             512\n                ReLU-117          [-1, 256, 14, 14]               0\n              Conv2d-118          [-1, 256, 14, 14]         589,824\n         BatchNorm2d-119          [-1, 256, 14, 14]             512\n                ReLU-120          [-1, 256, 14, 14]               0\n              Conv2d-121         [-1, 1024, 14, 14]         262,144\n         BatchNorm2d-122         [-1, 1024, 14, 14]           2,048\n                ReLU-123         [-1, 1024, 14, 14]               0\n          Bottleneck-124         [-1, 1024, 14, 14]               0\n              Conv2d-125          [-1, 256, 14, 14]         262,144\n         BatchNorm2d-126          [-1, 256, 14, 14]             512\n                ReLU-127          [-1, 256, 14, 14]               0\n              Conv2d-128          [-1, 256, 14, 14]         589,824\n         BatchNorm2d-129          [-1, 256, 14, 14]             512\n                ReLU-130          [-1, 256, 14, 14]               0\n              Conv2d-131         [-1, 1024, 14, 14]         262,144\n         BatchNorm2d-132         [-1, 1024, 14, 14]           2,048\n                ReLU-133         [-1, 1024, 14, 14]               0\n          Bottleneck-134         [-1, 1024, 14, 14]               0\n              Conv2d-135          [-1, 256, 14, 14]         262,144\n         BatchNorm2d-136          [-1, 256, 14, 14]             512\n                ReLU-137          [-1, 256, 14, 14]               0\n              Conv2d-138          [-1, 256, 14, 14]         589,824\n         BatchNorm2d-139          [-1, 256, 14, 14]             512\n                ReLU-140          [-1, 256, 14, 14]               0\n              Conv2d-141         [-1, 1024, 14, 14]         262,144\n         BatchNorm2d-142         [-1, 1024, 14, 14]           2,048\n                ReLU-143         [-1, 1024, 14, 14]               0\n          Bottleneck-144         [-1, 1024, 14, 14]               0\n              Conv2d-145          [-1, 512, 14, 14]         524,288\n         BatchNorm2d-146          [-1, 512, 14, 14]           1,024\n                ReLU-147          [-1, 512, 14, 14]               0\n              Conv2d-148            [-1, 512, 7, 7]       2,359,296\n         BatchNorm2d-149            [-1, 512, 7, 7]           1,024\n                ReLU-150            [-1, 512, 7, 7]               0\n              Conv2d-151           [-1, 2048, 7, 7]       1,048,576\n         BatchNorm2d-152           [-1, 2048, 7, 7]           4,096\n              Conv2d-153           [-1, 2048, 7, 7]       2,097,152\n         BatchNorm2d-154           [-1, 2048, 7, 7]           4,096\n                ReLU-155           [-1, 2048, 7, 7]               0\n          Bottleneck-156           [-1, 2048, 7, 7]               0\n              Conv2d-157            [-1, 512, 7, 7]       1,048,576\n         BatchNorm2d-158            [-1, 512, 7, 7]           1,024\n                ReLU-159            [-1, 512, 7, 7]               0\n              Conv2d-160            [-1, 512, 7, 7]       2,359,296\n         BatchNorm2d-161            [-1, 512, 7, 7]           1,024\n                ReLU-162            [-1, 512, 7, 7]               0\n              Conv2d-163           [-1, 2048, 7, 7]       1,048,576\n         BatchNorm2d-164           [-1, 2048, 7, 7]           4,096\n                ReLU-165           [-1, 2048, 7, 7]               0\n          Bottleneck-166           [-1, 2048, 7, 7]               0\n              Conv2d-167            [-1, 512, 7, 7]       1,048,576\n         BatchNorm2d-168            [-1, 512, 7, 7]           1,024\n                ReLU-169            [-1, 512, 7, 7]               0\n              Conv2d-170            [-1, 512, 7, 7]       2,359,296\n         BatchNorm2d-171            [-1, 512, 7, 7]           1,024\n                ReLU-172            [-1, 512, 7, 7]               0\n              Conv2d-173           [-1, 2048, 7, 7]       1,048,576\n         BatchNorm2d-174           [-1, 2048, 7, 7]           4,096\n                ReLU-175           [-1, 2048, 7, 7]               0\n          Bottleneck-176           [-1, 2048, 7, 7]               0\n              Conv2d-177           [-1, 1024, 7, 7]       2,098,176\n                ReLU-178           [-1, 1024, 7, 7]               0\n            Upsample-179         [-1, 1024, 14, 14]               0\n              Conv2d-180          [-1, 512, 14, 14]         524,800\n                ReLU-181          [-1, 512, 14, 14]               0\n              Conv2d-182          [-1, 512, 14, 14]       7,078,400\n                ReLU-183          [-1, 512, 14, 14]               0\n            Upsample-184          [-1, 512, 28, 28]               0\n              Conv2d-185          [-1, 512, 28, 28]         262,656\n                ReLU-186          [-1, 512, 28, 28]               0\n              Conv2d-187          [-1, 512, 28, 28]       4,719,104\n                ReLU-188          [-1, 512, 28, 28]               0\n            Upsample-189          [-1, 512, 56, 56]               0\n              Conv2d-190          [-1, 256, 56, 56]          65,792\n                ReLU-191          [-1, 256, 56, 56]               0\n              Conv2d-192          [-1, 256, 56, 56]       1,769,728\n                ReLU-193          [-1, 256, 56, 56]               0\n            Upsample-194        [-1, 256, 112, 112]               0\n              Conv2d-195         [-1, 64, 112, 112]           4,160\n                ReLU-196         [-1, 64, 112, 112]               0\n              Conv2d-197        [-1, 128, 112, 112]         368,768\n                ReLU-198        [-1, 128, 112, 112]               0\n            Upsample-199        [-1, 128, 224, 224]               0\n              Conv2d-200         [-1, 64, 224, 224]         110,656\n                ReLU-201         [-1, 64, 224, 224]               0\n              Conv2d-202          [-1, 6, 224, 224]             390\n    ================================================================\n    Total params: 40,549,382\n    Trainable params: 40,549,382\n    Non-trainable params: 0\n    ----------------------------------------------------------------\n\n\n# Define the main training loop\n\n```python\nfrom collections import defaultdict\nimport torch.nn.functional as F\nfrom loss import dice_loss\n\ndef calc_loss(pred, target, metrics, bce_weight=0.5):\n    bce = F.binary_cross_entropy_with_logits(pred, target)\n\n    pred = F.sigmoid(pred)\n    dice = dice_loss(pred, target)\n\n    loss = bce * bce_weight + dice * (1 - bce_weight)\n\n    metrics['bce'] += bce.data.cpu().numpy() * target.size(0)\n    metrics['dice'] += dice.data.cpu().numpy() * target.size(0)\n    metrics['loss'] += loss.data.cpu().numpy() * target.size(0)\n\n    return loss\n\ndef print_metrics(metrics, epoch_samples, phase):\n    outputs = []\n    for k in metrics.keys():\n        outputs.append(\"{}: {:4f}\".format(k, metrics[k] / epoch_samples))\n\n    print(\"{}: {}\".format(phase, \", \".join(outputs)))\n\ndef train_model(model, optimizer, scheduler, num_epochs=25):\n    best_model_wts = copy.deepcopy(model.state_dict())\n    best_loss = 1e10\n\n    for epoch in range(num_epochs):\n        print('Epoch {}/{}'.format(epoch, num_epochs - 1))\n        print('-' * 10)\n\n        since = time.time()\n\n        # Each epoch has a training and validation phase\n        for phase in ['train', 'val']:\n            if phase == 'train':\n                scheduler.step()\n                for param_group in optimizer.param_groups:\n                    print(\"LR\", param_group['lr'])\n\n                model.train()  # Set model to training mode\n            else:\n                model.eval()   # Set model to evaluate mode\n\n            metrics = defaultdict(float)\n            epoch_samples = 0\n\n            for inputs, labels in dataloaders[phase]:\n                inputs = inputs.to(device)\n                labels = labels.to(device)\n\n                # zero the parameter gradients\n                optimizer.zero_grad()\n\n                # forward\n                # track history if only in train\n                with torch.set_grad_enabled(phase == 'train'):\n                    outputs = model(inputs)\n                    loss = calc_loss(outputs, labels, metrics)\n\n                    # backward + optimize only if in training phase\n                    if phase == 'train':\n                        loss.backward()\n                        optimizer.step()\n\n                # statistics\n                epoch_samples += inputs.size(0)\n\n            print_metrics(metrics, epoch_samples, phase)\n            epoch_loss = metrics['loss'] / epoch_samples\n\n            # deep copy the model\n            if phase == 'val' and epoch_loss \u003c best_loss:\n                print(\"saving best model\")\n                best_loss = epoch_loss\n                best_model_wts = copy.deepcopy(model.state_dict())\n\n        time_elapsed = time.time() - since\n        print('{:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60))\n\n    print('Best val loss: {:4f}'.format(best_loss))\n\n    # load best model weights\n    model.load_state_dict(best_model_wts)\n    return model\n```\n\n## Training\n```python\nimport torch\nimport torch.optim as optim\nfrom torch.optim import lr_scheduler\nimport time\nimport copy\n\ndevice = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\nprint(device)\n\nnum_class = 6\nmodel = ResNetUNet(num_class).to(device)\n\n# freeze backbone layers\n#for l in model.base_layers:\n#    for param in l.parameters():\n#        param.requires_grad = False\n\noptimizer_ft = optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=1e-4)\n\nexp_lr_scheduler = lr_scheduler.StepLR(optimizer_ft, step_size=30, gamma=0.1)\n\nmodel = train_model(model, optimizer_ft, exp_lr_scheduler, num_epochs=60)\n```\n\n    cuda:0\n    Epoch 0/59\n    ----------\n    LR 0.0001\n    train: bce: 0.070256, dice: 0.856320, loss: 0.463288\n    val: bce: 0.014897, dice: 0.515814, loss: 0.265356\n    saving best model\n    0m 51s\n    Epoch 1/59\n    ----------\n    LR 0.0001\n    train: bce: 0.011369, dice: 0.309445, loss: 0.160407\n    val: bce: 0.003790, dice: 0.113682, loss: 0.058736\n    saving best model\n    0m 51s\n    Epoch 2/59\n    ----------\n    LR 0.0001\n    train: bce: 0.003480, dice: 0.089928, loss: 0.046704\n    val: bce: 0.002525, dice: 0.067604, loss: 0.035064\n    saving best model\n    0m 51s\n\n    (Omitted)\n\n    Epoch 57/59\n    ----------\n    LR 1e-05\n    train: bce: 0.000523, dice: 0.010289, loss: 0.005406\n    val: bce: 0.001558, dice: 0.030965, loss: 0.016261\n    0m 51s\n    Epoch 58/59\n    ----------\n    LR 1e-05\n    train: bce: 0.000518, dice: 0.010209, loss: 0.005364\n    val: bce: 0.001548, dice: 0.031034, loss: 0.016291\n    0m 51s\n    Epoch 59/59\n    ----------\n    LR 1e-05\n    train: bce: 0.000518, dice: 0.010168, loss: 0.005343\n    val: bce: 0.001566, dice: 0.030785, loss: 0.016176\n    0m 50s\n    Best val loss: 0.016171\n\n\n## Use the trained model\n\n```python\nimport math\n\nmodel.eval()   # Set model to the evaluation mode\n\n# Create another simulation dataset for test\ntest_dataset = SimDataset(3, transform = trans)\ntest_loader = DataLoader(test_dataset, batch_size=3, shuffle=False, num_workers=0)\n\n# Get the first batch\ninputs, labels = next(iter(test_loader))\ninputs = inputs.to(device)\nlabels = labels.to(device)\n\n# Predict\npred = model(inputs)\n# The loss functions include the sigmoid function.\npred = F.sigmoid(pred)\npred = pred.data.cpu().numpy()\nprint(pred.shape)\n\n# Change channel-order and make 3 channels for matplot\ninput_images_rgb = [reverse_transform(x) for x in inputs.cpu()]\n\n# Map each channel (i.e. class) to each color\ntarget_masks_rgb = [helper.masks_to_colorimg(x) for x in labels.cpu().numpy()]\npred_rgb = [helper.masks_to_colorimg(x) for x in pred]\n\nhelper.plot_side_by_side([input_images_rgb, target_masks_rgb, pred_rgb])\n```\n\n    (3, 6, 192, 192)\n\n### Left: Input image, Middle: Correct mask (Ground-truth), Rigth: Predicted mask\n\n![png](https://raw.githubusercontent.com/usuyama/pytorch-unet/master/images/output_9_1.png)\n","project_url":"https://awesome.ecosyste.ms/api/v1/projects/github.com%2Fusuyama%2Fpytorch-unet","html_url":"https://awesome.ecosyste.ms/projects/github.com%2Fusuyama%2Fpytorch-unet","lists_url":"https://awesome.ecosyste.ms/api/v1/projects/github.com%2Fusuyama%2Fpytorch-unet/lists"}