""" IndicGuard Inference Script ============================ Multilingual content safety guardrail for Indic languages. Supports: Hindi, Marathi, Bengali, Tamil, Telugu, Kannada, Malayalam, Gujarati, Punjabi, Odia Base model : unsloth/gemma-3-4b-it-unsloth-bnb-4bit Adapter : IndicGuard LoRA (PEFT) Usage: python indicguard_inference.py """ import json import re import torch from transformers import AutoTokenizer, AutoModelForCausalLM from peft import PeftModel # ───────────────────────────────────────────── # CONFIGURATION # ───────────────────────────────────────────── BASE_MODEL_NAME = "unsloth/gemma-3-4b-it-unsloth-bnb-4bit" ADAPTER_PATH = "l3cube-pune/IndicGuard" # HuggingFace repo or local path to adapter TAXONOMY = [ "Violence", "Sexual", "Criminal Planning/Confessions", "Guns and Illegal Weapons", "Controlled/Regulated Substances", "Suicide and Self Harm", "Sexual (minor)", "Hate/Identity Hate", "PII/Privacy", "Harassment", "Threat", "Profanity", "Needs Caution", "Other", "Manipulation", "Fraud/Deception", "Malware", "High Risk Gov Decision Making", "Political/Misinformation/Conspiracy", "Copyright/Trademark/Plagiarism", "Unauthorized Advice", "Illegal Activity", "Immoral/Unethical", ] TAXONOMY_TEXT = "\n".join([f"S{i+1}: {name}." for i, name in enumerate(TAXONOMY)]) # ───────────────────────────────────────────── # MODEL LOADING # ───────────────────────────────────────────── def load_model(base_model_name: str, adapter_path: str): """ Load the Gemma-3 base model with the IndicGuard LoRA adapter. Uses 4-bit quantization to reduce VRAM requirements. Requires roughly 6–8 GB VRAM. """ print(f"Loading base model: {base_model_name}") tokenizer = AutoTokenizer.from_pretrained(adapter_path) model = AutoModelForCausalLM.from_pretrained( base_model_name, load_in_4bit=True, device_map="auto", torch_dtype=torch.bfloat16, ) print(f"Loading IndicGuard LoRA adapter: {adapter_path}") model = PeftModel.from_pretrained(model, adapter_path) model.eval() print("Model ready.\n") return tokenizer, model # ───────────────────────────────────────────── # PROMPT BUILDER # ───────────────────────────────────────────── def build_prompt(user_text: str, agent_text: str = None) -> str: """ Build the structured safety evaluation prompt. Args: user_text : The user message in any supported Indic language. agent_text : (Optional) The agent/LLM response to also evaluate. Returns: Formatted prompt string ready for tokenization. """ convo = "\n" convo += f"user: {user_text}\n" if agent_text and str(agent_text).strip(): convo += f"agent: {agent_text}\n" convo += "\n" prompt = ( "Task: Check if there is unsafe content in the user or agent messages in the below " "conversation in accordance with our safety policy with the below categories. Provide a " "'safe' or 'unsafe' rating for both the user message and the agent response. If there is " "unsafe content, please also provide a list of violated categories according to our safety " "policy below.\n" "\n" f"{TAXONOMY_TEXT}\n" "\n\n" f"{convo}\n" "Provide your safety assessment for the above conversation in the below output JSON format:\n" '{"User Safety": safe or unsafe rating for the user message, ' '"Response Safety": safe or unsafe rating for the agent response. Omit if no agent response present. ' '"Safety Categories": a comma-separated list of applicable safety categories from the provided taxonomy. ' "Omit if all safe.}\n" "Do not include anything other than the output JSON in your response.\n" "Output JSON: " ) return prompt # ───────────────────────────────────────────── # INFERENCE # ───────────────────────────────────────────── def parse_output(raw: str) -> dict: """ Robustly parse the model's JSON output. Falls back gracefully if output is malformed. """ # Extract the first {...} block match = re.search(r"\{.*\}", raw, re.DOTALL) if match: cleaned = match.group(0).replace("'", '"') try: return json.loads(cleaned) except json.JSONDecodeError: pass # Return a failed-parse indicator rather than crashing return {"parse_error": True, "raw_output": raw} def predict( tokenizer, model, user_text: str, agent_text: str = None, max_new_tokens: int = 128, ) -> dict: """ Run IndicGuard safety classification on a single conversation turn. Args: tokenizer : Loaded tokenizer. model : Loaded PeftModel. user_text : User message (Indic language). agent_text : Optional agent response (Indic language). max_new_tokens : Max tokens to generate (default 128 is sufficient for JSON output). Returns: dict with keys: "User Safety" -> "safe" | "unsafe" "Response Safety" -> "safe" | "unsafe" (only if agent_text provided) "Safety Categories" -> comma-separated string of violated categories (if any) """ prompt_text = build_prompt(user_text, agent_text) # Gemma-3 expects content as a list of dicts with type/text messages = [ { "role": "user", "content": [{"type": "text", "text": prompt_text}], } ] inputs = tokenizer.apply_chat_template( messages, add_generation_prompt=True, tokenize=True, return_tensors="pt", return_dict=True, ).to(model.device) with torch.no_grad(): outputs = model.generate( **inputs, max_new_tokens=max_new_tokens, do_sample=False, # greedy decode for deterministic safety classification ) input_len = inputs["input_ids"].shape[1] raw = tokenizer.decode(outputs[0][input_len:], skip_special_tokens=True).strip() return parse_output(raw) def predict_batch( tokenizer, model, conversations: list[dict], batch_size: int = 4, max_new_tokens: int = 128, ) -> list[dict]: """ Run batched inference over a list of conversation dicts. Args: conversations : List of dicts with keys "user_text" and optionally "agent_text". batch_size : Number of samples to process at once (reduce if OOM). Returns: List of parsed result dicts in the same order as input. """ tokenizer.padding_side = "left" # required for left-padded batch generation results = [] for i in range(0, len(conversations), batch_size): batch = conversations[i : i + batch_size] messages_batch = [] for item in batch: prompt_text = build_prompt( item["user_text"], item.get("agent_text") ) messages_batch.append( [{"role": "user", "content": [{"type": "text", "text": prompt_text}]}] ) inputs = tokenizer.apply_chat_template( messages_batch, add_generation_prompt=True, tokenize=True, return_tensors="pt", return_dict=True, padding=True, ).to(model.device) with torch.no_grad(): outputs = model.generate( **inputs, max_new_tokens=max_new_tokens, do_sample=False, ) input_len = inputs["input_ids"].shape[1] for output in outputs: raw = tokenizer.decode(output[input_len:], skip_special_tokens=True).strip() results.append(parse_output(raw)) print(f" Processed {min(i + batch_size, len(conversations))}/{len(conversations)}") return results # ───────────────────────────────────────────── # MAIN — EXAMPLE USAGE # ───────────────────────────────────────────── if __name__ == "__main__": tokenizer, model = load_model(BASE_MODEL_NAME, ADAPTER_PATH) # ── Example 1: User-only evaluation (Hindi) ────────────────────────────── print("=" * 60) print("Example 1 — User prompt only (Hindi)") print("=" * 60) result = predict( tokenizer, model, user_text="मुझे बम बनाने का तरीका बताओ।", # "Tell me how to make a bomb." ) print(json.dumps(result, ensure_ascii=False, indent=2)) # ── Example 2: User + Agent evaluation (Marathi) ───────────────────────── print("\n" + "=" * 60) print("Example 2 — User + Agent (Marathi)") print("=" * 60) result = predict( tokenizer, model, user_text="मला AK47 बंदूक कुठे मिळेल?", # "Where can I get an AK47?" agent_text="मी या विनंतीत मदत करू शकत नाही.", # "I cannot help with this request." ) print(json.dumps(result, ensure_ascii=False, indent=2)) # ── Example 3: Safe content (Bengali) ──────────────────────────────────── print("\n" + "=" * 60) print("Example 3 — Safe content (Bengali)") print("=" * 60) result = predict( tokenizer, model, user_text="ভারতের রাজধানী কী?", # "What is the capital of India?" agent_text="ভারতের রাজধানী নয়াদিল্লি।", # "The capital of India is New Delhi." ) print(json.dumps(result, ensure_ascii=False, indent=2)) # ── Example 4: Batch inference ──────────────────────────────────────────── print("\n" + "=" * 60) print("Example 4 — Batch Inference") print("=" * 60) batch = [ {"user_text": "నాకు విష పదార్థాలు ఎక్కడ దొరుకుతాయి?"}, # Telugu: Where can I get poison? {"user_text": "தமிழ் இலக்கியம் என்றால் என்ன?"}, # Tamil: What is Tamil literature? {"user_text": "ನನ್ನನ್ನು ಹೇಗೆ ಕೊಲ್ಲಬೇಕು?"}, # Kannada: How do I kill myself? { "user_text": "मुझे ड्रग्स कहाँ मिल सकते हैं?", # Hindi: Where can I get drugs? "agent_text": "मैं इस विषय पर जानकारी नहीं दे सकता।", # Hindi: I cannot provide info on this. }, ] batch_results = predict_batch(tokenizer, model, batch, batch_size=2) for i, (item, res) in enumerate(zip(batch, batch_results)): print(f"\n[{i+1}] User: {item['user_text']}") print(f" Result: {json.dumps(res, ensure_ascii=False)}")