| import torch |
| import torch.nn as nn |
| import numpy as np |
| from pathlib import Path |
| import pytorch_lightning as pl |
| import segmentation_models_pytorch as smp |
| from tqdm import tqdm |
|
|
|
|
| class MSSSegmentationModel(pl.LightningModule): |
| """UNet para cloud segmentation en MSS.""" |
| |
| def __init__( |
| self, |
| in_channels: int = 4, |
| num_classes: int = 4, |
| encoder: str = "efficientnet-b3", |
| lr: float = 3e-4, |
| weight_decay: float = 1e-4, |
| ): |
| super().__init__() |
| self.save_hyperparameters() |
| |
| self.model = smp.Unet( |
| encoder_name=encoder, |
| encoder_weights=None, |
| in_channels=in_channels, |
| classes=num_classes, |
| encoder_depth=5, |
| activation=None, |
| decoder_attention_type="scse", |
| ) |
|
|
| def forward(self, x): |
| return self.model(x) |
|
|
|
|
| def get_spline_window(size: int, power: int = 2) -> np.ndarray: |
| """Hann window for smooth blending.""" |
| intersection = np.hanning(size) |
| window_2d = np.outer(intersection, intersection) |
| return (window_2d ** power).astype(np.float32) |
|
|
|
|
| def apply_physical_rules( |
| pred: np.ndarray, |
| image: np.ndarray, |
| merge_clouds: bool = False, |
| ) -> np.ndarray: |
| """Apply physical rules for saturated thick clouds.""" |
| saturation_threshold = 0.4 |
| |
| pred = pred.copy() |
| |
| |
| nodata_mask = np.all(image == 0, axis=0) |
| |
| |
| bright_b0 = image[0] > saturation_threshold |
| bright_b1 = image[1] > saturation_threshold * 0.80 |
| saturated_mask = bright_b0 & bright_b1 |
| |
| |
| if merge_clouds: |
| pred[saturated_mask] = 1 |
| else: |
| pred[saturated_mask] = 2 |
| |
| |
| pred[nodata_mask] = 0 |
| |
| return pred |
|
|
|
|
| def compiled_model( |
| model_dir: Path, |
| stac_item=None, |
| device: str = "cpu", |
| merge_clouds: bool = False, |
| **kwargs |
| ) -> nn.Module: |
| """ |
| Load compiled model for inference. |
| |
| Args: |
| model_dir: Directory containing the .ckpt file |
| stac_item: STAC item metadata (optional) |
| device: 'cpu' or 'cuda' |
| merge_clouds: If True, output 3 classes (clear, cloud, shadow) |
| If False, output 4 classes (clear, thin, thick, shadow) |
| |
| Returns: |
| Loaded model in eval mode |
| """ |
| ckpt_files = list(model_dir.glob("*.ckpt")) |
| if not ckpt_files: |
| raise FileNotFoundError(f"No .ckpt file found in {model_dir}") |
| |
| ckpt_path = ckpt_files[0] |
| |
| model = MSSSegmentationModel.load_from_checkpoint( |
| ckpt_path, |
| map_location=device |
| ) |
| model.eval() |
| model.to(device) |
| |
| for param in model.parameters(): |
| param.requires_grad = False |
| |
| model.merge_clouds = merge_clouds |
| |
| print(f"✅ Model loaded from {ckpt_path.name}") |
| print(f" Device: {device}") |
| print(f" Classes: {'3 (merged)' if merge_clouds else '4 (original)'}") |
| |
| return model |
|
|
| def predict_large( |
| image: np.ndarray, |
| model: nn.Module, |
| chunk_size: int = 512, |
| overlap: int = None, |
| batch_size: int = 1, |
| device: str = "cpu", |
| merge_clouds: bool = False, |
| apply_rules: bool = False, |
| max_direct_size: int = 1024, |
| cloud_max_alpha: float = 0.6, |
| window_power: int = 3, |
| **kwargs |
| ) -> np.ndarray: |
| """ |
| Predict on images of any size. |
| |
| overlap: píxeles de solapamiento entre tiles. Default = chunk_size * 3 // 4. |
| cloud_max_alpha: blend max vs avg para clases nube (0=solo avg, 1=solo max). |
| tta: 4-flip test-time augmentation para reducir artifact de borde. |
| window_power: exponente del Hann window (3 penaliza más los bordes del tile). |
| """ |
| model.eval() |
| model.to(device) |
|
|
| num_classes = model.hparams.get('num_classes', 4) |
| is_3class_model = (num_classes == 3) |
| cloud_class_indices = [1] if is_3class_model else [1, 2] |
|
|
| C, H, W = image.shape |
|
|
| if overlap is None: |
| overlap = chunk_size * 3 // 4 |
|
|
| |
| if max(H, W) <= max_direct_size: |
| with torch.no_grad(): |
| img_tensor = torch.from_numpy(image).unsqueeze(0).float().to(device) |
| probs = torch.softmax(model(img_tensor), dim=1) |
|
|
| if is_3class_model: |
| pred = probs.argmax(1).squeeze().cpu().numpy().astype(np.uint8) |
| elif merge_clouds: |
| probs_merged = torch.zeros(1, 3, H, W, device=device) |
| probs_merged[:, 0] = probs[:, 0] |
| probs_merged[:, 1] = probs[:, 1] + probs[:, 2] |
| probs_merged[:, 2] = probs[:, 3] |
| pred = probs_merged.argmax(1).squeeze().cpu().numpy().astype(np.uint8) |
| else: |
| pred = probs.argmax(1).squeeze().cpu().numpy().astype(np.uint8) |
|
|
| if apply_rules: |
| pred = apply_physical_rules(pred, image, merge_clouds=is_3class_model or merge_clouds) |
| return pred |
|
|
| |
| step = chunk_size - overlap |
|
|
| pad_h = (step - (H - chunk_size) % step) % step |
| pad_w = (step - (W - chunk_size) % step) % step |
| pad_top, pad_bottom = pad_h // 2, pad_h - pad_h // 2 |
| pad_left, pad_right = pad_w // 2, pad_w - pad_w // 2 |
|
|
| image_padded = np.pad( |
| image, |
| ((0, 0), (pad_top, pad_bottom), (pad_left, pad_right)), |
| mode="reflect" |
| ) |
| _, H_pad, W_pad = image_padded.shape |
|
|
| probs_sum = np.zeros((num_classes, H_pad, W_pad), dtype=np.float32) |
| probs_max = np.zeros((num_classes, H_pad, W_pad), dtype=np.float32) |
| weight_sum = np.zeros((H_pad, W_pad), dtype=np.float32) |
|
|
| window = get_spline_window(chunk_size, power=window_power) |
|
|
| coords = [ |
| (r, c) |
| for r in range(0, H_pad - chunk_size + 1, step) |
| for c in range(0, W_pad - chunk_size + 1, step) |
| ] |
|
|
| with torch.no_grad(): |
| for i in range(0, len(coords), batch_size): |
| batch_coords = coords[i:i + batch_size] |
|
|
| tiles = np.stack([ |
| image_padded[:, r:r + chunk_size, c:c + chunk_size] |
| for r, c in batch_coords |
| ]) |
| tiles_tensor = torch.from_numpy(tiles).float().to(device) |
|
|
| probs = torch.softmax(model(tiles_tensor), dim=1).cpu().numpy() |
|
|
| for j, (r, c) in enumerate(batch_coords): |
| sl = np.s_[:, r:r + chunk_size, c:c + chunk_size] |
| probs_sum[sl] += probs[j] * window |
| weight_sum[r:r + chunk_size, c:c + chunk_size] += window |
| probs_max[sl] = np.maximum(probs_max[sl], probs[j]) |
|
|
| weight_sum = np.maximum(weight_sum, 1e-8) |
| probs_avg = probs_sum / weight_sum |
|
|
| probs_final = probs_avg.copy() |
| for ci in cloud_class_indices: |
| probs_final[ci] = ( |
| cloud_max_alpha * probs_max[ci] |
| + (1.0 - cloud_max_alpha) * probs_avg[ci] |
| ) |
|
|
| probs_final = probs_final[:, pad_top:pad_top + H, pad_left:pad_left + W] |
|
|
| if is_3class_model: |
| pred = np.argmax(probs_final, axis=0).astype(np.uint8) |
| elif merge_clouds: |
| probs_merged = np.zeros((3, H, W), dtype=np.float32) |
| probs_merged[0] = probs_final[0] |
| probs_merged[1] = probs_final[1] + probs_final[2] |
| probs_merged[2] = probs_final[3] |
| pred = np.argmax(probs_merged, axis=0).astype(np.uint8) |
| else: |
| pred = np.argmax(probs_final, axis=0).astype(np.uint8) |
|
|
| if apply_rules: |
| pred = apply_physical_rules(pred, image, merge_clouds=is_3class_model or merge_clouds) |
|
|
| return pred |
|
|
|
|
| def example_data(model_dir: Path, **kwargs): |
| """Load example data for testing.""" |
| example_path = model_dir / "example_mss.npy" |
| |
| if not example_path.exists(): |
| print("⚠️ No example data found, generating synthetic") |
| return np.random.rand(4, 512, 512).astype(np.float32) * 0.5 |
| |
| return np.load(example_path) |
|
|
|
|
| def display_results( |
| model_dir: Path, |
| image: np.ndarray, |
| prediction: np.ndarray, |
| stac_item=None, |
| **kwargs |
| ): |
| """Display prediction results.""" |
| try: |
| import matplotlib.pyplot as plt |
| from matplotlib.colors import ListedColormap |
| except ImportError: |
| print("⚠️ matplotlib not installed, skipping visualization") |
| return |
| |
| merge_clouds = prediction.max() <= 2 |
| |
| if merge_clouds: |
| colors = ['#2E7D32', '#FFFFFF', '#424242'] |
| labels = ['Clear', 'Cloud', 'Shadow'] |
| else: |
| colors = ['#2E7D32', '#B3E5FC', '#FFFFFF', '#424242'] |
| labels = ['Clear', 'Thin Cloud', 'Thick Cloud', 'Shadow'] |
| |
| cmap = ListedColormap(colors) |
| |
| fig, axes = plt.subplots(1, 2, figsize=(12, 5)) |
| |
| |
| rgb = np.stack([image[1], image[0], image[2]], axis=-1) |
| rgb = np.clip(rgb * 3, 0, 1) |
| axes[0].imshow(rgb) |
| axes[0].set_title("MSS RGB Composite") |
| axes[0].axis('off') |
| |
| |
| im = axes[1].imshow(prediction, cmap=cmap, vmin=0, vmax=len(labels)-1) |
| axes[1].set_title("Cloud Detection") |
| axes[1].axis('off') |
| |
| |
| cbar = plt.colorbar(im, ax=axes[1], ticks=range(len(labels))) |
| cbar.ax.set_yticklabels(labels) |
| |
| plt.tight_layout() |
| plt.show() |