Skip to content

Commit 7ccecd2

Browse files
authored
Merge pull request #291 from huangshiyu13/main
uodate
2 parents 7419e43 + c545c71 commit 7ccecd2

File tree

4 files changed

+202
-89
lines changed

4 files changed

+202
-89
lines changed

openrl/supports/opengpu/manager.py

Lines changed: 85 additions & 83 deletions
Original file line numberDiff line numberDiff line change
@@ -20,89 +20,91 @@
2020
import traceback
2121
from typing import List, Union
2222

23-
from openrl.supports.opengpu.gpu_info import get_local_GPU_info, get_remote_GPU_info
24-
25-
26-
class RemoteGPUManager:
27-
def __init__(self, pytorch_config=None, check: bool = False):
28-
self.gpu_info_dict = get_remote_GPU_info()
29-
self.pytorch_config = pytorch_config
30-
self.server_list = []
31-
if self.pytorch_config is not None:
32-
for server_address in self.pytorch_config.GPU_usage_dict:
33-
self.server_list.append(server_address)
34-
35-
if check:
36-
self.check_gpus()
37-
38-
self.cal_learner_number()
39-
40-
def check_gpus(self):
41-
assert self.pytorch_config is not None
42-
assert len(self.server_list) > 0
43-
44-
bad_gpus = []
45-
for server_address in self.server_list:
46-
assert (
47-
server_address in self.gpu_info_dict
48-
), "can not get gpu info from {}".format(server_address)
49-
assert len(self.gpu_info_dict[server_address]["gpu_infos"]) > 0
50-
51-
for gpu_info in self.gpu_info_dict[server_address]["gpu_infos"]:
52-
if (
53-
self.pytorch_config.GPU_usage_dict[server_address]["gpus"] == "all"
54-
or gpu_info["gpu"]
55-
in self.pytorch_config.GPU_usage_dict[server_address]["gpus"]
56-
):
57-
if (
58-
gpu_info["memory"]["total"] - gpu_info["memory"]["used"]
59-
< self.pytorch_config.min_memory_per_gpu
60-
):
61-
bad_gpus.append(
62-
{
63-
"server": server_address,
64-
"gpu": gpu_info["gpu"],
65-
"free": (
66-
gpu_info["memory"]["total"]
67-
- gpu_info["memory"]["used"]
68-
),
69-
}
70-
)
71-
if len(bad_gpus) > 0:
72-
for bad_gpu in bad_gpus:
73-
print(
74-
"server:{} GPU:{}, minimal memory {}GB, but only get {}GB free"
75-
" memory.".format(
76-
bad_gpu["server"],
77-
bad_gpu["gpu"],
78-
self.pytorch_config.min_memory_per_gpu,
79-
bad_gpu["free"],
80-
)
81-
)
82-
assert False, "GPUs not satisfy."
83-
84-
def cal_learner_number(self):
85-
self.server_gpu_mapping = {}
86-
gpu_num = 0
87-
for server_address in self.server_list:
88-
gpu_mapping = {}
89-
for gpu_info in self.gpu_info_dict[server_address]["gpu_infos"]:
90-
if (
91-
self.pytorch_config.GPU_usage_dict[server_address]["gpus"] == "all"
92-
or gpu_info["gpu"]
93-
in self.pytorch_config.GPU_usage_dict[server_address]["gpus"]
94-
):
95-
gpu_mapping[gpu_info["gpu"]] = gpu_num
96-
gpu_num += 1
97-
self.server_gpu_mapping[server_address] = gpu_mapping
98-
self.learner_num = gpu_num
99-
100-
def get_gpu_info(self, server_list: list):
101-
gpu_infos = {}
102-
for server_address in server_list:
103-
if server_address in self.gpu_info_dict:
104-
gpu_infos[server_address] = self.gpu_info_dict[server_address]
105-
return gpu_infos
23+
from openrl.supports.opengpu.gpu_info import get_local_GPU_info
24+
25+
# from openrl.supports.opengpu.gpu_info import get_remote_GPU_info
26+
27+
28+
# class RemoteGPUManager:
29+
# def __init__(self, pytorch_config=None, check: bool = False):
30+
# self.gpu_info_dict = get_remote_GPU_info()
31+
# self.pytorch_config = pytorch_config
32+
# self.server_list = []
33+
# if self.pytorch_config is not None:
34+
# for server_address in self.pytorch_config.GPU_usage_dict:
35+
# self.server_list.append(server_address)
36+
#
37+
# if check:
38+
# self.check_gpus()
39+
#
40+
# self.cal_learner_number()
41+
#
42+
# def check_gpus(self):
43+
# assert self.pytorch_config is not None
44+
# assert len(self.server_list) > 0
45+
#
46+
# bad_gpus = []
47+
# for server_address in self.server_list:
48+
# assert (
49+
# server_address in self.gpu_info_dict
50+
# ), "can not get gpu info from {}".format(server_address)
51+
# assert len(self.gpu_info_dict[server_address]["gpu_infos"]) > 0
52+
#
53+
# for gpu_info in self.gpu_info_dict[server_address]["gpu_infos"]:
54+
# if (
55+
# self.pytorch_config.GPU_usage_dict[server_address]["gpus"] == "all"
56+
# or gpu_info["gpu"]
57+
# in self.pytorch_config.GPU_usage_dict[server_address]["gpus"]
58+
# ):
59+
# if (
60+
# gpu_info["memory"]["total"] - gpu_info["memory"]["used"]
61+
# < self.pytorch_config.min_memory_per_gpu
62+
# ):
63+
# bad_gpus.append(
64+
# {
65+
# "server": server_address,
66+
# "gpu": gpu_info["gpu"],
67+
# "free": (
68+
# gpu_info["memory"]["total"]
69+
# - gpu_info["memory"]["used"]
70+
# ),
71+
# }
72+
# )
73+
# if len(bad_gpus) > 0:
74+
# for bad_gpu in bad_gpus:
75+
# print(
76+
# "server:{} GPU:{}, minimal memory {}GB, but only get {}GB free"
77+
# " memory.".format(
78+
# bad_gpu["server"],
79+
# bad_gpu["gpu"],
80+
# self.pytorch_config.min_memory_per_gpu,
81+
# bad_gpu["free"],
82+
# )
83+
# )
84+
# assert False, "GPUs not satisfy."
85+
#
86+
# def cal_learner_number(self):
87+
# self.server_gpu_mapping = {}
88+
# gpu_num = 0
89+
# for server_address in self.server_list:
90+
# gpu_mapping = {}
91+
# for gpu_info in self.gpu_info_dict[server_address]["gpu_infos"]:
92+
# if (
93+
# self.pytorch_config.GPU_usage_dict[server_address]["gpus"] == "all"
94+
# or gpu_info["gpu"]
95+
# in self.pytorch_config.GPU_usage_dict[server_address]["gpus"]
96+
# ):
97+
# gpu_mapping[gpu_info["gpu"]] = gpu_num
98+
# gpu_num += 1
99+
# self.server_gpu_mapping[server_address] = gpu_mapping
100+
# self.learner_num = gpu_num
101+
#
102+
# def get_gpu_info(self, server_list: list):
103+
# gpu_infos = {}
104+
# for server_address in server_list:
105+
# if server_address in self.gpu_info_dict:
106+
# gpu_infos[server_address] = self.gpu_info_dict[server_address]
107+
# return gpu_infos
106108

107109

108110
class LocalGPUManager:

tests/test_supports/test_opendata/test_opendata.py

Lines changed: 30 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,17 +18,45 @@
1818

1919
import os
2020
import sys
21+
from pathlib import Path
2122

2223
import pytest
2324

24-
from openrl.supports.opendata.utils.opendata_utils import data_abs_path
25+
from openrl.supports.opendata.utils.opendata_utils import data_abs_path, load_dataset
2526

2627

2728
@pytest.mark.unittest
28-
def test_data_abs_path():
29+
def test_data_abs_path(tmpdir):
2930
data_path = "./"
3031
assert data_abs_path(data_path) == data_path
3132

33+
data_server_dir = Path.home() / "data_server/"
34+
35+
new_create = False
36+
if not data_server_dir.exists():
37+
data_server_dir.mkdir()
38+
new_create = True
39+
data_abs_path("data_server://data_path")
40+
if new_create:
41+
data_server_dir.rmdir()
42+
data_abs_path("data_server://data_path", str(tmpdir))
43+
44+
45+
@pytest.mark.unittest
46+
def test_load_dataset(tmpdir):
47+
try:
48+
load_dataset(str(tmpdir), "train")
49+
except Exception as e:
50+
pass
51+
try:
52+
load_dataset("data_server://data_path", "train")
53+
except Exception as e:
54+
pass
55+
try:
56+
load_dataset(str(tmpdir) + "/test", "train")
57+
except Exception as e:
58+
pass
59+
3260

3361
if __name__ == "__main__":
3462
sys.exit(pytest.main(["-sv", os.path.basename(__file__)]))
Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
#!/usr/bin/env python
2+
# -*- coding: utf-8 -*-
3+
# Copyright 2023 The OpenRL Authors.
4+
#
5+
# Licensed under the Apache License, Version 2.0 (the "License");
6+
# you may not use this file except in compliance with the License.
7+
# You may obtain a copy of the License at
8+
#
9+
# https://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing, software
12+
# distributed under the License is distributed on an "AS IS" BASIS,
13+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
# See the License for the specific language governing permissions and
15+
# limitations under the License.
16+
17+
""""""
18+
19+
import os
20+
import sys
21+
22+
import pytest
23+
24+
from openrl.supports.opengpu.gpu_info import preserve_decimal
25+
26+
27+
@pytest.mark.unittest
28+
def test_preserve_decimal():
29+
preserve_decimal(1, 2)
30+
preserve_decimal(1.1, 0)
31+
preserve_decimal(1.1, -1)
32+
preserve_decimal(1.1, 4)
33+
preserve_decimal(-1.1, 4)
34+
preserve_decimal(-0.1, 0)
35+
36+
37+
if __name__ == "__main__":
38+
sys.exit(pytest.main(["-sv", os.path.basename(__file__)]))

tests/test_supports/test_opengpu/test_manager.py

Lines changed: 49 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616

1717
""""""
1818

19+
import argparse
1920
import os
2021
import sys
2122

@@ -24,12 +25,56 @@
2425
from openrl.supports.opengpu.manager import LocalGPUManager
2526

2627

28+
@pytest.fixture(
29+
scope="module",
30+
params=[
31+
# 添加不同的参数组合以进行测试
32+
0,
33+
1,
34+
2,
35+
None,
36+
],
37+
)
38+
def learner_num(request):
39+
return request.param
40+
41+
42+
@pytest.fixture(scope="module", params=[True, False])
43+
def disable_cuda(request):
44+
return request.param
45+
46+
47+
@pytest.fixture(scope="module", params=["all", "single", "error_type"])
48+
def gpu_usage_type(request):
49+
return request.param
50+
51+
52+
@pytest.fixture(
53+
scope="module",
54+
)
55+
def args(learner_num, disable_cuda, gpu_usage_type):
56+
if learner_num is None:
57+
return None
58+
current_dict = {}
59+
current_dict["learner_num"] = learner_num
60+
current_dict["disable_cuda"] = disable_cuda
61+
current_dict["gpu_usage_type"] = gpu_usage_type
62+
63+
return argparse.Namespace(**current_dict)
64+
65+
2766
@pytest.mark.unittest
28-
def test_local_manager():
29-
manager = LocalGPUManager()
67+
def test_local_manager(args):
68+
manager = LocalGPUManager(args)
3069
manager.get_gpu()
31-
manager.get_learner_gpu()
32-
assert isinstance(manager.get_learner_gpus(), list)
70+
try:
71+
manager.get_learner_gpu()
72+
except IndexError as e:
73+
print("Caught an IndexError:", e)
74+
try:
75+
assert isinstance(manager.get_learner_gpus(), list)
76+
except IndexError as e:
77+
print("Caught an IndexError:", e)
3378
manager.get_worker_gpu()
3479
manager.log_info()
3580

0 commit comments

Comments
 (0)