| | import os |
| | import sys |
| | import cv2 |
| | import math |
| | import json |
| | import torch |
| | import argparse |
| | import numpy as np |
| | from PIL import Image |
| | from PIL import ImageOps |
| | from pathlib import Path |
| | import multiprocessing as mp |
| | from vitra.models import VITRA_Paligemma, load_model |
| | from vitra.utils.data_utils import resize_short_side_to_target, load_normalizer, recon_traj |
| | from vitra.utils.config_utils import load_config |
| | from vitra.datasets.human_dataset import pad_state_human, pad_action |
| | from scipy.spatial.transform import Rotation as R |
| | from vitra.datasets.dataset_utils import ( |
| | compute_new_intrinsics_resize, |
| | calculate_fov, |
| | ActionFeature, |
| | StateFeature, |
| | ) |
| |
|
| | repo_root = Path(__file__).parent.parent |
| | sys.path.insert(0, str(repo_root)) |
| |
|
| | from visualization.visualize_core import HandVisualizer, normalize_camera_intrinsics, save_to_video, Renderer, process_single_hand_labels |
| | from visualization.visualize_core import Config as HandConfig |
| |
|
| | def main(): |
| | """ |
| | Main execution function for hand action prediction and visualization. |
| | |
| | This function uses a multi-process architecture to separate hand reconstruction |
| | and VLA inference into independent processes, preventing CUDA conflicts. |
| | |
| | Workflow: |
| | 1. Parse command-line arguments and load model configurations |
| | 2. Initialize persistent services: |
| | - HandReconstructionService: Runs HAWOR + MOGE models in separate process |
| | - VLAInferenceService: Runs VLA model in separate process |
| | 3. Load or reconstruct hand state: |
| | - Uses precomputed .npy file if available (same stem as image) |
| | - Otherwise runs hand reconstruction service |
| | 4. Prepare input data: |
| | - Load and resize image |
| | - Extract hand state (translation, rotation, pose) for left/right hands |
| | - Create state and action masks based on which hands to predict |
| | 5. Run VLA inference to predict future hand actions (multiple samples for diversity) |
| | 6. Reconstruct absolute hand trajectories from relative actions |
| | 7. Visualize predicted hand motions using MANO hand model |
| | 8. Generate grid layout video showing all samples and save to file |
| | 9. Cleanup: Shutdown persistent services and free GPU memory |
| | |
| | """ |
| | parser = argparse.ArgumentParser(description="Hand VLA inference and visualization.") |
| | |
| | |
| | parser.add_argument('--config_path', type=str, required=True, help='Path to model configuration JSON file') |
| | parser.add_argument('--model_path', type=str, default=None, help='Path to model checkpoint (overrides config)') |
| | parser.add_argument('--statistics_path', type=str, default=None, help='Path to normalization statistics JSON (overrides config)') |
| | |
| | |
| | parser.add_argument('--image_path', type=str, required=True, help='Path to input image file') |
| | parser.add_argument('--hand_path', type=str, default=None, help='Path to hand state .npy file (optional, will run reconstruction if not provided)') |
| | parser.add_argument('--video_path', type=str, default='./example_human_inf.mp4', help='Path to save output visualization video') |
| | |
| | |
| | parser.add_argument('--hawor_model_path', type=str, default='./weights/hawor/checkpoints/hawor.ckpt', help='Path to HAWOR model weights') |
| | parser.add_argument('--detector_path', type=str, default='./weights/hawor/external/detector.pt', help='Path to hand detector model') |
| | parser.add_argument('--moge_model_name', type=str, default='Ruicheng/moge-2-vitl', help='MOGE model name from Hugging Face') |
| | parser.add_argument('--mano_path', type=str, default='/home/t-qixiuli/repo/VITRA/weights/mano', help='Path to MANO model files') |
| | |
| | |
| | |
| | parser.add_argument('--use_left', action='store_true', help='Enable left hand prediction') |
| | parser.add_argument('--use_right', action='store_true', help='Enable right hand prediction') |
| | parser.add_argument('--instruction', type=str, default="Left hand: Put the trash into the garbage. Right hand: None.", help='Text instruction for hand motion') |
| | parser.add_argument('--sample_times', type=int, default=4, help='Number of action samples to generate for diversity') |
| | parser.add_argument('--fps', type=int, default=8, help='Frames per second for output video') |
| | |
| | |
| | parser.add_argument('--save_state_local', action='store_true', help='Save hand state locally as .npy file') |
| |
|
| | |
| | |
| | os.environ["TOKENIZERS_PARALLELISM"] = "false" |
| |
|
| | args = parser.parse_args() |
| | |
| | |
| | if not args.use_left and not args.use_right: |
| | raise ValueError("At least one of --use_left or --use_right must be specified.") |
| | |
| | |
| | configs = load_config(args.config_path) |
| | |
| | |
| | if args.model_path is not None: |
| | configs['model_load_path'] = args.model_path |
| | if args.statistics_path is not None: |
| | configs['statistics_path'] = args.statistics_path |
| |
|
| | |
| | image_path_obj = Path(args.image_path) |
| | npy_path = image_path_obj.with_suffix('.npy') |
| |
|
| | |
| | print("Initializing services...") |
| | if npy_path.exists(): |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | print(f"Found precomputed hand state results: {npy_path}. Using the state instead of running hand recon.") |
| | hand_data = np.load(npy_path, allow_pickle=True).item() |
| |
|
| | hand_recon_service = None |
| | else: |
| | print(f"No precomputed hand state .npy found at {npy_path}. Starting hand reconstruction service.") |
| | |
| | |
| | hand_recon_service = HandReconstructionService(args) |
| | hand_data = None |
| |
|
| |
|
| | |
| | vla_service = VLAInferenceService(configs) |
| | |
| | |
| | hand_config = HandConfig(args) |
| | hand_config.FPS = args.fps |
| | visualizer = HandVisualizer(hand_config, render_gradual_traj=False) |
| |
|
| | try: |
| | if hand_data is None: |
| | |
| | print("Running hand reconstruction...") |
| | hand_data = hand_recon_service.reconstruct(args.image_path) |
| | if args.save_state_local: |
| | |
| | np.save(npy_path, hand_data, allow_pickle=True) |
| | print(f"Saved reconstructed hand state to {npy_path}") |
| |
|
| | |
| | image = Image.open(args.image_path) |
| | ori_w, ori_h = image.size |
| |
|
| | |
| | try: |
| | image = ImageOps.exif_transpose(image) |
| | except Exception: |
| | pass |
| |
|
| | image_resized = resize_short_side_to_target(image, target=224) |
| | w, h = image_resized.size |
| |
|
| | use_right = args.use_right |
| | use_left = args.use_left |
| |
|
| | |
| | current_state_left = None |
| | current_state_right = None |
| | |
| | if use_right: |
| | current_state_right, beta_right, fov_x, _ = get_state(hand_data, hand_side='right') |
| | if use_left: |
| | current_state_left, beta_left, fov_x, _ = get_state(hand_data, hand_side='left') |
| | |
| | fov_x = fov_x * np.pi /180 |
| | f_ori = ori_w / np.tan(fov_x / 2) /2 |
| | fov_y = 2 * np.arctan(ori_h / (2 * f_ori)) |
| |
|
| | f = w / np.tan(fov_x / 2) /2 |
| | intrinsics = np.array([ |
| | [f, 0, w/2], |
| | [0, f, h/2], |
| | [0, 0, 1] |
| | ]) |
| |
|
| | |
| | if current_state_left is None and current_state_right is None: |
| | raise ValueError("Both current_state_left and current_state_right are None") |
| | |
| | state_left = current_state_left if use_left else np.zeros_like(current_state_right) |
| | beta_left = beta_left if use_left else np.zeros_like(beta_right) |
| | state_right = current_state_right if use_right else np.zeros_like(current_state_left) |
| | beta_right = beta_right if use_right else np.zeros_like(beta_left) |
| | |
| | state = np.concatenate([state_left, beta_left, state_right, beta_right], axis=0) |
| | state_mask = np.array([use_left, use_right], dtype=bool) |
| | |
| | |
| | chunk_size = configs.get('fwd_pred_next_n', 16) |
| | action_mask = np.tile(np.array([[use_left, use_right]], dtype=bool), (chunk_size, 1)) |
| |
|
| | fov = np.array([fov_x, fov_y], dtype=np.float32) |
| | |
| | image_resized_np = np.array(image_resized) |
| |
|
| | |
| | instruction = args.instruction |
| |
|
| | |
| | print(f"Running VLA inference...") |
| | sample_times = args.sample_times |
| | unnorm_action = vla_service.predict( |
| | image=image_resized_np, |
| | instruction=instruction, |
| | state=state, |
| | state_mask=state_mask, |
| | action_mask=action_mask, |
| | fov=fov, |
| | num_ddim_steps=10, |
| | cfg_scale=5.0, |
| | sample_times=sample_times, |
| | ) |
| | |
| | fx_exo = intrinsics[0, 0] |
| | fy_exo = intrinsics[1, 1] |
| | renderer = Renderer(w, h, (fx_exo, fy_exo), 'cuda') |
| |
|
| | T = len(action_mask) + 1 |
| | traj_right_list = np.zeros((sample_times, T, 51), dtype=np.float32) |
| | traj_left_list = np.zeros((sample_times, T, 51), dtype=np.float32) |
| |
|
| | traj_mask = np.tile(np.array([[use_left, use_right]], dtype=bool), (T, 1)) |
| | left_hand_mask = traj_mask[:, 0] |
| | right_hand_mask = traj_mask[:, 1] |
| |
|
| | |
| | hand_mask = (left_hand_mask, right_hand_mask) |
| |
|
| | all_rendered_frames = [] |
| | |
| | |
| | for i in range(sample_times): |
| | traj_right = traj_right_list[i] |
| | traj_left = traj_left_list[i] |
| | |
| | if use_left: |
| | traj_left = recon_traj( |
| | state=state_left, |
| | rel_action=unnorm_action[i, :, 0:51], |
| | ) |
| | if use_right: |
| | traj_right = recon_traj( |
| | state=state_right, |
| | rel_action=unnorm_action[i, :, 51:102], |
| | ) |
| | |
| | left_hand_labels = { |
| | 'transl_worldspace': traj_left[:, 0:3], |
| | 'global_orient_worldspace': R.from_euler('xyz', traj_left[:, 3:6]).as_matrix(), |
| | 'hand_pose': euler_traj_to_rotmat_traj(traj_left[:, 6:51], T), |
| | 'beta': beta_left, |
| | } |
| | right_hand_labels = { |
| | 'transl_worldspace': traj_right[:, 0:3], |
| | 'global_orient_worldspace': R.from_euler('xyz', traj_right[:, 3:6]).as_matrix(), |
| | 'hand_pose': euler_traj_to_rotmat_traj(traj_right[:, 6:51], T), |
| | 'beta': beta_right, |
| | } |
| |
|
| | verts_left_worldspace, _ = process_single_hand_labels(left_hand_labels, left_hand_mask, visualizer.mano, is_left=True) |
| | verts_right_worldspace, _ = process_single_hand_labels(right_hand_labels, right_hand_mask, visualizer.mano, is_left=False) |
| |
|
| | hand_traj_wordspace = (verts_left_worldspace, verts_right_worldspace) |
| | |
| | R_w2c = np.broadcast_to(np.eye(3), (T, 3, 3)).copy() |
| | t_w2c = np.zeros((T, 3, 1), dtype=np.float32) |
| |
|
| | extrinsics = (R_w2c, t_w2c) |
| |
|
| | |
| | image_bgr = image_resized_np[..., ::-1] |
| | resize_video_frames = [image_bgr] * T |
| | save_frames = visualizer._render_hand_trajectory( |
| | resize_video_frames, |
| | hand_traj_wordspace, |
| | hand_mask, |
| | extrinsics, |
| | renderer, |
| | mode='first' |
| | ) |
| | |
| | all_rendered_frames.append(save_frames) |
| | |
| | |
| | |
| | |
| | num_frames = len(all_rendered_frames[0]) |
| | |
| | |
| | grid_cols = math.ceil(math.sqrt(sample_times)) |
| | grid_rows = math.ceil(sample_times / grid_cols) |
| | |
| | |
| | combined_frames = [] |
| | for frame_idx in range(num_frames): |
| | |
| | sample_frames = [all_rendered_frames[i][frame_idx] for i in range(sample_times)] |
| | |
| | |
| | while len(sample_frames) < grid_rows * grid_cols: |
| | black_frame = np.zeros_like(sample_frames[0]) |
| | sample_frames.append(black_frame) |
| | |
| | |
| | rows = [] |
| | for row_idx in range(grid_rows): |
| | row_frames = sample_frames[row_idx * grid_cols:(row_idx + 1) * grid_cols] |
| | row_concat = np.concatenate(row_frames, axis=1) |
| | rows.append(row_concat) |
| | |
| | |
| | combined_frame = np.concatenate(rows, axis=0) |
| | combined_frames.append(combined_frame) |
| |
|
| | |
| | save_to_video(combined_frames, f'{args.video_path}', fps=hand_config.FPS) |
| | print(f"Combined video with {sample_times} samples saved to {args.video_path}") |
| | |
| | finally: |
| | |
| | print("Shutting down services...") |
| | if hand_recon_service is not None: |
| | hand_recon_service.shutdown() |
| | vla_service.shutdown() |
| | print("All services shut down successfully") |
| | |
| |
|
| | def get_state(hand_data, hand_side='right'): |
| | """ |
| | Load and extract hand state from hand data. |
| | |
| | Args: |
| | hand_data (dict): Dictionary containing hand data |
| | hand_side (str): Which hand to extract, either 'left' or 'right'. Default is 'right'. |
| | |
| | Returns: |
| | tuple: (state_t0, beta, fov_x, None) where: |
| | - state_t0 (np.ndarray): Hand state [51] containing translation (3), |
| | global rotation (3 euler angles), and hand pose (45 euler angles) |
| | - beta (np.ndarray): MANO shape parameters [10] |
| | - fov_x (float): Horizontal field of view in degrees |
| | - None: Placeholder for optional text annotations |
| | """ |
| | if hand_side not in ['left', 'right']: |
| | raise ValueError(f"hand_side must be 'left' or 'right', got '{hand_side}'") |
| | |
| | hand_pose_t0 = hand_data[hand_side][0]['hand_pose'] |
| | hand_pose_t0_euler = R.from_matrix(hand_pose_t0).as_euler('xyz', degrees=False) |
| | hand_pose_t0_euler = hand_pose_t0_euler.reshape(-1) |
| | global_orient_mat_t0 = hand_data[hand_side][0]['global_orient'] |
| | R_t0_euler = R.from_matrix(global_orient_mat_t0).as_euler('xyz', degrees=False) |
| | transl_t0 = hand_data[hand_side][0]['transl'] |
| | state_t0 = np.concatenate([transl_t0, R_t0_euler, hand_pose_t0_euler]) |
| | fov_x = hand_data['fov_x'] |
| |
|
| | return state_t0, hand_data[hand_side][0]['beta'], fov_x, None |
| |
|
| | def euler_traj_to_rotmat_traj(euler_traj, T): |
| | """ |
| | Convert Euler angle trajectory to rotation matrix trajectory. |
| | |
| | Converts a sequence of hand poses represented as Euler angles into |
| | rotation matrices suitable for MANO model input. |
| | |
| | Args: |
| | euler_traj (np.ndarray): Hand pose trajectory as Euler angles. |
| | Shape: [T, 45] where T is number of timesteps |
| | and 45 = 15 joints * 3 Euler angles per joint |
| | T (int): Number of timesteps in the trajectory |
| | |
| | Returns: |
| | np.ndarray: Rotation matrix trajectory. Shape: [T, 15, 3, 3] |
| | where each [3, 3] block is a rotation matrix for one joint |
| | """ |
| | hand_pose = euler_traj.reshape(-1, 3) |
| | pose_matrices = R.from_euler('xyz', hand_pose).as_matrix() |
| | pose_matrices = pose_matrices.reshape(T, 15, 3, 3) |
| |
|
| | return pose_matrices |
| |
|
| |
|
| | def _hand_reconstruction_worker(args_dict, task_queue, result_queue): |
| | """ |
| | Persistent worker for hand reconstruction that runs in a separate process. |
| | Keeps model loaded and processes multiple requests until shutdown signal. |
| | """ |
| | from data.tools.hand_recon_core import Config, HandReconstructor |
| | |
| | hand_reconstructor = None |
| | |
| | try: |
| | |
| | class ArgsObj: |
| | pass |
| | args_obj = ArgsObj() |
| | for key, value in args_dict.items(): |
| | setattr(args_obj, key, value) |
| | |
| | |
| | print("[HandRecon Process] Initializing hand reconstructor...") |
| | config = Config(args_obj) |
| | hand_reconstructor = HandReconstructor(config=config, device='cuda') |
| | print("[HandRecon Process] Hand reconstructor ready") |
| | |
| | |
| | result_queue.put({'type': 'ready'}) |
| | |
| | |
| | while True: |
| | task = task_queue.get() |
| | |
| | if task['type'] == 'shutdown': |
| | print("[HandRecon Process] Received shutdown signal") |
| | break |
| | |
| | elif task['type'] == 'reconstruct': |
| | try: |
| | image_path = task['image_path'] |
| | image = cv2.imread(image_path) |
| | if image is None: |
| | raise ValueError(f"Failed to load image from {image_path}") |
| | |
| | image_list = [image] |
| | recon_results = hand_reconstructor.recon(image_list) |
| | |
| | result_queue.put({ |
| | 'type': 'result', |
| | 'success': True, |
| | 'data': recon_results |
| | }) |
| | |
| | except Exception as e: |
| | import traceback |
| | result_queue.put({ |
| | 'type': 'result', |
| | 'success': False, |
| | 'error': str(e), |
| | 'traceback': traceback.format_exc() |
| | }) |
| | |
| | except Exception as e: |
| | import traceback |
| | result_queue.put({ |
| | 'type': 'error', |
| | 'error': str(e), |
| | 'traceback': traceback.format_exc() |
| | }) |
| | |
| | finally: |
| | |
| | if hand_reconstructor is not None: |
| | del hand_reconstructor |
| | torch.cuda.empty_cache() |
| | torch.cuda.synchronize() |
| | print("[HandRecon Process] Cleaned up and exiting") |
| |
|
| |
|
| | def _vla_inference_worker(configs_dict, task_queue, result_queue): |
| | """ |
| | Persistent worker for VLA model inference that runs in a separate process. |
| | Keeps model loaded and processes multiple requests until shutdown signal. |
| | """ |
| | from vitra.models import load_model |
| | from vitra.utils.data_utils import load_normalizer |
| | from vitra.datasets.human_dataset import pad_state_human, pad_action |
| | from vitra.datasets.dataset_utils import ActionFeature, StateFeature |
| | |
| | model = None |
| | normalizer = None |
| | |
| | try: |
| | |
| | print("[VLA Process] Loading VLA model...") |
| | model = load_model(configs_dict).cuda() |
| | model.eval() |
| | normalizer = load_normalizer(configs_dict) |
| | print(f"[VLA Process] VLA model ready.") |
| | |
| | |
| | result_queue.put({'type': 'ready'}) |
| | |
| | |
| | while True: |
| | task = task_queue.get() |
| | |
| | if task['type'] == 'shutdown': |
| | print("[VLA Process] Received shutdown signal") |
| | break |
| | |
| | elif task['type'] == 'predict': |
| | try: |
| | image = task['image'] |
| | instruction = task['instruction'] |
| | state = task['state'] |
| | state_mask = task['state_mask'] |
| | action_mask = task['action_mask'] |
| | fov = task['fov'] |
| | num_ddim_steps = task.get('num_ddim_steps', 10) |
| | cfg_scale = task.get('cfg_scale', 5.0) |
| | sample_times = task.get('sample_times', 1) |
| | |
| | |
| | norm_state = normalizer.normalize_state(state.copy()) |
| | |
| | |
| | unified_action_dim = ActionFeature.ALL_FEATURES[1] |
| | unified_state_dim = StateFeature.ALL_FEATURES[1] |
| | |
| | unified_state, unified_state_mask = pad_state_human( |
| | state=norm_state, |
| | state_mask=state_mask, |
| | action_dim=normalizer.action_mean.shape[0], |
| | state_dim=normalizer.state_mean.shape[0], |
| | unified_state_dim=unified_state_dim, |
| | ) |
| | _, unified_action_mask = pad_action( |
| | actions=None, |
| | action_mask=action_mask.copy(), |
| | action_dim=normalizer.action_mean.shape[0], |
| | unified_action_dim=unified_action_dim |
| | ) |
| | |
| | |
| | fov = torch.from_numpy(fov).unsqueeze(0) |
| | unified_state = unified_state.unsqueeze(0) |
| | unified_state_mask = unified_state_mask.unsqueeze(0) |
| | unified_action_mask = unified_action_mask.unsqueeze(0) |
| | |
| | |
| | norm_action = model.predict_action( |
| | image=image, |
| | instruction=instruction, |
| | current_state=unified_state, |
| | current_state_mask=unified_state_mask, |
| | action_mask_torch=unified_action_mask, |
| | num_ddim_steps=num_ddim_steps, |
| | cfg_scale=cfg_scale, |
| | fov=fov, |
| | sample_times=sample_times, |
| | ) |
| | |
| | |
| | norm_action = norm_action[:, :, :102] |
| | unnorm_action = normalizer.unnormalize_action(norm_action) |
| | |
| | |
| | if isinstance(unnorm_action, torch.Tensor): |
| | unnorm_action_np = unnorm_action.cpu().numpy() |
| | else: |
| | unnorm_action_np = np.array(unnorm_action) |
| | |
| | result_queue.put({ |
| | 'type': 'result', |
| | 'success': True, |
| | 'data': unnorm_action_np |
| | }) |
| | |
| | except Exception as e: |
| | import traceback |
| | result_queue.put({ |
| | 'type': 'result', |
| | 'success': False, |
| | 'error': str(e), |
| | 'traceback': traceback.format_exc() |
| | }) |
| | |
| | except Exception as e: |
| | import traceback |
| | result_queue.put({ |
| | 'type': 'error', |
| | 'error': str(e), |
| | 'traceback': traceback.format_exc() |
| | }) |
| | |
| | finally: |
| | |
| | if model is not None: |
| | del model |
| | if normalizer is not None: |
| | del normalizer |
| | torch.cuda.empty_cache() |
| | torch.cuda.synchronize() |
| | print("[VLA Process] Cleaned up and exiting") |
| |
|
| |
|
| | class HandReconstructionService: |
| | """Service wrapper for persistent hand reconstruction process""" |
| | |
| | def __init__(self, args): |
| | self.ctx = mp.get_context('spawn') |
| | self.task_queue = self.ctx.Queue() |
| | self.result_queue = self.ctx.Queue() |
| | |
| | |
| | args_dict = { |
| | 'hawor_model_path': args.hawor_model_path, |
| | 'detector_path': args.detector_path, |
| | 'moge_model_name': args.moge_model_name, |
| | 'mano_path': args.mano_path, |
| | } |
| | |
| | |
| | self.process = self.ctx.Process( |
| | target=_hand_reconstruction_worker, |
| | args=(args_dict, self.task_queue, self.result_queue) |
| | ) |
| | self.process.start() |
| | |
| | |
| | ready_msg = self.result_queue.get() |
| | if ready_msg['type'] == 'ready': |
| | print("Hand reconstruction service initialized") |
| | elif ready_msg['type'] == 'error': |
| | raise RuntimeError(f"Failed to initialize hand reconstruction: {ready_msg['error']}") |
| | |
| | def reconstruct(self, image_path): |
| | """Request hand reconstruction for an image""" |
| | self.task_queue.put({ |
| | 'type': 'reconstruct', |
| | 'image_path': image_path |
| | }) |
| | |
| | result = self.result_queue.get() |
| | if result['type'] == 'result' and result['success']: |
| | return result['data'] |
| | else: |
| | raise RuntimeError(f"Hand reconstruction failed: {result.get('error', 'Unknown error')}") |
| | |
| | def shutdown(self): |
| | """Shutdown the persistent process""" |
| | self.task_queue.put({'type': 'shutdown'}) |
| | self.process.join(timeout=10) |
| | if self.process.is_alive(): |
| | self.process.terminate() |
| | self.process.join() |
| |
|
| |
|
| | class VLAInferenceService: |
| | """Service wrapper for persistent VLA inference process""" |
| | |
| | def __init__(self, configs): |
| | self.ctx = mp.get_context('spawn') |
| | self.task_queue = self.ctx.Queue() |
| | self.result_queue = self.ctx.Queue() |
| | |
| | |
| | self.process = self.ctx.Process( |
| | target=_vla_inference_worker, |
| | args=(configs, self.task_queue, self.result_queue) |
| | ) |
| | self.process.start() |
| | |
| | |
| | ready_msg = self.result_queue.get() |
| | if ready_msg['type'] == 'ready': |
| | print("VLA inference service initialized") |
| | elif ready_msg['type'] == 'error': |
| | raise RuntimeError(f"Failed to initialize VLA model: {ready_msg['error']}") |
| | |
| | def predict(self, image, instruction, state, state_mask, action_mask, |
| | fov, num_ddim_steps=10, cfg_scale=5.0, sample_times=1): |
| | """Request action prediction with state normalization and padding""" |
| |
|
| | self.task_queue.put({ |
| | 'type': 'predict', |
| | 'image': image, |
| | 'instruction': instruction, |
| | 'state': state, |
| | 'state_mask': state_mask, |
| | 'action_mask': action_mask, |
| | 'fov': fov, |
| | 'num_ddim_steps': num_ddim_steps, |
| | 'cfg_scale': cfg_scale, |
| | 'sample_times': sample_times, |
| | }) |
| | |
| | result = self.result_queue.get() |
| | if result['type'] == 'result' and result['success']: |
| | |
| | return result['data'] |
| | else: |
| | raise RuntimeError(f"VLA inference failed: {result.get('error', 'Unknown error')}") |
| | |
| | def shutdown(self): |
| | """Shutdown the persistent process""" |
| | self.task_queue.put({'type': 'shutdown'}) |
| | self.process.join(timeout=10) |
| | if self.process.is_alive(): |
| | self.process.terminate() |
| | self.process.join() |
| |
|
| |
|
| | if __name__ == "__main__": |
| | |
| | mp.set_start_method('spawn', force=True) |
| | main() |