MTP7 / app.py
teszenofficial's picture
Create app.py
df4b494 verified
import os
import sys
import pickle
import torch
import gradio as gr
from huggingface_hub import snapshot_download
# ======================
# CONFIGURACIÓN REPO HF
# ======================
REPO_ID = "teszenofficial/MTP7"
MODEL_FILE = "mtp_mini.pkl" # Asegúrate de que se llame así en tu repo
TOKENIZER_FILE = "mtp_tokenizer.model" # Asegúrate de que se llame así en tu repo
LOCAL_DIR = "mtptz_repo" # Nombre de la carpeta local donde se descarga
# ======================
# DESCARGA Y CARGA DEL MODELO
# ======================
def load_resources():
print(f"📦 Descargando modelo desde {REPO_ID}...")
# 1. Descargar el repositorio a una carpeta local
repo_path = snapshot_download(
repo_id=REPO_ID,
local_dir=LOCAL_DIR
)
print(f"✅ Modelo descargado en: {repo_path}")
# 2. Añadir la ruta al sys.path para poder importar model.py y tokenizer.py desde el repo
sys.path.insert(0, repo_path)
try:
# Intentamos importar las clases desde los archivos descargados en el repo
from model import MTPMiniModel
from tokenizer import MTPTokenizer
except ImportError as e:
print(f"❌ ERROR: No se pudieron importar 'model' o 'tokenizer'.")
print(f" Asegúrate de que subiste 'model.py' y 'tokenizer.py' al repo '{REPO_ID}'.")
raise e
# 3. Definir rutas completas
model_path = os.path.join(repo_path, MODEL_FILE)
tokenizer_path = os.path.join(repo_path, TOKENIZER_FILE)
# Verificar si existen
if not os.path.exists(model_path):
raise FileNotFoundError(f"No se encontró {MODEL_FILE} en el repo.")
if not os.path.exists(tokenizer_path):
raise FileNotFoundError(f"No se encontró {TOKENIZER_FILE} en el repo.")
# 4. Cargar Tokenizer
tokenizer = MTPTokenizer(tokenizer_path)
print(f"✅ Tokenizer cargado. Vocab size: {tokenizer.vocab_size()}")
# 5. Cargar Modelo
print(f"🧠 Cargando tensores...")
with open(model_path, 'rb') as f:
model_data = pickle.load(f)
config = model_data['config']
state_dict = model_data['model_state_dict']
vocab_size = model_data['vocab_size']
# Reconstruir el Modelo
use_swiglu = config['model'].get('use_swiglu', False)
model = MTPMiniModel(
vocab_size=vocab_size,
d_model=config['model']['d_model'],
n_layers=config['model']['n_layers'],
n_heads=config['model']['n_heads'],
d_ff=config['model']['d_ff'],
max_seq_len=config['model']['max_seq_len'],
dropout=0.0,
use_swiglu=use_swiglu
)
model.load_state_dict(state_dict)
model.eval()
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(DEVICE)
print(f"✅ Modelo cargado en {DEVICE}")
return model, tokenizer, DEVICE
# Cargar al inicio
model, tokenizer, DEVICE = load_resources()
# ======================
# FUNCIÓN DE GENERACIÓN
# ======================
def generate_response(message, history, temperature, max_tokens, top_p):
# Construir el prompt
# Formato: ### Instrucción:\n{input}\n\n### Respuesta:\n
prompt = f"### Instrucción:\n{message}\n\n### Respuesta:\n"
# Tokenizar
tokens = [tokenizer.bos_id()] + tokenizer.encode(prompt)
input_ids = torch.tensor([tokens], device=DEVICE)
# Generar usando el método del modelo
with torch.no_grad():
output_ids = model.generate(
input_ids,
max_new_tokens=int(max_tokens),
temperature=float(temperature),
top_k=40,
top_p=float(top_p),
repetition_penalty=1.15,
min_length=10,
eos_token_id=tokenizer.eos_id()
)
# Decodificar
gen_tokens = output_ids[0, len(tokens):].tolist()
safe_tokens = []
for t in gen_tokens:
if 0 <= t < tokenizer.vocab_size() and t != tokenizer.eos_id():
safe_tokens.append(t)
elif t == tokenizer.eos_id():
break
response = tokenizer.decode(safe_tokens).strip()
# Limpieza básica
if "### Instrucción:" in response:
response = response.split("### Instrucción:")[0].strip()
return response
# ======================
# INTERFAZ GRADIO
# ======================
with gr.Blocks(theme=gr.themes.Soft()) as demo:
gr.Markdown("# 🤖 MTP-7 Chat (Demo)")
gr.Markdown(f"Modelo cargado desde `teszenofficial/MTP7` en **{DEVICE}**.")
chat_interface = gr.ChatInterface(
fn=generate_response,
additional_inputs=[
gr.Slider(0.1, 2.0, value=0.7, label="Temperatura (Creatividad)"),
gr.Slider(50, 300, value=150, label="Máximos Tokens"),
gr.Slider(0.1, 1.0, value=0.92, label="Top-p (Nucleus)"),
],
examples=[
["¿Cuál es la capital de Francia?", 0.7, 150, 0.92],
["Explica qué es la relatividad.", 0.7, 150, 0.92]
],
cache_examples=False
)
if __name__ == "__main__":
demo.launch()