diff --git a/Dockerfile b/Dockerfile index b7b6285..59d66c1 100644 --- a/Dockerfile +++ b/Dockerfile @@ -3,6 +3,9 @@ FROM python:3.11.8 +# Dependencies for opencv +RUN apt update && apt upgrade -y && apt install -y ffmpeg + # Set up a new user named "user" with user ID 1000 RUN useradd -m -u 1000 user @@ -22,6 +25,9 @@ RUN pip install --no-cache-dir --upgrade pip # Copy the current directory contents into the container at $HOME/app setting the owner to the user COPY --chown=user . $HOME/app +# https://github.com/huggingface/lerobot/issues/105 +RUN pip install --no-cache-dir --upgrade cmake + # Install requirements.txt RUN pip install --no-cache-dir --upgrade -r requirements.txt diff --git a/app.py b/app.py index 8b66e7d..171ab88 100644 --- a/app.py +++ b/app.py @@ -18,8 +18,9 @@ from fastapi import FastAPI from fastapi.middleware.cors import CORSMiddleware from gradio_huggingfacehub_search import HuggingfaceHubSearch +from lerobot.common.datasets.lerobot_dataset import LeRobotDataset -from dataset_conversion import log_dataset_to_rerun +from dataset_conversion import log_dataset_to_rerun, log_lerobot_dataset_to_rerun CUSTOM_PATH = "/" @@ -50,14 +51,17 @@ def show_dataset(dataset_id: str, episode_index: int) -> str: rr.save(filename.as_posix()) - dataset = load_dataset(dataset_id, split="train", streaming=True) - - # This is for LeRobot datasets (https://huggingface.co/lerobot): - ds_subset = dataset.filter( - lambda frame: "episode_index" not in frame or frame["episode_index"] == episode_index - ) - - log_dataset_to_rerun(ds_subset) + if "/" in dataset_id and dataset_id.split("/")[0] == "lerobot": + dataset = LeRobotDataset(dataset_id) + log_lerobot_dataset_to_rerun(dataset, episode_index) + else: + dataset = load_dataset(dataset_id, split="train", streaming=True) + + # This is for LeRobot datasets (https://huggingface.co/lerobot): + ds_subset = dataset.filter( + lambda frame: "episode_index" not in frame or frame["episode_index"] == episode_index + ) + log_dataset_to_rerun(ds_subset) return filename.as_posix() diff --git a/dataset_conversion.py b/dataset_conversion.py index ed2928d..265b8f4 100644 --- a/dataset_conversion.py +++ b/dataset_conversion.py @@ -1,17 +1,55 @@ from __future__ import annotations import logging +from pathlib import PosixPath from typing import Any +import cv2 import numpy as np import rerun as rr +import torch +from lerobot.common.datasets.lerobot_dataset import LeRobotDataset from PIL import Image from tqdm import tqdm logger = logging.getLogger(__name__) -def to_rerun(column_name: str, value: Any) -> Any: +def get_frame( + video_path: PosixPath, timestamp: float, video_cache: dict[PosixPath, tuple[np.ndarray, float]] | None = None +) -> np.ndarray: + """ + Extracts a specific frame from a video. + + `video_path`: path to the video. + `timestamp`: timestamp of the wanted frame. + `video_cache`: cache to prevent reading the same video file twice. + """ + + if video_cache is None: + video_cache = {} + if video_path not in video_cache: + cap = cv2.VideoCapture(str(video_path)) + frames = [] + while cap.isOpened(): + success, frame = cap.read() + if success: + frames.append(frame) + else: + break + frame_rate = cap.get(cv2.CAP_PROP_FPS) + video_cache[video_path] = (frames, frame_rate) + + frames, frame_rate = video_cache[video_path] + return frames[int(timestamp * frame_rate)] + + +def to_rerun( + column_name: str, + value: Any, + video_cache: dict[PosixPath, tuple[np.ndarray, float]] | None = None, + videos_dir: PosixPath | None = None, +) -> Any: """Do our best to interpret the value and convert it to a Rerun-compatible archetype.""" if isinstance(value, Image.Image): if "depth" in column_name: @@ -27,22 +65,47 @@ def to_rerun(column_name: str, value: Any) -> Any: return rr.TextDocument(str(value)) # Fallback to text elif isinstance(value, float) or isinstance(value, int): return rr.Scalar(value) + elif isinstance(value, torch.Tensor): + if value.dim() == 0: + return rr.Scalar(value.item()) + elif value.dim() == 1: + return rr.BarChart(value) + elif value.dim() == 2 and "depth" in column_name: + return rr.DepthImage(value) + elif value.dim() == 2: + return rr.Image(value) + elif value.dim() == 3 and (value.shape[2] == 3 or value.shape[2] == 4): + return rr.Image(value) # Treat it as a RGB or RGBA image + else: + return rr.Tensor(value) + elif isinstance(value, dict) and "path" in value and "timestamp" in value: + path = (videos_dir or PosixPath("./")) / PosixPath(value["path"]) + timestamp = value["timestamp"] + return rr.Image(get_frame(path, timestamp, video_cache=video_cache)) else: return rr.TextDocument(str(value)) # Fallback to text -def log_dataset_to_rerun(dataset: Any) -> None: - # Special time-like columns for LeRobot datasets (https://huggingface.co/datasets/lerobot/): +def log_lerobot_dataset_to_rerun(dataset: LeRobotDataset, episode_index: int) -> None: + # Special time-like columns for LeRobot datasets (https://huggingface.co/lerobot/): TIME_LIKE = {"index", "frame_id", "timestamp"} # Ignore these columns (again, LeRobot-specific): IGNORE = {"episode_data_index_from", "episode_data_index_to", "episode_id"} - for row in tqdm(dataset): + hf_ds_subset = dataset.hf_dataset.filter( + lambda frame: "episode_index" not in frame or frame["episode_index"] == episode_index + ) + + video_cache: dict[PosixPath, tuple[np.ndarray, float]] = {} + + for row in tqdm(hf_ds_subset): # Handle time-like columns first, since they set a state (time is an index in Rerun): for column_name in TIME_LIKE: if column_name in row: cell = row[column_name] + if isinstance(cell, torch.Tensor) and cell.dim() == 0: + cell = cell.item() if isinstance(cell, int): rr.set_time_sequence(column_name, cell) elif isinstance(cell, float): @@ -54,5 +117,30 @@ def log_dataset_to_rerun(dataset: Any) -> None: for column_name, cell in row.items(): if column_name in TIME_LIKE or column_name in IGNORE: continue + else: + rr.log( + column_name, + to_rerun(column_name, cell, video_cache=video_cache, videos_dir=dataset.videos_dir.parent), + ) + + +def log_dataset_to_rerun(dataset: Any) -> None: + TIME_LIKE = {"index", "frame_id", "timestamp"} + + for row in tqdm(dataset): + # Handle time-like columns first, since they set a state (time is an index in Rerun): + for column_name in TIME_LIKE: + if column_name in row: + cell = row[column_name] + if isinstance(cell, int): + rr.set_time_sequence(column_name, cell) + elif isinstance(cell, float): + rr.set_time_seconds(column_name, cell) # assume seconds + else: + print(f"Unknown time-like column {column_name} with value {cell}") + # Now log actual data columns: + for column_name, cell in row.items(): + if column_name in TIME_LIKE: + continue rr.log(column_name, to_rerun(column_name, cell)) diff --git a/requirements.txt b/requirements.txt index 58af223..b6ca5e4 100644 --- a/requirements.txt +++ b/requirements.txt @@ -5,4 +5,6 @@ gradio_huggingfacehub_search pillow rerun-sdk>=0.15.0,<0.16.0 tqdm +opencv-python webdataset +git+https://github.com/huggingface/lerobot@7bb5b15f4c0393ba16b73f6482611892301401d7#egg=lerobot