| | import ast |
| | import collections |
| | import collections.abc |
| | import enum |
| | import itertools |
| | import json |
| | import os |
| | import operator |
| | import re |
| | import copy |
| | import random |
| |
|
| | import asdl |
| | import attr |
| | import pyrsistent |
| | import entmax |
| | import torch |
| | import torch.nn.functional as F |
| |
|
| | from seq2struct import ast_util |
| | from seq2struct import grammars |
| | from seq2struct.models import abstract_preproc |
| | from seq2struct.models import attention |
| | from seq2struct.models import variational_lstm |
| | from seq2struct.utils import registry |
| | from seq2struct.utils import vocab |
| | from seq2struct.utils import serialization |
| | from seq2struct.models.nl2code.tree_traversal import TreeTraversal |
| | from seq2struct.models.nl2code.train_tree_traversal import TrainTreeTraversal |
| | from seq2struct.models.nl2code.infer_tree_traversal import InferenceTreeTraversal |
| |
|
| |
|
| | def lstm_init(device, num_layers, hidden_size, *batch_sizes): |
| | init_size = batch_sizes + (hidden_size, ) |
| | if num_layers is not None: |
| | init_size = (num_layers, ) + init_size |
| | init = torch.zeros(*init_size, device=device) |
| | return (init, init) |
| |
|
| |
|
| | def maybe_stack(items, dim=None): |
| | to_stack = [item for item in items if item is not None] |
| | if not to_stack: |
| | return None |
| | elif len(to_stack) == 1: |
| | return to_stack[0].unsqueeze(dim) |
| | else: |
| | return torch.stack(to_stack, dim) |
| |
|
| |
|
| | def accumulate_logprobs(d, keys_and_logprobs): |
| | for key, logprob in keys_and_logprobs: |
| | existing = d.get(key) |
| | if existing is None: |
| | d[key] = logprob |
| | else: |
| | d[key] = torch.logsumexp( |
| | torch.stack((logprob, existing), dim=0), |
| | dim=0) |
| |
|
| | def get_field_presence_info(ast_wrapper, node, field_infos): |
| | present = [] |
| | for field_info in field_infos: |
| | field_value = node.get(field_info.name) |
| | is_present = field_value is not None and field_value != [] |
| |
|
| | maybe_missing = field_info.opt or field_info.seq |
| | is_builtin_type = field_info.type in ast_wrapper.primitive_types |
| |
|
| | if maybe_missing and is_builtin_type: |
| | |
| | present.append(is_present and type(field_value).__name__) |
| | elif maybe_missing and not is_builtin_type: |
| | present.append(is_present) |
| | elif not maybe_missing and is_builtin_type: |
| | present.append(type(field_value).__name__) |
| | elif not maybe_missing and not is_builtin_type: |
| | assert is_present |
| | present.append(True) |
| | return tuple(present) |
| |
|
| | @attr.s |
| | class NL2CodeDecoderPreprocItem: |
| | tree = attr.ib() |
| | orig_code = attr.ib() |
| |
|
| |
|
| | class NL2CodeDecoderPreproc(abstract_preproc.AbstractPreproc): |
| | def __init__( |
| | self, |
| | grammar, |
| | save_path, |
| | min_freq=3, |
| | max_count=5000, |
| | use_seq_elem_rules=False): |
| | self.grammar = registry.construct('grammar', grammar) |
| | self.ast_wrapper = self.grammar.ast_wrapper |
| |
|
| | self.vocab_path = os.path.join(save_path, 'dec_vocab.json') |
| | self.observed_productions_path = os.path.join(save_path, 'observed_productions.json') |
| | self.grammar_rules_path = os.path.join(save_path, 'grammar_rules.json') |
| | self.data_dir = os.path.join(save_path, 'dec') |
| |
|
| | self.vocab_builder = vocab.VocabBuilder(min_freq, max_count) |
| | self.use_seq_elem_rules = use_seq_elem_rules |
| |
|
| | self.items = collections.defaultdict(list) |
| | self.sum_type_constructors = collections.defaultdict(set) |
| | self.field_presence_infos = collections.defaultdict(set) |
| | self.seq_lengths = collections.defaultdict(set) |
| | self.primitive_types = set() |
| |
|
| | self.vocab = None |
| | self.all_rules = None |
| | self.rules_mask = None |
| |
|
| |
|
| | def validate_item(self, item, section): |
| | parsed = self.grammar.parse(item.code, section) |
| | if parsed: |
| | self.ast_wrapper.verify_ast(parsed) |
| | return True, parsed |
| | return section != 'train', None |
| |
|
| | def add_item(self, item, section, validation_info): |
| | root = validation_info |
| | if section == 'train': |
| | for token in self._all_tokens(root): |
| | self.vocab_builder.add_word(token) |
| | self._record_productions(root) |
| |
|
| | self.items[section].append( |
| | NL2CodeDecoderPreprocItem( |
| | tree=root, |
| | orig_code=item.code)) |
| | |
| | def clear_items(self): |
| | self.items = collections.defaultdict(list) |
| |
|
| | def save(self): |
| | os.makedirs(self.data_dir, exist_ok=True) |
| | self.vocab = self.vocab_builder.finish() |
| | self.vocab.save(self.vocab_path) |
| |
|
| | for section, items in self.items.items(): |
| | with open(os.path.join(self.data_dir, section + '.jsonl'), 'w', encoding='utf8') as f: |
| | for item in items: |
| | f.write(json.dumps(attr.asdict(item), ensure_ascii=False) + '\n') |
| |
|
| | |
| | self.sum_type_constructors = serialization.to_dict_with_sorted_values( |
| | self.sum_type_constructors) |
| | self.field_presence_infos = serialization.to_dict_with_sorted_values( |
| | self.field_presence_infos, key=str) |
| | self.seq_lengths = serialization.to_dict_with_sorted_values( |
| | self.seq_lengths) |
| | self.primitive_types = sorted(self.primitive_types) |
| | with open(self.observed_productions_path, 'w', encoding='utf8') as f: |
| | json.dump({ |
| | 'sum_type_constructors': self.sum_type_constructors, |
| | 'field_presence_infos': self.field_presence_infos, |
| | 'seq_lengths': self.seq_lengths, |
| | 'primitive_types': self.primitive_types, |
| | }, f, indent=2, sort_keys=True, ensure_ascii=False) |
| |
|
| | |
| | self.all_rules, self.rules_mask = self._calculate_rules() |
| | with open(self.grammar_rules_path, 'w', encoding='utf8') as f: |
| | json.dump({ |
| | 'all_rules': self.all_rules, |
| | 'rules_mask': self.rules_mask, |
| | }, f, indent=2, sort_keys=True, ensure_ascii=False) |
| |
|
| | def load(self): |
| | self.vocab = vocab.Vocab.load(self.vocab_path) |
| |
|
| | observed_productions = json.load(open(self.observed_productions_path, encoding='utf8')) |
| | self.sum_type_constructors = observed_productions['sum_type_constructors'] |
| | self.field_presence_infos = observed_productions['field_presence_infos'] |
| | self.seq_lengths = observed_productions['seq_lengths'] |
| | self.primitive_types = observed_productions['primitive_types'] |
| |
|
| | grammar = json.load(open(self.grammar_rules_path, encoding='utf8')) |
| | self.all_rules = serialization.tuplify(grammar['all_rules']) |
| | self.rules_mask = grammar['rules_mask'] |
| |
|
| | def dataset(self, section): |
| | return [ |
| | NL2CodeDecoderPreprocItem(**json.loads(line)) |
| | for line in open(os.path.join(self.data_dir, section + '.jsonl'), encoding='utf8')] |
| |
|
| | def _record_productions(self, tree): |
| | queue = [(tree, False)] |
| | while queue: |
| | node, is_seq_elem = queue.pop() |
| | node_type = node['_type'] |
| |
|
| | |
| | |
| | |
| | for type_name in [node_type] + node.get('_extra_types', []): |
| | if type_name in self.ast_wrapper.constructors: |
| | sum_type_name = self.ast_wrapper.constructor_to_sum_type[type_name] |
| | if is_seq_elem and self.use_seq_elem_rules: |
| | self.sum_type_constructors[sum_type_name + '_seq_elem'].add(type_name) |
| | else: |
| | self.sum_type_constructors[sum_type_name].add(type_name) |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | assert node_type in self.ast_wrapper.singular_types |
| | field_presence_info = get_field_presence_info( |
| | self.ast_wrapper, |
| | node, |
| | self.ast_wrapper.singular_types[node_type].fields) |
| | self.field_presence_infos[node_type].add(field_presence_info) |
| |
|
| | for field_info in self.ast_wrapper.singular_types[node_type].fields: |
| | field_value = node.get(field_info.name, [] if field_info.seq else None) |
| | to_enqueue = [] |
| | if field_info.seq: |
| | |
| | |
| | |
| | |
| | self.seq_lengths[field_info.type + '*'].add(len(field_value)) |
| | to_enqueue = field_value |
| | else: |
| | to_enqueue = [field_value] |
| | for child in to_enqueue: |
| | if isinstance(child, collections.abc.Mapping) and '_type' in child: |
| | queue.append((child, field_info.seq)) |
| | else: |
| | self.primitive_types.add(type(child).__name__) |
| |
|
| | def _calculate_rules(self): |
| | offset = 0 |
| |
|
| | all_rules = [] |
| | rules_mask = {} |
| |
|
| | |
| | |
| | |
| | for parent, children in sorted(self.sum_type_constructors.items()): |
| | assert not isinstance(children, set) |
| | rules_mask[parent] = (offset, offset + len(children)) |
| | offset += len(children) |
| | all_rules += [(parent, child) for child in children] |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | for name, field_presence_infos in sorted(self.field_presence_infos.items()): |
| | assert not isinstance(field_presence_infos, set) |
| | rules_mask[name] = (offset, offset + len(field_presence_infos)) |
| | offset += len(field_presence_infos) |
| | all_rules += [(name, presence) for presence in field_presence_infos] |
| |
|
| | |
| | |
| | |
| | |
| | for seq_type_name, lengths in sorted(self.seq_lengths.items()): |
| | assert not isinstance(lengths, set) |
| | rules_mask[seq_type_name] = (offset, offset + len(lengths)) |
| | offset += len(lengths) |
| | all_rules += [(seq_type_name, i) for i in lengths] |
| |
|
| | return tuple(all_rules), rules_mask |
| |
|
| |
|
| | def _all_tokens(self, root): |
| | queue = [root] |
| | while queue: |
| | node = queue.pop() |
| | type_info = self.ast_wrapper.singular_types[node['_type']] |
| |
|
| | for field_info in reversed(type_info.fields): |
| | field_value = node.get(field_info.name) |
| | if field_info.type in self.grammar.pointers: |
| | pass |
| | elif field_info.type in self.ast_wrapper.primitive_types: |
| | for token in self.grammar.tokenize_field_value(field_value): |
| | yield token |
| | elif isinstance(field_value, (list, tuple)): |
| | queue.extend(field_value) |
| | elif field_value is not None: |
| | queue.append(field_value) |
| |
|
| |
|
| | @attr.s |
| | class TreeState: |
| | node = attr.ib() |
| | parent_field_type = attr.ib() |
| |
|
| |
|
| | @registry.register('decoder', 'NL2Code') |
| | class NL2CodeDecoder(torch.nn.Module): |
| |
|
| | Preproc = NL2CodeDecoderPreproc |
| |
|
| | def __init__( |
| | self, |
| | device, |
| | preproc, |
| | |
| | rule_emb_size=128, |
| | node_embed_size=64, |
| | |
| | enc_recurrent_size=256, |
| | recurrent_size=256, |
| | dropout=0., |
| | desc_attn='bahdanau', |
| | copy_pointer=None, |
| | multi_loss_type='logsumexp', |
| | sup_att=None, |
| | use_align_mat=False, |
| | use_align_loss=False, |
| | enumerate_order=False, |
| | loss_type="softmax"): |
| | super().__init__() |
| | self._device = device |
| | self.preproc = preproc |
| | self.ast_wrapper = preproc.ast_wrapper |
| | self.terminal_vocab = preproc.vocab |
| |
|
| | self.rule_emb_size = rule_emb_size |
| | self.node_emb_size = node_embed_size |
| | self.enc_recurrent_size = enc_recurrent_size |
| | self.recurrent_size = recurrent_size |
| |
|
| | self.rules_index = {v: idx for idx, v in enumerate(self.preproc.all_rules)} |
| | self.use_align_mat = use_align_mat |
| | self.use_align_loss = use_align_loss |
| | self.enumerate_order = enumerate_order |
| |
|
| | if use_align_mat: |
| | from seq2struct.models.spider import spider_dec_func |
| | self.compute_align_loss = lambda *args: \ |
| | spider_dec_func.compute_align_loss(self, *args) |
| | self.compute_pointer_with_align = lambda *args: \ |
| | spider_dec_func.compute_pointer_with_align(self, *args) |
| |
|
| | if self.preproc.use_seq_elem_rules: |
| | self.node_type_vocab = vocab.Vocab( |
| | sorted(self.preproc.primitive_types) + |
| | sorted(self.ast_wrapper.custom_primitive_types) + |
| | sorted(self.preproc.sum_type_constructors.keys()) + |
| | sorted(self.preproc.field_presence_infos.keys()) + |
| | sorted(self.preproc.seq_lengths.keys()), |
| | special_elems=()) |
| | else: |
| | self.node_type_vocab = vocab.Vocab( |
| | sorted(self.preproc.primitive_types) + |
| | sorted(self.ast_wrapper.custom_primitive_types) + |
| | sorted(self.ast_wrapper.sum_types.keys()) + |
| | sorted(self.ast_wrapper.singular_types.keys()) + |
| | sorted(self.preproc.seq_lengths.keys()), |
| | special_elems=()) |
| |
|
| | self.state_update = variational_lstm.RecurrentDropoutLSTMCell( |
| | input_size=self.rule_emb_size * 2 + self.enc_recurrent_size + self.recurrent_size + self.node_emb_size, |
| | hidden_size=self.recurrent_size, |
| | dropout=dropout) |
| |
|
| | self.attn_type = desc_attn |
| | if desc_attn == 'bahdanau': |
| | self.desc_attn = attention.BahdanauAttention( |
| | query_size=self.recurrent_size, |
| | value_size=self.enc_recurrent_size, |
| | proj_size=50) |
| | elif desc_attn == 'mha': |
| | self.desc_attn = attention.MultiHeadedAttention( |
| | h=8, |
| | query_size=self.recurrent_size, |
| | value_size=self.enc_recurrent_size) |
| | elif desc_attn == 'mha-1h': |
| | self.desc_attn = attention.MultiHeadedAttention( |
| | h=1, |
| | query_size=self.recurrent_size, |
| | value_size=self.enc_recurrent_size) |
| | elif desc_attn == 'sep': |
| | self.question_attn = attention.MultiHeadedAttention( |
| | h=1, |
| | query_size=self.recurrent_size, |
| | value_size=self.enc_recurrent_size) |
| | self.schema_attn = attention.MultiHeadedAttention( |
| | h=1, |
| | query_size=self.recurrent_size, |
| | value_size=self.enc_recurrent_size) |
| | else: |
| | |
| | self.desc_attn = desc_attn |
| | self.sup_att = sup_att |
| |
|
| | self.rule_logits = torch.nn.Sequential( |
| | torch.nn.Linear(self.recurrent_size, self.rule_emb_size), |
| | torch.nn.Tanh(), |
| | torch.nn.Linear(self.rule_emb_size, len(self.rules_index))) |
| | self.rule_embedding = torch.nn.Embedding( |
| | num_embeddings=len(self.rules_index), |
| | embedding_dim=self.rule_emb_size) |
| |
|
| | self.gen_logodds = torch.nn.Linear(self.recurrent_size, 1) |
| | self.terminal_logits = torch.nn.Sequential( |
| | torch.nn.Linear(self.recurrent_size, self.rule_emb_size), |
| | torch.nn.Tanh(), |
| | torch.nn.Linear(self.rule_emb_size, len(self.terminal_vocab))) |
| | self.terminal_embedding = torch.nn.Embedding( |
| | num_embeddings=len(self.terminal_vocab), |
| | embedding_dim=self.rule_emb_size) |
| | if copy_pointer is None: |
| | self.copy_pointer = attention.BahdanauPointer( |
| | query_size=self.recurrent_size, |
| | key_size=self.enc_recurrent_size, |
| | proj_size=50) |
| | else: |
| | |
| | self.copy_pointer = copy_pointer |
| | if multi_loss_type == 'logsumexp': |
| | self.multi_loss_reduction = lambda logprobs: -torch.logsumexp(logprobs, dim=1) |
| | elif multi_loss_type == 'mean': |
| | self.multi_loss_reduction = lambda logprobs: -torch.mean(logprobs, dim=1) |
| | |
| | self.pointers = torch.nn.ModuleDict() |
| | self.pointer_action_emb_proj = torch.nn.ModuleDict() |
| | for pointer_type in self.preproc.grammar.pointers: |
| | self.pointers[pointer_type] = attention.ScaledDotProductPointer( |
| | query_size=self.recurrent_size, |
| | key_size=self.enc_recurrent_size) |
| | self.pointer_action_emb_proj[pointer_type] = torch.nn.Linear( |
| | self.enc_recurrent_size, self.rule_emb_size) |
| |
|
| | self.node_type_embedding = torch.nn.Embedding( |
| | num_embeddings=len(self.node_type_vocab), |
| | embedding_dim=self.node_emb_size) |
| |
|
| | |
| | self.zero_rule_emb = torch.zeros(1, self.rule_emb_size, device=self._device) |
| | self.zero_recurrent_emb = torch.zeros(1, self.recurrent_size, device=self._device) |
| | if loss_type == "softmax": |
| | self.xent_loss = torch.nn.CrossEntropyLoss(reduction='none') |
| | elif loss_type == "entmax": |
| | self.xent_loss = entmax.entmax15_loss |
| | elif loss_type == "sparsemax": |
| | self.xent_loss = entmax.sparsemax_loss |
| | elif loss_type == "label_smooth": |
| | self.xent_loss = self.label_smooth_loss |
| | |
| | def label_smooth_loss(self, X, target, smooth_value=0.1): |
| | if self.training: |
| | logits = torch.log_softmax(X, dim=1) |
| | size = X.size()[1] |
| | one_hot = torch.full(X.size(), smooth_value / (size - 1)).to(X.device) |
| | one_hot.scatter_(1, target.unsqueeze(0), 1 - smooth_value) |
| | loss = F.kl_div(logits, one_hot, reduction="batchmean") |
| | return loss.unsqueeze(0) |
| | else: |
| | return torch.nn.functional.cross_entropy(X, target, reduction="none") |
| |
|
| | @classmethod |
| | def _calculate_rules(cls, preproc): |
| | offset = 0 |
| |
|
| | all_rules = [] |
| | rules_mask = {} |
| |
|
| | |
| | |
| | |
| | for parent, children in sorted(preproc.sum_type_constructors.items()): |
| | assert parent not in rules_mask |
| | rules_mask[parent] = (offset, offset + len(children)) |
| | offset += len(children) |
| | all_rules += [(parent, child) for child in children] |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | for name, field_presence_infos in sorted(preproc.field_presence_infos.items()): |
| | assert name not in rules_mask |
| | rules_mask[name] = (offset, offset + len(field_presence_infos)) |
| | offset += len(field_presence_infos) |
| | all_rules += [(name, presence) for presence in field_presence_infos] |
| |
|
| | |
| | |
| | |
| | |
| | for seq_type_name, lengths in sorted(preproc.seq_lengths.items()): |
| | assert seq_type_name not in rules_mask |
| | rules_mask[seq_type_name] = (offset, offset + len(lengths)) |
| | offset += len(lengths) |
| | all_rules += [(seq_type_name, i) for i in lengths] |
| |
|
| | return all_rules, rules_mask |
| | |
| | def compute_loss(self, enc_input, example, desc_enc, debug): |
| | if not self.enumerate_order or not self.training: |
| | mle_loss = self.compute_mle_loss(enc_input, example, desc_enc, debug) |
| | else: |
| | mle_loss = self.compute_loss_from_all_ordering(enc_input, example, desc_enc, debug) |
| |
|
| | if self.use_align_loss: |
| | align_loss = self.compute_align_loss(desc_enc, example) |
| | return mle_loss + align_loss |
| | return mle_loss |
| |
|
| | def compute_loss_from_all_ordering(self, enc_input, example, desc_enc, debug): |
| | def get_permutations(node): |
| | def traverse_tree(node): |
| | nonlocal permutations |
| | if isinstance(node, (list, tuple)): |
| | p = itertools.permutations(range(len(node))) |
| | permutations.append(list(p)) |
| | for child in node: |
| | traverse_tree(child) |
| | elif isinstance(node, dict): |
| | for node_name in node: |
| | traverse_tree(node[node_name]) |
| |
|
| | permutations = [] |
| | traverse_tree(node) |
| | return permutations |
| | |
| | def get_perturbed_tree(node, permutation): |
| | def traverse_tree(node, parent_type, parent_node): |
| | if isinstance(node, (list, tuple)): |
| | nonlocal permutation |
| | p_node = [node[i] for i in permutation[0]] |
| | parent_node[parent_type] = p_node |
| | permutation = permutation[1:] |
| | for child in node: |
| | traverse_tree(child, None, None) |
| | elif isinstance(node, dict): |
| | for node_name in node: |
| | traverse_tree(node[node_name], node_name, node) |
| |
|
| | node = copy.deepcopy(node) |
| | traverse_tree(node, None, None) |
| | return node |
| | |
| | orig_tree = example.tree |
| | permutations = get_permutations(orig_tree) |
| | products = itertools.product(*permutations) |
| | loss_list = [] |
| | for product in products: |
| | tree = get_perturbed_tree(orig_tree, product) |
| | example.tree = tree |
| | loss = self.compute_mle_loss(enc_input, example, desc_enc) |
| | loss_list.append(loss) |
| | example.tree = orig_tree |
| | loss_v = torch.stack(loss_list, 0) |
| | return torch.logsumexp(loss_v, 0) |
| |
|
| | def compute_mle_loss(self, enc_input, example, desc_enc, debug=False): |
| | traversal = TrainTreeTraversal(self, desc_enc, debug) |
| | traversal.step(None) |
| | queue = [ |
| | TreeState( |
| | node=example.tree, |
| | parent_field_type=self.preproc.grammar.root_type, |
| | ) |
| | ] |
| | while queue: |
| | item = queue.pop() |
| | node = item.node |
| | parent_field_type = item.parent_field_type |
| |
|
| | if isinstance(node, (list, tuple)): |
| | node_type = parent_field_type + '*' |
| | rule = (node_type, len(node)) |
| | rule_idx = self.rules_index[rule] |
| | assert traversal.cur_item.state == TreeTraversal.State.LIST_LENGTH_APPLY |
| | traversal.step(rule_idx) |
| |
|
| | if self.preproc.use_seq_elem_rules and parent_field_type in self.ast_wrapper.sum_types: |
| | parent_field_type += '_seq_elem' |
| |
|
| | for i, elem in reversed(list(enumerate(node))): |
| | queue.append( |
| | TreeState( |
| | node=elem, |
| | parent_field_type=parent_field_type, |
| | )) |
| | continue |
| |
|
| | if parent_field_type in self.preproc.grammar.pointers: |
| | assert isinstance(node, int) |
| | assert traversal.cur_item.state == TreeTraversal.State.POINTER_APPLY |
| | pointer_map = desc_enc.pointer_maps.get(parent_field_type) |
| | if pointer_map: |
| | values = pointer_map[node] |
| | if self.sup_att == '1h': |
| | if len(pointer_map) == len(enc_input['columns']): |
| | if self.attn_type != 'sep': |
| | traversal.step(values[0], values[1:], node + len(enc_input['question'])) |
| | else: |
| | traversal.step(values[0], values[1:], node) |
| | else: |
| | if self.attn_type != 'sep': |
| | traversal.step(values[0], values[1:], node + len(enc_input['question']) + len(enc_input['columns'])) |
| | else: |
| | traversal.step(values[0], values[1:], node + len(enc_input['columns'])) |
| | else: |
| | traversal.step(values[0], values[1:]) |
| | else: |
| | traversal.step(node) |
| | continue |
| |
|
| | if parent_field_type in self.ast_wrapper.primitive_types: |
| | |
| | |
| | |
| | |
| | field_type = type(node).__name__ |
| | field_value_split = self.preproc.grammar.tokenize_field_value(node) + [ |
| | vocab.EOS] |
| |
|
| | for token in field_value_split: |
| | assert traversal.cur_item.state == TreeTraversal.State.GEN_TOKEN |
| | traversal.step(token) |
| | continue |
| | |
| | type_info = self.ast_wrapper.singular_types[node['_type']] |
| |
|
| | if parent_field_type in self.preproc.sum_type_constructors: |
| | |
| | rule = (parent_field_type, type_info.name) |
| | rule_idx = self.rules_index[rule] |
| | assert traversal.cur_item.state == TreeTraversal.State.SUM_TYPE_APPLY |
| | extra_rules = [ |
| | self.rules_index[parent_field_type, extra_type] |
| | for extra_type in node.get('_extra_types', [])] |
| | traversal.step(rule_idx, extra_rules) |
| |
|
| | if type_info.fields: |
| | |
| | |
| | present = get_field_presence_info(self.ast_wrapper, node, type_info.fields) |
| | rule = (node['_type'], tuple(present)) |
| | rule_idx = self.rules_index[rule] |
| | assert traversal.cur_item.state == TreeTraversal.State.CHILDREN_APPLY |
| | traversal.step(rule_idx) |
| |
|
| | |
| | for field_info in reversed(type_info.fields): |
| | if field_info.name not in node: |
| | continue |
| |
|
| | queue.append( |
| | TreeState( |
| | node=node[field_info.name], |
| | parent_field_type=field_info.type, |
| | )) |
| |
|
| | loss = torch.sum(torch.stack(tuple(traversal.loss), dim=0), dim=0) |
| | if debug: |
| | return loss, [attr.asdict(entry) for entry in traversal.history] |
| | else: |
| | return loss |
| | |
| |
|
| | def begin_inference(self, desc_enc, example): |
| | traversal = InferenceTreeTraversal(self, desc_enc, example) |
| | choices = traversal.step(None) |
| | return traversal, choices |
| |
|
| | def _desc_attention(self, prev_state, desc_enc): |
| | |
| | |
| | |
| | query = prev_state[0] |
| | if self.attn_type != 'sep': |
| | return self.desc_attn(query, desc_enc.memory, attn_mask=None) |
| | else: |
| | question_context, question_attention_logits = self.question_attn(query, desc_enc.question_memory) |
| | schema_context, schema_attention_logits = self.schema_attn(query, desc_enc.schema_memory) |
| | return question_context + schema_context, schema_attention_logits |
| | |
| | def _tensor(self, data, dtype=None): |
| | return torch.tensor(data, dtype=dtype, device=self._device) |
| | |
| | def _index(self, vocab, word): |
| | return self._tensor([vocab.index(word)]) |
| | |
| | def _update_state( |
| | self, |
| | node_type, |
| | prev_state, |
| | prev_action_emb, |
| | parent_h, |
| | parent_action_emb, |
| | desc_enc): |
| | |
| | desc_context, attention_logits = self._desc_attention(prev_state, desc_enc) |
| | if self.visualize_flag: |
| | attention_weights = F.softmax(attention_logits, dim = -1) |
| | print(attention_weights) |
| | |
| | node_type_emb = self.node_type_embedding( |
| | self._index(self.node_type_vocab, node_type)) |
| |
|
| | state_input = torch.cat( |
| | ( |
| | prev_action_emb, |
| | desc_context, |
| | parent_h, |
| | parent_action_emb, |
| | node_type_emb, |
| | ), |
| | dim=-1) |
| | new_state = self.state_update( |
| | |
| | state_input, prev_state) |
| | return new_state, attention_logits |
| |
|
| | def apply_rule( |
| | self, |
| | node_type, |
| | prev_state, |
| | prev_action_emb, |
| | parent_h, |
| | parent_action_emb, |
| | desc_enc): |
| | new_state, attention_logits = self._update_state( |
| | node_type, prev_state, prev_action_emb, parent_h, parent_action_emb, desc_enc) |
| | |
| | output = new_state[0] |
| | |
| | rule_logits = self.rule_logits(output) |
| |
|
| | return output, new_state, rule_logits |
| |
|
| | def rule_infer(self, node_type, rule_logits): |
| | rule_logprobs = torch.nn.functional.log_softmax(rule_logits, dim=-1) |
| | rules_start, rules_end = self.preproc.rules_mask[node_type] |
| |
|
| | |
| | return list(zip( |
| | range(rules_start, rules_end), |
| | rule_logprobs[0, rules_start:rules_end])) |
| |
|
| | def gen_token( |
| | self, |
| | node_type, |
| | prev_state, |
| | prev_action_emb, |
| | parent_h, |
| | parent_action_emb, |
| | desc_enc): |
| | new_state, attention_logits = self._update_state( |
| | node_type, prev_state, prev_action_emb, parent_h, parent_action_emb, desc_enc) |
| | |
| | output = new_state[0] |
| |
|
| | |
| | gen_logodds = self.gen_logodds(output).squeeze(1) |
| |
|
| | return new_state, output, gen_logodds |
| | |
| | def gen_token_loss( |
| | self, |
| | output, |
| | gen_logodds, |
| | token, |
| | desc_enc): |
| | |
| | token_idx = self._index(self.terminal_vocab, token) |
| | |
| | action_emb = self.terminal_embedding(token_idx) |
| |
|
| | |
| | |
| | |
| | |
| | |
| | desc_locs = desc_enc.find_word_occurrences(token) |
| | if desc_locs: |
| | |
| | |
| | copy_loc_logits = self.copy_pointer(output, desc_enc.memory) |
| | copy_logprob = ( |
| | |
| | |
| | torch.nn.functional.logsigmoid(-gen_logodds) - |
| | |
| | |
| | |
| | self.xent_loss(copy_loc_logits, self._tensor(desc_locs[0:1]))) |
| | else: |
| | copy_logprob = None |
| |
|
| | |
| | if token in self.terminal_vocab or copy_logprob is None: |
| | token_logits = self.terminal_logits(output) |
| | |
| | gen_logprob = ( |
| | |
| | |
| | torch.nn.functional.logsigmoid(gen_logodds) - |
| | |
| | |
| | self.xent_loss(token_logits, token_idx)) |
| | else: |
| | gen_logprob = None |
| |
|
| | |
| | loss_piece = -torch.logsumexp( |
| | maybe_stack([copy_logprob, gen_logprob], dim=1), |
| | dim=1) |
| | return loss_piece |
| | |
| | def token_infer(self, output, gen_logodds, desc_enc): |
| | |
| | |
| | |
| | copy_logprob = torch.nn.functional.logsigmoid(-gen_logodds) |
| | copy_loc_logits = self.copy_pointer(output, desc_enc.memory) |
| | |
| | |
| | copy_loc_logprobs = torch.nn.functional.log_softmax(copy_loc_logits, dim=-1) |
| | |
| | copy_loc_logprobs += copy_logprob |
| |
|
| | log_prob_by_word = {} |
| | |
| | |
| | accumulate_logprobs( |
| | log_prob_by_word, |
| | zip(desc_enc.words, copy_loc_logprobs.squeeze(0))) |
| | |
| | |
| | |
| | |
| | gen_logprob = torch.nn.functional.logsigmoid(gen_logodds) |
| | token_logits = self.terminal_logits(output) |
| | |
| | |
| | token_logprobs = torch.nn.functional.log_softmax(token_logits, dim=-1) |
| | |
| | |
| | token_logprobs += gen_logprob |
| |
|
| | accumulate_logprobs( |
| | log_prob_by_word, |
| | ((self.terminal_vocab[idx], token_logprobs[0, idx]) for idx in range(token_logprobs.shape[1]))) |
| | |
| | return list(log_prob_by_word.items()) |
| | |
| | def compute_pointer( |
| | self, |
| | node_type, |
| | prev_state, |
| | prev_action_emb, |
| | parent_h, |
| | parent_action_emb, |
| | desc_enc): |
| | new_state, attention_logits = self._update_state( |
| | node_type, prev_state, prev_action_emb, parent_h, parent_action_emb, desc_enc) |
| | |
| | output = new_state[0] |
| | |
| | pointer_logits = self.pointers[node_type]( |
| | output, desc_enc.pointer_memories[node_type]) |
| |
|
| | return output, new_state, pointer_logits, attention_logits |
| |
|
| |
|
| | def pointer_infer(self, node_type, logits): |
| | logprobs = torch.nn.functional.log_softmax(logits, dim=-1) |
| | return list(zip( |
| | |
| | range(logits.shape[1]), |
| | logprobs[0])) |