Skip to content

Commit e212e5f

Browse files
authored
Add files via upload
0 parents  commit e212e5f

7 files changed

+232
-0
lines changed

README.md

+26
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
# Inference experiments with large language models
2+
3+
## The effect of temperature
4+
The importance of statistical approaches to the study of large language models has been emphasized recently by the AI experts at [Antropic](https://www.anthropic.com/research/statistical-approach-to-model-evals). Large language models are usually stochastic - given a fixed prompt it generates different inference time responses each time due to non-zero temperature. This property is helpful in increasing creativity during language generation tasks. However, it is not clear how a non-zero temperature helps in solving precise mathematical problems. To understand this better we set up simple statistical experiments in this repo. The experiments are performed using [vLLM](https://docs.vllm.ai/en/latest/) on [MATH](https://paperswithcode.com/dataset/math) dataset.
5+
6+
### [Llama 8B](https://ai.meta.com/blog/meta-llama-3/)
7+
<center>
8+
<img alt="fig1" width="800px" src="acc-vs-samples-Llama8B.png">
9+
</center>
10+
11+
The experiment is done with max_tokens = 1024, num_few_shot = 2, top_p = 0.95, temperature = 0.6 on Llama 8B on a randomly chosen problem 5: "What is the 100th term of the arithmetic sequence 6, 10, 14, 18, ...?" We clearly see that there is a large variation to the response for a small number of inference samples. An earlier notable study on such questions is by [Matthew Renze and Erhan Guven](https://arxiv.org/pdf/2402.05201v1), in their work, 10 inference samples are considered for each problem. Given our results above it is clear that such results are not reliable. To overcome this difficulty we choose 1000 inference samples at each value of temperature and plot the mean accuracy below:
12+
13+
<center>
14+
<img alt="fig2" width="800px" src="acc-vs-temp-Llama8B.png">
15+
</center>
16+
17+
We clearly see that with an increase in temperature, the accuracy drops significantly for problem 5. On the other hand, for the average over MATH dataset we see an initial oscillation and then a decrease in accuracy as we increase temperature. This initial oscillation suggests there might be an interesting dependence of accuracy on the difficulty of the problem - for example, problems of a certain type might show an increase in accuracy as temperature is increased initially. To understand this phenomenon better we set up experiments on a further refined MATH dataset consisting of only algebra level 1 problems on Gemma 7B.
18+
19+
### [Gemma 7B](https://ai.google.dev/gemma)
20+
21+
<center>
22+
<img alt="fig1" width="800px" src="acc-vs-temp-Gemma7B.png">
23+
</center>
24+
25+
To our surprise, we find that there is a critical temperature at which accuracy attens to a local maxima. We would like to emphasize that we have increased our inference samples significantly, reducing the variance to a negligible value for a given problem. However, when we look across problems, variance also receives a contribution from the dataset size. Notably, the MATH dataset contains only 273 algebra, level 1 problems, on which we have experimented above. To have a more statistically robust prediction in the future, we need to have access to a much bigger dataset.
26+

acc-vs-samples-Llama8B.png

45.4 KB
Loading

acc-vs-temp-Gemma7B.png

34.4 KB
Loading

acc-vs-temp-Llama8B.png

48.7 KB
Loading

environment.yml

+170
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,170 @@
1+
name: env_vLLM # give name of the environment
2+
channels:
3+
- defaults
4+
dependencies:
5+
- _libgcc_mutex=0.1=main
6+
- bzip2=1.0.8=h7b6447c_0
7+
- ca-certificates=2024.9.24=h06a4308_0
8+
- ld_impl_linux-64=2.40=h12ee557_0
9+
- libffi=3.3=he6710b0_2
10+
- libgcc-ng=9.1.0=hdf63c60_0
11+
- libstdcxx-ng=9.1.0=hdf63c60_0
12+
- libuuid=1.0.3=h7f8727e_2
13+
- ncurses=6.3=h7f8727e_2
14+
- openssl=1.1.1w=h7f8727e_0
15+
- python=3.10.4=h12debd9_0
16+
- readline=8.1.2=h7f8727e_1
17+
- sqlite=3.38.5=hc218d9a_0
18+
- tk=8.6.12=h1ccaba5_0
19+
- xz=5.2.5=h7f8727e_1
20+
- zlib=1.2.12=h7f8727e_2
21+
- pip:
22+
- absl-py==2.1.0
23+
- accelerate==1.0.1
24+
- aiohappyeyeballs==2.4.3
25+
- aiohttp==3.10.10
26+
- aiosignal==1.3.1
27+
- annotated-types==0.7.0
28+
- antlr4-python3-runtime==4.11.0
29+
- anyio==4.6.2.post1
30+
- async-timeout==4.0.3
31+
- attrs==24.2.0
32+
- certifi==2024.8.30
33+
- chardet==5.2.0
34+
- charset-normalizer==3.4.0
35+
- click==8.1.7
36+
- cloudpickle==3.1.0
37+
- cmake==3.30.5
38+
- colorama==0.4.6
39+
- dataproperty==1.0.1
40+
- datasets==2.20.0
41+
- dill==0.3.8
42+
- diskcache==5.6.3
43+
- distro==1.9.0
44+
- docker==7.1.0
45+
- evaluate==0.4.3
46+
- exceptiongroup==1.2.2
47+
- fastapi==0.115.4
48+
- filelock==3.16.1
49+
- frozenlist==1.5.0
50+
- fsspec==2024.5.0
51+
- h11==0.14.0
52+
- httpcore==1.0.6
53+
- httptools==0.6.4
54+
- httpx==0.27.2
55+
- huggingface-hub==0.26.1
56+
- idna==3.10
57+
- interegular==0.3.3
58+
- jinja2==3.1.4
59+
- jiter==0.6.1
60+
- joblib==1.4.2
61+
- jsonlines==4.0.0
62+
- jsonschema==4.23.0
63+
- jsonschema-specifications==2024.10.1
64+
- lark==1.2.2
65+
- llvmlite==0.43.0
66+
- lm-eval==0.4.3
67+
- lm-format-enforcer==0.10.3
68+
- lxml==5.3.0
69+
- markupsafe==3.0.2
70+
- mbstrdecoder==1.1.3
71+
- more-itertools==10.5.0
72+
- mpmath==1.3.0
73+
- msgpack==1.1.0
74+
- multidict==6.1.0
75+
- multiprocess==0.70.16
76+
- nest-asyncio==1.6.0
77+
- networkx==3.4.2
78+
- ninja==1.11.1.1
79+
- nltk==3.9.1
80+
- numba==0.60.0
81+
- numexpr==2.10.1
82+
- numpy==1.26.4
83+
- nvidia-cublas-cu12==12.1.3.1
84+
- nvidia-cuda-cupti-cu12==12.1.105
85+
- nvidia-cuda-nvrtc-cu12==12.1.105
86+
- nvidia-cuda-runtime-cu12==12.1.105
87+
- nvidia-cudnn-cu12==9.1.0.70
88+
- nvidia-cufft-cu12==11.0.2.54
89+
- nvidia-curand-cu12==10.3.2.106
90+
- nvidia-cusolver-cu12==11.4.5.107
91+
- nvidia-cusparse-cu12==12.1.0.106
92+
- nvidia-ml-py==12.560.30
93+
- nvidia-nccl-cu12==2.20.5
94+
- nvidia-nvjitlink-cu12==12.6.77
95+
- nvidia-nvtx-cu12==12.1.105
96+
- openai==1.52.2
97+
- outlines==0.0.46
98+
- packaging==24.1
99+
- pandas==2.2.3
100+
- pathvalidate==3.2.1
101+
- peft==0.13.2
102+
- pillow==11.0.0
103+
- pip==24.2
104+
- portalocker==2.10.1
105+
- prometheus-client==0.21.0
106+
- prometheus-fastapi-instrumentator==7.0.0
107+
- propcache==0.2.0
108+
- protobuf==5.28.3
109+
- psutil==6.1.0
110+
- py-cpuinfo==9.0.0
111+
- pyairports==2.1.1
112+
- pyarrow==18.0.0
113+
- pyarrow-hotfix==0.6
114+
- pybind11==2.13.6
115+
- pycountry==24.6.1
116+
- pydantic==2.9.2
117+
- pydantic-core==2.23.4
118+
- pydra-config==0.0.1
119+
- pytablewriter==1.2.0
120+
- python-dateutil==2.9.0.post0
121+
- python-dotenv==1.0.1
122+
- pytz==2024.2
123+
- pyyaml==6.0.2
124+
- pyzmq==26.2.0
125+
- ray==2.38.0
126+
- referencing==0.35.1
127+
- regex==2024.9.11
128+
- requests==2.32.3
129+
- rouge-score==0.1.2
130+
- rpds-py==0.20.0
131+
- sacrebleu==2.4.3
132+
- safetensors==0.4.5
133+
- scikit-learn==1.5.2
134+
- scipy==1.14.1
135+
- sentencepiece==0.2.0
136+
- setuptools==75.1.0
137+
- six==1.16.0
138+
- sniffio==1.3.1
139+
- sqlitedict==2.1.0
140+
- starlette==0.41.2
141+
- sympy==1.13.3
142+
- tabledata==1.3.3
143+
- tabulate==0.9.0
144+
- tcolorpy==0.1.6
145+
- threadpoolctl==3.5.0
146+
- tiktoken==0.8.0
147+
- tokenizers==0.20.1
148+
- torch==2.4.0
149+
- torchvision==0.19.0
150+
- tqdm==4.66.3
151+
- tqdm-multiprocess==0.0.11
152+
- transformers==4.46.0
153+
- triton==3.0.0
154+
- typepy==1.3.2
155+
- typing-extensions==4.12.2
156+
- tzdata==2024.2
157+
- urllib3==2.2.3
158+
- uvicorn==0.32.0
159+
- uvloop==0.21.0
160+
- vllm==0.5.4
161+
- vllm-flash-attn==2.6.1
162+
- watchfiles==0.24.0
163+
- websockets==13.1
164+
- wheel==0.44.0
165+
- word2number==1.1
166+
- xformers==0.0.27.post2
167+
- xxhash==3.5.0
168+
- yarl==1.16.0
169+
- zstandard==0.23.0
170+
prefix: indranilhalder/env/env_vLLM #give path to the environment

evaluate_response_job.sh

+18
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
#!/bin/bash
2+
#
3+
#SBATCH --job-name=gpt2-eval
4+
#SBATCH --out="gpt2-eval-%A_%a.out"
5+
#SBATCH --cpus-per-task=32
6+
#SBATCH --mem=32G
7+
#SBATCH --nodes=1
8+
#SBATCH --time=1:00:00
9+
#SBATCH --array=0
10+
#SBATCH --gres=gpu:1
11+
#SBATCH --partition=kempner_h100
12+
#SBATCH --account=kempner_pehlevan_lab
13+
14+
export SAVE_DIR=samples/gpt2_samples1000_temp1-2_shot2
15+
16+
module load python/3.10.12-fasrc01 # update
17+
source activate /n/netscratch/pehlevan_lab/Everyone/indranilhalder/env/env_vLLM # update the location of env
18+
python iLLM/evaluate/math_datasets.py samples_dir=$SAVE_DIR/math_samples save_dir=$SAVE_DIR/math_eval dset=math

generate_response_job.sh

+18
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
#!/bin/bash
2+
#
3+
#SBATCH --job-name=gpt2-generate
4+
#SBATCH --out="gpt2-generate-%A_%a.out"
5+
#SBATCH --cpus-per-task=32
6+
#SBATCH --mem=32G
7+
#SBATCH --nodes=1
8+
#SBATCH --time=1:00:00
9+
#SBATCH --array=0
10+
#SBATCH --gres=gpu:1
11+
#SBATCH --partition=kempner_h100
12+
#SBATCH --account=kempner_pehlevan_lab
13+
14+
export SAVE_DIR=samples/gpt2_samples1000_temp1-2_shot2
15+
16+
module load python/3.10.12-fasrc01 # update
17+
source activate /n/netscratch/pehlevan_lab/Everyone/indranilhalder/env/env_vLLM # update the location of env
18+
python iLLM/generate/MATH.py model=gpt2 save_dir=$SAVE_DIR/math_samples temperature=1.2 num_samples=10 num_few_shot=2 num_workers=32 --list vllm_args --disable-log-requests list-- --list stop_strings Problem: list--

0 commit comments

Comments
 (0)