Instructions to use emarro/pcad2-200M-cnet-ar with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Transformers
How to use emarro/pcad2-200M-cnet-ar with Transformers:
# Use a pipeline as a high-level helper from transformers import pipeline pipe = pipeline("feature-extraction", model="emarro/pcad2-200M-cnet-ar", trust_remote_code=True)# Load model directly from transformers import AutoModel model = AutoModel.from_pretrained("emarro/pcad2-200M-cnet-ar", trust_remote_code=True, dtype="auto") - Notebooks
- Google Colab
- Kaggle
| from dataclasses import dataclass | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from einops import repeat, rearrange | |
| from mamba_ssm.ops.triton.ssd_combined import mamba_chunk_scan_combined | |
| from .modules_utils import get_seq_idx | |
| class RoutingModuleOutput: | |
| boundary_prob: torch.Tensor | |
| boundary_mask: torch.Tensor | |
| selected_probs: torch.Tensor | |
| class RoutingModuleState: | |
| """ | |
| The state of the routing module. | |
| Contains | |
| - [has_seen_tokens] (batch_size,) bool tensor. Whether that batch element has processed any tokens yet. | |
| - [last_hidden_state] (batch_size, d_model) tensor. The last hidden state of the batch element (used for boundary prediction). | |
| """ | |
| has_seen_tokens: torch.Tensor # (batch_size,) | |
| last_hidden_state: torch.Tensor # (batch_size, d_model) | |
| class DeChunkState: | |
| """ | |
| The state of the dechunk. | |
| Contains | |
| - [last_value] (batch_size, d_model) tensor. The last value of the batch element (used for the EMA). | |
| """ | |
| last_value: torch.Tensor # (batch_size, d_model) | |
| class RoutingModule(nn.Module): | |
| def __init__(self, d_model, selection="cos", device=None, dtype=None): | |
| self.d_model = d_model | |
| factory_kwargs = {"device": device, "dtype": dtype} | |
| super().__init__() | |
| self.selection = selection | |
| if selection == "mlp": | |
| self.mlp = nn.Sequential( | |
| nn.Linear(self.d_model, 1, bias=False), | |
| # nn.GELU(), | |
| # nn.Linear(self.d_model * 2, 1), | |
| nn.Sigmoid(), | |
| ) | |
| # self.mlp = nn.Sequential( | |
| # nn.Linear(self.d_model, d_model * 2, bias=False), | |
| # nn.GELU(), | |
| # nn.Linear(self.d_model * 2, 1, bias=False), | |
| # nn.Sigmoid(), | |
| # ) | |
| elif selection == "cos": | |
| self.q_proj_layer = nn.Linear( | |
| d_model, d_model, bias=False, **factory_kwargs | |
| ) | |
| self.k_proj_layer = nn.Linear( | |
| d_model, d_model, bias=False, **factory_kwargs | |
| ) | |
| with torch.no_grad(): | |
| self.q_proj_layer.weight.copy_(torch.eye(d_model)) | |
| self.k_proj_layer.weight.copy_(torch.eye(d_model)) | |
| self.q_proj_layer.weight._no_reinit = True | |
| self.k_proj_layer.weight._no_reinit = True | |
| else: | |
| raise Exception(f"Unrecognized selection mechanism {selection}") | |
| def allocate_inference_cache(self, batch_size, max_seqlen, device, dtype=None): | |
| return RoutingModuleState( | |
| has_seen_tokens=torch.zeros(batch_size, device=device, dtype=torch.bool), | |
| last_hidden_state=torch.zeros( | |
| batch_size, self.d_model, device=device, dtype=dtype | |
| ), | |
| ) | |
| def forward(self, hidden_states, cu_seqlens=None, mask=None, inference_params=None): | |
| assert (mask is not None) or (cu_seqlens is not None), ( | |
| "Either mask or cu_seqlens must be provided" | |
| ) | |
| if inference_params is not None: | |
| assert mask is not None, ( | |
| "Mask must be provided if inference_params is provided" | |
| ) | |
| assert (~inference_params.has_seen_tokens).all(), ( | |
| "Cannot have seen tokens when inference_params is not provided" | |
| ) | |
| if cu_seqlens is not None: | |
| # We are in packed mode, so hidden_states is (T, D). Make it (B, T, D) | |
| hidden_states = hidden_states.unsqueeze(0) | |
| if self.selection == "cos": | |
| cos_sim = torch.einsum( | |
| "b l d, b l d -> b l", | |
| F.normalize(self.q_proj_layer(hidden_states[:, :-1]), dim=-1), | |
| F.normalize(self.k_proj_layer(hidden_states[:, 1:]), dim=-1), | |
| ) | |
| boundary_prob = torch.clamp(((1 - cos_sim) / 2), min=0.0, max=1.0) | |
| elif self.selection == "mlp": | |
| boundary_prob = self.mlp(hidden_states)[:, 1:, 0] # [B,L, 1] | |
| # this clamp should no-op as long as no precision issues are encountered | |
| # Force boundary probability of the first element to 1.0 | |
| PAD_PROB = 1.0 | |
| boundary_prob = F.pad(boundary_prob, (1, 0), "constant", PAD_PROB) | |
| if cu_seqlens is not None: | |
| boundary_prob = boundary_prob.squeeze(0) | |
| boundary_prob[cu_seqlens[:-1]] = PAD_PROB | |
| boundary_prob = torch.stack(((1 - boundary_prob), boundary_prob), dim=-1) | |
| selected_idx = torch.argmax(boundary_prob, dim=-1) | |
| boundary_mask = selected_idx == 1 # (shape hidden_states.shape[:-1]) | |
| if mask is not None: | |
| # No invalid tokens can be selected | |
| boundary_mask = boundary_mask & mask | |
| if inference_params is not None: | |
| has_mask = mask.any(dim=-1) | |
| inference_params.has_seen_tokens.copy_( | |
| has_mask | inference_params.has_seen_tokens | |
| ) | |
| last_mask = torch.clamp(mask.sum(dim=-1) - 1, min=0) | |
| inference_params.last_hidden_state.copy_( | |
| torch.where( | |
| has_mask, | |
| hidden_states[ | |
| torch.arange( | |
| hidden_states.shape[0], device=hidden_states.device | |
| ), | |
| last_mask, | |
| ], | |
| inference_params.last_hidden_state, | |
| ) | |
| ) | |
| selected_probs = boundary_prob.gather( | |
| dim=-1, index=selected_idx.unsqueeze(-1) | |
| ) # (shape hidden_states.shape[:-1], 1) | |
| return RoutingModuleOutput( | |
| boundary_prob=boundary_prob, # (shape hidden_states.shape[:-1], 2) | |
| boundary_mask=boundary_mask, # (shape hidden_states.shape[:-1]) | |
| selected_probs=selected_probs, # (shape hidden_states.shape[:-1], 1) | |
| ) | |
| def step(self, hidden_states, inference_params): | |
| # hidden_states is (B, 1, D) | |
| hidden_states = hidden_states.squeeze(1) | |
| cos_sim = torch.einsum( | |
| "b d, b d -> b", | |
| F.normalize(self.q_proj_layer(inference_params.last_hidden_state), dim=-1), | |
| F.normalize(self.k_proj_layer(hidden_states), dim=-1), | |
| ) | |
| boundary_prob = torch.clamp(((1 - cos_sim) / 2), min=0.0, max=1.0) | |
| inference_params.last_hidden_state.copy_(hidden_states) | |
| boundary_prob = torch.where( | |
| inference_params.has_seen_tokens, | |
| boundary_prob, | |
| torch.ones_like(boundary_prob), | |
| ) | |
| boundary_prob = torch.stack(((1 - boundary_prob), boundary_prob), dim=-1) | |
| inference_params.has_seen_tokens.copy_( | |
| torch.ones_like(inference_params.has_seen_tokens) | |
| ) | |
| return RoutingModuleOutput( | |
| boundary_prob=boundary_prob, # (B, 2) | |
| boundary_mask=boundary_prob[..., 1] > 0.5, # (B,) | |
| selected_probs=boundary_prob.max(dim=-1).values.unsqueeze(-1), # (B, 1) | |
| ) | |
| class ChunkLayer(nn.Module): | |
| def forward(self, hidden_states, boundary_mask, cu_seqlens=None, mask=None): | |
| assert (mask is not None) or (cu_seqlens is not None), ( | |
| "Either mask or cu_seqlens must be provided" | |
| ) | |
| if cu_seqlens is not None: | |
| next_hidden_states = hidden_states[boundary_mask] | |
| next_cu_seqlens = F.pad( | |
| boundary_mask.cumsum(dim=0)[cu_seqlens[1:] - 1], (1, 0) | |
| ) | |
| next_max_seqlen = int((next_cu_seqlens[1:] - next_cu_seqlens[:-1]).max()) | |
| next_mask = None | |
| else: | |
| next_cu_seqlens = None | |
| num_tokens = boundary_mask.sum(dim=-1) | |
| next_max_seqlen = int(num_tokens.max()) | |
| device = hidden_states.device | |
| L = hidden_states.shape[1] | |
| token_idx = ( | |
| torch.arange(L, device=device)[None, :] + (~boundary_mask).long() * L | |
| ) | |
| seq_sorted_indices = torch.argsort(token_idx, dim=1) | |
| next_hidden_states = torch.gather( | |
| hidden_states, | |
| dim=1, | |
| index=seq_sorted_indices[:, :next_max_seqlen, None].expand( | |
| -1, -1, hidden_states.shape[-1] | |
| ), | |
| ) | |
| next_mask = ( | |
| torch.arange(next_max_seqlen, device=device)[None, :] | |
| < num_tokens[:, None] | |
| ) | |
| next_max_seqlen = None | |
| return next_hidden_states, next_cu_seqlens, next_max_seqlen, next_mask | |
| def step(self, hidden_states, boundary_mask): | |
| return hidden_states[boundary_mask] | |
| class DeChunkLayer(nn.Module): | |
| def __init__( | |
| self, | |
| d_model, | |
| dtype=torch.bfloat16, | |
| block_size=256, | |
| headdim=32, | |
| ): | |
| super().__init__() | |
| self.d_model = d_model | |
| # Just for Mamba2 kernel. | |
| self.dtype = dtype | |
| self.block_size = block_size | |
| self.headdim = headdim | |
| assert d_model % self.headdim == 0 | |
| self.nheads = d_model // self.headdim | |
| def allocate_inference_cache(self, batch_size, max_seqlen, device, dtype=None): | |
| return DeChunkState( | |
| last_value=torch.zeros( | |
| batch_size, self.d_model, device=device, dtype=dtype | |
| ), | |
| ) | |
| def forward( | |
| self, | |
| hidden_states, | |
| boundary_mask, | |
| boundary_prob, | |
| cu_seqlens=None, | |
| inference_params=None, | |
| mask=None, | |
| ): | |
| if inference_params is not None: | |
| assert mask is not None, ( | |
| "Mask must be provided if inference_params is provided" | |
| ) | |
| assert boundary_mask[:, 0].all(), ( | |
| "First token must be a boundary if running prefill" | |
| ) | |
| p = torch.clamp(boundary_prob[..., -1].float(), min=1e-4, max=1 - (1e-4)) | |
| if cu_seqlens is not None: | |
| p = p[boundary_mask].unsqueeze(0) | |
| seq_idx = get_seq_idx(cu_seqlens, device=hidden_states.device) | |
| else: | |
| B, L = boundary_mask.shape | |
| seq_idx = None | |
| token_idx = ( | |
| torch.arange(L, device=hidden_states.device)[None, :] | |
| + (~boundary_mask).long() * L | |
| ) | |
| seq_sorted_indices = torch.argsort(token_idx, dim=1) | |
| p = torch.gather( | |
| p, dim=1, index=seq_sorted_indices[:, : hidden_states.shape[1]] | |
| ) # (B, M) | |
| original_dtype = hidden_states.dtype | |
| # Reuse Mamba2 kernel for EMA Deaggregator. | |
| dt = torch.log(1 / (1 - p)).to(self.dtype) | |
| x = (hidden_states / dt[..., None]).to(self.dtype) | |
| A = -torch.ones( | |
| (self.nheads,), device=hidden_states.device, dtype=torch.float32 | |
| ) | |
| b = p.to(self.dtype) | |
| c = torch.ones_like(b) | |
| out = mamba_chunk_scan_combined( | |
| rearrange(x, "b l (h p) -> b l h p", p=self.headdim), | |
| repeat(dt, "b l -> b l h", h=self.nheads), | |
| A, | |
| rearrange(b, "b l -> b l 1 1"), | |
| rearrange(c, "b l -> b l 1 1"), | |
| chunk_size=self.block_size, | |
| seq_idx=seq_idx, | |
| ) | |
| out = rearrange(out, "b l h p -> b l (h p)") | |
| if cu_seqlens is not None: | |
| out = out.squeeze(0) | |
| plug_back_idx = boundary_mask.cumsum(dim=0) - 1 | |
| out = torch.gather( | |
| out, dim=0, index=plug_back_idx.unsqueeze(-1).expand(-1, self.d_model) | |
| ) | |
| else: | |
| plug_back_idx = torch.cumsum(boundary_mask, dim=1) - 1 # (B, L) | |
| out = torch.gather( | |
| out, | |
| dim=1, | |
| index=plug_back_idx.unsqueeze(-1).expand(-1, -1, self.d_model), | |
| ) | |
| if inference_params is not None: | |
| inference_params.last_value.copy_(out[:, -1]) | |
| return out.to(original_dtype) | |
| def step(self, hidden_states, boundary_mask, boundary_prob, inference_params): | |
| # hidden_states is (B', 1, D), where B' = boundary_mask.sum() | |
| # boundary_mask is (B,) and boundary_prob is (B, 2) | |
| B = boundary_mask.shape[0] | |
| # B_selected = hidden_states.shape[0] | |
| D = hidden_states.shape[-1] | |
| p = torch.zeros(B, device=hidden_states.device, dtype=hidden_states.dtype) | |
| p[boundary_mask] = boundary_prob[boundary_mask, -1].clamp( | |
| min=1e-4, max=1 - (1e-4) | |
| ) | |
| current_hidden_states = torch.zeros( | |
| B, D, device=hidden_states.device, dtype=hidden_states.dtype | |
| ) | |
| current_hidden_states[boundary_mask] = hidden_states.squeeze(1) | |
| result = p * current_hidden_states + (1 - p) * inference_params.last_value | |
| inference_params.last_value.copy_(result) | |
| return result.unsqueeze(1) | |