Skip to content
This repository was archived by the owner on May 6, 2024. It is now read-only.

Commit 5e7afa8

Browse files
bors[bot]dcramer
andauthored
Merge #326
326: ref: Various fixes and type annotations for mypy r=dcramer a=dcramer bors r+ Co-authored-by: David Cramer <dcramer@gmail.com>
2 parents 3672305 + cf3e605 commit 5e7afa8

File tree

16 files changed

+101
-82
lines changed

16 files changed

+101
-82
lines changed

setup.cfg

+6
Original file line numberDiff line numberDiff line change
@@ -94,3 +94,9 @@ ignore_missing_imports = True
9494

9595
[mypy-watchdog.*]
9696
ignore_missing_imports = True
97+
98+
[mypy-cached_property.*]
99+
ignore_missing_imports = True
100+
101+
[mypy-asyncpg.*]
102+
ignore_missing_imports = True

zeus/api/schemas/testcase.py

+5-3
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from collections import defaultdict
44
from marshmallow import Schema, fields, pre_dump
55
from sqlalchemy import and_
6-
from typing import List, Mapping
6+
from typing import Dict, List, Optional
77
from uuid import UUID
88

99
from zeus.config import db
@@ -18,7 +18,9 @@
1818
from .job import JobSchema
1919

2020

21-
def find_failure_origins(build: Build, test_failures: List[str]) -> Mapping[str, UUID]:
21+
def find_failure_origins(
22+
build: Build, test_failures: List[str]
23+
) -> Dict[str, Optional[UUID]]:
2224
"""
2325
Attempt to find originating causes of failures.
2426
@@ -127,7 +129,7 @@ def find_failure_origins(build: Build, test_failures: List[str]) -> Mapping[str,
127129
for test_hash, build_id in queryset:
128130
previous_test_failures[build_id].add(test_hash)
129131

130-
failures_at_build = dict()
132+
failures_at_build: Dict[str, Optional[UUID]] = {}
131133
searching = set(t for t in test_failures)
132134
# last_checked_run = build.id
133135
last_checked_run = None

zeus/auth.py

+23-19
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from flask import current_app, g, request, session
66
from itsdangerous import BadSignature, JSONWebSignatureSerializer
77
from sqlalchemy.orm import joinedload
8-
from typing import Mapping, Optional
8+
from typing import Any, Dict, Mapping, Optional
99
from urllib.parse import urlparse, urljoin
1010
from uuid import UUID
1111

@@ -27,9 +27,9 @@
2727

2828

2929
class Tenant(object):
30-
access = {}
30+
access: Dict[UUID, Optional[Permission]] = {}
3131

32-
def __init__(self, access: Optional[Mapping[UUID, Optional[Permission]]] = None):
32+
def __init__(self, access: Optional[Dict[UUID, Optional[Permission]]] = None):
3333
if access is not None:
3434
self.access = access
3535

@@ -54,7 +54,7 @@ def has_permission(self, repository_id: UUID, permission: Permission = None):
5454
return permission in access
5555

5656
@classmethod
57-
def from_user(cls, user: User):
57+
def from_user(cls, user: Optional[User]):
5858
if not user:
5959
return cls()
6060

@@ -70,7 +70,7 @@ def from_repository(
7070
if not repository:
7171
return cls()
7272

73-
return RepositoryTenant(access={repository.id: permission})
73+
return RepositoryTenant(repository.id, permission)
7474

7575
@classmethod
7676
def from_api_token(cls, token: ApiToken):
@@ -82,9 +82,7 @@ def from_api_token(cls, token: ApiToken):
8282

8383
class ApiTokenTenant(Tenant):
8484
def __init__(
85-
self,
86-
token_id: str,
87-
access: Optional[Mapping[UUID, Optional[Permission]]] = None,
85+
self, token_id: str, access: Optional[Dict[UUID, Optional[Permission]]] = None
8886
):
8987
self.token_id = token_id
9088
if access is not None:
@@ -94,7 +92,7 @@ def __repr__(self):
9492
return "<{} token_id={}>".format(type(self).__name__, self.token_id)
9593

9694
@cached_property
97-
def access(self) -> Mapping[UUID, Permission]:
95+
def access(self) -> Dict[UUID, Permission]:
9896
if not self.token_id:
9997
return {}
10098

@@ -108,9 +106,7 @@ def access(self) -> Mapping[UUID, Permission]:
108106

109107
class UserTenant(Tenant):
110108
def __init__(
111-
self,
112-
user_id: UUID,
113-
access: Optional[Mapping[UUID, Optional[Permission]]] = None,
109+
self, user_id: UUID, access: Optional[Dict[UUID, Optional[Permission]]] = None
114110
):
115111
self.user_id = user_id
116112
if access is not None:
@@ -120,7 +116,7 @@ def __repr__(self):
120116
return "<{} user_id={}>".format(type(self).__name__, self.user_id)
121117

122118
@cached_property
123-
def access(self) -> Mapping[UUID, Permission]:
119+
def access(self) -> Dict[UUID, Permission]:
124120
if not self.user_id:
125121
return {}
126122

@@ -142,7 +138,7 @@ def __repr__(self):
142138
)
143139

144140
@cached_property
145-
def access(self) -> Mapping[UUID, Permission]:
141+
def access(self) -> Dict[UUID, Optional[Permission]]:
146142
if not self.repository_id:
147143
return {}
148144

@@ -156,6 +152,7 @@ def get_tenant_from_headers(headers: Mapping) -> Optional[Tenant]:
156152
header = headers.get("Authorization", "")
157153
if header:
158154
return get_tenant_from_bearer_header(header)
155+
return None
159156

160157

161158
def get_tenant_from_request() -> Tenant:
@@ -171,6 +168,8 @@ def get_tenant_from_bearer_header(header: str) -> Optional[Tenant]:
171168
return None
172169

173170
match = _bearer_regexp.match(header)
171+
if not match:
172+
return None
174173
token = match.group(2)
175174
if not token.startswith("zeus-"):
176175
# Assuming this is a legacy token
@@ -314,13 +313,15 @@ def get_current_tenant() -> Tenant:
314313

315314
def generate_token(tenant: Tenant) -> bytes:
316315
s = JSONWebSignatureSerializer(current_app.secret_key, salt="auth")
317-
payload = {"access": {str(k): v for k, v in tenant.access.items()}}
316+
payload: Dict[str, Any] = {
317+
"access": {str(k): int(v) if v else None for k, v in tenant.access.items()}
318+
}
318319
if getattr(tenant, "user_id", None):
319320
payload["uid"] = str(tenant.user_id)
320321
return s.dumps(payload)
321322

322323

323-
def parse_token(token: str) -> Optional[str]:
324+
def parse_token(token: str) -> Optional[Any]:
324325
s = JSONWebSignatureSerializer(current_app.secret_key, salt="auth")
325326
try:
326327
return s.loads(token)
@@ -330,10 +331,12 @@ def parse_token(token: str) -> Optional[str]:
330331

331332

332333
def get_tenant_from_signed_token(token: str) -> Tenant:
333-
payload = parse_token(token)
334+
payload: Optional[Dict[str, Any]] = parse_token(token)
334335
if not payload:
335336
return Tenant()
336-
access = {UUID(k): v for k, v in payload["access"].items()}
337+
access = {
338+
UUID(k): Permission(v) if v else None for k, v in payload["access"].items()
339+
}
337340
if "uid" in payload:
338341
return UserTenant(user_id=UUID(payload["uid"]), access=access)
339342
return Tenant(access=access)
@@ -354,7 +357,7 @@ def is_safe_url(target: str) -> bool:
354357
)
355358

356359

357-
def get_redirect_target(clear=True, session=session) -> str:
360+
def get_redirect_target(clear=True, session=session) -> Optional[str]:
358361
if clear:
359362
session_target = session.pop("next", None)
360363
else:
@@ -366,6 +369,7 @@ def get_redirect_target(clear=True, session=session) -> str:
366369

367370
if is_safe_url(target):
368371
return target
372+
return None
369373

370374

371375
def bind_redirect_target(target: str = None, session=session):

zeus/config.py

-5
Original file line numberDiff line numberDiff line change
@@ -40,11 +40,6 @@
4040
metrics = Metrics()
4141

4242

43-
from flask_sqlalchemy.model import DefaultMeta
44-
45-
db.Model: DefaultMeta = db.Model
46-
47-
4843
def with_health_check(app):
4944
def middleware(environ, start_response):
5045
path_info = environ.get("PATH_INFO", "")

zeus/db/types/enum.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,15 @@
11
__all__ = ["Enum", "IntEnum", "StrEnum"]
22

3+
from enum import Enum as EnumType
4+
from typing import Optional, Type
5+
36
from sqlalchemy.types import TypeDecorator, INT, STRINGTYPE
47

58

69
class Enum(TypeDecorator):
710
impl = INT
811

9-
def __init__(self, enum=None, *args, **kwargs):
12+
def __init__(self, enum: Optional[Type[EnumType]] = None, *args, **kwargs):
1013
self.enum = enum
1114
super(Enum, self).__init__(*args, **kwargs)
1215

zeus/db/utils.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ def try_create(model, where: dict, defaults: dict = None) -> Optional[Any]:
1818
db.session.add(instance)
1919
except IntegrityError as exc:
2020
if "duplicate" not in str(exc):
21-
return
21+
return None
2222
raise
2323
return instance
2424

zeus/exceptions.py

+1
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@ def __init__(self, scope, identity):
4949
def get_upgrade_url(self) -> Optional[str]:
5050
if self.identity.provider == "github":
5151
return "/auth/github"
52+
return None
5253

5354

5455
class UnknownRepositoryBackend(Exception):

zeus/models/hook.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -33,12 +33,12 @@ class Hook(RepositoryBoundMixin, StandardAttributes, db.Model):
3333
def generate_token(cls) -> bytes:
3434
return token_bytes(64)
3535

36-
def get_signature(self) -> bytes:
36+
def get_signature(self) -> str:
3737
return hmac.new(
3838
key=self.token, msg=self.repository_id.bytes, digestmod=sha256
3939
).hexdigest()
4040

41-
def is_valid_signature(self, signature: bytes) -> bool:
41+
def is_valid_signature(self, signature: str) -> bool:
4242
return compare_digest(self.get_signature(), signature)
4343

4444
def get_provider(self):

zeus/notifications/email.py

+5-2
Original file line numberDiff line numberDiff line change
@@ -65,14 +65,17 @@ def build_message(build: Build, force=False) -> Message:
6565
current_app.logger.info("mail.missing-author", extra={"build_id": build.id})
6666
return
6767

68-
emails = find_linked_emails(build)
68+
emails: List[Tuple[UUID, str]] = find_linked_emails(build)
6969
if not emails and not force:
7070
current_app.logger.info("mail.no-linked-accounts", extra={"build_id": build.id})
7171
return
7272

7373
elif not emails:
7474
current_user = auth.get_current_user()
75-
emails = [[current_user.id, current_user.email]]
75+
if current_user:
76+
emails = [(current_user.id, current_user.email)]
77+
elif not force:
78+
return
7679

7780
# filter it down to the users that have notifications enabled
7881
user_options = dict(

zeus/storage/mock.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
11
from io import BytesIO
2-
from typing import Mapping
2+
from typing import Dict
33

44
from .base import FileStorage
55

6-
_cache: Mapping[str, bytes] = {}
6+
_cache: Dict[str, bytes] = {}
77

88

99
class FileStorageCache(FileStorage):

zeus/utils/builds.py

+19-19
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
from functools import reduce
55
from itertools import groupby
66
from operator import and_, or_
7-
from typing import Any, List, Mapping, Set, Tuple
7+
from typing import List, Optional, Set, Tuple
88
from sqlalchemy.orm import joinedload, subqueryload_all
99
from uuid import UUID
1010

@@ -16,20 +16,20 @@
1616
class MetaBuild:
1717
original: List[Build] = dataclasses.field(default_factory=list)
1818
ref: str = ""
19-
revision_sha: str = None
19+
revision_sha: Optional[str] = None
2020
label: str = ""
2121
stats: dict = dataclasses.field(default_factory=dict)
2222
result: Result = Result.unknown
2323
status: Status = Status.unknown
2424
authors: List[Author] = dataclasses.field(default_factory=list)
25-
date_created: datetime = None
26-
date_started: datetime = None
27-
date_finished: datetime = None
25+
date_created: Optional[datetime] = None
26+
date_started: Optional[datetime] = None
27+
date_finished: Optional[datetime] = None
2828

29-
revision: Revision = None
29+
revision: Optional[Revision] = None
3030

3131

32-
def merge_builds(target: MetaBuild, build: Build, with_relations=True) -> Build:
32+
def merge_builds(target: MetaBuild, build: Build, with_relations=True) -> MetaBuild:
3333
# Store the original build so we can retrieve its ID or number later, or
3434
# show a list of all builds in the UI
3535
target.original.append(build)
@@ -96,10 +96,8 @@ def merge_builds(target: MetaBuild, build: Build, with_relations=True) -> Build:
9696

9797

9898
def merge_build_group(
99-
build_group: Tuple[Any, List[Build]],
100-
required_hook_ids: List[UUID] = None,
101-
with_relations=True,
102-
) -> Build:
99+
build_group: List[Build], required_hook_ids: Set[str] = None, with_relations=True
100+
) -> MetaBuild:
103101
# XXX(dcramer): required_hook_ids is still dirty here, but its our simplest way
104102
# to get it into place
105103
grouped_builds = groupby(
@@ -111,8 +109,8 @@ def merge_build_group(
111109

112110
build = MetaBuild()
113111
build.original = []
114-
if set(required_hook_ids or ()).difference(
115-
set(str(b.hook_id) for b in build_group)
112+
if frozenset(required_hook_ids or ()).difference(
113+
frozenset(str(b.hook_id) for b in build_group)
116114
):
117115
build.result = Result.failed
118116

@@ -125,12 +123,12 @@ def merge_build_group(
125123

126124
def fetch_builds_for_revisions(
127125
revisions: List[Revision], with_relations=True
128-
) -> Mapping[str, Build]:
126+
) -> List[Tuple[Tuple[UUID, str], MetaBuild]]:
129127
# we query extra builds here, but its a lot easier than trying to get
130128
# sqlalchemy to do a ``select (subquery)`` clause and maintain tenant
131129
# constraints
132130
if not revisions:
133-
return {}
131+
return []
134132

135133
lookups = []
136134
for revision in revisions:
@@ -156,21 +154,23 @@ def fetch_builds_for_revisions(
156154
build_groups = groupby(
157155
builds, lambda build: (build.repository_id, build.revision_sha)
158156
)
159-
required_hook_ids: Set[UUID] = set()
157+
required_hook_ids: Set[str] = set()
160158
for build in builds:
161159
required_hook_ids.update(build.data.get("required_hook_ids") or ())
162160
return [
163161
(
164162
ident,
165163
merge_build_group(
166-
list(group), required_hook_ids, with_relations=with_relations
164+
list(build_group), required_hook_ids, with_relations=with_relations
167165
),
168166
)
169-
for ident, group in build_groups
167+
for ident, build_group in build_groups
170168
]
171169

172170

173-
def fetch_build_for_revision(revision: Revision, with_relations=True) -> Build:
171+
def fetch_build_for_revision(
172+
revision: Revision, with_relations=True
173+
) -> Optional[MetaBuild]:
174174
builds = fetch_builds_for_revisions([revision], with_relations=with_relations)
175175
if len(builds) < 1:
176176
return None

0 commit comments

Comments
 (0)