AR-VLA model card

This model was developed by INSAIT and KU Leuven.

Code and model weights for AR-VLA models are free to use under the Gemma license.

This repo provides model weights fine-tuned for a widowX setup with one external camera.

The weights work out of the box on simpler env and a real widowX robot in a similar toy kitchen scene.

Use with 🤗 Transformers

We provide a fully AutoModel compatible implementation of AR-VLA that can be used via Transformers.

Environment setup

The current implementation requires the following additional dependencies: roma, timm, flash-attn.

Here is a snippet to set up a working environment for inference via uv:

  1. Install uv:
wget -qO- https://github.com/astral-sh/uv/releases/download/0.7.5/uv-installer.sh | sh
  1. Create virtualenv and resolve the dependencies:
uv venv python 3.10.12
source .venv/bin/activate
uv pip install --torch-backend=cu126 roma==1.5.0 numpy==2.2.4 torch==2.6.0 torchvision==0.21.0 transformers==4.47.0 timm==1.0.15
uv pip install --no-build-isolation setuptools psutil flash-attn==2.7.3

Async Inference

We provide an interface to update the Vision-Language (VL) KV cache independently from action prediction. While the VL context remains cached, the action-specific KV cache and RoPE (Rotary Positional Embedding) steps are managed internally. This architecture enables efficient sequential calls and facilitates true streaming behavior.

Workflow Example:

# Initialize/Update the persistent VLM context
model.refresh_test_time_vlm()

# Predict actions based on incoming states without recomputing the VLM backbone
action_1 = model.next_test_time_action(state_1)
action_2 = model.next_test_time_action(state_2)
action_3 = model.next_test_time_action(state_3)

# Refresh the VLM context when a new observation or instruction is received
model.refresh_test_time_vlm()
action_4 = model.next_test_time_action(state_4)
action_5 = model.next_test_time_action(state_5)

Example usage

import numpy as np
import torch
from PIL import Image
from transformers import AutoModel, AutoProcessor

model_id = "INSAIT-Institute/arvla-bridge"

model = AutoModel.from_pretrained(model_id, trust_remote_code=True)
processor = AutoProcessor.from_pretrained(model_id, trust_remote_code=True)
model = model.to(device="cuda").eval()

image = Image.open("path/to/main_image.png").convert("RGB")

batch = processor.preprocess_inputs(
  chat=["pick up the cup", ""],
  images={"main": [image]},
  ee_pose_translation=np.zeros((1, 1, 3), dtype=np.float32),
  ee_pose_rotation=np.array([[[0.0, 0.0, 0.0, 1.0]]], dtype=np.float32),
  gripper=np.zeros((1, 1), dtype=np.float32),
  joints=np.zeros((1, 1, 7), dtype=np.float32),
  dataset_name=np.array(["bridge"]),
  inference_mode=True,
)

with torch.inference_mode():
  model.reset_test_time_cache()
  model.refresh_test_time_vlm(
    input_ids=batch["input_ids"].to("cuda"),
    attention_mask=torch.ones_like(batch["input_ids"], dtype=torch.bool).to("cuda"),
    images={k: v.to("cuda") for k, v in batch["images"].items()},
    ee_pose_translation=batch["ee_pose_translation"].to("cuda"),
    ee_pose_rotation=batch["ee_pose_rotation"].to("cuda"),
    gripper=batch["gripper"].unsqueeze(-1).to("cuda"),
    joints=batch["joints"].to("cuda"),
    control_tokens_ids=batch["control_tokens_ids"],
  )
  action = model.next_test_time_action(
    input_ids=batch["input_ids"].to("cuda"),
    ee_pose_translation=batch["ee_pose_translation"].to("cuda"),
    ee_pose_rotation=batch["ee_pose_rotation"].to("cuda"),
    gripper=batch["gripper"].unsqueeze(-1).to("cuda"),
    joints=batch["joints"].to("cuda"),
    control_tokens_ids=batch["control_tokens_ids"],
  )

print(action.translation.shape, action.rotation.shape, action.gripper.shape)

Summary

  • Model type: Vision-Language-Action with autoregressive action generation
  • Model id: you2who/ar-vla-bridge
  • License: Gemma Terms of Use
Downloads last month
21
Safetensors
Model size
3B params
Tensor type
F32
·
Video Preview
loading

Collection including INSAIT-Institute/arvla-bridge