Skip to content

Added Bigbird #26

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions openpmcvl/experiment/configs/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from openpmcvl.experiment.datasets.pmcvl import PMCVL
from openpmcvl.experiment.datasets.quilt1m import Quilt
from openpmcvl.experiment.datasets.roco import ROCO
from openpmcvl.experiment.modules.encoders import BiomedCLIPText, BiomedCLIPVision
from openpmcvl.experiment.modules.encoders import BiomedCLIPText, BiomedCLIPVision, BigBirdText
from openpmcvl.experiment.modules.pmc_clip import (
PmcClipText,
PmcClipVision,
Expand All @@ -28,7 +28,7 @@
PubmedClipVision,
)
from openpmcvl.experiment.modules.scheduler import CosineAnnealingWarmupLR
from openpmcvl.experiment.modules.tokenizer import OpenClipTokenizerWrapper
from openpmcvl.experiment.modules.tokenizer import OpenClipTokenizerWrapper, BigBirdTokenizerWrapper
from openpmcvl.experiment.modules.zero_shot_retrieval import (
ZeroShotCrossModalRetrievalEfficient,
)
Expand Down
126 changes: 126 additions & 0 deletions openpmcvl/experiment/configs/experiment/vitb16_bigbird_pmcoa.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
# @package _global_

defaults:
- /datasets@datasets.train.pmcoa: PMCOA
- /datasets/transforms@datasets.train.pmcoa.transform: med_clip_vision_transform
- /datasets@datasets.val.pmcoa: PMCOA
- /datasets/transforms@datasets.val.pmcoa.transform: med_clip_vision_transform
- /datasets@datasets.test.pmcoa: PMCOA
- /datasets/transforms@datasets.test.pmcoa.transform: med_clip_vision_transform
- /datasets/tokenizers@dataloader.train.collate_fn.batch_processors.text: BigBirdTokenizerWrapper
- /datasets/tokenizers@dataloader.val.collate_fn.batch_processors.text: BigBirdTokenizerWrapper
- /datasets/tokenizers@dataloader.test.collate_fn.batch_processors.text: BigBirdTokenizerWrapper
- /modules/encoders@task.encoders.text: BigBirdText
- /modules/encoders@task.encoders.rgb: BiomedCLIPVision
- /modules/layers@task.postprocessors.norm_and_logit_scale.norm: L2Norm
- /modules/layers@task.postprocessors.norm_and_logit_scale.logit_scale: LearnableLogitScaling
- /modules/losses@task.loss: CLIPLoss
- /modules/optimizers@task.optimizer: AdamW
- /modules/lr_schedulers@task.lr_scheduler.scheduler: CosineAnnealingWarmupLR
- /eval_task@task.evaluation_tasks.retrieval.task: ZeroShotCrossModalRetrievalEfficient
- /trainer/callbacks@trainer.callbacks.lr_monitor: LearningRateMonitor
- /trainer/callbacks@trainer.callbacks.model_checkpoint: ModelCheckpoint
- /trainer/callbacks@trainer.callbacks.early_stopping: EarlyStopping
- /trainer/callbacks@trainer.callbacks.model_summary: ModelSummary
- /trainer/logger@trainer.logger.wandb: WandbLogger
- override /task: ContrastivePretraining
- _self_
seed: 0

datasets:
train:
pmcoa:
split: train
val:
pmcoa:
split: valid
transform:
job_type: eval
test:
pmcoa:
split: test
transform:
job_type: eval

dataloader:
train:
batch_size: 256
num_workers: 4
val:
batch_size: 32
num_workers: 4
test:
num_workers: 4

task:
postprocessors:
norm_and_logit_scale:
norm:
dim: -1
logit_scale:
learnable: True
modality_module_mapping:
text:
postprocessor_key: norm_and_logit_scale
rgb:
postprocessor_key: norm_and_logit_scale
optimizer:
betas:
- 0.9
- 0.98
lr: 5.0e-4
weight_decay: 0.2
eps: 1.0e-6
lr_scheduler:
scheduler:
t_max: 104_671 # make sure to change this if max_epochs or accumulate_grad_batches is changed
warmup_length: 2000
extras:
interval: step
loss:
gather_with_grad: True
local_loss: True
evaluation_tasks:
retrieval:
task:
task_specs:
- query_modality: text
target_modality: rgb
top_k: [1, 5, 10]
- query_modality: rgb
target_modality: text
top_k: [1, 5, 10]
run_on_validation: false
run_on_test: true

trainer:
max_epochs: 64
precision: bf16-mixed
deterministic: False
benchmark: True
sync_batchnorm: False # set to True if using DDP with batchnorm
log_every_n_steps: 100
accumulate_grad_batches: 4
check_val_every_n_epoch: 1
callbacks:
model_checkpoint:
monitor: val/loss
save_top_k: 1
save_last: True
every_n_epochs: 1
dirpath: /checkpoint/${oc.env:USER}/${oc.env:SLURM_JOB_ID} # only works on Vector SLURM environment
early_stopping:
monitor: val/loss
patience: 5
mode: min
model_summary:
max_depth: 2

tags:
- ${experiment_name}
- contrastive pretraining
- rgb
- text
- clip
- pmcvl
- openpmcvl
84 changes: 84 additions & 0 deletions openpmcvl/experiment/modules/encoders.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

import torch
import torch.nn.functional as F
from transformers import AutoTokenizer, AutoModel, AutoConfig
from huggingface_hub import hf_hub_download
from mmlearn.conf import external_store
from mmlearn.datasets.core import Modalities
Expand Down Expand Up @@ -247,3 +248,86 @@ def forward(self, inputs: Dict[Union[str, Modality], Any]) -> Tuple[torch.Tensor
features = F.normalize(features, dim=-1) if self.normalize else features

return (features,)




@external_store(
group="modules/encoders",
provider="openpmcvl",
model_name_or_path="google/bigbird-pegasus-large-pubmed",
)
class BigBirdText(nn.Module):
"""Wrapper around the Big Bird text encoder loaded via Hugging Face.

Parameters
----------
model_name_or_path : str
The Hugging Face model name or a local path from which to load the model.
pretrained : bool, default=True
Whether to load the pretrained weights or not.
use_all_token_embeddings : bool, default=False
Whether to use all token embeddings for the text. If `False`, the pooled output
(mean over token embeddings) will be used.
normalize: bool, default=False
Whether to normalize output features of the encoder.
"""

def __init__(
self,
model_name_or_path: str = "google/bigbird-pegasus-large-pubmed",
pretrained: bool = True,
use_all_token_embeddings: bool = False,
normalize: bool = False,
model_config_kwargs: Optional[Dict[str, Any]] = None,
) -> None:
"""Initialize the model."""
super().__init__()

if pretrained:
# Load pretrained model
self.model = AutoModel.from_pretrained(model_name_or_path)
else:
# Load model configuration and create model from config
config = AutoConfig.from_pretrained(model_name_or_path)
self.model = AutoModel.from_config(config)

# Model configuration
self.use_all_token_embeddings = use_all_token_embeddings
self.normalize = normalize
self.emb_dim = self.model.config.hidden_size # Big Bird embedding size (768)

# Add a linear layer to project embeddings from 768 to 512
self.projection = nn.Linear(self.emb_dim, 512)
def forward(self, inputs: Dict[Union[str, Modality], Any]) -> Tuple[torch.Tensor]:
"""Run the forward pass.

Parameters
----------
inputs : Dict[str | Modality, Any]
The input data. The `input_ids` and `attention_mask` will be expected
under the `Modalities.TEXT.name` and `"attention_mask"` keys, respectively.

Returns
-------
Tuple[torch.Tensor]
The text embeddings. Will be a tuple with a single element.
"""
input_ids = inputs[Modalities.TEXT.name]

attention_mask = inputs["attention_mask"]

# Extract features from Big Bird
features = self.model(input_ids=input_ids, attention_mask=attention_mask)["last_hidden_state"]

# Mean pooling over the token embeddings if `use_all_token_embeddings` is False
if not self.use_all_token_embeddings:
features = features.mean(dim=1)


features = F.normalize(features, dim=-1) if self.normalize else features

# Apply the linear projection
projected_features = self.projection(features)

return (projected_features,)
55 changes: 54 additions & 1 deletion openpmcvl/experiment/modules/tokenizer.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,14 @@
"""Wrapper to load BiomedCLIP tokenizer from open_clip."""

from typing import Any, List, Union
from typing import Any, List, Union, Dict

from mmlearn.conf import external_store
from open_clip import get_tokenizer
import torch
from transformers import AutoTokenizer
from mmlearn.conf import external_store
from mmlearn.datasets.core import Modalities
from mmlearn.datasets.core.modalities import Modality


@external_store(group="datasets/tokenizers", provider="openpmcvl")
Expand All @@ -24,3 +29,51 @@ def __init__(
def __call__(self, x: Union[str, List[str]]) -> Any:
"""Pass any input to loaded tokenizer."""
return self.tokenizer(x)


@external_store(
group="datasets/tokenizers",
provider="openpmcvl",
model_name_or_path="google/bigbird-pegasus-large-pubmed",
)
class BigBirdTokenizerWrapper:
"""Wrapper for the Big Bird tokenizer.

Parameters
----------
model_name_or_path : str
The Hugging Face model name or a local path from which to load the tokenizer.
max_length : int, default=512
The maximum sequence length for the tokenizer.
"""

def __init__(self, model_name_or_path: str = "google/bigbird-pegasus-large-pubmed", max_length: int = 512) -> None:
"""Initialize the tokenizer."""
super().__init__()
self.tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)
self.max_length = max_length

def __call__(self, texts: Union[str, List[str]]) -> Dict[str, torch.Tensor]:
"""Tokenize input texts.

Parameters
----------
texts : Union[str, List[str]]
The input text or a list of texts to tokenize.

Returns
-------
Dict[str, torch.Tensor]
A dictionary containing tokenized inputs with keys `input_ids` and `attention_mask`.
"""
tokenized = self.tokenizer(
texts,
return_tensors="pt",
padding="max_length",
truncation=True,
max_length=self.max_length,
)
return {
Modalities.TEXT.name: tokenized["input_ids"],
"attention_mask": tokenized["attention_mask"],
}
29 changes: 29 additions & 0 deletions openpmcvl/experiment/scripts/train/pmc_oa_2/vitb16_bigbird.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
# bs=25
# a100
mmlearn_run --multirun hydra.launcher.mem_gb=0 \
hydra.launcher.qos=a100_arashaf \
hydra.launcher.partition=a100 \
hydra.launcher.gres=gpu:4 \
hydra.launcher.cpus_per_task=4 \
hydra.launcher.tasks_per_node=4 \
hydra.launcher.nodes=4 \
hydra.launcher.stderr_to_stdout=true \
hydra.launcher.timeout_min=828 \
'+hydra.launcher.additional_parameters={export: ALL}' \
'hydra.searchpath=[pkg://openpmcvl.experiment.configs]' \
+experiment=vitb16_bigbird_pmcoa \
experiment_name=vitb16_bigbird_pmcoa \
dataloader.train.batch_size=25 \
dataloader.val.batch_size=16 \
dataloader.train.num_workers=4 \
dataloader.val.num_workers=4 \
task.encoders.text.pretrained=False \
task.encoders.rgb.pretrained=False \
task.lr_scheduler.scheduler.t_max=823 \
task.lr_scheduler.scheduler.warmup_length=100 \
trainer.num_nodes=4 \
trainer.devices=[0,1,3,4] \
strict_loading=False \
resume_from_checkpoint="path/to/checkpoint" \
trainer.logger.wandb.id="" \
trainer.logger.wandb.resume="must"