Skip to content

Commit 5eaa5c6

Browse files
authored
Merge pull request #92 from octo-models/new_release
Update with many small changes/fixes
2 parents cab7f94 + d53b1ed commit 5eaa5c6

37 files changed

+1990
-1443
lines changed

README.md

Lines changed: 22 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
# Octo
2-
[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1z0vELj_lX9OWeoMG_WvXnQs43aPOEAhz?usp=sharing)
2+
[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://githubtocolab.com/octo-models/octo/blob/main/examples/01_inference_pretrained.ipynb)
33
[![License: MIT](https://img.shields.io/badge/License-MIT-yellow.svg)](https://opensource.org/licenses/MIT)
44
[![Static Badge](https://img.shields.io/badge/Project-Page-a)](https://octo-models.github.io/)
55
![](https://github.com/rail-berkeley/octo/workflows/run-debug/badge.svg)
@@ -15,7 +15,7 @@ for an inference example.
1515

1616
```python
1717
from octo.model.octo_model import OctoModel
18-
model = OctoModel.load_pretrained("hf://rail-berkeley/octo-base")
18+
model = OctoModel.load_pretrained("hf://rail-berkeley/octo-base-1.5")
1919
print(model.get_pretty_spec())
2020
```
2121

@@ -48,7 +48,7 @@ See the [Jax Github page](https://github.com/google/jax) for more details on ins
4848

4949
Test the installation by finetuning on the debug dataset:
5050
```bash
51-
python scripts/finetune.py --config.pretrained_path=hf://rail-berkeley/octo-small --debug
51+
python scripts/finetune.py --config.pretrained_path=hf://rail-berkeley/octo-small-1.5 --debug
5252
```
5353

5454
## Checkpoints
@@ -99,7 +99,7 @@ We provide a [minimal example](examples/02_finetune_new_observation_action.py) f
9999
We also provide a more advanced finetuning script that allows you to change hyperparameters via a config file and logs finetuning
100100
metrics. To run advanced finetuning, use:
101101
```bash
102-
python scripts/finetune.py --config.pretrained_path=hf://rail-berkeley/octo-small
102+
python scripts/finetune.py --config.pretrained_path=hf://rail-berkeley/octo-small-1.5
103103
```
104104

105105
We offer three finetuning modes depending on the parts of the model that are kept frozen: ```head_only```, ```head_mlp_only```, and ```full``` to finetune the full model.
@@ -114,9 +114,9 @@ Loading and running a trained Octo model is as easy as:
114114
```python
115115
from octo.model import OctoModel
116116

117-
model = OctoModel.load_pretrained("hf://rail-berkeley/octo-small")
117+
model = OctoModel.load_pretrained("hf://rail-berkeley/octo-small-1.5")
118118
task = model.create_tasks(texts=["pick up the spoon"])
119-
action = model.sample_action(observation, task, rng=jax.random.PRNGKey(0))
119+
action = model.sample_actions(observation, task, rng=jax.random.PRNGKey(0))
120120
```
121121

122122
We provide examples for evaluating Octo [in a simulated Gym environment](examples/03_eval_finetuned.py) as well
@@ -140,20 +140,29 @@ To evaluate on your own environment, simply wrap it in a Gym interface and follo
140140
| Visualization | [visualization_lib.py](octo/utils/visualization_lib.py) | Utilities for offline qualitative & quantitative eval. |
141141

142142
## FAQ
143-
#### What is the `pad_mask` in the observation dictionary?
144-
The `pad_mask` indicates which observations should be attended to, which is important when using multiple timesteps of observation history. Octo was trained with a history window size of 2, meaning the model can predict an action using both the current observation and the previous observation. However, at the very beginning of the trajectory, there is no previous observation, so we need to set `pad_mask=False` at the corresponding index. If you use Octo with a window size of 1, pad_mask should always just be `[True]`, indicating that the one and only observation in the window should be attended to. Note that if you wrap your robot environment with the `HistoryWrapper` (see [gym_wrappers.py](octo/utils/gym_wrappers.py)), the `pad_mask` key will be added to the observation dictionary for you.
143+
#### What is the `timestep_pad_mask` in the observation dictionary?
144+
The `timestep_pad_mask` indicates which observations should be attended to, which is important when using multiple timesteps of observation history. Octo was trained with a history window size of 2, meaning the model can predict an action using both the current observation and the previous observation. However, at the very beginning of the trajectory, there is no previous observation, so we need to set `timestep_pad_mask=False` at the corresponding index. If you use Octo with a window size of 1, `timestep_pad_mask` should always just be `[True]`, indicating that the one and only observation in the window should be attended to. Note that if you wrap your robot environment with the `HistoryWrapper` (see [gym_wrappers.py](octo/utils/gym_wrappers.py)), the `timestep_pad_mask` key will be added to the observation dictionary for you.
145145
#### What is `pad_mask_dict` in the observation dictionary?
146-
While `pad_mask` indicates which observations should be attended to on a timestep level, `pad_mask_dict` indicates which elements of the observation should be attended to within a single timestep. For example, for datasets without language labels, `pad_mask_dict["language_instruction"]` is set to `False`. For datasets without a wrist camera, `pad_mask_dict["image_wrist"]` is set to `False`. For convenience, if a key is missing from the observation dict, it is equivalent to setting `pad_mask_dict` to `False` for that key.
146+
While `timestep_pad_mask` indicates which observations should be attended to on a timestep level, `pad_mask_dict` indicates which elements of the observation should be attended to within a single timestep. For example, for datasets without language labels, `pad_mask_dict["language_instruction"]` is set to `False`. For datasets without a wrist camera, `pad_mask_dict["image_wrist"]` is set to `False`. For convenience, if a key is missing from the observation dict, it is equivalent to setting `pad_mask_dict` to `False` for that key.
147147
#### Does `model.sample_actions([...])` return the full trajectory to solve a task?
148148
Octo was pretrained with an action chunking size of 4, meaning it predicts the next 4 actions at once. You can choose to execute all these actions before sampling new ones, or only execute the first action before sampling new ones (also known as receding horizon control). You can also do something more advanced like [temporal ensembling](octo/utils/gym_wrappers.py).
149149

150+
## Updates for Version 1.5
151+
- Improved cross-attention between visual and language tokens by repeating language tokens at every timestep in the context window.
152+
- Augmented the language instructions in the data with rephrasings from GPT-3.5.
153+
- Bug fixes:
154+
- Turned off dropout in the diffusion head due to incompatibility with layer norm.
155+
- Fixed an off-by-one error with the attention mask.
156+
- Fixed an issue where different image augmentations did not get fresh random seeds.
157+
150158
## Citation
151159

152160
```
153-
@misc{octo_2023,
161+
@inproceedings{octo_2023,
154162
title={Octo: An Open-Source Generalist Robot Policy},
155-
author = {{Octo Model Team} and Dibya Ghosh and Homer Walke and Karl Pertsch and Kevin Black and Oier Mees and Sudeep Dasari and Joey Hejna and Charles Xu and Jianlan Luo and Tobias Kreiman and {You Liang} Tan and Dorsa Sadigh and Chelsea Finn and Sergey Levine},
156-
howpublished = {\url{https://octo-models.github.io}},
157-
year = {2023},
163+
author = {{Octo Model Team} and Dibya Ghosh and Homer Walke and Karl Pertsch and Kevin Black and Oier Mees and Sudeep Dasari and Joey Hejna and Charles Xu and Jianlan Luo and Tobias Kreiman and {You Liang} Tan and Pannag Sanketi and Quan Vuong and Ted Xiao and Dorsa Sadigh and Chelsea Finn and Sergey Levine},
164+
booktitle = {Proceedings of Robotics: Science and Systems},
165+
address = {Delft, Netherlands},
166+
year = {2024},
158167
}
159168
```

docs/assets/teaser.jpg

-73.4 KB
Loading

examples/01_inference_pretrained.ipynb

Lines changed: 55 additions & 135 deletions
Large diffs are not rendered by default.

examples/02_finetune_new_observation_action.py

Lines changed: 8 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
55
To run this example, first download and extract the dataset from here: https://rail.eecs.berkeley.edu/datasets/example_sim_data.zip
66
7-
python examples/02_finetune_new_observation_action.py --pretrained_path=hf://rail-berkeley/octo-small --data_dir=...
7+
python examples/02_finetune_new_observation_action.py --pretrained_path=hf://rail-berkeley/octo-small-1.5 --data_dir=...
88
"""
99
from absl import app, flags, logging
1010
import flax
@@ -15,7 +15,6 @@
1515
import wandb
1616

1717
from octo.data.dataset import make_single_dataset
18-
from octo.data.utils.data_utils import NormalizationType
1918
from octo.model.components.action_heads import L1ActionHead
2019
from octo.model.components.tokenizers import LowdimObsTokenizer
2120
from octo.model.octo_model import OctoModel
@@ -70,14 +69,12 @@ def main(_):
7069
name="aloha_sim_cube_scripted_dataset",
7170
data_dir=FLAGS.data_dir,
7271
image_obs_keys={"primary": "top"},
73-
state_obs_keys=["state"],
72+
proprio_obs_key="state",
7473
language_key="language_instruction",
75-
action_proprio_normalization_type=NormalizationType.NORMAL,
76-
absolute_action_mask=[True] * 14,
7774
),
7875
traj_transform_kwargs=dict(
7976
window_size=1,
80-
future_action_window_size=49, # so we get 50 actions for our action chunk
77+
action_horizon=50,
8178
),
8279
frame_transform_kwargs=dict(
8380
resize_size={"primary": (256, 256)},
@@ -116,10 +113,10 @@ def process_batch(batch):
116113
high=2.0,
117114
obs_keys=["proprio"],
118115
)
119-
# Fully override the old action head with a new one (for smaller changes, you can use update_module_config)
116+
# Fully override the old action head with a new one (for smaller changes, you can use update_config)
120117
config["model"]["heads"]["action"] = ModuleSpec.create(
121118
L1ActionHead,
122-
pred_horizon=50,
119+
action_horizon=50,
123120
action_dim=14,
124121
readout_key="readout_action",
125122
)
@@ -162,13 +159,14 @@ def loss_fn(params, batch, rng, train=True):
162159
transformer_embeddings = bound_module.octo_transformer(
163160
batch["observation"],
164161
batch["task"],
165-
batch["observation"]["pad_mask"],
162+
batch["observation"]["timestep_pad_mask"],
166163
train=train,
167164
)
168165
action_loss, action_metrics = bound_module.heads["action"].loss(
169166
transformer_embeddings, # Action head knows to pull out the action readout_key
170167
batch["action"],
171-
pad_mask=batch["observation"]["pad_mask"],
168+
batch["observation"]["timestep_pad_mask"],
169+
batch["action_pad_mask"],
172170
train=train,
173171
)
174172
return action_loss, action_metrics

examples/03_eval_finetuned.py

Lines changed: 18 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
"""
22
This script demonstrates how to load and rollout a finetuned Octo model.
3-
We use the Octo model finetuned on ALOHA sim data from the examples/finetune_new_observation_action.py script.
3+
We use the Octo model finetuned on ALOHA sim data from the examples/02_finetune_new_observation_action.py script.
44
55
For installing the ALOHA sim environment, clone: https://github.com/tonyzhaozh/act
66
Then run:
@@ -15,6 +15,7 @@
1515
cd examples
1616
python3 03_eval_finetuned.py --finetuned_path=<path_to_finetuned_aloha_checkpoint>
1717
"""
18+
from functools import partial
1819
import sys
1920

2021
from absl import app, flags, logging
@@ -25,10 +26,12 @@
2526

2627
sys.path.append("path/to/your/act")
2728

28-
from envs.aloha_sim_env import AlohaGymEnv # keep this to register ALOHA sim env
29+
# keep this to register ALOHA sim env
30+
from envs.aloha_sim_env import AlohaGymEnv # noqa
2931

3032
from octo.model.octo_model import OctoModel
31-
from octo.utils.gym_wrappers import HistoryWrapper, RHCWrapper, UnnormalizeActionProprio
33+
from octo.utils.gym_wrappers import HistoryWrapper, NormalizeProprio, RHCWrapper
34+
from octo.utils.train_callbacks import supply_rng
3235

3336
FLAGS = flags.FLAGS
3437

@@ -49,27 +52,31 @@ def main(_):
4952
##################################################################################################################
5053
# environment needs to implement standard gym interface + return observations of the following form:
5154
# obs = {
52-
# "image_0": ...
53-
# "image_1": ...
55+
# "image_primary": ...
5456
# }
5557
# it should also implement an env.get_task() function that returns a task dict with goal and/or language instruct.
5658
# task = {
5759
# "language_instruction": "some string"
5860
# "goal": {
59-
# "image_0": ...
60-
# "image_1": ...
61+
# "image_primary": ...
6162
# }
6263
# }
6364
##################################################################################################################
6465
env = gym.make("aloha-sim-cube-v0")
6566

67+
# wrap env to normalize proprio
68+
env = NormalizeProprio(env, model.dataset_statistics)
69+
6670
# add wrappers for history and "receding horizon control", i.e. action chunking
6771
env = HistoryWrapper(env, horizon=1)
6872
env = RHCWrapper(env, exec_horizon=50)
6973

70-
# wrap env to handle action/proprio normalization -- match normalization type to the one used during finetuning
71-
env = UnnormalizeActionProprio(
72-
env, model.dataset_statistics, normalization_type="normal"
74+
# the supply_rng wrapper supplies a new random key to sample_actions every time it's called
75+
policy_fn = supply_rng(
76+
partial(
77+
model.sample_actions,
78+
unnormalization_statistics=model.dataset_statistics["action"],
79+
),
7380
)
7481

7582
# running rollouts
@@ -85,9 +92,7 @@ def main(_):
8592
episode_return = 0.0
8693
while len(images) < 400:
8794
# model returns actions of shape [batch, pred_horizon, action_dim] -- remove batch
88-
actions = model.sample_actions(
89-
jax.tree_map(lambda x: x[None], obs), task, rng=jax.random.PRNGKey(0)
90-
)
95+
actions = policy_fn(jax.tree_map(lambda x: x[None], obs), task)
9196
actions = actions[0]
9297

9398
# step env -- info contains full "chunk" of observations for logging

examples/04_eval_finetuned_on_robot.py

Lines changed: 17 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -22,11 +22,8 @@
2222
from widowx_envs.widowx_env_service import WidowXClient, WidowXConfigs, WidowXStatus
2323

2424
from octo.model.octo_model import OctoModel
25-
from octo.utils.gym_wrappers import (
26-
HistoryWrapper,
27-
TemporalEnsembleWrapper,
28-
UnnormalizeActionProprio,
29-
)
25+
from octo.utils.gym_wrappers import HistoryWrapper, TemporalEnsembleWrapper
26+
from octo.utils.train_callbacks import supply_rng
3027

3128
np.set_printoptions(suppress=True)
3229

@@ -50,9 +47,10 @@
5047
flags.DEFINE_integer("im_size", None, "Image size", required=True)
5148
flags.DEFINE_string("video_save_path", None, "Path to save video")
5249
flags.DEFINE_integer("num_timesteps", 120, "num timesteps")
53-
flags.DEFINE_integer("horizon", 1, "Observation history length")
54-
flags.DEFINE_integer("pred_horizon", 1, "Length of action sequence from model")
55-
flags.DEFINE_integer("exec_horizon", 1, "Length of action sequence to execute")
50+
flags.DEFINE_integer("window_size", 2, "Observation history length")
51+
flags.DEFINE_integer(
52+
"action_horizon", 4, "Length of action sequence to execute/ensemble"
53+
)
5654

5755

5856
# show image flag
@@ -64,10 +62,9 @@
6462
Bridge data was collected with non-blocking control and a step duration of 0.2s.
6563
However, we relabel the actions to make it look like the data was collected with
6664
blocking control and we evaluate with blocking control.
67-
We also use a step duration of 0.4s to reduce the jerkiness of the policy.
68-
Be sure to change the step duration back to 0.2 if evaluating with non-blocking control.
65+
Be sure to use a step duration of 0.2 if evaluating with non-blocking control.
6966
"""
70-
STEP_DURATION = 0.4
67+
STEP_DURATION = 0.2
7168
STICKY_GRIPPER_NUM_STEPS = 1
7269
WORKSPACE_BOUNDS = [[0.1, -0.15, -0.01, -1.57, 0], [0.45, 0.25, 0.25, 1.57, 0]]
7370
CAMERA_TOPICS = [{"name": "/blue/image_raw"}]
@@ -107,16 +104,12 @@ def main(_):
107104
)
108105

109106
# wrap the robot environment
110-
env = UnnormalizeActionProprio(
111-
env, model.dataset_statistics["bridge_dataset"], normalization_type="normal"
112-
)
113-
env = HistoryWrapper(env, FLAGS.horizon)
114-
env = TemporalEnsembleWrapper(env, FLAGS.pred_horizon)
107+
env = HistoryWrapper(env, FLAGS.window_size)
108+
env = TemporalEnsembleWrapper(env, FLAGS.action_horizon)
115109
# switch TemporalEnsembleWrapper with RHCWrapper for receding horizon control
116-
# env = RHCWrapper(env, FLAGS.exec_horizon)
110+
# env = RHCWrapper(env, FLAGS.action_horizon)
117111

118-
# create policy function
119-
@jax.jit
112+
# create policy functions
120113
def sample_actions(
121114
pretrained_model: OctoModel,
122115
observations,
@@ -129,22 +122,19 @@ def sample_actions(
129122
observations,
130123
tasks,
131124
rng=rng,
125+
unnormalization_statistics=pretrained_model.dataset_statistics[
126+
"bridge_dataset"
127+
]["action"],
132128
)
133129
# remove batch dim
134130
return actions[0]
135131

136-
def supply_rng(f, rng=jax.random.PRNGKey(0)):
137-
def wrapped(*args, **kwargs):
138-
nonlocal rng
139-
rng, key = jax.random.split(rng)
140-
return f(*args, rng=key, **kwargs)
141-
142-
return wrapped
143-
144132
policy_fn = supply_rng(
145133
partial(
146134
sample_actions,
147135
model,
136+
argmax=FLAGS.deterministic,
137+
temperature=FLAGS.temperature,
148138
)
149139
)
150140

examples/05_dataloading.ipynb

Lines changed: 18 additions & 253 deletions
Large diffs are not rendered by default.

examples/06_pytorch_oxe_dataloader.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,7 @@ def __len__(self):
6464
traj_transform_kwargs=dict(
6565
goal_relabeling_strategy="uniform",
6666
window_size=2,
67-
future_action_window_size=3,
67+
action_horizon=4,
6868
subsample_length=100,
6969
),
7070
frame_transform_kwargs=dict(

examples/envs/widowx_env.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -42,14 +42,12 @@ def convert_obs(obs, im_size):
4242
# NOTE: assume image_1 is not available
4343
return {
4444
"image_primary": image_obs,
45-
"proprio": proprio,
4645
}
4746

4847

4948
def null_obs(img_size):
5049
return {
5150
"image_primary": np.zeros((img_size, img_size, 3), dtype=np.uint8),
52-
"proprio": np.zeros((8,), dtype=np.float64),
5351
}
5452

5553

0 commit comments

Comments
 (0)