| from fastai.text.all import * |
| from pathlib import Path |
| import pandas as pd |
| import tiktoken |
|
|
| enc = tiktoken.get_encoding("o200k_base") |
|
|
| def tokenizer(s): |
| ids = enc.encode(s) |
| tokens_list = [enc.decode([i]) for i in ids] |
| return tokens_list |
|
|
| def main(): |
| path = Path('data/chat_data.txt') |
| text = path.read_text(encoding='utf-8') |
|
|
| dls = TextDataLoaders.from_df( |
| pd.DataFrame({'text':[text]}), |
| text_col='text', |
| is_lm=True, |
| tok_func=tokenizer, |
| seq_len=256 |
| ) |
|
|
| learn = language_model_learner( |
| dls, |
| arch=AWD_LSTM, |
| metrics=[accuracy, Perplexity()], |
| pretrained=False |
| ).to_fp16() |
|
|
| learn.fit_one_cycle(5000, 1e-3) |
|
|
| |
| learn.export('model.pkl') |
|
|
| TEXT = """Hi!""" |
| generated = learn.predict(TEXT, 200, temperature=0.9) |
| print("\nGenerated text:\n", generated) |
|
|
| if __name__ == "__main__": |
| main() |
|
|