DeepConf / example_simple_generations.py
kashif's picture
kashif HF Staff
example scripts
c1cd11a
Raw
History Blame Contribute Delete
5.52 kB
"""
Simple examples showing DeepConf sample generations
"""
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig
def generate_with_deepconf(
question: str,
enable_early_stopping: bool = True,
threshold: float = 10.0,
window_size: int = 10,
max_tokens: int = 128,
):
"""Generate with DeepConf and show results"""
# Load model (cached)
model_name = "Qwen/Qwen2.5-0.5B-Instruct"
model = AutoModelForCausalLM.from_pretrained(
model_name, torch_dtype=torch.float16, device_map="auto", local_files_only=True
)
tokenizer = AutoTokenizer.from_pretrained(model_name, local_files_only=True)
# Prepare prompt
messages = [{"role": "user", "content": question}]
prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
# Configure generation
gen_config = GenerationConfig(
do_sample=True,
temperature=0.7,
top_p=0.95,
max_new_tokens=max_tokens,
enable_conf=True,
enable_early_stopping=enable_early_stopping,
threshold=threshold,
window_size=window_size,
output_confidences=True,
return_dict_in_generate=True,
pad_token_id=tokenizer.eos_token_id,
)
# Generate
outputs = model.generate(**inputs, generation_config=gen_config, custom_generate="kashif/DeepConf", trust_remote_code=True)
# Extract results
generated_text = tokenizer.decode(outputs.sequences[0], skip_special_tokens=True)
tokens_generated = outputs.sequences.shape[1] - inputs.input_ids.shape[1]
if hasattr(outputs, "confidences") and outputs.confidences is not None:
min_conf = outputs.confidences.min().item()
max_conf = outputs.confidences.max().item()
mean_conf = outputs.confidences.mean().item()
else:
min_conf = max_conf = mean_conf = None
return {
"text": generated_text,
"tokens": tokens_generated,
"min_conf": min_conf,
"max_conf": max_conf,
"mean_conf": mean_conf,
}
def print_result(title: str, question: str, result: dict):
"""Pretty print generation result"""
print(f"\n{'=' * 80}")
print(f"{title}")
print(f"{'=' * 80}")
print(f"Question: {question}")
print(f"\nGenerated ({result['tokens']} tokens):")
print(f"{'-' * 80}")
print(result["text"])
print(f"{'-' * 80}")
if result["min_conf"] is not None:
print("\nConfidence stats:")
print(f" Min: {result['min_conf']:.3f}")
print(f" Max: {result['max_conf']:.3f}")
print(f" Mean: {result['mean_conf']:.3f}")
if __name__ == "__main__":
print("\n" + "â–ˆ" * 80)
print("DEEPCONF SAMPLE GENERATIONS")
print("â–ˆ" * 80)
# Example 1: Math with aggressive early stopping
result = generate_with_deepconf(
"What is 25 * 4?", enable_early_stopping=True, threshold=8.0, window_size=5, max_tokens=64
)
print_result("Example 1: Math (Aggressive Early Stopping)", "What is 25 * 4?", result)
# Example 2: Math with permissive early stopping
result = generate_with_deepconf(
"What is 25 * 4?", enable_early_stopping=True, threshold=15.0, window_size=5, max_tokens=64
)
print_result("Example 2: Math (Permissive Early Stopping)", "What is 25 * 4?", result)
# Example 3: Math without early stopping
result = generate_with_deepconf("What is 25 * 4?", enable_early_stopping=False, max_tokens=64)
print_result("Example 3: Math (No Early Stopping)", "What is 25 * 4?", result)
# Example 4: Reasoning question
result = generate_with_deepconf(
"If 5 apples cost $10, how much do 3 apples cost?",
enable_early_stopping=True,
threshold=8.0,
window_size=5,
max_tokens=96,
)
print_result("Example 4: Word Problem", "If 5 apples cost $10, how much do 3 apples cost?", result)
# Example 5: Factual question
result = generate_with_deepconf(
"Who wrote Romeo and Juliet?", enable_early_stopping=True, threshold=6.0, window_size=5, max_tokens=64
)
print_result("Example 5: Factual Question", "Who wrote Romeo and Juliet?", result)
# Example 6: Calculation
result = generate_with_deepconf(
"Calculate: (15 + 8) × 2", enable_early_stopping=True, threshold=7.0, window_size=5, max_tokens=96
)
print_result("Example 6: Calculation", "Calculate: (15 + 8) × 2", result)
# Example 7: Definition
result = generate_with_deepconf(
"Define photosynthesis in simple terms.",
enable_early_stopping=True,
threshold=10.0,
window_size=10,
max_tokens=128,
)
print_result("Example 7: Definition", "Define photosynthesis in simple terms.", result)
# Example 8: Step-by-step
result = generate_with_deepconf(
"Solve: x + 5 = 12. Show your steps.", enable_early_stopping=True, threshold=8.0, window_size=5, max_tokens=96
)
print_result("Example 8: Step-by-step Solution", "Solve: x + 5 = 12. Show your steps.", result)
print(f"\n{'â–ˆ' * 80}")
print("ALL EXAMPLES COMPLETE")
print("â–ˆ" * 80)
print("\nKey observations:")
print("- Lower threshold → Earlier stopping (fewer tokens)")
print("- Higher threshold → Later stopping (more tokens)")
print("- No early stopping → Always generates max_tokens")
print("- Confidence varies based on model certainty")