| import base64 |
| import re |
| import os |
| import pathlib |
| import random |
| import time |
| from io import BytesIO |
|
|
| from diffusers import DiffusionPipeline, EulerAncestralDiscreteScheduler |
| import gradio as gr |
| import imgkit |
| from PIL import Image |
| import torch |
| from transformers import GPT2LMHeadModel, GPT2TokenizerFast, pipeline |
|
|
|
|
| gpu = False |
|
|
| AUTH_TOKEN = os.environ.get('AUTH_TOKEN') |
| BASE_MODEL = "gpt2" |
| MERGED_MODEL = "gpt2-magic-card" |
|
|
| if gpu: |
| image_pipeline = DiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16, |
| revision="fp16", use_auth_token=AUTH_TOKEN) |
| scheduler = EulerAncestralDiscreteScheduler.from_config(image_pipeline.scheduler.config) |
| image_pipeline.scheduler = scheduler |
| image_pipeline.to("cuda") |
| else: |
| image_pipeline = DiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", use_auth_token=AUTH_TOKEN) |
| scheduler = EulerAncestralDiscreteScheduler.from_config(image_pipeline.scheduler.config) |
| image_pipeline.scheduler = scheduler |
|
|
| |
| |
|
|
| model = GPT2LMHeadModel.from_pretrained(MERGED_MODEL) |
| tokenizer = GPT2TokenizerFast.from_pretrained(BASE_MODEL) |
| END_TOKEN = '###' |
| eos_id = tokenizer.encode(END_TOKEN) |
| text_pipeline = pipeline('text-generation', model=model, tokenizer=tokenizer) |
|
|
|
|
| def gen_card_text(name): |
| if name == '': |
| prompt = f"Name: {random.choice('ABCDEFGHIJKLMNOPQRSTUVWXYZ')}" |
| else: |
| prompt = f"Name: {name}\n" |
| print(f'GENERATING CARD TEXT with prompt: {prompt}') |
| output = text_pipeline(prompt, max_length=512, num_return_sequences=1, num_beams=5, temperature=1.5, do_sample=True, |
| repetition_penalty=1.2, eos_token_id=eos_id) |
| result = output[0]['generated_text'].split("###")[0].replace(r'\r\n', '\n').replace('\r', '').replace(r'\r', '') |
| print(f'GENERATING CARD COMPLETE') |
| print(result) |
| if name == '': |
| pattern = re.compile('Name: (.*)') |
| name = pattern.findall(result)[0] |
| return name, result |
|
|
|
|
| pathlib.Path('card_data').mkdir(parents=True, exist_ok=True) |
| pathlib.Path('card_images').mkdir(parents=True, exist_ok=True) |
| pathlib.Path('card_html').mkdir(parents=True, exist_ok=True) |
| pathlib.Path('rendered_cards').mkdir(parents=True, exist_ok=True) |
|
|
|
|
| def run(name): |
| start = time.time() |
| print(f'BEGINNING RUN FOR {name}') |
| name, text = gen_card_text(name) |
| save_name = get_savename('card_data', name, 'txt') |
| pathlib.Path(f'card_data/{save_name}').write_text(text, encoding='utf-8') |
|
|
| pattern = re.compile('Type: (.*)') |
| card_type = pattern.findall(text)[0] |
| prompt_template = f"fantasy illustration of a {card_type} {name}, by Greg Rutkowski" |
| print(f"GENERATING IMAGE FOR {prompt_template}") |
| |
| images = image_pipeline(prompt_template, width=512, height=368, num_inference_steps=20).images |
| card_image = None |
| for image in images: |
| save_name = get_savename('card_images', name, 'png') |
| image.save(f"card_images/{save_name}") |
| card_image = image |
|
|
| image_data = pil_to_base64(card_image) |
| html = format_html(text, image_data) |
| save_name = get_savename('card_html', name, 'html') |
| pathlib.Path(f'card_html/{save_name}').write_text(html, encoding='utf-8') |
| rendered = html_to_png(name, html) |
|
|
| end = time.time() |
| print(f'RUN COMPLETED IN {int(end - start)} seconds') |
| return rendered, text, card_image, html |
|
|
|
|
| def pil_to_base64(image): |
| print('CONVERTING PIL IMAGE TO BASE64 STRING') |
| buffered = BytesIO() |
| image.save(buffered, format="PNG") |
| img_str = base64.b64encode(buffered.getvalue()) |
| print('CONVERTING PIL IMAGE TO BASE64 STRING COMPLETE') |
| return img_str |
|
|
|
|
| def format_html(text, image_data): |
| template = pathlib.Path("colab-data-test/card_template.html").read_text(encoding='utf-8') |
| if "['U']" in text: |
| template = template.replace("{card_color}", 'style="background-color:#5a73ab"') |
| elif "['W']" in text: |
| template = template.replace("{card_color}", 'style="background-color:#f0e3d0"') |
| elif "['G']" in text: |
| template = template.replace("{card_color}", 'style="background-color:#325433"') |
| elif "['B']" in text: |
| template = template.replace("{card_color}", 'style="background-color:#1a1b1e"') |
| elif "['R']" in text: |
| template = template.replace("{card_color}", 'style="background-color:#c2401c"') |
| elif "Type: Land" in text: |
| template = template.replace("{card_color}", 'style="background-color:#aa8c71"') |
| elif "Type: Artifact" in text: |
| template = template.replace("{card_color}", 'style="background-color:#9ba7bc"') |
| else: |
| template = template.replace("{card_color}", 'style="background-color:#edd99d"') |
| pattern = re.compile('Name: (.*)') |
| name = pattern.findall(text)[0] |
| template = template.replace("{name}", name) |
| pattern = re.compile('ManaCost: (.*)') |
| mana_cost = pattern.findall(text)[0] |
| if mana_cost == "None": |
| template = template.replace("{mana_cost}", '<i class="ms ms-cost" style="visibility: hidden"></i>') |
| else: |
| symbols = [] |
| for c in mana_cost: |
| if c in {"{", "}"}: |
| continue |
| else: |
| symbols.append(c.lower()) |
| formatted_symbols = [] |
| for s in symbols: |
| formatted_symbols.append(f'<i class="ms ms-{s} ms-cost ms-shadow"></i>') |
| template = template.replace("{mana_cost}", "\n".join(formatted_symbols[::-1])) |
| if not isinstance(image_data, (bytes, bytearray)): |
| template = template.replace('{image_data}', f'{image_data}') |
| else: |
| template = template.replace('{image_data}', f'data:image/png;base64,{image_data.decode("utf-8")}') |
| pattern = re.compile('Type: (.*)') |
| card_type = pattern.findall(text)[0] |
| template = template.replace("{card_type}", card_type) |
| if len(card_type) > 30: |
| template = template.replace("{type_size}", "16") |
| else: |
| template = template.replace("{type_size}", "18") |
| pattern = re.compile('Rarity: (.*)') |
| rarity = pattern.findall(text)[0] |
| template = template.replace("{rarity}", f"ss-{rarity}") |
| pattern = re.compile('Text: (.*)\nFlavorText', re.MULTILINE | re.DOTALL) |
| card_text = pattern.findall(text)[0] |
| text_lines = [] |
| for line in card_text.splitlines(): |
| line = line.replace('{T}', '<i class="ms ms-tap ms-cost" style="top:0px;float:none;height: 18px;width: 18px;font-size: 13px;"></i>') |
| line = line.replace('{UT}', '<i class="ms ms-untap ms-cost" style="top:0px;float:none;height: 18px;width: 18px;font-size: 13px;"></i>') |
| line = line.replace('{E}', '<i class="ms ms-instant ms-cost" style="top:0px;float:none;height: 18px;width: 18px;font-size: 13px;"></i>') |
| line = re.sub(r"{(.*?)}", r'<i class="ms ms-\1 ms-cost" style="top:0px;float:none;height: 18px;width: 18px;font-size: 13px;"></i>'.lower(), line) |
| line = re.sub(r"ms-(.)/(.)", r'<i class="ms ms-\1\2 ms-cost" style="top:0px;float:none;height: 18px;width: 18px;font-size: 13px;"></i>'.lower(), line) |
| line = line.replace('(', '(<i>').replace(')', '</i>)') |
| text_lines.append(f"<p>{line}</p>") |
| template = template.replace("{card_text}", "\n".join(text_lines)) |
| pattern = re.compile('FlavorText: (.*)\nPower', re.MULTILINE | re.DOTALL) |
| flavor_text = pattern.findall(text) |
| if flavor_text: |
| flavor_text = flavor_text[0] |
| flavor_text_lines = [] |
| for line in flavor_text.splitlines(): |
| flavor_text_lines.append(f"<p>{line}</p>") |
| template = template.replace("{flavor_text}", "<blockquote>" + "\n".join(flavor_text_lines) + "</blockquote>") |
| else: |
| template = template.replace("{flavor_text}", "") |
| if len(card_text) + len(flavor_text or '') > 170 or len(text_lines) > 3: |
| template = template.replace("{text_size}", '16') |
| template = template.replace('ms-cost" style="top:0px;float:none;height: 18px;width: 18px;font-size: 13px;"></i>', |
| 'ms-cost" style="top:0px;float:none;height: 16px;width: 16px;font-size: 11px;"></i>') |
| else: |
| template = template.replace("{text_size}", '18') |
| pattern = re.compile('Power: (.*)') |
| power = pattern.findall(text) |
| if power: |
| power = power[0] |
| if not power: |
| template = template.replace("{power_toughness}", "") |
| pattern = re.compile('Toughness: (.*)') |
| toughness = pattern.findall(text)[0] |
| template = template.replace("{power_toughness}", f'<header class="powerToughness"><div><h2 style="font-family: \'Beleren\';font-size: 19px;">{power}/{toughness}</h2></div></header>') |
| else: |
| template = template.replace("{power_toughness}", "") |
| pathlib.Path("test.html").write_text(template, encoding='utf-8') |
| return template |
|
|
|
|
| def get_savename(directory, name, extension): |
| save_name = f"{name}.{extension}" |
| i = 1 |
| while os.path.exists(os.path.join(directory, save_name)): |
| save_name = save_name.replace(f'.{extension}', '').split('-')[0] + f"-{i}.{extension}" |
| i += 1 |
| return save_name |
|
|
|
|
| def html_to_png(card_name, html): |
| save_name = get_savename('rendered_cards', card_name, 'png') |
| print('CONVERTING HTML CARD TO PNG IMAGE') |
|
|
| path = os.path.join('rendered_cards', save_name) |
| try: |
| css = ['./colab-data-test/css/mana.css', './colab-data-test/css/keyrune.css', './colab-data-test/css/mtg_custom.css'] |
| imgkit.from_string(html, path, {"xvfb": ""}, css=css) |
| except: |
| try: |
| |
| from html2image import Html2Image |
| rendered_card_dir = 'rendered_cards' |
| hti = Html2Image(output_path=rendered_card_dir) |
| paths = hti.screenshot(html_str=html, |
| css_file=['./colab-data-test/css/mtg_custom.css', './colab-data-test/css/mana.css', './colab-data-test/css/keyrune.css'], |
| save_as=save_name, size=(450, 600)) |
| print(paths) |
| path = paths[0] |
| except: |
| pass |
| print('OPENING IMAGE FROM FILE') |
| img = Image.open(path) |
| print('CROPPING BACKGROUND') |
| area = (0, 50, 400, 600) |
| cropped_img = img.crop(area) |
| cropped_img.resize((400, 550)) |
| cropped_img.save(os.path.join(path)) |
| print('CONVERTING HTML CARD TO PNG IMAGE COMPLETE') |
| return cropped_img.convert('RGB') |
|
|
|
|
| app_description = ( |
| """ |
| # Create your own Magic: The Gathering cards! |
| Enter a name, click Submit, it may take up to 10 minutes to run on the free CPU only instance. |
| """).strip() |
| input_box = gr.Textbox(label="Enter a card name", placeholder="Firebolt") |
| rendered_card = gr.Image(label="Card", type='pil', value="examples/card.png") |
| output_text_box = gr.Textbox(label="Card Text", value=pathlib.Path("examples/text.txt").read_text('utf-8')) |
| output_card_image = gr.Image(label="Card Image", type='pil', value="examples/image.png") |
| output_card_html = gr.HTML(label="Card HTML", visible=False, show_label=False) |
| x = gr.components.Textbox() |
| iface = gr.Interface(title="MagicGen", theme="default", description=app_description, fn=run, inputs=[input_box], |
| outputs=[rendered_card, output_text_box, output_card_image, output_card_html]) |
|
|
| iface.launch() |
|
|