| import os |
| import cv2 |
| import imghdr |
| import shutil |
| import warnings |
| import numpy as np |
| import gradio as gr |
| from dataclasses import dataclass |
| from mivolo.predictor import Predictor |
| from utils import is_url, download_file, get_jpg_files, _L, MODEL_DIR, TMP_DIR |
|
|
|
|
| @dataclass |
| class Cfg: |
| detector_weights: str |
| checkpoint: str |
| device: str = "cpu" |
| with_persons: bool = True |
| disable_faces: bool = False |
| draw: bool = True |
|
|
|
|
| class ValidImgDetector: |
| predictor = None |
|
|
| def __init__(self): |
| detector_path = f"{MODEL_DIR}/yolov8x_person_face.pt" |
| age_gender_path = f"{MODEL_DIR}/model_imdb_cross_person_4.22_99.46.pth.tar" |
| predictor_cfg = Cfg(detector_path, age_gender_path) |
| self.predictor = Predictor(predictor_cfg) |
|
|
| def _detect( |
| self, |
| image: np.ndarray, |
| score_threshold: float, |
| iou_threshold: float, |
| mode: str, |
| predictor: Predictor, |
| ) -> np.ndarray: |
| predictor.detector.detector_kwargs["conf"] = score_threshold |
| predictor.detector.detector_kwargs["iou"] = iou_threshold |
| if mode == "Use persons and faces": |
| use_persons = True |
| disable_faces = False |
|
|
| elif mode == "Use persons only": |
| use_persons = True |
| disable_faces = True |
|
|
| elif mode == "Use faces only": |
| use_persons = False |
| disable_faces = False |
|
|
| predictor.age_gender_model.meta.use_persons = use_persons |
| predictor.age_gender_model.meta.disable_faces = disable_faces |
| detected_objects, out_im = predictor.recognize(image) |
| has_child, has_female, has_male = False, False, False |
| if len(detected_objects.ages) > 0: |
| has_child = _L("是") if min(detected_objects.ages) < 18 else _L("否") |
| has_female = _L("是") if "female" in detected_objects.genders else _L("否") |
| has_male = _L("是") if "male" in detected_objects.genders else _L("否") |
|
|
| return out_im[:, :, ::-1], has_child, has_female, has_male |
|
|
| def valid_img(self, img_path): |
| image = cv2.imread(img_path) |
| return self._detect(image, 0.4, 0.7, "Use persons and faces", self.predictor) |
|
|
|
|
| def infer(photo: str): |
| status = "Success" |
| result = child = female = male = None |
| try: |
| if is_url(photo): |
| if os.path.exists(TMP_DIR): |
| shutil.rmtree(TMP_DIR) |
|
|
| photo = download_file(photo, f"{TMP_DIR}/download.jpg") |
|
|
| detector = ValidImgDetector() |
| if not photo or not os.path.exists(photo) or imghdr.what(photo) == None: |
| raise ValueError("请正确输入图片") |
|
|
| result, child, female, male = detector.valid_img(photo) |
|
|
| except Exception as e: |
| status = f"{e}" |
|
|
| return status, result, child, female, male |
|
|
|
|
| if __name__ == "__main__": |
| warnings.filterwarnings("ignore") |
| with gr.Blocks() as iface: |
| gr.Markdown(_L("# 性别年龄检测器")) |
| with gr.Tab(_L("上传模式")): |
| gr.Interface( |
| fn=infer, |
| inputs=gr.Image(label=_L("上传照片"), type="filepath"), |
| outputs=[ |
| gr.Textbox(label=_L("状态栏"), show_copy_button=True), |
| gr.Image( |
| label=_L("检测结果"), |
| type="numpy", |
| show_share_button=False, |
| ), |
| gr.Textbox(label=_L("存在儿童")), |
| gr.Textbox(label=_L("存在女性")), |
| gr.Textbox(label=_L("存在男性")), |
| ], |
| examples=get_jpg_files(f"{MODEL_DIR}/examples"), |
| flagging_mode="never", |
| cache_examples=False, |
| ) |
|
|
| with gr.Tab(_L("在线模式")): |
| gr.Interface( |
| fn=infer, |
| inputs=gr.Textbox( |
| label=_L("网络图片链接"), |
| show_copy_button=True, |
| ), |
| outputs=[ |
| gr.Textbox(label=_L("状态栏"), show_copy_button=True), |
| gr.Image( |
| label=_L("检测结果"), |
| type="numpy", |
| show_share_button=False, |
| ), |
| gr.Textbox(label=_L("存在儿童")), |
| gr.Textbox(label=_L("存在女性")), |
| gr.Textbox(label=_L("存在男性")), |
| ], |
| flagging_mode="never", |
| ) |
|
|
| iface.launch() |
|
|