Image Segmentation
Transformers
Safetensors
actu
feature-extraction
climate
geospatial
remote-sensing
spatiotemporal
multi-modal
earth-observation
time-series
hydrology
custom_code
Instructions to use DarthReca/actu-direction-classification with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Transformers
How to use DarthReca/actu-direction-classification with Transformers:
# Use a pipeline as a high-level helper from transformers import pipeline pipe = pipeline("image-segmentation", model="DarthReca/actu-direction-classification", trust_remote_code=True)# Load model directly from transformers import AutoModel model = AutoModel.from_pretrained("DarthReca/actu-direction-classification", trust_remote_code=True, dtype="auto") - Notebooks
- Google Colab
- Kaggle
| from dataclasses import dataclass | |
| import numpy as np | |
| import timm | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from einops import rearrange, repeat | |
| from segmentation_models_pytorch.base import SegmentationHead | |
| from segmentation_models_pytorch.decoders.unet.decoder import UnetDecoder | |
| from timm.layers.create_act import create_act_layer | |
| from transformers import PretrainedConfig, PreTrainedModel | |
| from transformers.modeling_outputs import SemanticSegmenterOutput | |
| from .convlstm import ConvLSTM | |
| class ACTUConfig(PretrainedConfig): | |
| model_type = "actu" | |
| def __init__( | |
| self, | |
| # Base ACTU parameters | |
| in_channels: int = 3, | |
| kernel_size: tuple[int, int] = (3, 3), | |
| padding="same", | |
| stride=(1, 1), | |
| backbone="resnet34", | |
| bias=True, | |
| batch_first=True, | |
| bidirectional=False, | |
| original_resolution=(256, 256), | |
| act_layer="sigmoid", | |
| n_classes=1, | |
| # Variant control parameters | |
| use_dem_input: bool = False, | |
| use_climate_branch: bool = False, | |
| # Climate branch parameters | |
| climate_seq_len=5, | |
| climate_input_dim=6, | |
| lstm_hidden_dim=128, | |
| num_lstm_layers=1, | |
| **kwargs, | |
| ): | |
| super().__init__(**kwargs) | |
| self.in_channels = in_channels | |
| self.kernel_size = kernel_size | |
| self.padding = padding | |
| self.stride = stride | |
| self.backbone = backbone | |
| self.bias = bias | |
| self.batch_first = batch_first | |
| self.bidirectional = bidirectional | |
| self.original_resolution = original_resolution | |
| self.act_layer = act_layer | |
| self.n_classes = n_classes | |
| # Parameters to control variants | |
| self.use_dem_input = use_dem_input | |
| self.use_climate_branch = use_climate_branch | |
| self.climate_seq_len = climate_seq_len | |
| self.climate_input_dim = climate_input_dim | |
| self.lstm_hidden_dim = lstm_hidden_dim | |
| self.num_lstm_layers = num_lstm_layers | |
| # Adjust in_channels if DEM is used | |
| if self.use_dem_input: | |
| self.in_channels += 1 | |
| class ACTUForImageSegmentation(PreTrainedModel): | |
| config_class = ACTUConfig | |
| def __init__(self, config: ACTUConfig): | |
| super().__init__(config) | |
| self.config = config | |
| self.encoder: nn.Module = timm.create_model( | |
| config.backbone, features_only=True, in_chans=config.in_channels | |
| ) | |
| with torch.no_grad(): | |
| dummy_input_channels = config.in_channels | |
| dummy_input = torch.randn( | |
| 1, dummy_input_channels, *config.original_resolution, device=self.device | |
| ) | |
| embs = self.encoder(dummy_input) | |
| self.embs_shape = [e.shape for e in embs] | |
| self.encoder_channels = [e[1] for e in self.embs_shape] | |
| self.convlstm = nn.ModuleList( | |
| [ | |
| ConvLSTM( | |
| in_channels=shape[1], | |
| hidden_channels=shape[1], | |
| kernel_size=config.kernel_size, | |
| padding=config.padding, | |
| stride=config.stride, | |
| bias=config.bias, | |
| batch_first=config.batch_first, | |
| bidirectional=config.bidirectional, | |
| ) | |
| for shape in self.embs_shape | |
| ] | |
| ) | |
| if self.config.use_climate_branch: | |
| self.climate_branch = ClimateBranchLSTM( | |
| output_shapes=[e[1:] for e in self.embs_shape], | |
| lstm_hidden_dim=config.lstm_hidden_dim, | |
| climate_seq_len=config.climate_seq_len, | |
| climate_input_dim=config.climate_input_dim, | |
| num_lstm_layers=config.num_lstm_layers, | |
| ) | |
| self.fusers = nn.ModuleList( | |
| GatedFusion(enc, enc) for enc in self.encoder_channels | |
| ) | |
| self.decoder = UnetDecoder( | |
| encoder_channels=[1] + self.encoder_channels, | |
| decoder_channels=self.encoder_channels[::-1], | |
| n_blocks=len(self.encoder_channels), | |
| ) | |
| self.seg_head = nn.Sequential( | |
| SegmentationHead( | |
| in_channels=self.encoder_channels[0], | |
| out_channels=config.n_classes, | |
| ), | |
| create_act_layer(config.act_layer, inplace=True), | |
| ) | |
| def forward( | |
| self, | |
| pixel_values: torch.Tensor, | |
| climate: torch.Tensor = None, | |
| dem: torch.Tensor = None, | |
| labels: torch.Tensor = None, | |
| **kwargs, | |
| ) -> SemanticSegmenterOutput: | |
| b, t = pixel_values.shape[:2] | |
| original_size = pixel_values.shape[-2:] | |
| # Handle DEM input | |
| if self.config.use_dem_input: | |
| if dem is None: | |
| raise ValueError( | |
| "DEM tensor must be provided when use_dem_input is True." | |
| ) | |
| dem_repeated = repeat(dem, "b c h w -> b t c h w", t=t) | |
| pixel_values = torch.cat([pixel_values, dem_repeated], dim=2) | |
| # 1. Encode images per time step | |
| encoded_sequence = self._encode_images(pixel_values) | |
| # 2. Handle Climate Branch Fusion | |
| if self.config.use_climate_branch: | |
| if climate is None: | |
| raise ValueError( | |
| "Climate tensor must be provided when use_climate_branch is True." | |
| ) | |
| climate_features = self.climate_branch(climate) | |
| # Reshape for fusion | |
| encoded_sequence_reshaped = [ | |
| rearrange(f, "b t c h w -> (b t) c h w") for f in encoded_sequence | |
| ] | |
| climate_features_reshaped = [ | |
| rearrange(f, "b t c h w -> (b t) c h w") for f in climate_features | |
| ] | |
| # Fuse features | |
| fused_features = [ | |
| fuser(img, clim) | |
| for fuser, img, clim in zip( | |
| self.fusers, encoded_sequence_reshaped, climate_features_reshaped | |
| ) | |
| ] | |
| # Reshape back to sequence | |
| encoded_sequence = [ | |
| rearrange(f, "(b t) c h w -> b t c h w", b=b) for f in fused_features | |
| ] | |
| # 3. Process sequence with ConvLSTM | |
| temporal_features = self._encode_timeseries(encoded_sequence) | |
| # 4. Decode to get the segmentation map | |
| logits = self._decode(temporal_features, size=original_size) | |
| loss = None | |
| if labels is not None: | |
| loss_fct = nn.CrossEntropyLoss() | |
| loss = loss_fct(logits, labels.float().unsqueeze(1)) | |
| return SemanticSegmenterOutput( | |
| loss=loss, | |
| logits=logits, | |
| ) | |
| def _encode_images(self, x: torch.Tensor) -> list[torch.Tensor]: | |
| B = x.size(0) | |
| encoded_frames = self.encoder(rearrange(x, "b t c h w -> (b t) c h w")) | |
| return [ | |
| rearrange(frames, "(b t) c h w -> b t c h w", b=B) | |
| for frames in encoded_frames | |
| ] | |
| def _encode_timeseries(self, timeseries: torch.Tensor) -> list[torch.Tensor]: | |
| outs = [] | |
| for convlstm, encoded in reversed(list(zip(self.convlstm, timeseries))): | |
| lstm_out, (_, _) = convlstm(encoded) | |
| outs.append(lstm_out[:, -1, :, :, :]) | |
| return outs | |
| def _decode(self, x: torch.Tensor, size: tuple[int, int]) -> torch.Tensor: | |
| trend_map = self.decoder(*[None] + x[::-1]) | |
| trend_map = self.seg_head(trend_map) | |
| trend_map = F.interpolate( | |
| trend_map, size=size, mode="bilinear", align_corners=False | |
| ) | |
| return trend_map | |
| class ClimateBranchLSTM(nn.Module): | |
| """ | |
| Processes climate time series data using an LSTM. | |
| Input shape: (B, T, T_1, C_clim) -> e.g., (B, 5, 6, 5) | |
| Output shape: (B, T, output_dim) -> e.g., (B, 5, 128) | |
| """ | |
| def __init__( | |
| self, | |
| output_shapes: list[tuple[int, int, int]], | |
| climate_input_dim=5, | |
| climate_seq_len=6, | |
| lstm_hidden_dim=64, | |
| num_lstm_layers=1, | |
| ): | |
| super().__init__() | |
| self.climate_seq_len = climate_seq_len | |
| self.climate_input_dim = climate_input_dim | |
| self.lstm_hidden_dim = lstm_hidden_dim | |
| self.num_lstm_layers = num_lstm_layers | |
| self.proj_dim = 128 | |
| self.output_shapes = output_shapes | |
| self.lstm = nn.LSTM( | |
| input_size=climate_input_dim, | |
| hidden_size=lstm_hidden_dim, | |
| num_layers=num_lstm_layers, | |
| batch_first=True, # Crucial: expects input shape (batch, seq_len, features) | |
| dropout=0.3 if num_lstm_layers > 1 else 0, | |
| bidirectional=False, | |
| ) | |
| # Linear layer to project LSTM output to the desired final dimension | |
| self.fc = nn.Linear(lstm_hidden_dim, self.proj_dim) | |
| self.upsamples = nn.ModuleList( | |
| _build_upsampler(self.proj_dim, *shape[:2]) for shape in output_shapes | |
| ) | |
| def forward(self, climate_data: torch.Tensor) -> list[torch.Tensor]: | |
| # climate_data shape: (B, T, T_1, C_clim), e.g., (B, 5, 6, 5) | |
| B_img, B_cli, T, C = climate_data.shape | |
| # Reshape for LSTM: Treat each sequence independently | |
| lstm_input = rearrange(climate_data, "Bi Bc T C -> (Bi Bc) T C") | |
| # Pass through LSTM | |
| _, (hidden, _) = self.lstm.forward(lstm_input) | |
| # Get the last layer's hidden state | |
| last_hidden = ( | |
| hidden[[hidden.size(0) // 2, -1]] if self.lstm.bidirectional else hidden[-1] | |
| ) | |
| if last_hidden.ndim == 3: | |
| last_hidden = hidden.mean(dim=0) | |
| # Pass the final hidden state through the fully connected layer(s) and upsample | |
| climate_features = self.fc(last_hidden) | |
| climate_features = rearrange(climate_features, "b c -> b c 1 1") | |
| climate_features = [ | |
| rearrange( | |
| u(climate_features), "(Bi Bc) C H W -> Bi Bc C H W", Bi=B_img, Bc=B_cli | |
| ) | |
| for u in self.upsamples | |
| ] | |
| return climate_features | |
| class GatedFusion(nn.Module): | |
| def __init__(self, img_channels, clim_channels): | |
| super().__init__() | |
| self.gate = nn.Sequential( | |
| nn.Sequential( | |
| nn.Conv2d( | |
| img_channels + clim_channels, img_channels, kernel_size=3, padding=1 | |
| ), | |
| nn.ReLU(inplace=True), | |
| nn.Conv2d(img_channels, img_channels, kernel_size=1), | |
| nn.Sigmoid(), # Gate values between 0 and 1 | |
| ) | |
| ) | |
| def forward(self, img_feat, clim_feat): | |
| gate = self.gate(torch.cat([img_feat, clim_feat], dim=1)) | |
| return gate * img_feat + (1 - gate) * clim_feat | |
| def _build_upsampler( | |
| in_channels: int, target_channels: int, target_h: int | |
| ) -> nn.Sequential: | |
| layers = [] | |
| current_h = 1 | |
| # Expand to target channels early (e.g., 1x1 → 1x1 with target_channels) | |
| layers += [nn.Conv2d(in_channels, target_channels, kernel_size=1), nn.GELU()] | |
| # Upsample spatially to target_h | |
| while current_h < target_h: | |
| next_h = min(current_h * 2, target_h) | |
| layers += [ | |
| nn.Upsample(scale_factor=2, mode="nearest"), | |
| nn.Conv2d(target_channels, target_channels, kernel_size=3, padding=1), | |
| nn.GELU(), | |
| ] | |
| current_h = next_h | |
| return nn.Sequential(*layers) | |