Segmentación Semántica de Imágenes con Redes U-Net en PyTorch: Un Enfoque Paso a Paso
Este tutorial te guiará a través de la implementación de una red U-Net, una arquitectura popular para la segmentación semántica de imágenes. Aprenderás desde los fundamentos teóricos hasta la construcción del modelo en PyTorch y su entrenamiento con un conjunto de datos real. Descubre cómo delimitar píxel a píxel los objetos en una imagen.
La segmentación semántica es una tarea fundamental en visión artificial que consiste en clasificar cada píxel de una imagen en una categoría predefinida. A diferencia de la detección de objetos que dibuja cajas delimitadoras, la segmentación semántica proporciona una comprensión a nivel de píxel del contenido de la imagen, lo que es crucial en aplicaciones como la conducción autónoma, la medicina y la robótica.
En este tutorial, exploraremos una de las arquitecturas más influyentes para la segmentación semántica: la red U-Net. Desarrollada inicialmente para la segmentación de imágenes biomédicas, la U-Net ha demostrado ser increíblemente efectiva y adaptable a una amplia gama de problemas.
💡 ¿Qué es la Segmentación Semántica? Píxeles que Cobran Vida
Imagina una foto donde quieres identificar exactamente dónde está el coche, dónde está la carretera y dónde está el cielo. La segmentación semántica hace precisamente eso: asigna una etiqueta de clase a cada píxel de la imagen. El resultado es una "máscara de segmentación" donde cada color representa una clase diferente de objeto o región.
🎯 Aplicaciones Clave de la Segmentación Semántica
Las aplicaciones de esta tecnología son vastas y crecen rápidamente:
- Medicina: Detección de tumores, segmentación de órganos en imágenes de MRI o CT. 🩺
- Conducción Autónoma: Entender el entorno para identificar carreteras, peatones, señales de tráfico y otros vehículos. 🚗
- Robótica: Permitir a los robots interactuar con el entorno reconociendo objetos y superficies. 🤖
- Visión Remota y Agricultura: Monitorización de cultivos, detección de enfermedades en plantas, mapeo de uso del suelo. 🌾
- Edición de Imágenes: Eliminación de fondo, edición selectiva de regiones. 🖼️
🛠️ Entendiendo la Arquitectura U-Net: Codificador-Decodificador
La U-Net es una red neuronal convolucional (CNN) que se destaca por su arquitectura en forma de 'U', que combina un camino de contracción (codificador) y un camino de expansión (decodificador) con conexiones de salto (skip connections) entre ellos.
El Camino de Contracción (Codificador) 📉
Este camino es similar a una red neuronal convolucional típica (como VGG). Consiste en bloques repetidos de operaciones:
- Doble convolución: Dos capas convolucionales de 3x3 (seguidas de ReLU y Batch Normalization). Su propósito es extraer características del input.
- Pooling máximo: Una capa de pooling máximo de 2x2 con paso de 2. Esto reduce la resolución espacial de la imagen (downsampling) y aumenta el número de canales (características), capturando información contextual.
Cada paso de pooling reduce el tamaño espacial de la imagen a la mitad, mientras que duplica la cantidad de mapas de características. Esto permite a la red aprender características más abstractas y de mayor nivel.
El Camino de Expansión (Decodificador) 📈
Este camino toma las características de baja resolución del codificador y las reconstruye a una resolución más alta para generar la máscara de segmentación. Consiste en:
- Up-sampling: Una capa de convolución transpuesta (o deconvolución) de 2x2 o una interpolación bilineal seguida de una convolución normal. Esto duplica la resolución espacial de los mapas de características.
- Concatenación con Skip Connections: Aquí es donde la U-Net brilla. Los mapas de características del up-sampling se concatenan con los mapas de características correspondientes y de la misma resolución del camino de contracción. Estas "conexiones de salto" proporcionan al decodificador información de alta resolución que se perdió durante el downsampling, lo que es crucial para reconstruir detalles finos en la segmentación.
- Doble convolución: Al igual que en el codificador, dos capas convolucionales de 3x3 (ReLU, Batch Normalization) procesan las características combinadas.
Capa Final
Después del último bloque del decodificador, una capa convolucional de 1x1 reduce el número de canales al número de clases que queremos segmentar. Luego, se aplica una función de activación (como softmax si hay múltiples clases o sigmoid para dos clases) para obtener las probabilidades de clase para cada píxel.
⚙️ Preparando el Entorno: PyTorch y Dependencias
Para este tutorial, usaremos PyTorch, una librería de aprendizaje profundo muy flexible y potente. Necesitarás Python 3.7+ y los siguientes paquetes. Puedes instalarlos fácilmente con pip:
pip install torch torchvision opencv-python numpy matplotlib scikit-image
Verifica tu instalación de PyTorch:
import torch
print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
Si CUDA available devuelve True, ¡enhorabuena! Podrás aprovechar la GPU para un entrenamiento más rápido. Si es False, el entrenamiento se realizará en la CPU, lo cual es más lento pero funcional.
📖 El Conjunto de Datos: Preparando tus Imágenes
Para la segmentación semántica, necesitamos un dataset que contenga no solo imágenes de entrada, sino también sus correspondientes máscaras de segmentación (también conocidas como ground truth). Cada píxel en la máscara de segmentación debe tener un valor numérico que represente su clase.
Para este tutorial, simularemos un conjunto de datos simple o utilizaremos un subconjunto de un dataset público como Pascal VOC o Cityscapes. Para simplificar, asumiremos que tenemos un directorio data/images con las imágenes de entrada y un directorio data/masks con las máscaras correspondientes. Las máscaras serán imágenes en escala de grises donde cada valor de píxel representa una clase.
Por ejemplo, para 2 clases (fondo y objeto): 0 para fondo, 1 para objeto.
import os
from PIL import Image
import numpy as np
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
class CustomSegmentationDataset(Dataset):
def __init__(self, img_dir, mask_dir, transform=None):
self.img_dir = img_dir
self.mask_dir = mask_dir
self.transform = transform
self.image_filenames = [f for f in os.listdir(img_dir) if f.endswith('.png') or f.endswith('.jpg')]
def __len__(self):
return len(self.image_filenames)
def __getitem__(self, idx):
img_name = self.image_filenames[idx]
img_path = os.path.join(self.img_dir, img_name)
mask_path = os.path.join(self.mask_dir, img_name.replace('.jpg', '.png')) # Asume máscaras PNG con mismo nombre
image = Image.open(img_path).convert("RGB")
mask = Image.open(mask_path).convert("L") # "L" para escala de grises de 0 a 255
# Convertir máscara a numpy array para asegurar valores discretos de clase
mask = np.array(mask)
mask[mask > 0] = 1 # Suponiendo que 0 es fondo y >0 es objeto (para 2 clases)
mask = Image.fromarray(mask.astype(np.uint8))
if self.transform:
image = self.transform(image)
mask = self.transform(mask) # Las transformaciones deben aplicarse a ambas
# La máscara debe ser un tensor de tipo Long para la función de pérdida (CrossEntropyLoss)
return image, mask.squeeze().long()
# Definir transformaciones
IMG_HEIGHT = 256
IMG_WIDTH = 256
transform = transforms.Compose([
transforms.Resize((IMG_HEIGHT, IMG_WIDTH)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) # Normalización ImageNet
])
mask_transform = transforms.Compose([
transforms.Resize((IMG_HEIGHT, IMG_WIDTH)),
transforms.ToTensor() # No normalizamos las máscaras
])
# Simulación de datos (crea directorios y archivos dummy si no los tienes)
# if not os.path.exists('data/images'): os.makedirs('data/images')
# if not os.path.exists('data/masks'): os.makedirs('data/masks')
# # Crea algunos archivos dummy para probar la carga
# for i in range(5):
# Image.new('RGB', (512, 512), color = (i*50, i*100, i*150)).save(f'data/images/img_{i}.png')
# Image.new('L', (512, 512), color = i % 2 * 255).save(f'data/masks/img_{i}.png') # Máscaras con 0 o 255
# Ejemplo de uso del DataLoader
dataset = CustomSegmentationDataset(
img_dir='data/images',
mask_dir='data/masks',
transform=transform # Aplicamos el mismo transform para imagen y máscara
)
# Ajustamos el dataset para aplicar el transform de máscara de forma separada si es necesario
# o aseguramos que el transform de imagen no afecte las etiquetas
# Una mejor práctica sería tener transforms separados para la imagen y la máscara
class SegmentationTransforms:
def __init__(self, img_size):
self.img_transform = transforms.Compose([
transforms.Resize(img_size),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
self.mask_transform = transforms.Compose([
transforms.Resize(img_size, interpolation=transforms.InterpolationMode.NEAREST),
transforms.ToTensor()
])
def __call__(self, image, mask):
return self.img_transform(image), self.mask_transform(mask)
# Actualizar la clase del dataset para usar el transform combinado
class CustomSegmentationDatasetV2(Dataset):
def __init__(self, img_dir, mask_dir, transforms=None):
self.img_dir = img_dir
self.mask_dir = mask_dir
self.transforms = transforms
self.image_filenames = [f for f in os.listdir(img_dir) if f.endswith('.png') or f.endswith('.jpg')]
def __len__(self):
return len(self.image_filenames)
def __getitem__(self, idx):
img_name = self.image_filenames[idx]
img_path = os.path.join(self.img_dir, img_name)
mask_path = os.path.join(self.mask_dir, img_name.replace('.jpg', '.png'))
image = Image.open(img_path).convert("RGB")
mask = Image.open(mask_path).convert("L")
# Convertir máscara a numpy array para asegurar valores discretos de clase
mask = np.array(mask)
# Asegúrate de que los valores de píxel de la máscara corresponden a tus IDs de clase
# Por ejemplo, si tienes 2 clases: 0 para fondo, 1 para objeto. Ajusta esto según tu dataset.
mask[mask > 0] = 1 # O mapea 255 a 1 si tus máscaras son binarias 0/255
mask = Image.fromarray(mask.astype(np.uint8))
if self.transforms:
image, mask = self.transforms(image, mask)
return image, mask.squeeze().long()
# Crear instancia de transforms
segmentation_transforms = SegmentationTransforms(img_size=(IMG_HEIGHT, IMG_WIDTH))
dataset = CustomSegmentationDatasetV2(
img_dir='data/images',
mask_dir='data/masks',
transforms=segmentation_transforms
)
dataloader = DataLoader(dataset, batch_size=4, shuffle=True)
# Obtener un batch para verificar
images, masks = next(iter(dataloader))
print(f"Shape de imágenes: {images.shape}") # Esperado: [batch_size, channels, H, W]
print(f"Shape de máscaras: {masks.shape}") # Esperado: [batch_size, H, W] (Long tensor)
✨ Implementando la U-Net en PyTorch
Ahora, implementaremos la arquitectura U-Net usando torch.nn.
Bloques Convolucionales Básicos
Primero, definimos un bloque de doble convolución, que será reutilizado tanto en el codificador como en el decodificador.
import torch.nn as nn
import torch.nn.functional as F
class DoubleConv(nn.Module):
"""Bloque de doble convolución con Batch Normalization y ReLU"""
def __init__(self, in_channels, out_channels):
super().__init__()
self.double_conv = nn.Sequential(
nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1, bias=False),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True),
nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1, bias=False),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True)
)
def forward(self, x):
return self.double_conv(x)
El Camino de Contracción (Downsampling)
Cada paso del codificador incluirá un MaxPool2d seguido de un DoubleConv.
class Down(nn.Module):
"""Downsampling con MaxPool2d y luego DoubleConv"""
def __init__(self, in_channels, out_channels):
super().__init__()
self.maxpool_conv = nn.Sequential(
nn.MaxPool2d(2),
DoubleConv(in_channels, out_channels)
)
def forward(self, x):
return self.maxpool_conv(x)
El Camino de Expansión (Upsampling)
Cada paso del decodificador incluirá una convolución transpuesta para upsampling, seguida de la concatenación con las características del codificador y un DoubleConv.
class Up(nn.Module):
"""Upsampling con ConvTranspose2d, concatenación y luego DoubleConv"""
def __init__(self, in_channels, out_channels):
super().__init__()
self.up = nn.ConvTranspose2d(in_channels, in_channels // 2, kernel_size=2, stride=2)
self.conv = DoubleConv(in_channels, out_channels)
def forward(self, x1, x2):
x1 = self.up(x1)
# Asegurar que los tamaños son compatibles para la concatenación
# Esto es importante si el upsampling no coincide exactamente debido a padding/stride
diffY = x2.size()[2] - x1.size()[2]
diffX = x2.size()[3] - x1.size()[3]
x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2,
diffY // 2, diffY - diffY // 2])
x = torch.cat([x2, x1], dim=1) # Concatenar a lo largo de la dimensión de canales
return self.conv(x)
La Arquitectura U-Net Completa
Finalmente, ensamblamos todos los bloques para formar la U-Net.
class UNet(nn.Module):
def __init__(self, n_channels, n_classes, bilinear=False):
super(UNet, self).__init__()
self.n_channels = n_channels
self.n_classes = n_classes
self.bilinear = bilinear
self.inc = DoubleConv(n_channels, 64)
self.down1 = Down(64, 128)
self.down2 = Down(128, 256)
self.down3 = Down(256, 512)
self.down4 = Down(512, 1024)
self.up1 = Up(1024, 512)
self.up2 = Up(512, 256)
self.up3 = Up(256, 128)
self.up4 = Up(128, 64)
self.outc = nn.Conv2d(64, n_classes, kernel_size=1)
def forward(self, x):
x1 = self.inc(x)
x2 = self.down1(x1)
x3 = self.down2(x2)
x4 = self.down3(x3)
x5 = self.down4(x4)
x = self.up1(x5, x4)
x = self.up2(x, x3)
x = self.up3(x, x2)
x = self.up4(x, x1)
logits = self.outc(x)
return logits
🏋️ Entrenamiento del Modelo
Ahora que tenemos el dataset y el modelo, es hora de entrenar la U-Net. Definiremos la función de pérdida, el optimizador y el bucle de entrenamiento.
import torch.optim as optim
# Hyperparámetros
num_epochs = 10
batch_size = 4
learning_rate = 0.001
num_classes = 2 # Fondo + Objeto
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Usando dispositivo: {device}")
model = UNet(n_channels=3, n_classes=num_classes).to(device)
optimizer = optim.Adam(model.parameters(), lr=learning_rate)
# CrossEntropyLoss es ideal para problemas de segmentación semántica (clasificación por píxel)
# Las etiquetas (máscaras) deben ser de tipo Long y no tener dimensión de canal (N, H, W)
criterion = nn.CrossEntropyLoss()
# Bucle de entrenamiento
for epoch in range(num_epochs):
model.train() # Poner el modelo en modo entrenamiento
running_loss = 0.0
for batch_idx, (images, masks) in enumerate(dataloader):
images = images.to(device)
masks = masks.to(device)
# Forward pass
outputs = model(images)
loss = criterion(outputs, masks)
# Backward and optimize
optimizer.zero_grad()
loss.backward()
optimizer.step()
running_loss += loss.item()
if (batch_idx + 1) % 10 == 0:
print(f"Epoch [{epoch+1}/{num_epochs}], Batch [{batch_idx+1}/{len(dataloader)}], Loss: {running_loss/10:.4f}")
running_loss = 0.0
print(f"Epoch {epoch+1} completada. Pérdida promedio: {running_loss/len(dataloader):.4f}")
print("Entrenamiento completado.")
# Opcional: Guardar el modelo entrenado
torch.save(model.state_dict(), 'unet_segmentation_model.pth')
✅ Evaluación y Predicción
Después de entrenar el modelo, querrás evaluar su rendimiento y usarlo para hacer predicciones en nuevas imágenes.
Métricas de Evaluación
Para la segmentación, la métrica más común es el IoU (Intersection over Union) o Coeficiente de Jaccard, y la precisión por píxel.
# Función para calcular IoU
def iou_score(outputs, masks, num_classes):
iou = []
outputs = torch.argmax(outputs, dim=1) # Obtener la clase predicha para cada píxel
for cls in range(num_classes):
pred_mask = (outputs == cls).float()
true_mask = (masks == cls).float()
intersection = (pred_mask * true_mask).sum()
union = (pred_mask + true_mask).sum() - intersection
if union == 0: # Evitar división por cero si no hay píxeles de esta clase
iou.append(torch.tensor(1.0)) if intersection == 0 else iou.append(torch.tensor(0.0))
else:
iou.append(intersection / union)
return torch.mean(torch.stack(iou))
# Evaluación simple en el dataset de entrenamiento (idealmente usar un conjunto de validación/prueba)
model.eval() # Poner el modelo en modo evaluación
total_iou = 0
with torch.no_grad():
for images, masks in dataloader:
images = images.to(device)
masks = masks.to(device)
outputs = model(images)
total_iou += iou_score(outputs, masks, num_classes)
mean_iou = total_iou / len(dataloader)
print(f"IoU promedio en el dataset: {mean_iou:.4f}")
Haciendo Predicciones en Nuevas Imágenes
Para predecir en una imagen individual, necesitarás aplicar las mismas transformaciones que usaste durante el entrenamiento.
import matplotlib.pyplot as plt
from torchvision.utils import make_grid
def visualize_segmentation(image_tensor, true_mask_tensor, pred_mask_tensor, num_classes):
fig, ax = plt.subplots(1, 3, figsize=(15, 5))
# Convertir tensores a numpy para visualización
# Des-normalizar la imagen si se aplicó normalización
inv_normalize = transforms.Normalize(mean=[-0.485/0.229, -0.456/0.224, -0.406/0.225],
std=[1/0.229, 1/0.224, 1/0.225])
image_display = inv_normalize(image_tensor).permute(1, 2, 0).cpu().numpy()
# Asegurar que los valores estén en el rango [0, 1]
image_display = np.clip(image_display, 0, 1)
# Las máscaras ya están en el rango de clases (0, 1, ...)
true_mask_display = true_mask_tensor.cpu().numpy()
pred_mask_display = pred_mask_tensor.cpu().numpy()
# Mostrar imagen original
ax[0].imshow(image_display)
ax[0].set_title('Imagen Original')
ax[0].axis('off')
# Mostrar máscara real
ax[1].imshow(true_mask_display, cmap='viridis', vmin=0, vmax=num_classes-1) # Usa un colormap para clases
ax[1].set_title('Máscara Real')
ax[1].axis('off')
# Mostrar máscara predicha
ax[2].imshow(pred_mask_display, cmap='viridis', vmin=0, vmax=num_classes-1)
ax[2].set_title('Máscara Predicha')
ax[2].axis('off')
plt.show()
# Cargar el modelo entrenado (si guardaste el estado)
# model.load_state_dict(torch.load('unet_segmentation_model.pth'))
# model.to(device)
model.eval()
# Tomar un ejemplo del dataloader para predecir
images_batch, masks_batch = next(iter(dataloader))
with torch.no_grad():
images_batch = images_batch.to(device)
outputs_batch = model(images_batch)
predictions_batch = torch.argmax(outputs_batch, dim=1)
# Visualizar el primer ejemplo del batch
visualize_segmentation(
images_batch[0].cpu(),
masks_batch[0].cpu(),
predictions_batch[0].cpu(),
num_classes
)
# Si quieres predecir en una nueva imagen desde un archivo:
# def predict_single_image(image_path, model, transforms, device, num_classes):
# image = Image.open(image_path).convert("RGB")
# input_tensor, _ = transforms(image, Image.new('L', image.size)) # La máscara dummy no se usa
# input_batch = input_tensor.unsqueeze(0).to(device) # Añadir dimensión de batch
#
# with torch.no_grad():
# output = model(input_batch)
# prediction = torch.argmax(output, dim=1).squeeze(0) # Eliminar dimensión de batch
#
# return prediction.cpu()
# predict_single_image('path/to/your/new_image.jpg', model, segmentation_transforms, device, num_classes)
🚀 Optimizaciones y Consideraciones Avanzadas
Una vez que tengas una U-Net funcional, hay varias áreas donde puedes optimizar y mejorar el rendimiento:
- Aumento de Datos (Data Augmentation): Aplicar transformaciones como rotaciones, volteos, zoom, cambios de brillo/contraste a tus imágenes de entrenamiento ayuda a que el modelo sea más robusto y generalice mejor. Es crucial aplicar las mismas transformaciones tanto a la imagen como a su máscara correspondiente.
- Funciones de Pérdida Personalizadas: Además de
CrossEntropyLoss, puedes explorar otras funciones de pérdida específicas para segmentación, comoDice Loss,Focal Losso combinaciones de ellas. Estas pueden ser especialmente útiles en datasets con un gran desequilibrio de clases (por ejemplo, cuando los objetos de interés ocupan muy pocos píxeles). - Métricas Avanzadas: Para una evaluación más exhaustiva, considera otras métricas como la precisión media por clase, recall, F1-score, o curvas PR (Precision-Recall) para cada clase.
- Optimización del Optimizador: Experimenta con diferentes optimizadores (
SGDcon momentum,RMSprop) y programas de tasa de aprendizaje (learning rate schedulers) que ajustan la tasa de aprendizaje durante el entrenamiento. - Variantes de U-Net: Existen muchas variantes de la arquitectura U-Net, como
Attention U-Net,R2U-Net,3D U-Net(para volúmenes de datos),V-Net, etc. Cada una introduce modificaciones para abordar desafíos específicos. - Pre-entrenamiento: Utilizar un codificador pre-entrenado en un gran dataset de clasificación (como ImageNet) puede acelerar la convergencia y mejorar el rendimiento, especialmente con datasets pequeños. Sin embargo, en PyTorch, el codificador de la U-Net suele construirse desde cero, a diferencia de otras arquitecturas de segmentación que adaptan modelos pre-entrenados para la parte de codificación (ej.
DeepLabV3). - Entrenamiento Distribuido: Para datasets muy grandes y modelos complejos, considera el entrenamiento distribuido en múltiples GPUs o máquinas.
¿Por qué U-Net y no otras arquitecturas como FCN o DeepLab?
Las Fully Convolutional Networks (FCNs) fueron pioneras en la segmentación end-to-end. U-Net se basa en FCNs pero mejora la captura de detalles finos gracias a sus conexiones de salto que fusionan características de alta resolución del codificador con características de baja resolución del decodificador. DeepLabv3+ es otra arquitectura potente que utiliza convoluciones atrous y un decodificador mejorado, logrando resultados de vanguardia, especialmente en datasets complejos como Cityscapes. La elección depende del problema, los recursos y la necesidad de precisión versus complejidad computacional. U-Net es a menudo un excelente punto de partida por su simplicidad relativa y alto rendimiento.Conclusión ✨
Has llegado al final de este tutorial sobre la segmentación semántica con U-Net en PyTorch. Hemos cubierto los fundamentos teóricos, la implementación del modelo, la preparación del dataset y el bucle de entrenamiento básico. La U-Net sigue siendo una herramienta increíblemente poderosa y versátil en el campo de la visión artificial, especialmente cuando la precisión a nivel de píxel es crucial.
Esperamos que este tutorial te sirva como una base sólida para explorar y aplicar la segmentación semántica en tus propios proyectos. ¡El mundo de la visión artificial es vasto y emocionante, y la U-Net es solo una de las muchas maravillas que te esperan! ¡Feliz segmentación! 🚀
Tutoriales relacionados
Comentarios (0)
Aún no hay comentarios. ¡Sé el primero!