File size: 3,506 Bytes
1bd1563
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
199800a
 
 
 
 
 
 
 
 
 
1bd1563
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
"""Model loader for the Enterprise AI Assistant."""

from typing import Any, Dict

from langchain_groq import ChatGroq

from logger.logging import get_logger
from utils.config_loader import ConfigLoader

logger = get_logger(__name__)


class ModelLoader:
    """Loads and configures language models."""

    def __init__(self, model_provider: str = "groq"):
        try:
            self.config = ConfigLoader()
            self.model_provider = model_provider.lower()
            self.llm = None
            logger.info("ModelLoader initialized")

        except Exception as e:
            error_msg = f"Error in ModelLoader Initialization -> {str(e)}"
            raise Exception(error_msg)

    def load_llm(self) -> Any:
        """Load the specified language model."""
        try:
            if self.llm is not None:
                logger.info(f"Returning cached {self.model_provider} model")
                return self.llm

            if self.model_provider == "groq":
                self.llm = self._load_groq_model()
            else:
                error_msg = f"Unsupported model provider: {self.model_provider}"
                raise Exception(error_msg)

            logger.info(f"Successfully loaded {self.model_provider} model")
            return self.llm

        except Exception as e:
            error_msg = f"Error in load_llm -> {str(e)}"
            raise Exception(error_msg)

    def _load_groq_model(self) -> ChatGroq:
        """Load Groq model."""
        try:
            api_key = self.config.get_api_key("groq")
            if not api_key:
                error_msg = "Groq API key not found. Please set GROQ_API_KEY environment variable"
                logger.error(error_msg)
                raise Exception(error_msg)

            model_name = self.config.get_env("MODEL_NAME", "llama-3.1-8b-instant")
            try:
                temperature = float(self.config.get_env("MODEL_TEMPERATURE", "0.1"))
            except (ValueError, TypeError):
                logger.warning("Invalid MODEL_TEMPERATURE, using default 0.1")
                temperature = 0.1
            try:
                max_tokens = int(self.config.get_env("MODEL_MAX_TOKENS", "4096"))
            except (ValueError, TypeError):
                logger.warning("Invalid MODEL_MAX_TOKENS, using default 4096")
                max_tokens = 4096

            logger.info(f"Loading Groq model: {model_name}")

            return ChatGroq(
                groq_api_key=api_key,
                model_name=model_name,
                temperature=temperature,
                max_tokens=max_tokens,
            )

        except Exception as e:
            error_msg = f"Error in _load_groq_model -> {str(e)}"
            raise Exception(error_msg)

    def get_model_info(self) -> Dict[str, Any]:
        """Get information about the loaded model."""
        try:
            if self.llm is None:
                return {"provider": self.model_provider, "loaded": False}

            return {
                "provider": self.model_provider,
                "loaded": True,
                "model_name": getattr(self.llm, "model_name", "unknown"),
                "temperature": getattr(self.llm, "temperature", "unknown"),
                "max_tokens": getattr(self.llm, "max_tokens", "unknown"),
            }

        except Exception as e:
            logger.error(f"Error in get_model_info -> {str(e)}")
            return {"provider": self.model_provider, "loaded": False, "error": str(e)}