| from typing import Generator, Iterable, List, TypeVar |
|
|
| import numpy as np |
| import supervision as sv |
| import torch |
| import umap |
| from sklearn.cluster import KMeans |
| from tqdm import tqdm |
| from transformers import AutoProcessor, SiglipVisionModel |
|
|
| V = TypeVar("V") |
|
|
| SIGLIP_MODEL_PATH = 'google/siglip-base-patch16-224' |
|
|
|
|
| def create_batches( |
| sequence: Iterable[V], batch_size: int |
| ) -> Generator[List[V], None, None]: |
| """ |
| Generate batches from a sequence with a specified batch size. |
| |
| Args: |
| sequence (Iterable[V]): The input sequence to be batched. |
| batch_size (int): The size of each batch. |
| |
| Yields: |
| Generator[List[V], None, None]: A generator yielding batches of the input |
| sequence. |
| """ |
| batch_size = max(batch_size, 1) |
| current_batch = [] |
| for element in sequence: |
| if len(current_batch) == batch_size: |
| yield current_batch |
| current_batch = [] |
| current_batch.append(element) |
| if current_batch: |
| yield current_batch |
|
|
|
|
| class TeamClassifier: |
| """ |
| A classifier that uses a pre-trained SiglipVisionModel for feature extraction, |
| UMAP for dimensionality reduction, and KMeans for clustering. |
| """ |
| def __init__(self, device: str = 'cpu', batch_size: int = 32): |
| """ |
| Initialize the TeamClassifier with device and batch size. |
| |
| Args: |
| device (str): The device to run the model on ('cpu' or 'cuda'). |
| batch_size (int): The batch size for processing images. |
| """ |
| self.device = device |
| self.batch_size = batch_size |
| self.features_model = SiglipVisionModel.from_pretrained( |
| SIGLIP_MODEL_PATH).to(device) |
| self.processor = AutoProcessor.from_pretrained(SIGLIP_MODEL_PATH) |
| self.reducer = umap.UMAP(n_components=3) |
| self.cluster_model = KMeans(n_clusters=2) |
|
|
| def extract_features(self, crops: List[np.ndarray]) -> np.ndarray: |
| """ |
| Extract features from a list of image crops using the pre-trained |
| SiglipVisionModel. |
| |
| Args: |
| crops (List[np.ndarray]): List of image crops. |
| |
| Returns: |
| np.ndarray: Extracted features as a numpy array. |
| """ |
| crops = [sv.cv2_to_pillow(crop) for crop in crops] |
| batches = create_batches(crops, self.batch_size) |
| data = [] |
| with torch.no_grad(): |
| for batch in tqdm(batches, desc='Embedding extraction'): |
| inputs = self.processor( |
| images=batch, return_tensors="pt").to(self.device) |
| outputs = self.features_model(**inputs) |
| embeddings = torch.mean(outputs.last_hidden_state, dim=1).cpu().numpy() |
| data.append(embeddings) |
|
|
| return np.concatenate(data) |
|
|
| def fit(self, crops: List[np.ndarray]) -> None: |
| """ |
| Fit the classifier model on a list of image crops. |
| |
| Args: |
| crops (List[np.ndarray]): List of image crops. |
| """ |
| data = self.extract_features(crops) |
| projections = self.reducer.fit_transform(data) |
| self.cluster_model.fit(projections) |
|
|
| def predict(self, crops: List[np.ndarray]) -> np.ndarray: |
| """ |
| Predict the cluster labels for a list of image crops. |
| |
| Args: |
| crops (List[np.ndarray]): List of image crops. |
| |
| Returns: |
| np.ndarray: Predicted cluster labels. |
| """ |
| if len(crops) == 0: |
| return np.array([]) |
|
|
| data = self.extract_features(crops) |
| projections = self.reducer.transform(data) |
| return self.cluster_model.predict(projections) |
|
|