| import torch |
| import numpy as np |
| from sports.common.team import TeamClassifier |
|
|
| |
| model = torch.load("team_classifier.pth", map_location="cpu") |
|
|
| |
| def predict_teams(crops): |
| """ |
| Predicts team assignments for a list of player crops (numpy arrays). |
| Args: |
| crops (List[np.ndarray]): List of player crops as numpy arrays. |
| Returns: |
| np.ndarray: Predicted team labels (0 or 1) |
| """ |
| return model.predict(crops) |
|
|
| if __name__ == "__main__": |
| |
| |
| |
| dummy_crop = np.zeros((224, 224, 3), dtype=np.uint8) |
| crops = [dummy_crop] |
| preds = predict_teams(crops) |
| print("Predicted team labels:", preds) |
|
|