| |
| """ |
| Script for preparing the SFT data for fine-tuning AMD-OLMo model. |
| Modifed from https://github.com/allenai/OLMo/blob/main/scripts/prepare_tulu_data.py |
| """ |
|
|
| import logging |
| from argparse import ArgumentParser |
| from functools import partial |
| from pathlib import Path |
|
|
| import datasets as ds |
| import numpy as np |
| from rich.progress import track |
|
|
| from olmo.tokenizer import Tokenizer |
| from olmo.util import prepare_cli_environment |
| import random |
| from tqdm import tqdm |
|
|
| log = logging.getLogger(__name__) |
|
|
|
|
| def convert_code_feedback_to_tulu_format(dataset, mix=False): |
| log.info("Converting code_feedback ...") |
| y_all = [] |
| for i, sample in enumerate(dataset): |
| y = { |
| "dataset": "code_feedback", |
| "id": "code_feedback_{}".format(i), |
| "messages": sample['messages'] |
| } |
| y_all.append(y) |
|
|
| log.info(f"In total {len(y_all)} samples") |
| if mix: |
| return y_all |
| else: |
| new_dataset = ds.Dataset.from_list(y_all) |
| return new_dataset |
|
|
|
|
| def convert_OpenHermes_to_tulu_format(dataset, mix=False): |
| log.info("Converting OpenHermes ...") |
| role_map = {"human": "user", "gpt": "assistant", "system": "system"} |
| y_all = [] |
| for i, sample in enumerate(dataset): |
| y = { |
| "dataset": "OpenHermes", |
| "id": "OpenHermes_{}".format(i), |
| "messages": [{"role": role_map[mssg["from"]], "content": mssg["value"]} for mssg in sample['conversations']] |
| } |
| y_all.append(y) |
| |
| log.info(f"In total {len(y_all)} samples") |
| if mix: |
| return y_all |
| else: |
| new_dataset = ds.Dataset.from_list(y_all) |
| return new_dataset |
|
|
|
|
| def convert_WebInstructSub_to_tulu_format(dataset, mix=False): |
| log.info("Converting WebInstructSub ...") |
| y_all = [] |
| for i, sample in tqdm(enumerate(dataset)): |
| y = { |
| "dataset": "WebInstructSub", |
| "id": "WebInstructSub_{}".format(i), |
| "messages": [{"role": "user", "content": sample["question"]}, {"role": "assistant", "content": sample["answer"]}] |
| } |
| y_all.append(y) |
| |
| log.info(f"In total {len(y_all)} samples") |
| if mix: |
| return y_all |
| else: |
| new_dataset = ds.Dataset.from_list(y_all) |
| return new_dataset |
| |
| |
| def main(opts) -> None: |
| tokenizer: Tokenizer |
| if Path(opts.tokenizer).is_file(): |
| tokenizer = Tokenizer.from_file(opts.tokenizer, eos_token_id=opts.eos, pad_token_id=opts.pad) |
| else: |
| tokenizer = Tokenizer.from_pretrained(opts.tokenizer, eos_token_id=opts.eos, pad_token_id=opts.pad) |
|
|
| if opts.dataset == "tulu": |
| dataset = ds.load_dataset("allenai/tulu-v2-sft-mixture", split="train") |
| elif opts.dataset == "2nd-phase": |
| datasets = ["code-feedback", "OpenHermes", "WebInstructSub"] |
| combined_datasets = [] |
| for dataset_name in datasets: |
| if dataset_name == "code-feedback": |
| dataset = ds.load_dataset("m-a-p/Code-Feedback", split="train") |
| dataset = convert_code_feedback_to_tulu_format(dataset, mix=True) |
| elif dataset_name == "OpenHermes": |
| dataset = ds.load_dataset("teknium/OpenHermes-2.5", split="train") |
| dataset = convert_OpenHermes_to_tulu_format(dataset, mix=True) |
| elif dataset_name == "WebInstructSub": |
| dataset = ds.load_dataset("TIGER-Lab/WebInstructSub", split="train") |
| dataset = convert_WebInstructSub_to_tulu_format(dataset, mix=True) |
|
|
| combined_datasets += dataset |
|
|
| random.seed(42) |
| random.shuffle(combined_datasets) |
| log.info(f"In total {len(combined_datasets)} samples") |
| dataset = ds.Dataset.from_list(combined_datasets) |
|
|
| log.info("Tokenizing dataset...") |
| dataset = dataset.map( |
| partial(preprocess, tokenizer=tokenizer, max_seq_len=opts.seq_len), |
| batched=False, |
| remove_columns=["dataset", "id", "messages"], |
| num_proc=opts.num_proc, |
| ) |
|
|
| log.info("Filtering dataset...") |
| n = len(dataset) |
| dataset = dataset.filter(filter, batched=False, num_proc=opts.num_proc) |
| log.info(f"Filtered out {n - len(dataset):,d} examples") |
|
|
| log.info("Counting tokens...") |
| total_tokens = 0 |
| for ex in track(dataset): |
| assert len(ex["input_ids"]) == opts.seq_len |
| total_tokens += len(ex["input_ids"]) |
| log.info(f"Total tokens: {total_tokens:,d}") |
|
|
| log.info(f"Saving results to '{opts.output_dir}'...") |
| output_dir = Path(opts.output_dir) |
| output_dir.mkdir(exist_ok=True, parents=True) |
|
|
| input_ids_file = np.memmap( |
| str(output_dir / "input_ids.npy"), dtype=np.uint16, mode="w+", shape=(total_tokens,) |
| ) |
| label_mask_file = np.memmap( |
| str(output_dir / "label_mask.npy"), dtype=np.bool_, mode="w+", shape=(total_tokens,) |
| ) |
| offset = 0 |
| for ex in track(dataset): |
| ex_len = len(ex["input_ids"]) |
| input_ids_file[offset : offset + ex_len] = ex["input_ids"] |
| label_mask_file[offset : offset + ex_len] = ex["label_mask"] |
| offset += ex_len |
| input_ids_file.flush() |
| label_mask_file.flush() |
|
|
| log.info("Done!") |
|
|
|
|
| def filter(example): |
| return example["n_labels"] > 0 |
|
|
|
|
| def preprocess(example, tokenizer: Tokenizer, max_seq_len: int): |
| input_ids = [tokenizer.eos_token_id] |
| label_mask = [False] |
|
|
| for msg in example["messages"]: |
| role_tokens = tokenizer.encode(f"<|{msg['role']}|>\n", add_special_tokens=False) |
| label_mask += [False] * len(role_tokens) |
| input_ids += role_tokens |
|
|
| if msg["role"] == "assistant": |
| content_tokens = tokenizer.encode( |
| msg["content"].strip() + tokenizer.eos_token + "\n", add_special_tokens=False |
| ) |
| label_mask += [True] * len(content_tokens) |
| |
| assert content_tokens[-2] == tokenizer.eos_token_id |
| label_mask[-1] = False |
| else: |
| content_tokens = tokenizer.encode(msg["content"].strip() + "\n", add_special_tokens=False) |
| label_mask += [False] * len(content_tokens) |
| input_ids += content_tokens |
|
|
| input_ids = input_ids[:max_seq_len] |
| label_mask = label_mask[:max_seq_len] |
|
|
| if len(input_ids) < max_seq_len: |
| pad_len = max_seq_len - len(input_ids) |
| input_ids += [tokenizer.pad_token_id] * pad_len |
| label_mask += [False] * pad_len |
|
|
| assert len(input_ids) == len(label_mask) |
| n_labels = sum(label_mask) |
|
|
| return {"input_ids": input_ids, "label_mask": label_mask, "n_labels": n_labels} |
|
|
|
|
| def get_parser() -> ArgumentParser: |
| parser = ArgumentParser(description="Prepare Math dataset") |
| parser.add_argument("--output_dir", type=str, help="""Directory to save the results to.""") |
| parser.add_argument( |
| "-t", |
| "--tokenizer", |
| type=str, |
| help="""Tokenizer path or identifier.""", |
| default=Path(__file__).parent / "tokenizers" / "allenai_eleuther-ai-gpt-neox-20b-pii-special.json", |
| ) |
| parser.add_argument("-ds", "--dataset", type=str, help="""Dataset that we are processing. tulu or 2nd-phase""", default="tulu") |
| parser.add_argument("-s", "--seq-len", type=int, help="""Max sequence length.""", default=2048) |
| parser.add_argument("--eos", type=int, help="""EOS token ID.""", default=50279) |
| parser.add_argument("--pad", type=int, help="""PAD token ID.""", default=1) |
| parser.add_argument("-j", "--num-proc", type=int, help="""Number of workers.""", default=8) |
| return parser |
|
|
|
|
| if __name__ == "__main__": |
| prepare_cli_environment() |
| opts = get_parser().parse_args() |
| main(opts) |