AR-VLA model card
This model was developed by INSAIT and KU Leuven.
Code and model weights for AR-VLA models are free to use under the Gemma license.
This repo provides model weights fine-tuned for a widowX setup with one external camera.
The weights work out of the box on simpler env and a real widowX robot in a similar toy kitchen scene.
Use with 🤗 Transformers
We provide a fully AutoModel compatible implementation of AR-VLA that can be used via Transformers.
Environment setup
The current implementation requires the following additional dependencies: roma, timm, flash-attn.
Here is a snippet to set up a working environment for inference via uv:
- Install
uv:
wget -qO- https://github.com/astral-sh/uv/releases/download/0.7.5/uv-installer.sh | sh
- Create virtualenv and resolve the dependencies:
uv venv python 3.10.12
source .venv/bin/activate
uv pip install --torch-backend=cu126 roma==1.5.0 numpy==2.2.4 torch==2.6.0 torchvision==0.21.0 transformers==4.47.0 timm==1.0.15
uv pip install --no-build-isolation setuptools psutil flash-attn==2.7.3
Async Inference
We provide an interface to update the Vision-Language (VL) KV cache independently from action prediction. While the VL context remains cached, the action-specific KV cache and RoPE (Rotary Positional Embedding) steps are managed internally. This architecture enables efficient sequential calls and facilitates true streaming behavior.
Workflow Example:
# Initialize/Update the persistent VLM context
model.refresh_test_time_vlm()
# Predict actions based on incoming states without recomputing the VLM backbone
action_1 = model.next_test_time_action(state_1)
action_2 = model.next_test_time_action(state_2)
action_3 = model.next_test_time_action(state_3)
# Refresh the VLM context when a new observation or instruction is received
model.refresh_test_time_vlm()
action_4 = model.next_test_time_action(state_4)
action_5 = model.next_test_time_action(state_5)
Example usage
import numpy as np
import torch
from PIL import Image
from transformers import AutoModel, AutoProcessor
model_id = "INSAIT-Institute/arvla-bridge"
model = AutoModel.from_pretrained(model_id, trust_remote_code=True)
processor = AutoProcessor.from_pretrained(model_id, trust_remote_code=True)
model = model.to(device="cuda").eval()
image = Image.open("path/to/main_image.png").convert("RGB")
batch = processor.preprocess_inputs(
chat=["pick up the cup", ""],
images={"main": [image]},
ee_pose_translation=np.zeros((1, 1, 3), dtype=np.float32),
ee_pose_rotation=np.array([[[0.0, 0.0, 0.0, 1.0]]], dtype=np.float32),
gripper=np.zeros((1, 1), dtype=np.float32),
joints=np.zeros((1, 1, 7), dtype=np.float32),
dataset_name=np.array(["bridge"]),
inference_mode=True,
)
with torch.inference_mode():
model.reset_test_time_cache()
model.refresh_test_time_vlm(
input_ids=batch["input_ids"].to("cuda"),
attention_mask=torch.ones_like(batch["input_ids"], dtype=torch.bool).to("cuda"),
images={k: v.to("cuda") for k, v in batch["images"].items()},
ee_pose_translation=batch["ee_pose_translation"].to("cuda"),
ee_pose_rotation=batch["ee_pose_rotation"].to("cuda"),
gripper=batch["gripper"].unsqueeze(-1).to("cuda"),
joints=batch["joints"].to("cuda"),
control_tokens_ids=batch["control_tokens_ids"],
)
action = model.next_test_time_action(
input_ids=batch["input_ids"].to("cuda"),
ee_pose_translation=batch["ee_pose_translation"].to("cuda"),
ee_pose_rotation=batch["ee_pose_rotation"].to("cuda"),
gripper=batch["gripper"].unsqueeze(-1).to("cuda"),
joints=batch["joints"].to("cuda"),
control_tokens_ids=batch["control_tokens_ids"],
)
print(action.translation.shape, action.rotation.shape, action.gripper.shape)
Summary
- Model type: Vision-Language-Action with autoregressive action generation
- Model id:
you2who/ar-vla-bridge - License: Gemma Terms of Use
- Downloads last month
- 21