Skip to content

Commit 206237e

Browse files
committed
Add test train code for infini gemma
1 parent fc91088 commit 206237e

File tree

1 file changed

+150
-0
lines changed

1 file changed

+150
-0
lines changed

test_train.small.gemma.infini.py

+150
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,150 @@
1+
import os
2+
3+
# os.environ["CUDA_VISIBLE_DEVICES"] = "1" # TODO: set the GPU device
4+
os.environ["WANDB_PROJECT"] = "InfiniTransformer"
5+
# os.environ["WANDB_MODE"] = "offline"
6+
7+
8+
from itertools import chain
9+
10+
import torch
11+
from datasets import load_dataset
12+
13+
from transformers import (
14+
AutoTokenizer,
15+
Trainer,
16+
TrainingArguments,
17+
set_seed,
18+
default_data_collator,
19+
)
20+
from infini_gemma import GemmaForCausalLM, GemmaConfig
21+
22+
set_seed(42)
23+
24+
print("Torch Version:", torch.__version__)
25+
print("CUDA:", torch.cuda.is_available())
26+
27+
if torch.cuda.is_available():
28+
device = "cuda:0" # set GPU device using CUDA_VISIBLE_DEVICES
29+
else:
30+
device = "cpu"
31+
32+
if os.path.exists("./models/gemma-2b"):
33+
model = GemmaForCausalLM.from_pretrained(
34+
"./models/gemma-2b", torch_dtype="auto", device_map="auto"
35+
)
36+
config = model.config
37+
print(config)
38+
print(model)
39+
else:
40+
config = GemmaConfig.from_pretrained(
41+
"google/gemma-2b",
42+
attn_implementation="eager",
43+
)
44+
# config.max_position_embeddings = 128
45+
config.use_cache = False
46+
config.segment_size = config.max_position_embeddings
47+
48+
print(config)
49+
50+
# Create the Gemma model with Infini-attention
51+
model = GemmaForCausalLM(config)
52+
# model = model.from_pretrained("google/gemma-2b")
53+
pretrained_model = GemmaForCausalLM.from_pretrained(
54+
"google/gemma-2b", torch_dtype="auto"
55+
)
56+
# Step 4: Transfer weights
57+
# Note: This is a simplified example; you need to ensure that each parameter's dimensions match.
58+
for param in model.named_parameters():
59+
name = param[0]
60+
if name in pretrained_model.state_dict():
61+
# Check if dimensions match, and only then assign the weights
62+
if param[1].size() == pretrained_model.state_dict()[name].size():
63+
param[1].data = pretrained_model.state_dict()[name].data.clone()
64+
else:
65+
print(f"Skipping {name} due to size mismatch.")
66+
print(model)
67+
# model = model.to(torch.bfloat16)
68+
model = model.to(device)
69+
70+
# wiki = load_dataset("wikipedia", "20220301.en", split="train[:20000]")
71+
wiki = load_dataset("wikitext", "wikitext-2-raw-v1")
72+
73+
tokenizer = AutoTokenizer.from_pretrained("google/gemma-2b")
74+
75+
76+
def tokenize_function(examples):
77+
return tokenizer(examples["text"])
78+
79+
80+
try:
81+
column_names = list(wiki["train"].features)
82+
except KeyError:
83+
column_names = list(wiki.features)
84+
tokenized_datasets = wiki.map(
85+
tokenize_function, remove_columns=column_names, batched=True
86+
)
87+
88+
89+
block_size = config.segment_size * 4 # will be 32768
90+
print("block_size:", block_size)
91+
92+
93+
def group_texts(examples):
94+
# Concatenate all texts.
95+
concatenated_examples = {k: list(chain(*examples[k])) for k in examples.keys()}
96+
total_length = len(concatenated_examples[list(examples.keys())[0]])
97+
# We drop the small remainder, and if the total_length < block_size we exclude this batch and return an empty dict.
98+
# We could add padding if the model supported it instead of this drop, you can customize this part to your needs.
99+
total_length = (total_length // block_size) * block_size
100+
# Split by chunks of max_len.
101+
result = {
102+
k: [t[i : i + block_size] for i in range(0, total_length, block_size)]
103+
for k, t in concatenated_examples.items()
104+
}
105+
result["labels"] = result["input_ids"].copy()
106+
return result
107+
108+
109+
lm_datasets = tokenized_datasets.map(
110+
group_texts,
111+
batched=True,
112+
)
113+
114+
print(lm_datasets)
115+
# print(lm_datasets["train"]["input_ids"][0])
116+
117+
training_args = TrainingArguments(
118+
output_dir="./models/gemma-2b-wikitext",
119+
overwrite_output_dir=True,
120+
num_train_epochs=1,
121+
per_device_train_batch_size=1, # to test batch dim
122+
save_total_limit=1,
123+
report_to="wandb", # "none" if you don't want to report to wandb
124+
run_name="gemma-2b-wikitext",
125+
optim="adafactor",
126+
learning_rate=1e-4,
127+
bf16=True,
128+
logging_first_step=True,
129+
logging_steps=1,
130+
save_strategy="epoch",
131+
# warmup_ratio=0.1,
132+
max_grad_norm=1.0,
133+
gradient_checkpointing=True, # Reduce vram 69G -> 43G
134+
)
135+
136+
try:
137+
train_dataset = lm_datasets["train"]
138+
except KeyError:
139+
train_dataset = lm_datasets
140+
141+
trainer = Trainer(
142+
model=model,
143+
tokenizer=tokenizer,
144+
args=training_args,
145+
train_dataset=train_dataset,
146+
# eval_dataset=lm_datasets["validation"],
147+
data_collator=default_data_collator,
148+
)
149+
150+
trainer.train()

0 commit comments

Comments
 (0)