AR-VLA
Collection
3 items โข Updated
How to use you2who/paligemma-arvla-bridge with Transformers:
# Load model directly
from transformers import AutoModel
model = AutoModel.from_pretrained("you2who/paligemma-arvla-bridge", trust_remote_code=True, dtype="auto")This model was developed by INSAIT and KU Leuven.
Code and model weights for AR-VLA models are free to use under the Gemma license.
This repo provides model weights fine-tuned for a widowX setup with one external camera.
The weights work out of the box on simpler env and a real widowX robot in a similar toy kitchen scene.
We provide a fully AutoModel compatible implementation of AR-VLA that can be used via Transformers.
The current implementation requires the following additional dependencies: roma, timm, flash-attn.
Here is a snippet to set up a working environment for inference via uv:
uv:wget -qO- https://github.com/astral-sh/uv/releases/download/0.7.5/uv-installer.sh | sh
uv venv python 3.10.12
source .venv/bin/activate
uv pip install --torch-backend=cu126 roma==1.5.0 numpy==2.2.4 torch==2.6.0 torchvision==0.21.0 transformers==4.47.0 timm==1.0.15
uv pip install --no-build-isolation setuptools psutil flash-attn==2.7.3
import numpy as np
import torch
from PIL import Image
from transformers import AutoModel, AutoProcessor
model_id = "you2who/ar-vla-bridge"
model = AutoModel.from_pretrained(model_id, trust_remote_code=True)
processor = AutoProcessor.from_pretrained(model_id, trust_remote_code=True)
model = model.to(device="cuda").eval()
image = Image.open("path/to/main_image.png").convert("RGB")
batch = processor.preprocess_inputs(
chat=["pick up the cup", ""],
images={"main": [image]},
ee_pose_translation=np.zeros((1, 1, 3), dtype=np.float32),
ee_pose_rotation=np.array([[[0.0, 0.0, 0.0, 1.0]]], dtype=np.float32),
gripper=np.zeros((1, 1), dtype=np.float32),
joints=np.zeros((1, 1, 7), dtype=np.float32),
dataset_name=np.array(["bridge"]),
inference_mode=True,
)
with torch.inference_mode():
model.reset_test_time_cache()
model.refresh_test_time_vlm(
input_ids=batch["input_ids"].to("cuda"),
attention_mask=torch.ones_like(batch["input_ids"], dtype=torch.bool).to("cuda"),
images={k: v.to("cuda") for k, v in batch["images"].items()},
ee_pose_translation=batch["ee_pose_translation"].to("cuda"),
ee_pose_rotation=batch["ee_pose_rotation"].to("cuda"),
gripper=batch["gripper"].unsqueeze(-1).to("cuda"),
joints=batch["joints"].to("cuda"),
control_tokens_ids=batch["control_tokens_ids"],
)
action = model.next_test_time_action(
input_ids=batch["input_ids"].to("cuda"),
ee_pose_translation=batch["ee_pose_translation"].to("cuda"),
ee_pose_rotation=batch["ee_pose_rotation"].to("cuda"),
gripper=batch["gripper"].unsqueeze(-1).to("cuda"),
joints=batch["joints"].to("cuda"),
control_tokens_ids=batch["control_tokens_ids"],
)
print(action.translation.shape, action.rotation.shape, action.gripper.shape)
you2who/ar-vla-bridge