Instructions to use aardvark-labs/stp-classifier-4-1 with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Transformers
How to use aardvark-labs/stp-classifier-4-1 with Transformers:
# Use a pipeline as a high-level helper from transformers import pipeline pipe = pipeline("text-classification", model="aardvark-labs/stp-classifier-4-1")# Load model directly from transformers import AutoTokenizer, AutoModelForSequenceClassification tokenizer = AutoTokenizer.from_pretrained("aardvark-labs/stp-classifier-4-1") model = AutoModelForSequenceClassification.from_pretrained("aardvark-labs/stp-classifier-4-1") - Notebooks
- Google Colab
- Kaggle
| import os | |
| import json | |
| import torch | |
| from transformers import AutoModelForSequenceClassification, AutoTokenizer, AutoConfig | |
| # ๋ชจ๋ธ๊ณผ ํ ํฌ๋์ด์ ๋ฅผ ๋ก๋ํ๋ ํจ์ | |
| def model_fn(model_dir): | |
| """ | |
| SageMaker๊ฐ ๋ชจ๋ธ์ ๋ก๋ํ๊ธฐ ์ํด ํธ์ถํ๋ ํจ์ | |
| Args: | |
| model_dir (str): ๋ชจ๋ธ ํ์ผ์ด ์ ์ฅ๋ ๋๋ ํ ๋ฆฌ ๊ฒฝ๋ก | |
| Returns: | |
| dict: ๋ชจ๋ธ, ํ ํฌ๋์ด์ , ์ค์ ๋ฑ์ ํฌํจํ ๋์ ๋๋ฆฌ | |
| """ | |
| # ํ๊ฒฝ ๋ณ์ ์ค์ (์ ํ ์ฌํญ) | |
| os.environ["TOKENIZERS_PARALLELISM"] = "false" | |
| # ์ค์ ํ์ผ ๋ก๋ | |
| config_path = os.path.join(model_dir, "config.json") | |
| config = AutoConfig.from_pretrained(config_path) | |
| print(f"Loading model from {model_dir}") | |
| print(f"Device: {'cuda' if torch.cuda.is_available() else 'cpu'}") | |
| # ๋ ์ด๋ธ ๋งคํ ๋ก๋ (์๋ ๊ฒฝ์ฐ) | |
| label_map = {} | |
| label_map_path = os.path.join(model_dir, "label_map.json") | |
| if os.path.exists(label_map_path): | |
| with open(label_map_path, 'r', encoding='utf-8') as f: | |
| label_map = json.load(f) | |
| print(f"Loaded label map from {label_map_path}") | |
| else: | |
| print("No label map found. Using numeric indices as labels.") | |
| # ๋ชจ๋ธ ๋ก๋ | |
| model = AutoModelForSequenceClassification.from_pretrained( | |
| model_dir, | |
| config=config, | |
| torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32 | |
| ) | |
| # GPU ์ฌ์ฉ ๊ฐ๋ฅํ ๊ฒฝ์ฐ ๋ชจ๋ธ์ GPU๋ก ์ด๋ | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| model = model.to(device) | |
| model.eval() | |
| # ํ ํฌ๋์ด์ ๋ก๋ | |
| tokenizer = AutoTokenizer.from_pretrained(model_dir) | |
| return { | |
| "model": model, | |
| "tokenizer": tokenizer, | |
| "config": config, | |
| "device": device, | |
| "label_map": label_map | |
| } | |
| # ์ ๋ ฅ ๋ฐ์ดํฐ ์ฒ๋ฆฌ ํจ์ | |
| def input_fn(request_body, request_content_type): | |
| """ | |
| SageMaker๊ฐ ์์ฒญ ๋ฐ์ดํฐ๋ฅผ ์ฒ๋ฆฌํ๊ธฐ ์ํด ํธ์ถํ๋ ํจ์ | |
| Args: | |
| request_body: ์์ฒญ ๋ณธ๋ฌธ ๋ฐ์ดํฐ | |
| request_content_type (str): ์์ฒญ ์ฝํ ์ธ ํ์ | |
| Returns: | |
| dict: ์ฒ๋ฆฌ๋ ์ ๋ ฅ ๋ฐ์ดํฐ | |
| """ | |
| if request_content_type == "application/json": | |
| input_data = json.loads(request_body) | |
| # ๋ฌธ์์ด์ธ ๊ฒฝ์ฐ ํ ์คํธ๋ก ์ฒ๋ฆฌ | |
| if isinstance(input_data, str): | |
| return {"text": input_data} | |
| return input_data | |
| elif request_content_type == "text/plain": | |
| # ์ผ๋ฐ ํ ์คํธ ์ฒ๋ฆฌ | |
| return {"text": request_body.decode('utf-8')} | |
| else: | |
| raise ValueError(f"์ง์๋์ง ์๋ ์ฝํ ์ธ ํ์ : {request_content_type}") | |
| # ์์ธก ํจ์ | |
| def predict_fn(input_data, model_dict): | |
| """ | |
| SageMaker๊ฐ ๋ชจ๋ธ ์์ธก์ ์ํํ๊ธฐ ์ํด ํธ์ถํ๋ ํจ์ | |
| Args: | |
| input_data (dict): ์ฒ๋ฆฌ๋ ์ ๋ ฅ ๋ฐ์ดํฐ | |
| model_dict (dict): model_fn์์ ๋ฐํํ ๋ชจ๋ธ ์ ๋ณด | |
| Returns: | |
| dict: ์์ธก ๊ฒฐ๊ณผ | |
| """ | |
| model = model_dict["model"] | |
| tokenizer = model_dict["tokenizer"] | |
| device = model_dict["device"] | |
| label_map = model_dict["label_map"] | |
| # ์ ๋ ฅ ํ ์คํธ ๊ฐ์ ธ์ค๊ธฐ | |
| if "text" in input_data: | |
| text = input_data["text"] | |
| else: | |
| raise ValueError("์ ๋ ฅ ๋ฐ์ดํฐ์ 'text' ํ๋๊ฐ ์์ต๋๋ค") | |
| # ํ ํฐํ ์ต์ | |
| max_length = input_data.get("max_length", 512) | |
| padding = input_data.get("padding", "max_length") | |
| truncation = input_data.get("truncation", True) | |
| # ํ ํฐํ | |
| inputs = tokenizer( | |
| text, | |
| return_tensors="pt", | |
| padding=padding, | |
| truncation=truncation, | |
| max_length=max_length | |
| ) | |
| # ์ ๋ ฅ ํ ์๋ฅผ ๋๋ฐ์ด์ค๋ก ์ด๋ | |
| inputs = {name: tensor.to(device) for name, tensor in inputs.items()} | |
| # ๋ชจ๋ธ ์ถ๋ก | |
| with torch.no_grad(): | |
| outputs = model(**inputs) | |
| logits = outputs.logits | |
| probabilities = torch.softmax(logits, dim=1) | |
| # ์ด์ง ๋ถ๋ฅ ๋ชจ๋ธ์ธ ๊ฒฝ์ฐ (ํด๋์ค ์๊ฐ 2์ธ ๊ฒฝ์ฐ) | |
| if logits.shape[1] == 2: | |
| positive_prob = probabilities[0, 1].item() | |
| negative_prob = probabilities[0, 0].item() | |
| prediction = 1 if positive_prob > 0.5 else 0 | |
| result = { | |
| "prediction": prediction, | |
| "positive_probability": positive_prob, | |
| "negative_probability": negative_prob | |
| } | |
| # ๋ ์ด๋ธ ๋งคํ์ด ์๋ ๊ฒฝ์ฐ ๋ ์ด๋ธ ์ถ๊ฐ | |
| if label_map: | |
| pred_label = str(prediction) | |
| if pred_label in label_map: | |
| result["label"] = label_map[pred_label] | |
| # ๋ค์ค ํด๋์ค ๋ชจ๋ธ์ธ ๊ฒฝ์ฐ | |
| else: | |
| predictions = torch.argmax(probabilities, dim=1).cpu().numpy().tolist() | |
| probabilities = probabilities.cpu().numpy().tolist()[0] | |
| result = { | |
| "prediction": predictions[0], | |
| "probabilities": probabilities, | |
| } | |
| # ๋ ์ด๋ธ ๋งคํ์ด ์๋ ๊ฒฝ์ฐ ๋ ์ด๋ธ ์ถ๊ฐ | |
| if label_map: | |
| pred_label = str(predictions[0]) | |
| if pred_label in label_map: | |
| result["label"] = label_map[pred_label] | |
| # ๋ชจ๋ ๋ ์ด๋ธ์ ๋ํ ํ๋ฅ ๋งคํ ์ถ๊ฐ | |
| result["label_probabilities"] = { | |
| label_map.get(str(idx), str(idx)): prob | |
| for idx, prob in enumerate(probabilities) | |
| } | |
| return result | |
| # ์ถ๋ ฅ ๋ฐ์ดํฐ ์ฒ๋ฆฌ ํจ์ | |
| def output_fn(prediction, response_content_type): | |
| """ | |
| SageMaker๊ฐ ์์ธก ๊ฒฐ๊ณผ๋ฅผ ์๋ต ํ์์ผ๋ก ๋ณํํ๊ธฐ ์ํด ํธ์ถํ๋ ํจ์ | |
| Args: | |
| prediction: predict_fn์์ ๋ฐํํ ์์ธก ๊ฒฐ๊ณผ | |
| response_content_type (str): ์ํ๋ ์๋ต ์ฝํ ์ธ ํ์ | |
| Returns: | |
| str: ์ง๋ ฌํ๋ ์์ธก ๊ฒฐ๊ณผ | |
| """ | |
| if response_content_type == "application/json": | |
| return json.dumps(prediction, ensure_ascii=False) | |
| else: | |
| raise ValueError(f"์ง์๋์ง ์๋ ์ฝํ ์ธ ํ์ : {response_content_type}") |