aniketp2009gmail's picture
Upload folder using huggingface_hub
199800a verified
"""Model loader for the Enterprise AI Assistant."""
from typing import Any, Dict
from langchain_groq import ChatGroq
from logger.logging import get_logger
from utils.config_loader import ConfigLoader
logger = get_logger(__name__)
class ModelLoader:
"""Loads and configures language models."""
def __init__(self, model_provider: str = "groq"):
try:
self.config = ConfigLoader()
self.model_provider = model_provider.lower()
self.llm = None
logger.info("ModelLoader initialized")
except Exception as e:
error_msg = f"Error in ModelLoader Initialization -> {str(e)}"
raise Exception(error_msg)
def load_llm(self) -> Any:
"""Load the specified language model."""
try:
if self.llm is not None:
logger.info(f"Returning cached {self.model_provider} model")
return self.llm
if self.model_provider == "groq":
self.llm = self._load_groq_model()
else:
error_msg = f"Unsupported model provider: {self.model_provider}"
raise Exception(error_msg)
logger.info(f"Successfully loaded {self.model_provider} model")
return self.llm
except Exception as e:
error_msg = f"Error in load_llm -> {str(e)}"
raise Exception(error_msg)
def _load_groq_model(self) -> ChatGroq:
"""Load Groq model."""
try:
api_key = self.config.get_api_key("groq")
if not api_key:
error_msg = "Groq API key not found. Please set GROQ_API_KEY environment variable"
logger.error(error_msg)
raise Exception(error_msg)
model_name = self.config.get_env("MODEL_NAME", "llama-3.1-8b-instant")
try:
temperature = float(self.config.get_env("MODEL_TEMPERATURE", "0.1"))
except (ValueError, TypeError):
logger.warning("Invalid MODEL_TEMPERATURE, using default 0.1")
temperature = 0.1
try:
max_tokens = int(self.config.get_env("MODEL_MAX_TOKENS", "4096"))
except (ValueError, TypeError):
logger.warning("Invalid MODEL_MAX_TOKENS, using default 4096")
max_tokens = 4096
logger.info(f"Loading Groq model: {model_name}")
return ChatGroq(
groq_api_key=api_key,
model_name=model_name,
temperature=temperature,
max_tokens=max_tokens,
)
except Exception as e:
error_msg = f"Error in _load_groq_model -> {str(e)}"
raise Exception(error_msg)
def get_model_info(self) -> Dict[str, Any]:
"""Get information about the loaded model."""
try:
if self.llm is None:
return {"provider": self.model_provider, "loaded": False}
return {
"provider": self.model_provider,
"loaded": True,
"model_name": getattr(self.llm, "model_name", "unknown"),
"temperature": getattr(self.llm, "temperature", "unknown"),
"max_tokens": getattr(self.llm, "max_tokens", "unknown"),
}
except Exception as e:
logger.error(f"Error in get_model_info -> {str(e)}")
return {"provider": self.model_provider, "loaded": False, "error": str(e)}