k-gpt-v1 / handler.py
Univars's picture
Upload handler.py with huggingface_hub
0948eaf verified
"""
Custom handler for HuggingFace Inference Endpoints.
Loads Qwen 2.5 7B base model (4-bit quantized) + K-GPT v1 LoRA adapter.
Uses 4-bit NF4 quantization so the model fits on a T4 GPU (16GB VRAM).
"""
from typing import Dict, List, Any
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
from peft import PeftModel
class EndpointHandler:
def __init__(self, path: str = ""):
"""
Load the base model (4-bit quantized) and apply the LoRA adapter.
'path' is the local directory where the HF repo was cloned.
"""
base_model_id = "Qwen/Qwen2.5-7B"
# Load tokenizer from the adapter repo
self.tokenizer = AutoTokenizer.from_pretrained(path)
if self.tokenizer.pad_token is None:
self.tokenizer.pad_token = self.tokenizer.eos_token
# 4-bit NF4 quantization config (matches training QLoRA config)
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.float16,
bnb_4bit_use_double_quant=True,
)
# Load base model in 4-bit (~4GB VRAM instead of 14GB)
base_model = AutoModelForCausalLM.from_pretrained(
base_model_id,
quantization_config=bnb_config,
device_map="auto",
trust_remote_code=True,
)
# Load LoRA adapter on top of quantized base
self.model = PeftModel.from_pretrained(base_model, path)
self.model.eval()
def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
"""
Process inference requests.
Supports both simple text input and chat-style messages.
"""
inputs = data.get("inputs", "")
parameters = data.get("parameters", {})
# Handle chat-style messages
if isinstance(inputs, list):
prompt = ""
for msg in inputs:
role = msg.get("role", "user").upper()
content = msg.get("content", "")
prompt += f"{role}: {content}\n"
prompt += "ASSISTANT: "
inputs = prompt
# Default parameters
max_new_tokens = parameters.get("max_new_tokens", 256)
temperature = parameters.get("temperature", 0.7)
top_p = parameters.get("top_p", 0.9)
do_sample = parameters.get("do_sample", temperature > 0)
# Tokenize
input_ids = self.tokenizer(inputs, return_tensors="pt").to(self.model.device)
# Generate
with torch.no_grad():
outputs = self.model.generate(
**input_ids,
max_new_tokens=max_new_tokens,
temperature=temperature if do_sample else None,
top_p=top_p if do_sample else None,
do_sample=do_sample,
pad_token_id=self.tokenizer.eos_token_id,
)
# Decode only the new tokens
generated = outputs[0][input_ids["input_ids"].shape[1]:]
text = self.tokenizer.decode(generated, skip_special_tokens=True)
return [{"generated_text": text}]