1
+ # Always import asyncio
2
+ import asyncio
1
3
from collections .abc import Awaitable , Callable , Generator
2
4
from functools import wraps
3
5
from typing import Literal , NewType , ParamSpec , Protocol , TypeVar , cast , final
4
- # Always import asyncio
5
- import asyncio
6
+
6
7
7
8
class AsyncLock (Protocol ):
8
9
"""A protocol for an asynchronous lock."""
@@ -15,7 +16,7 @@ async def __aexit__(self, exc_type, exc_val, exc_tb) -> None: ...
15
16
16
17
17
18
# Define context types as literals
18
- AsyncContext = Literal [" asyncio" , " trio" , " unknown" ]
19
+ AsyncContext = Literal [' asyncio' , ' trio' , ' unknown' ]
19
20
20
21
21
22
# Check for anyio and trio availability
@@ -42,10 +43,10 @@ def _is_in_trio_context() -> bool:
42
43
"""
43
44
if not has_trio :
44
45
return False
45
-
46
+
46
47
# Import trio here since we already checked it's available
47
48
import trio
48
-
49
+
49
50
try :
50
51
# Will raise RuntimeError if not in trio context
51
52
trio .lowlevel .current_task ()
@@ -61,13 +62,13 @@ def detect_async_context() -> AsyncContext:
61
62
AsyncContext: The current async context type
62
63
"""
63
64
if not has_anyio : # pragma: no cover
64
- return " asyncio"
65
+ return ' asyncio'
65
66
66
67
if _is_in_trio_context ():
67
- return " trio"
68
+ return ' trio'
68
69
69
70
# Default to asyncio
70
- return " asyncio"
71
+ return ' asyncio'
71
72
72
73
73
74
_ValueType = TypeVar ('_ValueType' )
@@ -126,7 +127,9 @@ def __init__(self, coro: Awaitable[_ValueType]) -> None:
126
127
"""We need just an awaitable to work with."""
127
128
self ._coro = coro
128
129
self ._cache : _ValueType | _Sentinel = _sentinel
129
- self ._lock : AsyncLock | None = None # Will be created lazily based on the backend
130
+ self ._lock : AsyncLock | None = (
131
+ None # Will be created lazily based on the backend
132
+ )
130
133
131
134
def __await__ (self ) -> Generator [None , None , _ValueType ]:
132
135
"""
@@ -176,8 +179,9 @@ def _create_lock(self) -> AsyncLock:
176
179
"""Create the appropriate lock based on the current async context."""
177
180
context = detect_async_context ()
178
181
179
- if context == " trio" and has_anyio :
182
+ if context == ' trio' and has_anyio :
180
183
import anyio
184
+
181
185
return anyio .Lock ()
182
186
183
187
# For asyncio or unknown contexts
@@ -228,4 +232,4 @@ def decorator(
228
232
) -> _AwaitableT :
229
233
return ReAwaitable (coro (* args , ** kwargs )) # type: ignore[return-value]
230
234
231
- return decorator
235
+ return decorator
0 commit comments