Instructions to use ksridhar/atari_2B_atari_airraid_1111 with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- sample-factory
How to use ksridhar/atari_2B_atari_airraid_1111 with sample-factory:
python -m sample_factory.huggingface.load_from_hub -r ksridhar/atari_2B_atari_airraid_1111 -d ./train_dir
- Notebooks
- Google Colab
- Kaggle
| diff --git a/README.md b/README.md | |
| index e51a12b..a6e1ca1 100644 | |
| --- a/README.md | |
| +++ b/README.md | |
| conda activate jat | |
| pip install -e .[dev] | |
| ``` | |
| +## REGENT fork of sample-factory: Installation | |
| +Following [this install ink](https://www.samplefactory.dev/01-get-started/installation/) but for the fork: | |
| +```shell | |
| +git clone https://github.com/kaustubhsridhar/sample-factory.git | |
| +cd sample-factory | |
| +pip install -e .[dev,mujoco,atari,envpool,vizdoom] | |
| +``` | |
| + | |
| +# Regent fork of sample-factory: Train Unseen Env Policies and Generate Datasets | |
| +Train policies using envpool's atari: | |
| +```shell | |
| +bash scripts_sample-factory/train_unseen_atari.sh | |
| +``` | |
| +Note that the training command inside the above script was obtained from the config files of Ed Beeching's Atari 57 models on Huggingface. An example is [here](https://huggingface.co/edbeeching/atari_2B_atari_mspacman_1111/blob/main/cfg.json#L124). See my discussion [here](https://huggingface.co/edbeeching/atari_2B_atari_mspacman_1111/discussions/2). | |
| + | |
| ## PREV Installation | |
| To get started with JAT, follow these steps: | |
| python -u scripts_jat_regent/eval_RandP.py --task ${TASK} &> outputs/RandP/${TAS | |
| ``` | |
| ### REGENT Analyze data | |
| +Necessary: | |
| ```shell | |
| -python -u examples_regent/compare_datasets.py &> examples_regent/compare_datasets.txt & | |
| - | |
| python -u examples_regent/analyze_rows_tokenized.py &> examples_regent/analyze_rows_tokenized.txt & | |
| +``` | |
| +Already ran and output dict in code: | |
| +```shell | |
| python -u examples_regent/get_dim_all_vector_tasks.py &> examples_regent/get_dim_all_vector_tasks.txt & | |
| + | |
| +python -u examples_regent/count_rows_to_consider.py &> examples_regent/count_rows_to_consider.txt & | |
| +``` | |
| + | |
| +Optional: | |
| +```shell | |
| +python -u examples_regent/compare_datasets.py &> examples_regent/compare_datasets.txt & | |
| ``` | |
| ## PREV Dataset | |
| diff --git a/jat_regent/RandP.py b/jat_regent/RandP.py | |
| deleted file mode 100644 | |
| index b2bd8bf..0000000 | |
| --- a/jat_regent/RandP.py | |
| +++ /dev/null | |
| -import warnings | |
| -from dataclasses import dataclass | |
| -from typing import List, Optional, Tuple, Union | |
| - | |
| -import numpy as np | |
| -import torch | |
| -import torch.nn.functional as F | |
| -from gymnasium import spaces | |
| -from torch import BoolTensor, FloatTensor, LongTensor, Tensor, nn | |
| -from transformers import GPTNeoModel, GPTNeoPreTrainedModel | |
| -from transformers.modeling_outputs import ModelOutput | |
| -from transformers.models.vit.modeling_vit import ViTPatchEmbeddings | |
| - | |
| -from jat.configuration_jat import JatConfig | |
| -from jat.processing_jat import JatProcessor | |
| - | |
| - | |
| -class RandP(): | |
| - def __init__(self, dataset) -> None: | |
| - self.steps = 0 | |
| - # create an index for retrieval in vector obs envs (OR) collect all images in Atari | |
| - | |
| - def reset_rl(self): | |
| - self.steps = 0 | |
| - | |
| - def get_next_action( | |
| - self, | |
| - processor: JatProcessor, | |
| - continuous_observation: Optional[List[float]] = None, | |
| - discrete_observation: Optional[List[int]] = None, | |
| - text_observation: Optional[str] = None, | |
| - image_observation: Optional[np.ndarray] = None, | |
| - action_space: Union[spaces.Box, spaces.Discrete] = None, | |
| - reward: Optional[float] = None, | |
| - deterministic: bool = False, | |
| - context_window: Optional[int] = None, | |
| - ): | |
| - pass | |
| \ No newline at end of file | |
| diff --git a/jat_regent/modelling_jat_regent.py b/jat_regent/modelling_jat_regent.py | |
| deleted file mode 100644 | |
| index e69de29..0000000 | |
| diff --git a/jat_regent/utils.py b/jat_regent/utils.py | |
| index 56bfb44..36f6cca 100644 | |
| --- a/jat_regent/utils.py | |
| +++ b/jat_regent/utils.py | |
| from tqdm import tqdm | |
| from autofaiss import build_index | |
| +UNSEEN_TASK_NAMES = { # Total -- atari: 57, metaworld: 50, babyai: 39, mujoco: 11 | |
| + | |
| +} | |
| + | |
| def myprint(str): | |
| - # check if first character of string is a newline character | |
| - if str[0] == '\n': | |
| - str_without_newline = str[1:] | |
| + # check if first characters of string are newline character | |
| + num_newlines = 0 | |
| + while str[num_newlines] == '\n': | |
| print() | |
| - else: | |
| - str_without_newline = str | |
| + num_newlines += 1 | |
| + str_without_newline = str[num_newlines:] | |
| print(f'{datetime.now().strftime("%Y-%m-%d %H:%M:%S")}: {str_without_newline}') | |
| def is_png_img(item): | |
| return isinstance(item, PngImagePlugin.PngImageFile) | |
| +def get_last_row_for_1M_states(task): | |
| + last_row_idx = {'atari-alien': 14134, 'atari-amidar': 14319, 'atari-assault': 14427, 'atari-asterix': 14456, 'atari-asteroids': 14348, 'atari-atlantis': 14325, 'atari-bankheist': 14167, 'atari-battlezone': 13981, 'atari-beamrider': 13442, 'atari-berzerk': 13534, 'atari-bowling': 14110, 'atari-boxing': 14542, 'atari-breakout': 13474, 'atari-centipede': 14196, 'atari-choppercommand': 13397, 'atari-crazyclimber': 14026, 'atari-defender': 13504, 'atari-demonattack': 13499, 'atari-doubledunk': 14292, 'atari-enduro': 13260, 'atari-fishingderby': 14073, 'atari-freeway': 14016, 'atari-frostbite': 14075, 'atari-gopher': 13143, 'atari-gravitar': 14405, 'atari-hero': 14044, 'atari-icehockey': 14017, 'atari-jamesbond': 12678, 'atari-kangaroo': 14248, 'atari-krull': 14204, 'atari-kungfumaster': 14030, 'atari-montezumarevenge': 14219, 'atari-mspacman': 14120, 'atari-namethisgame': 13575, 'atari-phoenix': 13539, 'atari-pitfall': 14287, 'atari-pong': 14151, 'atari-privateeye': 14105, 'atari-qbert': 14026, 'atari-riverraid': 14275, 'atari-roadrunner': 14127, 'atari-robotank': 14079, 'atari-seaquest': 14097, 'atari-skiing': 14708, 'atari-solaris': 14199, 'atari-spaceinvaders': 12652, 'atari-stargunner': 13822, 'atari-surround': 13840, 'atari-tennis': 14062, 'atari-timepilot': 13896, 'atari-tutankham': 13121, 'atari-upndown': 13504, 'atari-venture': 14260, 'atari-videopinball': 14272, 'atari-wizardofwor': 13920, 'atari-yarsrevenge': 13981, 'atari-zaxxon': 13833, 'babyai-action-obj-door': 95000, 'babyai-blocked-unlock-pickup': 29279, 'babyai-boss-level-no-unlock': 12087, 'babyai-boss-level': 12101, 'babyai-find-obj-s5': 32974, 'babyai-go-to-door': 95000, 'babyai-go-to-imp-unlock': 9286, 'babyai-go-to-local': 95000, 'babyai-go-to-obj-door': 95000, 'babyai-go-to-obj': 95000, 'babyai-go-to-red-ball-grey': 95000, 'babyai-go-to-red-ball-no-dists': 95000, 'babyai-go-to-red-ball': 95000, 'babyai-go-to-red-blue-ball': 95000, 'babyai-go-to-seq': 13744, 'babyai-go-to': 18974, 'babyai-key-corridor': 9014, 'babyai-mini-boss-level': 38119, 'babyai-move-two-across-s8n9': 24505, 'babyai-one-room-s8': 95000, 'babyai-open-door': 95000, 'babyai-open-doors-order-n4': 95000, 'babyai-open-red-door': 95000, 'babyai-open-two-doors': 73291, 'babyai-open': 32559, 'babyai-pickup-above': 34084, 'babyai-pickup-dist': 89640, 'babyai-pickup-loc': 95000, 'babyai-pickup': 18670, 'babyai-put-next-local': 83187, 'babyai-put-next': 56986, 'babyai-synth-loc': 21605, 'babyai-synth-seq': 13049, 'babyai-synth': 19409, 'babyai-unblock-pickup': 17881, 'babyai-unlock-local': 71186, 'babyai-unlock-pickup': 50883, 'babyai-unlock-to-unlock': 23062, 'babyai-unlock': 11734, 'metaworld-assembly': 10000, 'metaworld-basketball': 10000, 'metaworld-bin-picking': 10000, 'metaworld-box-close': 10000, 'metaworld-button-press-topdown-wall': 10000, 'metaworld-button-press-topdown': 10000, 'metaworld-button-press-wall': 10000, 'metaworld-button-press': 10000, 'metaworld-coffee-button': 10000, 'metaworld-coffee-pull': 10000, 'metaworld-coffee-push': 10000, 'metaworld-dial-turn': 10000, 'metaworld-disassemble': 10000, 'metaworld-door-close': 10000, 'metaworld-door-lock': 10000, 'metaworld-door-open': 10000, 'metaworld-door-unlock': 10000, 'metaworld-drawer-close': 10000, 'metaworld-drawer-open': 10000, 'metaworld-faucet-close': 10000, 'metaworld-faucet-open': 10000, 'metaworld-hammer': 10000, 'metaworld-hand-insert': 10000, 'metaworld-handle-press-side': 10000, 'metaworld-handle-press': 10000, 'metaworld-handle-pull-side': 10000, 'metaworld-handle-pull': 10000, 'metaworld-lever-pull': 10000, 'metaworld-peg-insert-side': 10000, 'metaworld-peg-unplug-side': 10000, 'metaworld-pick-out-of-hole': 10000, 'metaworld-pick-place-wall': 10000, 'metaworld-pick-place': 10000, 'metaworld-plate-slide-back-side': 10000, 'metaworld-plate-slide-back': 10000, 'metaworld-plate-slide-side': 10000, 'metaworld-plate-slide': 10000, 'metaworld-push-back': 10000, 'metaworld-push-wall': 10000, 'metaworld-push': 10000, 'metaworld-reach-wall': 10000, 'metaworld-reach': 10000, 'metaworld-shelf-place': 10000, 'metaworld-soccer': 10000, 'metaworld-stick-pull': 10000, 'metaworld-stick-push': 10000, 'metaworld-sweep-into': 10000, 'metaworld-sweep': 10000, 'metaworld-window-close': 10000, 'metaworld-window-open': 10000, 'mujoco-ant': 4023, 'mujoco-doublependulum': 4002, 'mujoco-halfcheetah': 4000, 'mujoco-hopper': 4931, 'mujoco-humanoid': 4119, 'mujoco-pendulum': 4959, 'mujoco-pusher': 9000, 'mujoco-reacher': 9000, 'mujoco-standup': 4000, 'mujoco-swimmer': 4000, 'mujoco-walker': 4101} | |
| + return last_row_idx[task] | |
| + | |
| +def get_last_row_for_100k_states(task): | |
| + last_row_idx = {'atari-alien': 3135, 'atari-amidar': 3142, 'atari-assault': 3132, 'atari-asterix': 3181, 'atari-asteroids': 3127, 'atari-atlantis': 3128, 'atari-bankheist': 3156, 'atari-battlezone': 3136, 'atari-beamrider': 3131, 'atari-berzerk': 3127, 'atari-bowling': 3148, 'atari-boxing': 3227, 'atari-breakout': 3128, 'atari-centipede': 3176, 'atari-choppercommand': 3144, 'atari-crazyclimber': 3134, 'atari-defender': 3127, 'atari-demonattack': 3127, 'atari-doubledunk': 3175, 'atari-enduro': 3126, 'atari-fishingderby': 3155, 'atari-freeway': 3131, 'atari-frostbite': 3146, 'atari-gopher': 3128, 'atari-gravitar': 3202, 'atari-hero': 3144, 'atari-icehockey': 3138, 'atari-jamesbond': 3131, 'atari-kangaroo': 3160, 'atari-krull': 3162, 'atari-kungfumaster': 3143, 'atari-montezumarevenge': 3168, 'atari-mspacman': 3143, 'atari-namethisgame': 3131, 'atari-phoenix': 3127, 'atari-pitfall': 3131, 'atari-pong': 3160, 'atari-privateeye': 3158, 'atari-qbert': 3136, 'atari-riverraid': 3157, 'atari-roadrunner': 3150, 'atari-robotank': 3133, 'atari-seaquest': 3138, 'atari-skiing': 3271, 'atari-solaris': 3129, 'atari-spaceinvaders': 3128, 'atari-stargunner': 3129, 'atari-surround': 3143, 'atari-tennis': 3129, 'atari-timepilot': 3132, 'atari-tutankham': 3127, 'atari-upndown': 3127, 'atari-venture': 3148, 'atari-videopinball': 3130, 'atari-wizardofwor': 3138, 'atari-yarsrevenge': 3129, 'atari-zaxxon': 3133, 'babyai-action-obj-door': 15923, 'babyai-blocked-unlock-pickup': 2919, 'babyai-boss-level-no-unlock': 1217, 'babyai-boss-level': 1159, 'babyai-find-obj-s5': 3345, 'babyai-go-to-door': 18875, 'babyai-go-to-imp-unlock': 923, 'babyai-go-to-local': 18724, 'babyai-go-to-obj-door': 16472, 'babyai-go-to-obj': 20197, 'babyai-go-to-red-ball-grey': 16953, 'babyai-go-to-red-ball-no-dists': 20165, 'babyai-go-to-red-ball': 18730, 'babyai-go-to-red-blue-ball': 16934, 'babyai-go-to-seq': 1439, 'babyai-go-to': 1964, 'babyai-key-corridor': 900, 'babyai-mini-boss-level': 3789, 'babyai-move-two-across-s8n9': 2462, 'babyai-one-room-s8': 16994, 'babyai-open-door': 13565, 'babyai-open-doors-order-n4': 9706, 'babyai-open-red-door': 21185, 'babyai-open-two-doors': 7348, 'babyai-open': 3331, 'babyai-pickup-above': 3392, 'babyai-pickup-dist': 19693, 'babyai-pickup-loc': 16405, 'babyai-pickup': 1806, 'babyai-put-next-local': 8303, 'babyai-put-next': 5703, 'babyai-synth-loc': 2183, 'babyai-synth-seq': 1316, 'babyai-synth': 1964, 'babyai-unblock-pickup': 1886, 'babyai-unlock-local': 7118, 'babyai-unlock-pickup': 5107, 'babyai-unlock-to-unlock': 2309, 'babyai-unlock': 1177, 'metaworld-assembly': 1000, 'metaworld-basketball': 1000, 'metaworld-bin-picking': 1000, 'metaworld-box-close': 1000, 'metaworld-button-press-topdown-wall': 1000, 'metaworld-button-press-topdown': 1000, 'metaworld-button-press-wall': 1000, 'metaworld-button-press': 1000, 'metaworld-coffee-button': 1000, 'metaworld-coffee-pull': 1000, 'metaworld-coffee-push': 1000, 'metaworld-dial-turn': 1000, 'metaworld-disassemble': 1000, 'metaworld-door-close': 1000, 'metaworld-door-lock': 1000, 'metaworld-door-open': 1000, 'metaworld-door-unlock': 1000, 'metaworld-drawer-close': 1000, 'metaworld-drawer-open': 1000, 'metaworld-faucet-close': 1000, 'metaworld-faucet-open': 1000, 'metaworld-hammer': 1000, 'metaworld-hand-insert': 1000, 'metaworld-handle-press-side': 1000, 'metaworld-handle-press': 1000, 'metaworld-handle-pull-side': 1000, 'metaworld-handle-pull': 1000, 'metaworld-lever-pull': 1000, 'metaworld-peg-insert-side': 1000, 'metaworld-peg-unplug-side': 1000, 'metaworld-pick-out-of-hole': 1000, 'metaworld-pick-place-wall': 1000, 'metaworld-pick-place': 1000, 'metaworld-plate-slide-back-side': 1000, 'metaworld-plate-slide-back': 1000, 'metaworld-plate-slide-side': 1000, 'metaworld-plate-slide': 1000, 'metaworld-push-back': 1000, 'metaworld-push-wall': 1000, 'metaworld-push': 1000, 'metaworld-reach-wall': 1000, 'metaworld-reach': 1000, 'metaworld-shelf-place': 1000, 'metaworld-soccer': 1000, 'metaworld-stick-pull': 1000, 'metaworld-stick-push': 1000, 'metaworld-sweep-into': 1000, 'metaworld-sweep': 1000, 'metaworld-window-close': 1000, 'metaworld-window-open': 1000, 'mujoco-ant': 401, 'mujoco-doublependulum': 401, 'mujoco-halfcheetah': 400, 'mujoco-hopper': 491, 'mujoco-humanoid': 415, 'mujoco-pendulum': 495, 'mujoco-pusher': 1000, 'mujoco-reacher': 2000, 'mujoco-standup': 400, 'mujoco-swimmer': 400, 'mujoco-walker': 407} | |
| + return last_row_idx[task] | |
| + | |
| def get_obs_dim(task): | |
| assert task.startswith("babyai") or task.startswith("metaworld") or task.startswith("mujoco") | |
| all_obs_dims={'babyai-action-obj-door': 212, 'babyai-blocked-unlock-pickup': 212, 'babyai-boss-level-no-unlock': 212, 'babyai-boss-level': 212, 'babyai-find-obj-s5': 212, 'babyai-go-to-door': 212, 'babyai-go-to-imp-unlock': 212, 'babyai-go-to-local': 212, 'babyai-go-to-obj-door': 212, 'babyai-go-to-obj': 212, 'babyai-go-to-red-ball-grey': 212, 'babyai-go-to-red-ball-no-dists': 212, 'babyai-go-to-red-ball': 212, 'babyai-go-to-red-blue-ball': 212, 'babyai-go-to-seq': 212, 'babyai-go-to': 212, 'babyai-key-corridor': 212, 'babyai-mini-boss-level': 212, 'babyai-move-two-across-s8n9': 212, 'babyai-one-room-s8': 212, 'babyai-open-door': 212, 'babyai-open-doors-order-n4': 212, 'babyai-open-red-door': 212, 'babyai-open-two-doors': 212, 'babyai-open': 212, 'babyai-pickup-above': 212, 'babyai-pickup-dist': 212, 'babyai-pickup-loc': 212, 'babyai-pickup': 212, 'babyai-put-next-local': 212, 'babyai-put-next': 212, 'babyai-synth-loc': 212, 'babyai-synth-seq': 212, 'babyai-synth': 212, 'babyai-unblock-pickup': 212, 'babyai-unlock-local': 212, 'babyai-unlock-pickup': 212, 'babyai-unlock-to-unlock': 212, 'babyai-unlock': 212, 'metaworld-assembly': 39, 'metaworld-basketball': 39, 'metaworld-bin-picking': 39, 'metaworld-box-close': 39, 'metaworld-button-press-topdown-wall': 39, 'metaworld-button-press-topdown': 39, 'metaworld-button-press-wall': 39, 'metaworld-button-press': 39, 'metaworld-coffee-button': 39, 'metaworld-coffee-pull': 39, 'metaworld-coffee-push': 39, 'metaworld-dial-turn': 39, 'metaworld-disassemble': 39, 'metaworld-door-close': 39, 'metaworld-door-lock': 39, 'metaworld-door-open': 39, 'metaworld-door-unlock': 39, 'metaworld-drawer-close': 39, 'metaworld-drawer-open': 39, 'metaworld-faucet-close': 39, 'metaworld-faucet-open': 39, 'metaworld-hammer': 39, 'metaworld-hand-insert': 39, 'metaworld-handle-press-side': 39, 'metaworld-handle-press': 39, 'metaworld-handle-pull-side': 39, 'metaworld-handle-pull': 39, 'metaworld-lever-pull': 39, 'metaworld-peg-insert-side': 39, 'metaworld-peg-unplug-side': 39, 'metaworld-pick-out-of-hole': 39, 'metaworld-pick-place-wall': 39, 'metaworld-pick-place': 39, 'metaworld-plate-slide-back-side': 39, 'metaworld-plate-slide-back': 39, 'metaworld-plate-slide-side': 39, 'metaworld-plate-slide': 39, 'metaworld-push-back': 39, 'metaworld-push-wall': 39, 'metaworld-push': 39, 'metaworld-reach-wall': 39, 'metaworld-reach': 39, 'metaworld-shelf-place': 39, 'metaworld-soccer': 39, 'metaworld-stick-pull': 39, 'metaworld-stick-push': 39, 'metaworld-sweep-into': 39, 'metaworld-sweep': 39, 'metaworld-window-close': 39, 'metaworld-window-open': 39, 'mujoco-ant': 27, 'mujoco-doublependulum': 11, 'mujoco-halfcheetah': 17, 'mujoco-hopper': 11, 'mujoco-humanoid': 376, 'mujoco-pendulum': 4, 'mujoco-pusher': 23, 'mujoco-reacher': 11, 'mujoco-standup': 376, 'mujoco-swimmer': 8, 'mujoco-walker': 17} | |
| - return all_obs_dims[task] | |
| + return (all_obs_dims[task],) | |
| def get_act_dim(task): | |
| assert task.startswith("babyai") or task.startswith("metaworld") or task.startswith("mujoco") | |
| def get_act_dim(task): | |
| elif task.startswith("mujoco"): | |
| all_act_dims={'mujoco-ant': 8, 'mujoco-doublependulum': 1, 'mujoco-halfcheetah': 6, 'mujoco-hopper': 3, 'mujoco-humanoid': 17, 'mujoco-pendulum': 1, 'mujoco-pusher': 7, 'mujoco-reacher': 2, 'mujoco-standup': 17, 'mujoco-swimmer': 2, 'mujoco-walker': 6} | |
| return all_act_dims[task] | |
| - | |
| -def process_row_atari(attn_mask, row_of_obs, task): | |
| - """ | |
| - Example for selection with bools: | |
| - >>> a = np.array([0,1,2,3,4,5]) | |
| - >>> b = np.array([1,0,0,0,0,1]).astype(bool) | |
| - >>> a[b] | |
| - array([0, 5]) | |
| - """ | |
| - attn_mask = np.array(attn_mask).astype(bool) | |
| - row_of_obs = torch.stack([to_tensor(np.array(img)) for img in row_of_obs]) | |
| - row_of_obs = row_of_obs[attn_mask] | |
| +def get_task_info(task): | |
| + rew_key = 'rewards' | |
| + attn_key = 'attention_mask' | |
| + if task.startswith("atari"): | |
| + obs_key = 'image_observations' | |
| + act_key = 'discrete_actions' | |
| + B = 32 # half of 54 | |
| + obs_dim = (3, 4*84, 84) | |
| + elif task.startswith("babyai"): | |
| + obs_key = 'discrete_observations' # also has 'text_observations' only for raw dataset not for tokenized dataset (as it is combined into discrete_observation in tokenized dataset) | |
| + act_key = 'discrete_actions' | |
| + B = 256 # half of 512 | |
| + obs_dim = get_obs_dim(task) | |
| + elif task.startswith("metaworld") or task.startswith("mujoco"): | |
| + obs_key = 'continuous_observations' | |
| + act_key = 'continuous_actions' | |
| + B = 256 | |
| + obs_dim = get_obs_dim(task) | |
| + | |
| + return rew_key, attn_key, obs_key, act_key, B, obs_dim | |
| + | |
| +def process_row_of_obs_atari_full_without_mask(row_of_obs): | |
| + | |
| + if not isinstance(row_of_obs, torch.Tensor): | |
| + row_of_obs = torch.stack([to_tensor(np.array(img)) for img in row_of_obs]) | |
| row_of_obs = row_of_obs * 0.5 + 0.5 # denormalize from [-1, 1] to [0, 1] | |
| - assert row_of_obs.shape == (sum(attn_mask), 84, 4, 84) | |
| + assert row_of_obs.shape == (len(row_of_obs), 84, 4, 84) | |
| row_of_obs = row_of_obs.permute(0, 2, 1, 3) # (*, 4, 84, 84) | |
| - row_of_obs = row_of_obs.reshape(sum(attn_mask), 4*84, 84) # put side-by-side | |
| + row_of_obs = row_of_obs.reshape(len(row_of_obs), 4*84, 84) # put side-by-side | |
| row_of_obs = row_of_obs.unsqueeze(1).repeat(1, 3, 1, 1) # repeat for 3 channels | |
| - assert row_of_obs.shape == (sum(attn_mask), 3, 4*84, 84) # sum(attn_mask) is the batch size dimension | |
| - | |
| - return attn_mask, row_of_obs | |
| + assert row_of_obs.shape == (len(row_of_obs), 3, 4*84, 84) # sum(attn_mask) is the batch size dimension | |
| + | |
| + return row_of_obs | |
| -def process_row_vector(attn_mask, row_of_obs, task, return_numpy=False): | |
| - attn_mask = np.array(attn_mask).astype(bool) | |
| +def collect_all_atari_data(dataset, all_row_idxs=None): | |
| + if all_row_idxs is None: | |
| + all_row_idxs = list(range(len(dataset['train']))) | |
| - row_of_obs = np.array(row_of_obs) | |
| - if not return_numpy: | |
| - row_of_obs = torch.tensor(row_of_obs) | |
| - row_of_obs = row_of_obs[attn_mask] | |
| - assert row_of_obs.shape == (sum(attn_mask), get_obs_dim(task)) | |
| - | |
| - return attn_mask, row_of_obs | |
| - | |
| -def retrieve_atari(row_of_obs, # query: (row_B, 3, 4*84, 84) | |
| - dataset, # to retrieve from | |
| - all_rows_to_consider, # rows to consider | |
| - num_to_retrieve, # top-k | |
| + all_rows_of_obs = [] | |
| + all_attn_masks = [] | |
| + for row_idx in tqdm(all_row_idxs): | |
| + datarow = dataset['train'][row_idx] | |
| + row_of_obs = process_row_of_obs_atari_full_without_mask(datarow['image_observations']) | |
| + attn_mask = np.array(datarow['attention_mask']).astype(bool) | |
| + all_rows_of_obs.append(row_of_obs) # appending tensor | |
| + all_attn_masks.append(attn_mask) # appending np array | |
| + all_rows_of_obs = torch.stack(all_rows_of_obs, dim=0) # stacking tensors | |
| + all_attn_masks = np.stack(all_attn_masks, axis=0) # concatenating np arrays | |
| + assert (all_rows_of_obs.shape == (len(all_row_idxs), 32, 3, 4*84, 84) and | |
| + all_attn_masks.shape == (len(all_row_idxs), 32)) | |
| + return all_attn_masks, all_rows_of_obs | |
| + | |
| +def collect_all_data(dataset, task, obs_key): | |
| + last_row_idx = get_last_row_for_100k_states(task) | |
| + all_row_idxs = list(range(last_row_idx)) | |
| + if task.startswith("atari"): | |
| + myprint("Collecting all Atari images and Atari attention masks...") | |
| + all_attn_masks_OG, all_rows_of_obs_OG = collect_all_atari_data(dataset, all_row_idxs) | |
| + else: | |
| + datarows = dataset['train'][all_row_idxs] | |
| + all_rows_of_obs_OG = np.array(datarows[obs_key]) | |
| + all_attn_masks_OG = np.array(datarows['attention_mask']).astype(bool) | |
| + return all_rows_of_obs_OG, all_attn_masks_OG, all_row_idxs | |
| + | |
| +def collect_subset(all_rows_of_obs_OG, | |
| + all_attn_masks_OG, | |
| + all_rows_to_consider, | |
| + kwargs | |
| + ): | |
| + """ | |
| + Function to collect subset of data given all_rows_to_consider, reshape it, create all_indices and return. | |
| + Used in both retrieve_atari() and retrieve_vector() --> build_index_vector(). | |
| + """ | |
| + myprint(f'\n\n\n' + ('-'*100) + f'Collecting subset...') | |
| + # read kwargs | |
| + B, task, obs_dim = kwargs['B'], kwargs['task'], kwargs['obs_dim'] | |
| + | |
| + # take subset based on all_rows_to_consider | |
| + myprint(f'Taking subset of data based on all_rows_to_consider...') | |
| + all_processed_rows_of_obs = all_rows_of_obs_OG[all_rows_to_consider] | |
| + all_attn_masks = all_attn_masks_OG[all_rows_to_consider] | |
| + assert (all_processed_rows_of_obs.shape == (len(all_rows_to_consider), B, *obs_dim) and | |
| + all_attn_masks.shape == (len(all_rows_to_consider), B)) | |
| + | |
| + # reshape | |
| + myprint(f'Reshaping data...') | |
| + all_attn_masks = all_attn_masks.reshape(-1) | |
| + all_processed_rows_of_obs = all_processed_rows_of_obs.reshape(-1, *obs_dim) | |
| + all_processed_rows_of_obs = all_processed_rows_of_obs[all_attn_masks] | |
| + assert (all_attn_masks.shape == (len(all_rows_to_consider) * B,) and | |
| + all_processed_rows_of_obs.shape == (np.sum(all_attn_masks), *obs_dim)) | |
| + | |
| + # collect indices of data | |
| + myprint(f'Collecting indices of data...') | |
| + all_indices = np.array([[row_idx, i] for row_idx in all_rows_to_consider for i in range(B)]) | |
| + all_indices = all_indices[all_attn_masks] # this is fine because all attn masks have 0s that only come after 1s | |
| + assert all_indices.shape == (np.sum(all_attn_masks), 2) | |
| + | |
| + myprint(f'{all_indices.shape=}, {all_processed_rows_of_obs.shape=}') | |
| + myprint(('-'*100) + '\n\n\n') | |
| + return all_indices, all_processed_rows_of_obs | |
| + | |
| +def retrieve_atari(row_of_obs, # query: (xbdim, 3, 4*84, 84) / (xdim *obs_dim) | |
| + all_processed_rows_of_obs, | |
| + all_indices, | |
| + num_to_retrieve, | |
| kwargs | |
| - ): | |
| + ): | |
| + """ | |
| + Retrieval for Atari with images, ssim distance, and on GPU. | |
| + """ | |
| assert isinstance(row_of_obs, torch.Tensor) | |
| # read kwargs # Note: B = len of row | |
| - B, attn_key, obs_key, device, task, batch_size_retrieval = kwargs['B'], kwargs['attn_key'], kwargs['obs_key'], kwargs['device'], kwargs['task'], kwargs['batch_size_retrieval'] | |
| + B, device, batch_size_retrieval = kwargs['B'], kwargs['device'], kwargs['batch_size_retrieval'] | |
| # batch size of row_of_obs which can be <= B since we process before calling this function | |
| - row_B = row_of_obs.shape[0] | |
| - | |
| + xbdim = row_of_obs.shape[0] | |
| + | |
| + # collect subset of data that we can retrieve from | |
| + ydim = all_processed_rows_of_obs.shape[0] | |
| + | |
| # first argument for ssim | |
| - repeated_row_og = row_of_obs.repeat_interleave(B, dim=0).to(device) | |
| - assert repeated_row_og.shape == (row_B*B, 3, 4*84, 84) | |
| + xbatch = row_of_obs.repeat_interleave(batch_size_retrieval, dim=0).to(device) | |
| + assert xbatch.shape == (xbdim * batch_size_retrieval, 3, 4*84, 84) | |
| - # iterate over all other rows | |
| + # iterate over data that we can retrieve from in batches | |
| all_ssim = [] | |
| - all_indices = [] | |
| - total = 0 | |
| - for other_row_idx in tqdm(all_rows_to_consider): | |
| - other_attn_mask, other_row_of_obs = process_row_atari(dataset['train'][other_row_idx][attn_key], dataset['train'][other_row_idx][obs_key]) | |
| - | |
| - # batch size of other_row_of_obs | |
| - other_row_B = other_row_of_obs.shape[0] | |
| - total += other_row_B | |
| - | |
| - # first argument for ssim: RECHECK | |
| - if other_row_B < B: # when other row has less observations than expected | |
| - repeated_row = row_of_obs.repeat_interleave(other_row_B, dim=0).to(device) | |
| - elif other_row_B == B: # otherwise just use the one created before the for loop | |
| - repeated_row = repeated_row_og | |
| - assert repeated_row.shape == (row_B*other_row_B, 3, 4*84, 84) | |
| - | |
| + for j in range(0, ydim, batch_size_retrieval): | |
| # second argument for ssim | |
| - repeated_other_row = other_row_of_obs.repeat(row_B, 1, 1, 1).to(device) | |
| - assert repeated_other_row.shape == (row_B*other_row_B, 3, 4*84, 84) | |
| + ybatch = all_processed_rows_of_obs[j:j+batch_size_retrieval] | |
| + ybdim = ybatch.shape[0] | |
| + ybatch = ybatch.repeat(xbdim, 1, 1, 1).to(device) | |
| + assert ybatch.shape == (ybdim * xbdim, 3, 4*84, 84) | |
| + | |
| + if ybdim < batch_size_retrieval: # for last batch | |
| + xbatch = row_of_obs.repeat_interleave(ybdim, dim=0).to(device) | |
| + assert xbatch.shape == (xbdim * ybdim, 3, 4*84, 84) | |
| # compare via ssim and updated all_ssim | |
| - ssim_score = ssim(repeated_row, repeated_other_row, data_range=1.0, size_average=False) | |
| - ssim_score = ssim_score.reshape(row_B, other_row_B) | |
| + ssim_score = ssim(xbatch, ybatch, data_range=1.0, size_average=False) | |
| + ssim_score = ssim_score.reshape(xbdim, ybdim) | |
| all_ssim.append(ssim_score) | |
| - # update all_indices | |
| - all_indices.extend([[other_row_idx, i] for i in range(other_row_B)]) | |
| - | |
| # concat | |
| all_ssim = torch.cat(all_ssim, dim=1) | |
| - assert all_ssim.shape == (row_B, total) | |
| + assert all_ssim.shape == (xbdim, ydim) | |
| - all_indices = np.array(all_indices) | |
| - assert all_indices.shape == (total, 2) | |
| + assert all_indices.shape == (ydim, 2) | |
| # get top-k indices | |
| topk_values, topk_indices = torch.topk(all_ssim, num_to_retrieve, dim=1, largest=True) | |
| topk_indices = topk_indices.cpu().numpy() | |
| - assert topk_indices.shape == (row_B, num_to_retrieve) | |
| + assert topk_indices.shape == (xbdim, num_to_retrieve) | |
| # convert topk indices to indices in the dataset | |
| - retrieved_indices = np.array(all_indices[topk_indices]) | |
| - assert retrieved_indices.shape == (row_B, num_to_retrieve, 2) | |
| - | |
| - # pad the above to expected B | |
| - if row_B < B: | |
| - retrieved_indices = np.concatenate([retrieved_indices, np.zeros((B-row_B, num_to_retrieve, 2), dtype=int)], axis=0) | |
| - assert retrieved_indices.shape == (B, num_to_retrieve, 2) | |
| + retrieved_indices = all_indices[topk_indices] | |
| + assert retrieved_indices.shape == (xbdim, num_to_retrieve, 2) | |
| return retrieved_indices | |
| -def build_index_vector(all_rows_of_obs_og, | |
| - all_attn_masks_og, | |
| +def build_index_vector(all_rows_of_obs_OG, | |
| + all_attn_masks_OG, | |
| all_rows_to_consider, | |
| kwargs | |
| - ): | |
| + ): | |
| + """ | |
| + Builds FAISS index for vector observation environments. | |
| + """ | |
| # read kwargs # Note: B = len of row | |
| - B, attn_key, obs_key, device, task, batch_size_retrieval, nb_cores_autofaiss = kwargs['B'], kwargs['attn_key'], kwargs['obs_key'], kwargs['device'], kwargs['task'], kwargs['batch_size_retrieval'], kwargs['nb_cores_autofaiss'] | |
| - obs_dim = get_obs_dim(task) | |
| + nb_cores_autofaiss = kwargs['nb_cores_autofaiss'] | |
| - # take subset based on all_rows_to_consider | |
| - myprint(f'Taking subset') | |
| - all_rows_of_obs = all_rows_of_obs_og[all_rows_to_consider] | |
| - all_attn_masks = all_attn_masks_og[all_rows_to_consider] | |
| - assert (all_rows_of_obs.shape == (len(all_rows_to_consider), B, obs_dim) and | |
| - all_attn_masks.shape == (len(all_rows_to_consider), B)) | |
| - | |
| - # reshape | |
| - all_attn_masks = all_attn_masks.reshape(-1) | |
| - all_rows_of_obs = all_rows_of_obs.reshape(-1, obs_dim) | |
| - all_rows_of_obs = all_rows_of_obs[all_attn_masks] | |
| - assert all_rows_of_obs.shape == (np.sum(all_attn_masks), obs_dim) | |
| + # take subset based on all_rows_to_consider, reshape, and save indices of data | |
| + all_indices, all_processed_rows_of_obs = collect_subset(all_rows_of_obs_OG, all_attn_masks_OG, all_rows_to_consider, kwargs) | |
| - # save indices of data to retrieve from | |
| - myprint(f'Saving indices of data to retrieve from') | |
| - all_indices = np.array([[row_idx, i] for row_idx in all_rows_to_consider for i in range(B)]) | |
| - all_indices = all_indices[all_attn_masks] # this is fine because all attn masks have 0s that only come after 1s | |
| - assert all_indices.shape == (np.sum(all_attn_masks), 2) | |
| + # make sure input to build_index is float, otherwise you will get reading temp file error | |
| + all_processed_rows_of_obs = all_processed_rows_of_obs.astype(float) | |
| # build index | |
| - myprint(f'Building index...') | |
| - knn_index, knn_index_infos = build_index(embeddings=all_rows_of_obs, # Note: embeddings have to be float to avoid errors in autofaiss / embedding_reader! | |
| + myprint(('-'*100) + 'Building index...') | |
| + knn_index, knn_index_infos = build_index(embeddings=all_processed_rows_of_obs, # Note: embeddings have to be float to avoid errors in autofaiss / embedding_reader! | |
| save_on_disk=False, | |
| min_nearest_neighbors_to_retrieve=20, # default: 20 | |
| max_index_query_time_ms=10, # default: 10 | |
| def build_index_vector(all_rows_of_obs_og, | |
| metric_type='l2', | |
| nb_cores=nb_cores_autofaiss, # default: None # "The number of cores to use, by default will use all cores" as seen in https://criteo.github.io/autofaiss/getting_started/quantization.html#the-build-index-command | |
| ) | |
| + myprint(('-'*100) + '\n\n\n') | |
| - return knn_index, all_indices | |
| + return all_indices, knn_index | |
| -def retrieve_vector(row_of_obs, # query: (row_B, dim) | |
| - dataset, # to retrieve from | |
| - all_rows_to_consider, # rows to consider | |
| - num_to_retrieve, # top-k | |
| +def retrieve_vector(row_of_obs, # query: (xbdim, *obs_dim) | |
| + knn_index, | |
| + all_indices, | |
| + num_to_retrieve, | |
| kwargs | |
| - ): | |
| + ): | |
| + """ | |
| + Retrieval for vector observation environments. | |
| + """ | |
| assert isinstance(row_of_obs, np.ndarray) | |
| # read few kwargs | |
| B = kwargs['B'] | |
| # batch size of row_of_obs which can be <= B since we process before calling this function | |
| - row_B = row_of_obs.shape[0] | |
| + xbdim = row_of_obs.shape[0] | |
| - # read dataset_tuple | |
| - all_rows_of_obs, all_attn_masks = dataset | |
| - | |
| - # create index and all_indices | |
| - knn_index, all_indices = build_index_vector(all_rows_of_obs, all_attn_masks, all_rows_to_consider, kwargs) | |
| - | |
| # retrieve | |
| myprint(f'Retrieving...') | |
| topk_indices, _ = knn_index.search(row_of_obs, 10 * num_to_retrieve) | |
| topk_indices = topk_indices.astype(int) | |
| - assert topk_indices.shape == (row_B, 10 * num_to_retrieve) | |
| + assert topk_indices.shape == (xbdim, 10 * num_to_retrieve) | |
| # remove -1s and crop to num_to_retrieve | |
| try: | |
| def retrieve_vector(row_of_obs, # query: (row_B, dim) | |
| print(f'-------------------------------------------------------------------------------------------------------------------------------------------') | |
| print(f'Leaving some -1s in topk_indices and continuing') | |
| topk_indices = np.array([indices[:num_to_retrieve] for indices in topk_indices]) | |
| - assert topk_indices.shape == (row_B, num_to_retrieve) | |
| + assert topk_indices.shape == (xbdim, num_to_retrieve) | |
| # convert topk indices to indices in the dataset | |
| retrieved_indices = all_indices[topk_indices] | |
| - assert retrieved_indices.shape == (row_B, num_to_retrieve, 2) | |
| - | |
| - # pad the above to expected B | |
| - if row_B < B: | |
| - retrieved_indices = np.concatenate([retrieved_indices, np.zeros((B-row_B, num_to_retrieve, 2), dtype=int)], axis=0) | |
| - assert retrieved_indices.shape == (B, num_to_retrieve, 2) | |
| + assert retrieved_indices.shape == (xbdim, num_to_retrieve, 2) | |
| - myprint(f'Returning') | |
| return retrieved_indices | |
| \ No newline at end of file | |
| diff --git a/scripts_regent/eval_RandP.py b/scripts_regent/eval_RandP.py | |
| index 07e545c..146b347 100755 | |
| --- a/scripts_regent/eval_RandP.py | |
| +++ b/scripts_regent/eval_RandP.py | |
| from transformers import AutoModelForCausalLM, AutoProcessor, HfArgumentParser | |
| from jat.eval.rl import TASK_NAME_TO_ENV_ID, make | |
| from jat.utils import normalize, push_to_hub, save_video_grid | |
| -from jat_regent.RandP import RandP | |
| +from jat_regent.modeling_RandP import RandP | |
| from datasets import load_from_disk | |
| from datasets.config import HF_DATASETS_CACHE | |
| +from jat_regent.utils import myprint | |
| @dataclass | |
| def eval_rl(model, processor, task, eval_args): | |
| scores = [] | |
| frames = [] | |
| for episode in tqdm(range(eval_args.num_episodes), desc=task, unit="episode", leave=False): | |
| + myprint(('-'*100) + f'{episode=}') | |
| observation, _ = env.reset() | |
| reward = None | |
| rewards = [] | |
| def eval_rl(model, processor, task, eval_args): | |
| frames.append(np.array(env.render(), dtype=np.uint8)) | |
| scores.append(sum(rewards)) | |
| + myprint(('-'*100) + '\n\n\n') | |
| env.close() | |
| raw_mean, raw_std = np.mean(scores), np.std(scores) | |
| def main(): | |
| tasks.extend([env_id for env_id in TASK_NAME_TO_ENV_ID.keys() if env_id.startswith(domain)]) | |
| device = torch.device("cpu") if eval_args.use_cpu else get_default_device() | |
| - processor = None | |
| + processor = AutoProcessor.from_pretrained( | |
| + 'jat-project/jat', cache_dir=None, trust_remote_code=True | |
| + ) | |
| evaluations = {} | |
| video_list = [] | |
| def main(): | |
| for task in tqdm(tasks, desc="Evaluation", unit="task", leave=True): | |
| if task in TASK_NAME_TO_ENV_ID.keys(): | |
| + myprint(('-'*100) + f'{task=}') | |
| dataset = load_from_disk(f'{HF_DATASETS_CACHE}/jat-project/jat-dataset-tokenized/{task}') | |
| - model = RandP(dataset) | |
| + model = RandP(task, | |
| + dataset, | |
| + device,) | |
| scores, frames, fps = eval_rl(model, processor, task, eval_args) | |
| evaluations[task] = scores | |
| # Save the video | |
| if eval_args.save_video: | |
| video_list.append(frames) | |
| input_fps.append(fps) | |
| + myprint(('-'*100) + '\n\n\n') | |
| else: | |
| warnings.warn(f"Task {task} is not supported.") | |
| diff --git a/scripts_regent/offline_retrieval_jat_regent.py b/scripts_regent/offline_retrieval_jat_regent.py | |
| index c83d259..aad678a 100644 | |
| --- a/scripts_regent/offline_retrieval_jat_regent.py | |
| +++ b/scripts_regent/offline_retrieval_jat_regent.py | |
| import time | |
| from datetime import datetime | |
| from datasets import load_from_disk | |
| from datasets.config import HF_DATASETS_CACHE | |
| -from jat_regent.utils import myprint, process_row_atari, process_row_vector, retrieve_atari, retrieve_vector | |
| +from jat_regent.utils import myprint, get_task_info, collect_all_data, process_row_of_obs_atari_full_without_mask, retrieve_atari, retrieve_vector, collect_subset, build_index_vector | |
| import logging | |
| logging.basicConfig(level=logging.DEBUG) | |
| def main(): | |
| parser = argparse.ArgumentParser(description='Build RAAGENT sequence indices') | |
| parser.add_argument('--task', type=str, default='atari-alien', help='Task name') | |
| parser.add_argument('--num_to_retrieve', type=int, default=100, help='Number of states/windows to retrieve') | |
| - parser.add_argument('--nb_cores_autofaiss', type=int, default=16, help='Number of cores to use for faiss in vector observation environments') | |
| + parser.add_argument('--nb_cores_autofaiss', type=int, default=16, help='Number of cores to use for faiss in vector obs envs') | |
| + parser.add_argument('--batch_size_retrieval', type=int, default=1024, help='Batch size for retrieval in atari') | |
| args = parser.parse_args() | |
| # load dataset, map, device, for task | |
| def main(): | |
| dataset_path = f"{HF_DATASETS_CACHE}/jat-project/jat-dataset-tokenized/{task}" | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| - rew_key = 'rewards' | |
| - attn_key = 'attention_mask' | |
| - if task.startswith("atari"): | |
| - obs_key = 'image_observations' | |
| - act_key = 'discrete_actions' | |
| - len_row_tokenized_known = 32 # half of 54 | |
| - process_row_fn = process_row_atari | |
| - retrieve_fn = retrieve_atari | |
| - elif task.startswith("babyai"): | |
| - obs_key = 'discrete_observations'# also has 'text_observations' only for raw dataset not for tokenized dataset (as it is combined into discrete_observation in tokenized dataset) | |
| - act_key = 'discrete_actions' | |
| - len_row_tokenized_known = 256 # half of 512 | |
| - process_row_fn = lambda attn_mask, row_of_obs, task: process_row_vector(attn_mask, row_of_obs, task, return_numpy=True) | |
| - retrieve_fn = retrieve_vector | |
| - elif task.startswith("metaworld") or task.startswith("mujoco"): | |
| - obs_key = 'continuous_observations' | |
| - act_key = 'continuous_actions' | |
| - len_row_tokenized_known = 256 | |
| - process_row_fn = lambda attn_mask, row_of_obs, task: process_row_vector(attn_mask, row_of_obs, task, return_numpy=True) | |
| - retrieve_fn = retrieve_vector | |
| + rew_key, attn_key, obs_key, act_key, B, obs_dim = get_task_info(task) | |
| dataset = load_from_disk(dataset_path) | |
| with open(f"{dataset_path}/map_from_rows_to_episodes_for_tokenized.json", 'r') as f: | |
| map_from_rows_to_episodes_for_tokenized = json.load(f) | |
| # setup kwargs | |
| - len_dataset = len(dataset['train']) | |
| - B = len_row_tokenized_known | |
| kwargs = {'B': B, | |
| - 'attn_key':attn_key, | |
| - 'obs_key':obs_key, | |
| - 'device':device, | |
| - 'task':task, | |
| - 'batch_size_retrieval':None, | |
| - 'nb_cores_autofaiss':None if task.startswith("atari") else args.nb_cores_autofaiss, | |
| - } | |
| + 'obs_dim': obs_dim, | |
| + 'attn_key': attn_key, | |
| + 'obs_key': obs_key, | |
| + 'device': device, | |
| + 'task': task, | |
| + 'batch_size_retrieval': args.batch_size_retrieval, | |
| + 'nb_cores_autofaiss': None if task.startswith("atari") else args.nb_cores_autofaiss, | |
| + } | |
| # collect all observations in a single array (this takes some time) for vector observation environments | |
| - if not task.startswith("atari"): | |
| - myprint("Collecting all observations/attn_masks in a single array") | |
| - all_rows_of_obs = np.array(dataset['train'][obs_key]) | |
| - all_attn_masks = np.array(dataset['train'][attn_key]).astype(bool) | |
| + myprint("Collecting all observations/attn_masks") | |
| + all_rows_of_obs_OG, all_attn_masks_OG, all_row_idxs = collect_all_data(dataset, task, obs_key) | |
| # iterate over rows | |
| all_retrieved_indices = [] | |
| - for row_idx in range(len_dataset): | |
| - myprint(f"\nProcessing row {row_idx}/{len_dataset}") | |
| + for row_idx in all_row_idxs: | |
| + myprint(f"\nProcessing row {row_idx}/{len(all_row_idxs)}") | |
| current_ep = map_from_rows_to_episodes_for_tokenized[str(row_idx)] | |
| - attn_mask, row_of_obs = process_row_fn(dataset['train'][row_idx][attn_key], dataset['train'][row_idx][obs_key], task) | |
| + # get row_of_obs and attn_mask | |
| + datarow = dataset['train'][row_idx] | |
| + attn_mask = np.array(datarow[attn_key]).astype(bool) | |
| + if task.startswith("atari"): | |
| + row_of_obs = process_row_of_obs_atari_full_without_mask(datarow[obs_key]) | |
| + else: | |
| + row_of_obs = np.array(datarow[obs_key]) | |
| + row_of_obs = row_of_obs[attn_mask] | |
| + assert row_of_obs.shape == (np.sum(attn_mask), *obs_dim) | |
| # compare with rows from all but the current episode | |
| - all_other_rows = [idx for idx in range(len_dataset) if map_from_rows_to_episodes_for_tokenized[str(idx)] != current_ep] | |
| + all_other_row_idxs = [idx for idx in all_row_idxs if map_from_rows_to_episodes_for_tokenized[str(idx)] != current_ep] | |
| # do the retrieval | |
| - retrieved_indices = retrieve_fn(row_of_obs=row_of_obs, | |
| - dataset=dataset if task.startswith("atari") else (all_rows_of_obs, all_attn_masks), | |
| - all_rows_to_consider=all_other_rows, | |
| - num_to_retrieve=args.num_to_retrieve, | |
| - kwargs=kwargs, | |
| - ) | |
| + if task.startswith("atari"): | |
| + all_indices, all_processed_rows_of_obs = collect_subset(all_rows_of_obs_OG=all_rows_of_obs_OG, | |
| + all_attn_masks_OG=all_attn_masks_OG, | |
| + all_rows_to_consider=all_row_idxs, | |
| + kwargs=kwargs) | |
| + retrieved_indices = retrieve_atari(row_of_obs=row_of_obs, | |
| + all_processed_rows_of_obs=all_processed_rows_of_obs, | |
| + all_indices=all_indices, | |
| + num_to_retrieve=args.num_to_retrieve, | |
| + kwargs=kwargs) | |
| + else: | |
| + all_indices, knn_index = build_index_vector(all_rows_of_obs_OG=all_rows_of_obs_OG, | |
| + all_attn_masks_OG=all_attn_masks_OG, | |
| + all_rows_to_consider=all_other_row_idxs, | |
| + kwargs=kwargs) | |
| + retrieved_indices = retrieve_vector(row_of_obs=row_of_obs, | |
| + knn_index=knn_index, | |
| + all_indices=all_indices, | |
| + num_to_retrieve=args.num_to_retrieve, | |
| + kwargs=kwargs) | |
| + | |
| + # pad the above to expected B | |
| + xbdim = row_of_obs.shape[0] | |
| + if xbdim < B: | |
| + retrieved_indices = np.concatenate([retrieved_indices, np.zeros((B-xbdim, args.num_to_retrieve, 2), dtype=int)], axis=0) | |
| + assert retrieved_indices.shape == (B, args.num_to_retrieve, 2) | |
| # collect retrieved indices | |
| all_retrieved_indices.append(retrieved_indices) | |
| # concat | |
| all_retrieved_indices = np.stack(all_retrieved_indices, axis=0) | |
| - assert all_retrieved_indices.shape == (len_dataset, B, args.num_to_retrieve, 2) | |
| + assert all_retrieved_indices.shape == (len(all_row_idxs), B, args.num_to_retrieve, 2) | |
| # save arrays as bin for easy memmap access and faster loading | |
| - all_retrieved_indices.tofile(f"{dataset_path}/retrieved_indices_{len_dataset}_{B}_{args.num_to_retrieve}_2.bin") | |
| + all_retrieved_indices.tofile(f"{dataset_path}/retrieved_indices_{len(all_row_idxs)}_{B}_{args.num_to_retrieve}_2.bin") | |
| if __name__ == "__main__": | |
| main() | |
| \ No newline at end of file | |