jizzu/llama2_indian_law_v3
Viewer • Updated • 24.6k • 9 • 4
How to use ajay-drew/Mistral-7B-Indian-Law with PEFT:
from peft import PeftModel
from transformers import AutoModelForCausalLM
base_model = AutoModelForCausalLM.from_pretrained("mistralai/Mistral-7B-v0.1")
model = PeftModel.from_pretrained(base_model, "ajay-drew/Mistral-7B-Indian-Law")A fine-tuned version of the Mistral 7B model, optimized for understanding and generating responses related to Indian law using Parameter-Efficient Fine-Tuning (PEFT) with QLoRA and LoRA techniques.
Use the code below to get started with the model.
Use pip install transformers peft torch, use torch with cuda
from transformers import AutoModelForCausalLM, AutoTokenizer
from peft import PeftModel
model_name = "ajay-drew/midtral-7b-indian-law"
tokenizer = AutoTokenizer.from_pretrained(model_name)
base_model = AutoModelForCausalLM.from_pretrained("mistralai/Mixtral-7B-v0.1")
# Load fine-tuned weights with PEFT
model = PeftModel.from_pretrained(base_model, model_name)
text = "What is the penalty for using forged document? " # Ask custom questions on Indian Law
inputs = tokenizer(text, return_tensors="pt")
outputs = model.generate(**inputs, max_length=200)
print(tokenizer.decode(outputs[0], skip_special_tokens=True))
To check the perplexity of the model use the below code after you run pip install transformers datasets torch use torch with cuda support for reduced metrics check.
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
from datasets import load_dataset
import torch
dataset = load_dataset("kshitij230/Indian-Law", split="train")
quantization_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_compute_dtype=torch.float16
)
model_name = "ajay-drew/Mistral-7B-Indian-Law"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(
model_name,
quantization_config=quantization_config,
device_map="auto"
)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.eval()
total_loss = 0
total_tokens = 0
test_texts = dataset['Instruction'][:500]
with torch.no_grad():
for text in test_texts:
inputs = tokenizer(text, return_tensors="pt", truncation=True, max_length=512).to(device)
outputs = model(**inputs, labels=inputs["input_ids"])
loss = outputs.loss
if loss is not None: # Ensure loss is valid
total_loss += loss.item() * inputs["input_ids"].size(1)
total_tokens += inputs["input_ids"].size(1)
if total_tokens > 0:
perplexity = torch.exp(torch.tensor(total_loss / total_tokens)).item()
print(f"Perplexity: {perplexity}")
print(f"Total tokens: {total_tokens}")
print(f"Total loss: {total_loss}")
else:
print("Error: No tokens processed. Check dataset or tokenization.")
Mistral 7B (a transformer-based language model with 7 billion parameters)Base model
mistralai/Mistral-7B-v0.1