Instructions to use DiffusionTableQA/llada-8b-table-sft-dcot with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- PEFT
How to use DiffusionTableQA/llada-8b-table-sft-dcot with PEFT:
from peft import PeftModel from transformers import AutoModelForCausalLM base_model = AutoModelForCausalLM.from_pretrained("GSAI-ML/LLaDA-8B-Instruct") model = PeftModel.from_pretrained(base_model, "DiffusionTableQA/llada-8b-table-sft-dcot") - Notebooks
- Google Colab
- Kaggle
| import torch | |
| import numpy as np | |
| import torch.nn.functional as F | |
| import time | |
| import re | |
| from collections import Counter | |
| from transformers import AutoTokenizer, AutoModel | |
| def add_gumbel_noise(logits, temperature): | |
| if temperature == 0: | |
| return logits | |
| logits = logits.to(torch.float64) | |
| noise = torch.rand_like(logits, dtype=torch.float64) | |
| gumbel_noise = (- torch.log(noise)) ** temperature | |
| return logits.exp() / gumbel_noise | |
| def get_num_transfer_tokens(block_mask_index: torch.Tensor, steps: int) -> torch.Tensor: | |
| device = block_mask_index.device | |
| dtype = torch.long | |
| total = block_mask_index.sum(dim=1) | |
| base = torch.div(total, steps, rounding_mode='floor') | |
| rem = total - base * steps | |
| num_transfer_tokens = base.unsqueeze(1).expand(-1, steps).to(dtype) | |
| cols = torch.arange(steps, device=device).unsqueeze(0) | |
| add_mask = cols < rem.unsqueeze(1) | |
| num_transfer_tokens = num_transfer_tokens + add_mask.to(dtype) | |
| return num_transfer_tokens | |
| # ================================================================= | |
| # [์์ ๋จ] top_prob_margin ์ง์ ์ถ๊ฐ | |
| # ================================================================= | |
| def get_transfer_index(logits, temperature, remasking, mask_index, x, num_transfer_tokens, threshold=None): | |
| # 1) Sample proposal x0 | |
| logits_with_noise = add_gumbel_noise(logits, temperature=temperature) | |
| x0 = torch.argmax(logits_with_noise, dim=-1) | |
| # 2) Confidence for chosen tokens | |
| if remasking == "low_confidence": | |
| p = F.softmax(logits.to(torch.float64), dim=-1) | |
| x0_p = torch.gather(p, dim=-1, index=x0.unsqueeze(-1)).squeeze(-1) | |
| # [์ฌ๊ธฐ ์ถ๊ฐ๋จ!] top_prob_margin ๋ก์ง ๋ณต์ | |
| elif remasking == "top_prob_margin": | |
| p = F.softmax(logits.to(torch.float64), dim=-1) | |
| top2_probs, _ = torch.topk(p, k=2, dim=-1) | |
| x0_p = top2_probs[..., 0] - top2_probs[..., 1] | |
| elif remasking == "random": | |
| x0_p = torch.rand(x0.shape, device=x0.device, dtype=torch.float64) | |
| else: | |
| raise NotImplementedError(remasking) | |
| # Only modify masked spots | |
| x0 = torch.where(mask_index, x0, x) | |
| neg_inf = torch.tensor(torch.finfo(x0_p.dtype).min, device=x0_p.device, dtype=x0_p.dtype) | |
| confidence = torch.where(mask_index, x0_p, neg_inf) | |
| # 3) Pick positions to transfer | |
| if threshold is not None: | |
| transfer_index = mask_index & (confidence >= threshold) | |
| max_conf_indices = torch.argmax(confidence, dim=1, keepdim=True) | |
| force_mask = torch.zeros_like(transfer_index).scatter_(1, max_conf_indices, True) | |
| transfer_index = transfer_index | force_mask | |
| transfer_index = transfer_index & mask_index | |
| return x0, transfer_index | |
| if num_transfer_tokens is None: | |
| raise ValueError("num_transfer_tokens must be a tensor when threshold is None.") | |
| if num_transfer_tokens.dim() == 2 and num_transfer_tokens.size(1) == 1: | |
| num_transfer_tokens = num_transfer_tokens.squeeze(1) | |
| num_transfer_tokens = num_transfer_tokens.to(dtype=torch.long, device=confidence.device) | |
| num_transfer_tokens = torch.clamp(num_transfer_tokens, min=0) | |
| values, idx = torch.sort(confidence, dim=1, descending=True) | |
| B, L = confidence.shape | |
| cols = torch.arange(L, device=confidence.device).unsqueeze(0).expand(B, L) | |
| k_expanded = num_transfer_tokens.unsqueeze(1).expand(B, L) | |
| select_sorted = cols < k_expanded | |
| transfer_int = torch.zeros(B, L, device=confidence.device, dtype=torch.int8) | |
| transfer_int = transfer_int.scatter(1, idx, select_sorted.to(torch.int8)) | |
| transfer_index = transfer_int.bool() & mask_index | |
| return x0, transfer_index | |
| # ================================================================= | |
| # [์์ ๋จ] top_prob_margin ์ง์ ์ถ๊ฐ (Dynamic ๋ฒ์ ) | |
| # ================================================================= | |
| def get_transfer_index_dynamic(logits, temperature, remasking, mask_index, x, num_transfer_tokens, factor=1): | |
| logits_with_noise = add_gumbel_noise(logits, temperature=temperature) | |
| x0 = torch.argmax(logits_with_noise, dim=-1) | |
| if remasking == 'low_confidence': | |
| p = F.softmax(logits.to(torch.float64), dim=-1) | |
| x0_p = torch.squeeze(torch.gather(p, dim=-1, index=torch.unsqueeze(x0, -1)), -1) | |
| # [์ฌ๊ธฐ ์ถ๊ฐ๋จ!] top_prob_margin ๋ก์ง ๋ณต์ | |
| elif remasking == 'top_prob_margin': | |
| p = F.softmax(logits.to(torch.float64), dim=-1) | |
| top2_probs, _ = torch.topk(p, k=2, dim=-1) | |
| x0_p = top2_probs[..., 0] - top2_probs[..., 1] | |
| elif remasking == 'random': | |
| x0_p = torch.rand((x0.shape[0], x0.shape[1]), device=x0.device) | |
| else: | |
| raise NotImplementedError(remasking) | |
| x0 = torch.where(mask_index, x0, x) | |
| confidence = torch.where(mask_index, x0_p, -np.inf) | |
| transfer_index = torch.zeros_like(x0, dtype=torch.bool, device=x0.device) | |
| num_transfer_tokens = mask_index.sum(dim=1, keepdim=True) | |
| for j in range(confidence.shape[0]): | |
| num_tokens = int(num_transfer_tokens[j].item()) | |
| if num_tokens == 0: continue | |
| ns = list(range(1, num_transfer_tokens[j] + 1)) | |
| es = [factor / (n + 1) for n in ns] | |
| threshs = [1 - e for e in es] | |
| threshs[0] = -1 | |
| sorted_confidence = torch.sort(confidence[j][mask_index[j]], dim=-1, descending=True)[0] | |
| top_i = len(threshs) | |
| for i in range(len(threshs)): | |
| if sorted_confidence[i] < threshs[i]: | |
| top_i = i | |
| break | |
| if top_i == 0: top_i = 1 | |
| _, select_index = torch.topk(confidence[j], k=top_i) | |
| transfer_index[j, select_index] = True | |
| return x0, transfer_index | |
| # ================================================================= | |
| # generate_standard (๊ธฐ์กด ํจ์) | |
| # ================================================================= | |
| def generate_standard(model, prompt, attention_mask=None, steps=128, gen_length=128, block_length=128, temperature=0., | |
| cfg_scale=0., remasking='low_confidence', mask_id=126336, logits_eos_inf=False, confidence_eos_eot_inf=False): | |
| x = torch.full((prompt.shape[0], prompt.shape[1] + gen_length), mask_id, dtype=torch.long).to(model.device) | |
| x[:, :prompt.shape[1]] = prompt.clone() | |
| if attention_mask is not None: | |
| attention_mask = torch.cat([attention_mask, torch.ones((prompt.shape[0], gen_length), dtype=attention_mask.dtype, device=model.device)], dim=-1) | |
| prompt_index = (x != mask_id) | |
| assert gen_length % block_length == 0 | |
| num_blocks = gen_length // block_length | |
| assert steps % num_blocks == 0 | |
| steps = steps // num_blocks | |
| for num_block in range(num_blocks): | |
| block_mask_index = (x[:, prompt.shape[1] + num_block * block_length: prompt.shape[1] + (num_block + 1) * block_length] == mask_id) | |
| num_transfer_tokens = get_num_transfer_tokens(block_mask_index, steps) | |
| for i in range(steps): | |
| mask_index = (x == mask_id) | |
| if cfg_scale > 0.: | |
| un_x = x.clone() | |
| un_x[prompt_index] = mask_id | |
| x_ = torch.cat([x, un_x], dim=0) | |
| if attention_mask is not None: | |
| attention_mask_ = torch.cat([attention_mask, attention_mask], dim=0) | |
| logits = model(x_, attention_mask=attention_mask_).logits | |
| logits, un_logits = torch.chunk(logits, 2, dim=0) | |
| logits = un_logits + (cfg_scale + 1) * (logits - un_logits) | |
| else: | |
| logits = model(x, attention_mask=attention_mask).logits | |
| if logits_eos_inf: | |
| logits[:, :, 126081] = -torch.inf | |
| logits_with_noise = add_gumbel_noise(logits, temperature=temperature) | |
| x0 = torch.argmax(logits_with_noise, dim=-1) | |
| if confidence_eos_eot_inf: | |
| logits_with_noise[:, :, 126081] = logits[:, :, 126348] = -torch.inf | |
| if remasking == 'low_confidence': | |
| p = F.softmax(logits, dim=-1) | |
| x0_p = torch.squeeze(torch.gather(p, dim=-1, index=torch.unsqueeze(x0, -1)), -1) | |
| elif remasking == 'top_prob_margin': | |
| p = F.softmax(logits, dim=-1) | |
| top2_probs, _ = torch.topk(p, k=2, dim=-1) | |
| x0_p = top2_probs[:, :, 0] - top2_probs[:, :, 1] | |
| elif remasking == 'random': | |
| x0_p = torch.rand((x0.shape[0], x0.shape[1]), device=x0.device) | |
| else: | |
| raise NotImplementedError(remasking) | |
| x0_p[:, prompt.shape[1] + (num_block + 1) * block_length:] = -np.inf | |
| x0 = torch.where(mask_index, x0, x) | |
| confidence = torch.where(mask_index, x0_p, -np.inf) | |
| transfer_index = torch.zeros_like(x0, dtype=torch.bool, device=x0.device) | |
| for j in range(confidence.shape[0]): | |
| _, select_index = torch.topk(confidence[j], k=num_transfer_tokens[j, i]) | |
| transfer_index[j, select_index] = True | |
| x[transfer_index] = x0[transfer_index] | |
| return x | |
| # ================================================================= | |
| # generate_with_dual_cache (์ต์ ํ ํจ์) | |
| # ================================================================= | |
| def generate_with_dual_cache( | |
| model, prompt, steps=128, gen_length=128, block_length=128, temperature=0., | |
| remasking="low_confidence", mask_id=126336, threshold=None, factor=None, | |
| cfg_scale=0., logits_eos_inf=False, confidence_eos_eot_inf=False, attention_mask=None | |
| ): | |
| if cfg_scale > 0: | |
| print("โ ๏ธ Warning: cfg_scale > 0 is not supported in Dual Cache mode. Falling back to standard generate.") | |
| return generate_standard(model, prompt, attention_mask, steps, gen_length, block_length, temperature, cfg_scale, remasking, mask_id, logits_eos_inf, confidence_eos_eot_inf) | |
| B = prompt.shape[0] | |
| Lp = int(prompt.shape[1]) | |
| assert gen_length % block_length == 0 | |
| num_blocks = gen_length // block_length | |
| assert steps % num_blocks == 0 | |
| steps_per_block = steps // num_blocks | |
| x = torch.full((B, Lp + gen_length), mask_id, dtype=torch.long, device=model.device) | |
| x[:, :Lp] = prompt | |
| nfe = 0 | |
| for nb in range(num_blocks): | |
| s = Lp + nb * block_length | |
| e = s + block_length | |
| block_mask_index = (x[:, s:e] == mask_id) | |
| num_transfer_tokens = get_num_transfer_tokens(block_mask_index, steps_per_block) | |
| # 1) Warm KV-cache | |
| out_full = model(x, use_cache=True) | |
| past_key_values = out_full.past_key_values | |
| nfe += 1 | |
| replace_position = torch.zeros_like(x, dtype=torch.bool) | |
| replace_position[:, s:e] = True | |
| global_mask_index = (x == mask_id) | |
| global_mask_index[:, e:] = False | |
| if factor is None: | |
| quota0 = None if threshold is not None else num_transfer_tokens[:, 0] | |
| # ์ฌ๊ธฐ remasking ์ธ์๊ฐ 'top_prob_margin'์ด์ด๋ ์ด์ ์๋ํจ | |
| x0, transfer_index = get_transfer_index( | |
| out_full.logits, temperature, remasking, global_mask_index, x, quota0, threshold | |
| ) | |
| else: | |
| x0, transfer_index = get_transfer_index_dynamic( | |
| out_full.logits, temperature, remasking, global_mask_index, x, None, factor | |
| ) | |
| x = torch.where(transfer_index, x0, x) | |
| for i in range(1, steps_per_block): | |
| if (x[:, s:e] == mask_id).sum() == 0: | |
| break | |
| try: | |
| logits_blk = model( | |
| x[:, s:e], past_key_values=past_key_values, use_cache=True, replace_position=replace_position | |
| ).logits | |
| except TypeError: | |
| logits_blk = model( | |
| x[:, s:e], past_key_values=past_key_values, use_cache=True | |
| ).logits | |
| mask_blk = (x[:, s:e] == mask_id) | |
| if factor is None: | |
| quota_i = None if threshold is not None else num_transfer_tokens[:, i] | |
| x0_blk, transfer_idx_blk = get_transfer_index( | |
| logits_blk, temperature, remasking, mask_blk, x[:, s:e], quota_i, threshold | |
| ) | |
| else: | |
| x0_blk, transfer_idx_blk = get_transfer_index_dynamic( | |
| logits_blk, temperature, remasking, mask_blk, x[:, s:e], None, factor | |
| ) | |
| blk_old = x[:, s:e] | |
| blk_new = torch.where(transfer_idx_blk, x0_blk, blk_old) | |
| x = torch.cat([x[:, :s], blk_new, x[:, e:]], dim=1) | |
| nfe += 1 | |
| return x | |
| # Alias | |
| generate = generate_standard |