aaljabari commited on
Commit
8690086
·
verified ·
1 Parent(s): d53598d

Create transforms.py

Browse files
Files changed (1) hide show
  1. Nested/data/transforms.py +127 -0
Nested/data/transforms.py ADDED
@@ -0,0 +1,127 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from transformers import BertTokenizer
3
+ from functools import partial
4
+ import logging
5
+ import re
6
+ import itertools
7
+ import Nested
8
+
9
+ logger = logging.getLogger(__name__)
10
+
11
+
12
+ class BertSeqTransform:
13
+ def __init__(self, bert_model, vocab, max_seq_len=512):
14
+ self.tokenizer = BertTokenizer.from_pretrained(bert_model)
15
+ self.encoder = partial(
16
+ self.tokenizer.encode,
17
+ max_length=max_seq_len,
18
+ truncation=True,
19
+ )
20
+ self.max_seq_len = max_seq_len
21
+ self.vocab = vocab
22
+
23
+ def __call__(self, segment):
24
+ subwords, tags, tokens = list(), list(), list()
25
+ unk_token = Nested.data.datasets.Token(text="UNK")
26
+
27
+ for token in segment:
28
+ # Sometimes the tokenizer fails to encode the word and return no input_ids, in that case, we use
29
+ # the input_id for [UNK]
30
+ token_subwords = self.encoder(token.text)[1:-1] or self.encoder("[UNK]")[1:-1]
31
+ subwords += token_subwords
32
+ tags += [self.vocab.tags[0].get_stoi()[token.gold_tag[0]]] + [self.vocab.tags[0].get_stoi()["O"]] * (len(token_subwords) - 1)
33
+ tokens += [token] + [unk_token] * (len(token_subwords) - 1)
34
+
35
+ # Truncate to max_seq_len
36
+ if len(subwords) > self.max_seq_len - 2:
37
+ text = " ".join([t.text for t in tokens if t.text != "UNK"])
38
+ logger.info("Truncating the sequence %s to %d", text, self.max_seq_len - 2)
39
+ subwords = subwords[:self.max_seq_len - 2]
40
+ tags = tags[:self.max_seq_len - 2]
41
+ tokens = tokens[:self.max_seq_len - 2]
42
+
43
+ subwords.insert(0, self.tokenizer.cls_token_id)
44
+ subwords.append(self.tokenizer.sep_token_id)
45
+
46
+ tags.insert(0, self.vocab.tags[0].get_stoi()["O"])
47
+ tags.append(self.vocab.tags[0].get_stoi()["O"])
48
+
49
+ tokens.insert(0, unk_token)
50
+ tokens.append(unk_token)
51
+
52
+ return torch.LongTensor(subwords), torch.LongTensor(tags), tokens, len(tokens)
53
+
54
+
55
+ class NestedTagsTransform:
56
+ def __init__(self, bert_model, vocab, max_seq_len=512):
57
+ self.tokenizer = BertTokenizer.from_pretrained(bert_model)
58
+ self.encoder = partial(
59
+ self.tokenizer.encode,
60
+ max_length=max_seq_len,
61
+ truncation=True,
62
+ )
63
+ self.max_seq_len = max_seq_len
64
+ self.vocab = vocab
65
+
66
+ def __call__(self, segment):
67
+ tags, tokens, subwords = list(), list(), list()
68
+ unk_token = Nested.data.datasets.Token(text="UNK")
69
+
70
+ # Encode each token and get its subwords and IDs
71
+ for token in segment:
72
+ # Sometimes the tokenizer fails to encode the word and return no input_ids, in that case, we use
73
+ # the input_id for [UNK]
74
+ token.subwords = self.encoder(token.text)[1:-1] or self.encoder("[UNK]")[1:-1]
75
+ subwords += token.subwords
76
+ tokens += [token] + [unk_token] * (len(token.subwords) - 1)
77
+
78
+ # Construct the labels for each tag type
79
+ # The sequence will have a list of tags for each type
80
+ # The final tags for a sequence is a matrix NUM_TAG_TYPES x SEQ_LEN
81
+ # Example:
82
+ # [
83
+ # [O, O, B-PERS, I-PERS, O, O, O]
84
+ # [B-ORG, I-ORG, O, O, O, O, O]
85
+ # [O, O, O, O, O, O, B-GPE]
86
+ # ]
87
+ for vocab in self.vocab.tags[1:]:
88
+ vocab_tags = "|".join(["^" + t + "$" for t in vocab.get_itos() if "-" in t])
89
+ r = re.compile(vocab_tags)
90
+
91
+ # This is really messy
92
+ # For a given token we find a matching tag_name, BUT we might find
93
+ # multiple matches (i.e. a token can be labeled B-ORG and I-ORG) in this
94
+ # case we get only the first tag as we do not have overlapping of same type
95
+ single_type_tags = [[(list(filter(r.match, token.gold_tag))
96
+ or ["O"])[0]] + ["O"] * (len(token.subwords) - 1)
97
+ for token in segment]
98
+ single_type_tags = list(itertools.chain(*single_type_tags))
99
+ tags.append([vocab.get_stoi()[tag] for tag in single_type_tags])
100
+
101
+ # Truncate to max_seq_len
102
+ if len(subwords) > self.max_seq_len - 2:
103
+ text = " ".join([t.text for t in tokens if t.text != "UNK"])
104
+ logger.info("Truncating the sequence %s to %d", text, self.max_seq_len - 2)
105
+ subwords = subwords[:self.max_seq_len - 2]
106
+ tags = [t[:self.max_seq_len - 2] for t in tags]
107
+ tokens = tokens[:self.max_seq_len - 2]
108
+
109
+ # Add dummy token at the start end of sequence
110
+ tokens.insert(0, unk_token)
111
+ tokens.append(unk_token)
112
+
113
+ # Add CLS and SEP at start end of subwords
114
+ subwords.insert(0, self.tokenizer.cls_token_id)
115
+ subwords.append(self.tokenizer.sep_token_id)
116
+ subwords = torch.LongTensor(subwords)
117
+
118
+ # Add "O" tags for the first and last subwords
119
+ tags = torch.Tensor(tags)
120
+ tags = torch.column_stack((
121
+ torch.Tensor([vocab.get_stoi()["O"] for vocab in self.vocab.tags[1:]]),
122
+ tags,
123
+ torch.Tensor([vocab.get_stoi()["O"] for vocab in self.vocab.tags[1:]]),
124
+ )).unsqueeze(0)
125
+
126
+ mask = torch.ones_like(tags)
127
+ return subwords, tags, tokens, mask, len(tokens)