Building U-Net architecture in PyTorch

Key takeaways:

  • U-Net is pivotal in the field of biomedical image segmentation, effectively addressing limited data challenges.

  • The architecture of U-Net combines a contracting path for feature extraction and an expansive path for reconstructing segmentation masks.

  • The contracting path uses repeated convolutions and max-pooling to downsample and increase feature channels, capturing detailed patterns.

  • The expansive path employs transposed convolutions for upsampling and integrates features from the contracting path to enhance accuracy.

  • Key functions include cropping for dimension matching and repeated convolutions for feature refinement.

  • U-Net is used in medical imaging and extends to satellite imagery and industrial quality control, significantly advancing image segmentation methods.

U-Net, introduced by Olaf Ronneberger, Philipp Fischer, and Thomas Brox in 2015, is a pivotal architecture in biomedical image segmentation. Its inception came in response to the need for accurate segmentation of biomedical images, where traditional methods struggled due to limited data availability and small annotated datasets. U-Net revolutionized the field by providing a robust solution for semantic segmentation tasks, particularly in medical image analysis.

This architecture gained prominence due to its unique design, incorporating both contracting and expansive paths. The contracting path, reminiscent of a typical convolutional neural network (CNN), efficiently extracts hierarchical features from input images. The expansive path then reconstructs the segmentation mask, leveraging the learned features to achieve precise pixel-wise classification. U-Net’s architecture facilitates accurate segmentation even with limited training data, making it a cornerstone in various medical imaging applications.

U-Net architecture
U-Net architecture

Implementation of the U-Net architecture

The U-Net architecture consists of a contracting path (left side) and an expansive path (right side). The contracting path employs repeated convolutions and max-pooling operations to extract features and reduce spatial dimensions. Conversely, the expansive path utilizes transposed convolutions for upsampling and reconstructing the segmentation map, incorporating skip connections to preserve spatial information.

Development of key functions

The cropping function is essential for seamlessly integrating feature maps from the contracting path into the expansive path. It ensures that the dimensions match after concatenation, compensating for the loss of border pixels during convolutions.

def cropping(input, target):
size_in = input.shape[2]
size_target = target.shape[2]
margin = (size_in - size_target) // 2
return input[:, :, margin:size_in-margin, margin:size_in-margin]

The repeated convolution function defines a double convolutional layer with ReLU activation. It encapsulates the basic building block used throughout the U-Net architecture.

def repeated_conv(in_c, out_c):
double_conv = nn.Sequential(
nn.Conv2d(in_channels=in_c, out_channels=out_c, kernel_size=3),
nn.ReLU(inplace=True),
nn.Conv2d(in_channels=out_c, out_channels=out_c, kernel_size=3),
nn.ReLU(inplace=True)
)
return double_conv

Encoder

The encoder, also known as the contracting path, plays a crucial role in extracting meaningful features from the input image. Let’s break down how each component works:

  • Repeated convolutional layers: In the encoder section, we define a series of repeated convolutional layers to process the input image. Each convolutional layer applies a 3x3 convolution operation, followed by a Rectified Linear Unit (ReLU) activation function. This combination helps capture important patterns and structures in the input data.

  • Max pooling for downsampling: After each set of convolutional layers, we apply a 2x2 max-pooling operation with a stride of 2. This downsamples the feature maps, reducing their spatial dimensions while retaining important information. Additionally, it helps increase the receptive field, enabling the network to capture a larger context.

  • Increasing feature channels: As we progress deeper into the encoder, the number of feature channels doubles at each downsampling step. This augmentation of feature channels enhances the network’s capacity to capture intricate details and representations from the input image.

import torch
import torch.nn as nn
class U_NET(nn.Module):
def __init__(self):
super(U_NET, self).__init__()
self.DC1 = self.repeated_conv(in_c=1, out_c=64)
self.DC2 = self.repeated_conv(in_c=64, out_c=128)
self.DC3 = self.repeated_conv(in_c=128, out_c=256)
self.DC4 = self.repeated_conv(in_c=256, out_c=512)
self.DC5 = self.repeated_conv(in_c=512, out_c=1024)
self.MP_2x2 = nn.MaxPool2d(kernel_size=2, stride=2)
def repeated_conv(self, in_c, out_c):
double_conv = nn.Sequential(
nn.Conv2d(in_channels=in_c, out_channels=out_c, kernel_size=3),
nn.ReLU(inplace=True),
nn.Conv2d(in_channels=out_c, out_channels=out_c, kernel_size=3),
nn.ReLU(inplace=True)
)
return double_conv
def forward(self, x):
# Encoder / Contraction(Left-Path)
x1 = self.DC1(x)
x2 = self.MP_2x2(x1)
x3 = self.DC2(x2)
x4 = self.MP_2x2(x3)
x5 = self.DC3(x4)
x6 = self.MP_2x2(x5)
x7 = self.DC4(x6)
x8 = self.MP_2x2(x7)
x9 = self.DC5(x8)
return x9
input = torch.rand(1, 1, 572, 572)
model = U_NET()
output = model(input)
print(f'Input shape: {input.shape}')
print(f'Output shape: {output.shape}')

Code explanation

  • Lines 8–12: The U_NET class initializes components such as DC1 to DC5 for repeated convolutional layers in the contracting path and MP_2x2 for max-pooling, defining the U-Net architecture.

  • Lines 16–23: The repeated_conv method defines sequential 3x3 convolutions followed by ReLU activation, stacking convolutional blocks to form the contracting path of the U-Net.

  • Lines 26–37: Inside the forward method, the input tensor undergoes transformations through the contracting path, extracting hierarchical features via convolutional layers and downsampling using max-pooling, resulting in the encoded representation.

  • Lines 40–44: Random input tensor input of shape (1, 1, 572, 572) is generated, and then passed through the U-Net model model to obtain an output tensor output, whose shape is printed along with the shape of the input tensor.

Decoder

The decoder, also known as the expansive path, plays a critical role in reconstructing the segmentation mask from the encoded feature representation. Let’s delve into its functionality:

  • Upsampling with transposed convolution: The decoder starts by performing upsampling using transposed convolutional layers (UTC1, UTC2, UTC3, UTC4). These layers increase the spatial dimensions of the feature maps, allowing the network to reconstruct the segmentation mask with finer details.

  • Cropping and concatenation: After each upsampling step, we perform cropping on the feature map from the contracting path (x7, x5, x3, x1). This cropping ensures that the dimensions match before concatenation. We do this because during convolution operations, border pixels are lost, and cropping helps to align the feature maps properly. After cropping, we concatenate the cropped feature map from the contracting path with the upsampled feature map from the expansive path. This concatenation merges high-level semantic information from the contracting path with detailed spatial information from the expansive path, facilitating accurate reconstruction.

  • Repeated convolutional layers: Following concatenation, the network passes the concatenated feature map through repeated convolutional layers (UC1, UC2, UC3, UC4). These layers refine the merged information, extracting intricate patterns and fine-grained details crucial for accurate segmentation.

  • Output layer: Finally, the refined feature map passes through a 1x1 convolutional layer (out) to map the feature vectors to the desired number of classes, resulting in the final segmentation mask.

import torch
import torch.nn as nn

class U_NET(nn.Module):
    def __init__(self):
        super(U_NET, self).__init__()

        self.DC1 = self.repeated_conv(in_c=1, out_c=64)
        self.DC2 = self.repeated_conv(in_c=64, out_c=128)
        self.DC3 = self.repeated_conv(in_c=128, out_c=256)
        self.DC4 = self.repeated_conv(in_c=256, out_c=512)
        self.DC5 = self.repeated_conv(in_c=512, out_c=1024)
        
        self.MP_2x2 = nn.MaxPool2d(kernel_size=2, stride=2)

        self.UTC1 = nn.ConvTranspose2d(in_channels=1024, out_channels=512, kernel_size=2, stride=2)
        self.UC1 = self.repeated_conv(in_c=1024, out_c=512)
        self.UTC2 = nn.ConvTranspose2d(in_channels=512, out_channels=256, kernel_size=2, stride=2)
        self.UC2 = self.repeated_conv(in_c=512, out_c=256)
        self.UTC3 = nn.ConvTranspose2d(in_channels=256, out_channels=128, kernel_size=2, stride=2)
        self.UC3 = self.repeated_conv(in_c=256, out_c=128)
        self.UTC4 = nn.ConvTranspose2d(in_channels=128, out_channels=64, kernel_size=2, stride=2)
        self.UC4 = self.repeated_conv(in_c=128, out_c=64)

        self.out = nn.Conv2d(in_channels=64, out_channels=2, kernel_size=1)

    def repeated_conv(self, in_c, out_c):
        double_conv = nn.Sequential(
            nn.Conv2d(in_channels=in_c, out_channels=out_c, kernel_size=3),
            nn.ReLU(inplace=True),
            nn.Conv2d(in_channels=out_c, out_channels=out_c, kernel_size=3),
            nn.ReLU(inplace=True)
        )
        return double_conv

    def cropping(self, input, target):
        size_in = input.shape[2]
        size_target = target.shape[2]
        margin = (size_in - size_target) // 2
        return input[:, :, margin:size_in-margin, margin:size_in-margin] # center cropping
        

    def forward(self, x):
        # Encoder / Contraction(Left-Path)
        x1 = self.DC1(x)
        x2 = self.MP_2x2(x1)
        x3 = self.DC2(x2)
        x4 = self.MP_2x2(x3)
        x5 = self.DC3(x4)
        x6 = self.MP_2x2(x5)
        x7 = self.DC4(x6)
        x8 = self.MP_2x2(x7)
        x9 = self.DC5(x8)
        print("Encoded input shape: ",x9.shape)


        # Decoder / Expension(Right-Path)
        x = self.UTC1(x9)
        x_cropped = self.cropping(x7, x)
        x = self.UC1(torch.cat([x_cropped, x],1))

        x = self.UTC2(x)
        x_cropped = self.cropping(x5, x)
        x = self.UC2(torch.cat([x_cropped, x],1))
        
        x = self.UTC3(x)
        x_cropped = self.cropping(x3, x)
        x = self.UC3(torch.cat([x_cropped, x],1))

        x = self.UTC4(x)
        x_cropped = self.cropping(x1, x)
        x = self.UC4(torch.cat([x_cropped, x],1))
        
        x = self.out(x)
        
        return x


input = torch.rand(1, 1, 572, 572)
print(f"Input shape: {input.shape}")
model = U_NET()
output = model(input)
print(f"Final output shape: {output.shape}")
The complete U-Net architecture implementation

Code explanation

  • Lines 5–14: DC1 to DC5 represent sequential convolutional layers for feature extraction in the contracting path.

  • Lines 16–22: Define MP_2x2 as the max-pooling operation for downsampling in the contracting path.

  • Lines 25–36: UTC1 to UTC4 are the transposed convolutional layers responsible for upsampling in the decoder.

  • Lines 39–48: This details UC1 to UC4 as the convolutional layers in the decoder, concatenated with cropped feature maps for refinement.

Closing insights

In conclusion, the U-Net architecture has revolutionized image segmentation, particularly in biomedical imaging, with its innovative design comprising a contracting path for feature extraction and an expansive path for precise localization. Its robustness in accurate segmentation, even with limited data, renders it invaluable in various medical imaging applications, from tumor detection to organ segmentation. Beyond the medical field, U-Net’s versatility extends to satellite image analysis, industrial quality control, and autonomous driving, making it a pivotal tool with vast potential across diverse domains. Its impact continues to shape the landscape of image segmentation, promising further advancements in computer vision and machine learning applications.

Free Resources

Copyright ©2025 Educative, Inc. All rights reserved