π Quick Start#
1. Create segmentation model
Segmentation model is just a PyTorch nn.Module, which can be created as easy as:
import segmentation_models_pytorch as smp
model = smp.Unet(
encoder_name="resnet34", # choose encoder, e.g. mobilenet_v2 or efficientnet-b7
encoder_weights="imagenet", # use `imagenet` pre-trained weights for encoder initialization
in_channels=1, # model input channels (1 for gray-scale images, 3 for RGB, etc.)
classes=3, # model output channels (number of classes in your dataset)
)
Check the page with available model architectures.
Check the table with available ported encoders and its corresponding weights.
Pytorch Image Models (timm) encoders are also supported, check it here.
Alternatively, you can use smp.create_model function to create a model by name:
model = smp.create_model(
arch="fpn", # name of the architecture, e.g. 'Unet'/ 'FPN' / etc. Case INsensitive!
encoder_name="mit_b0",
encoder_weights="imagenet",
in_channels=1,
classes=3,
)
2. Configure data preprocessing
All encoders have pretrained weights. Preparing your data the same way as during weights pre-training may give your better results (higher metric score and faster convergence). But it is relevant only for 1-2-3-channels images and not necessary in case you train the whole model, not only decoder.
from segmentation_models_pytorch.encoders import get_preprocessing_fn
preprocess_input = get_preprocessing_fn('resnet18', pretrained='imagenet')
3. Congratulations! π
You are done! Now you can train your model with your favorite framework, or as simple as:
for images, gt_masks in dataloader:
predicted_mask = model(images)
loss = loss_fn(predicted_mask, gt_masks)
loss.backward()
optimizer.step()
Check the following examples:
Finetuning notebook on Oxford Pet dataset with PyTorch Lightning
Finetuning script for cloth segmentation with PyTorch Lightning