| import gradio as gr |
| import os |
| import random |
| import uuid |
| import csv |
| from datetime import datetime |
| from pathlib import Path |
| from PIL import Image |
| from huggingface_hub import CommitScheduler, snapshot_download |
|
|
| |
| DATASET_REPO_ID = "Emilyxml/moveit" |
| DATA_FOLDER = "data" |
| LOG_FOLDER = Path("logs") |
| LOG_FOLDER.mkdir(parents=True, exist_ok=True) |
| TOKEN = os.environ.get("HF_TOKEN") |
|
|
| |
| if not os.path.exists(DATA_FOLDER) or not os.listdir(DATA_FOLDER): |
| try: |
| print("🚀 正在从 Dataset 下载数据...") |
| snapshot_download( |
| repo_id=DATASET_REPO_ID, |
| repo_type="dataset", |
| local_dir=DATA_FOLDER, |
| token=TOKEN, |
| allow_patterns=["*.jpg", "*.png", "*.jpeg", "*.webp", "*.txt"] |
| ) |
| print("✅ 数据下载完成!") |
| except Exception as e: |
| print(f"⚠️ 下载失败: {e}") |
|
|
| |
| scheduler = CommitScheduler( |
| repo_id=DATASET_REPO_ID, |
| repo_type="dataset", |
| folder_path=LOG_FOLDER, |
| path_in_repo="logs", |
| every=1, |
| token=TOKEN |
| ) |
|
|
| |
| def load_data(): |
| groups = {} |
| if not os.path.exists(DATA_FOLDER): |
| return {}, [] |
|
|
| for filename in os.listdir(DATA_FOLDER): |
| if filename.startswith('.'): continue |
| file_path = os.path.join(DATA_FOLDER, filename) |
| prefix = filename[:5] |
| |
| if prefix not in groups: |
| groups[prefix] = {"origin": None, "candidates": [], "instruction": "暂无说明"} |
|
|
| if filename.lower().endswith(('.png', '.jpg', '.jpeg', '.webp')): |
| if "_origin" in filename.lower(): |
| groups[prefix]["origin"] = file_path |
| else: |
| groups[prefix]["candidates"].append(file_path) |
| elif filename.lower().endswith('.txt'): |
| try: |
| with open(file_path, "r", encoding="utf-8") as f: |
| groups[prefix]["instruction"] = f.read() |
| except: |
| with open(file_path, "r", encoding="gbk") as f: |
| groups[prefix]["instruction"] = f.read() |
|
|
| valid_groups = {} |
| for k, v in groups.items(): |
| if v["origin"] is not None or len(v["candidates"]) > 0: |
| valid_groups[k] = v |
| |
| group_ids = list(valid_groups.keys()) |
| random.shuffle(group_ids) |
| print(f"Loaded {len(group_ids)} groups.") |
| return valid_groups, group_ids |
|
|
| ALL_GROUPS, ALL_GROUP_IDS = load_data() |
|
|
| |
| def optimize_image(image_path, max_width=800): |
| """ |
| 读取图片并调整大小,减少传输时间。 |
| max_width: 限制最大宽度为 800px (足够人眼评估) |
| """ |
| if not image_path: |
| return None |
| try: |
| img = Image.open(image_path) |
| |
| if img.width > max_width: |
| ratio = max_width / img.width |
| new_height = int(img.height * ratio) |
| img = img.resize((max_width, new_height), Image.LANCZOS) |
| return img |
| except Exception as e: |
| print(f"Error loading image {image_path}: {e}") |
| return None |
|
|
| |
|
|
| def get_next_question(user_state): |
| """准备下一题的数据""" |
| idx = user_state["index"] |
| |
| if idx >= len(ALL_GROUP_IDS): |
| return ( |
| gr.update(visible=False), |
| gr.update(visible=False), |
| gr.update(visible=False), |
| gr.update(visible=False), |
| gr.update(visible=False), |
| gr.update(visible=False), |
| gr.update(value="## 🎉 测试结束!感谢您的参与。", visible=True), |
| user_state, |
| [] |
| ) |
| |
| group_id = ALL_GROUP_IDS[idx] |
| group_data = ALL_GROUPS[group_id] |
| |
| |
| origin_img = optimize_image(group_data["origin"], max_width=600) |
| |
| |
| candidates = group_data["candidates"].copy() |
| random.shuffle(candidates) |
| |
| gallery_items = [] |
| choices = [] |
| candidates_info = [] |
| |
| for i, path in enumerate(candidates): |
| label = f"Option {chr(65+i)}" |
| |
| |
| optimized_img = optimize_image(path, max_width=600) |
| |
| gallery_items.append((optimized_img, label)) |
| choices.append(label) |
| candidates_info.append({"label": label, "path": path}) |
| |
| instruction = f"### 任务 ({idx + 1} / {len(ALL_GROUP_IDS)})\n\n{group_data['instruction']}" |
| |
| return ( |
| gr.update(value=origin_img, visible=True if origin_img else False), |
| gr.update(value=gallery_items, visible=True), |
| gr.update(choices=choices, value=[], visible=True), |
| gr.update(value=instruction, visible=True), |
| gr.update(visible=True), |
| gr.update(visible=True), |
| gr.update(visible=False), |
| user_state, |
| candidates_info |
| ) |
|
|
| def save_and_next(user_state, candidates_info, selected_options, is_none=False): |
| current_idx = user_state["index"] |
| group_id = ALL_GROUP_IDS[current_idx] |
| |
| if is_none: |
| choice_str = "Rejected All" |
| method_str = "None_Satisfied" |
| else: |
| if not selected_options: |
| raise gr.Error("请至少勾选一个选项,或点击“都不满意”") |
| |
| choice_str = "; ".join(selected_options) |
| selected_methods = [] |
| for opt in selected_options: |
| for info in candidates_info: |
| if info["label"] == opt: |
| path = info["path"] |
| filename = os.path.basename(path) |
| name = os.path.splitext(filename)[0] |
| parts = name.split('_', 1) |
| method = parts[1] if len(parts) > 1 else name |
| selected_methods.append(method) |
| break |
| method_str = "; ".join(selected_methods) |
|
|
| user_file = LOG_FOLDER / f"user_{user_state['user_id']}.csv" |
| with scheduler.lock: |
| exists = user_file.exists() |
| with open(user_file, "a", newline="", encoding="utf-8") as f: |
| writer = csv.writer(f) |
| if not exists: |
| writer.writerow(["user_id", "timestamp", "group_id", "choices", "methods"]) |
| writer.writerow([ |
| user_state["user_id"], |
| datetime.now().strftime("%Y-%m-%d %H:%M:%S"), |
| group_id, |
| choice_str, |
| method_str |
| ]) |
| |
| user_state["index"] += 1 |
| return get_next_question(user_state) |
|
|
| |
| with gr.Blocks(title="User Study") as demo: |
| |
| state_user = gr.State(lambda: {"user_id": str(uuid.uuid4())[:8], "index": 0}) |
| state_candidates_info = gr.State([]) |
| |
| with gr.Row(): |
| md_instruction = gr.Markdown("Loading...") |
| |
| with gr.Row(): |
| with gr.Column(scale=1): |
| |
| img_origin = gr.Image(label="Reference (参考原图)", interactive=False, height=400, format="jpeg") |
| |
| with gr.Column(scale=2): |
| gallery_candidates = gr.Gallery( |
| label="Candidates (候选结果)", |
| columns=[2], |
| height="auto", |
| object_fit="contain", |
| interactive=False, |
| format="jpeg" |
| ) |
| |
| gr.Markdown("👇 **请在下方勾选您认为最好的结果(可多选):**") |
| |
| checkbox_options = gr.CheckboxGroup( |
| choices=[], |
| label="您的选择", |
| info="对应上方图片的标签 (Option A, B...)" |
| ) |
| |
| with gr.Row(): |
| btn_submit = gr.Button("🚀 提交 (Submit)", variant="primary") |
| btn_none = gr.Button("🚫 都不满意 (None)", variant="stop") |
| |
| md_end = gr.Markdown(visible=False) |
|
|
| demo.load( |
| fn=get_next_question, |
| inputs=[state_user], |
| outputs=[img_origin, gallery_candidates, checkbox_options, md_instruction, btn_submit, btn_none, md_end, state_user, state_candidates_info] |
| ) |
| |
| btn_submit.click( |
| fn=lambda s, c, o: save_and_next(s, c, o, is_none=False), |
| inputs=[state_user, state_candidates_info, checkbox_options], |
| outputs=[img_origin, gallery_candidates, checkbox_options, md_instruction, btn_submit, btn_none, md_end, state_user, state_candidates_info] |
| ) |
| |
| btn_none.click( |
| fn=lambda s, c, o: save_and_next(s, c, o, is_none=True), |
| inputs=[state_user, state_candidates_info, checkbox_options], |
| outputs=[img_origin, gallery_candidates, checkbox_options, md_instruction, btn_submit, btn_none, md_end, state_user, state_candidates_info] |
| ) |
|
|
| if __name__ == "__main__": |
| demo.launch() |