1
+ # Always import asyncio
2
+ import asyncio
1
3
from collections .abc import Awaitable , Callable , Generator
2
4
from functools import wraps
3
5
from typing import Literal , NewType , ParamSpec , Protocol , TypeVar , cast , final
4
- # Always import asyncio
5
- import asyncio
6
+
6
7
7
8
class AsyncLock (Protocol ):
8
9
"""A protocol for an asynchronous lock."""
@@ -15,7 +16,7 @@ async def __aexit__(self, exc_type, exc_val, exc_tb) -> None: ...
15
16
16
17
17
18
# Define context types as literals
18
- AsyncContext = Literal [" asyncio" , " trio" , " unknown" ]
19
+ AsyncContext = Literal [' asyncio' , ' trio' , ' unknown' ]
19
20
20
21
21
22
def _is_anyio_available () -> bool :
@@ -60,10 +61,10 @@ def _is_in_trio_context() -> bool:
60
61
"""
61
62
if not has_trio :
62
63
return False # pragma: no cover
63
-
64
+
64
65
# Import trio here since we already checked it's available
65
66
import trio
66
-
67
+
67
68
try :
68
69
# Will raise RuntimeError if not in trio context
69
70
trio .lowlevel .current_task ()
@@ -80,13 +81,13 @@ def detect_async_context() -> AsyncContext:
80
81
AsyncContext: The current async context type
81
82
"""
82
83
if not has_anyio : # pragma: no cover
83
- return " asyncio"
84
+ return ' asyncio'
84
85
85
86
if _is_in_trio_context ():
86
- return " trio"
87
+ return ' trio'
87
88
88
89
# Default to asyncio
89
- return " asyncio"
90
+ return ' asyncio'
90
91
91
92
92
93
_ValueType = TypeVar ('_ValueType' )
@@ -145,7 +146,9 @@ def __init__(self, coro: Awaitable[_ValueType]) -> None:
145
146
"""We need just an awaitable to work with."""
146
147
self ._coro = coro
147
148
self ._cache : _ValueType | _Sentinel = _sentinel
148
- self ._lock : AsyncLock | None = None # Will be created lazily based on the backend
149
+ self ._lock : AsyncLock | None = (
150
+ None # Will be created lazily based on the backend
151
+ )
149
152
150
153
def __await__ (self ) -> Generator [None , None , _ValueType ]:
151
154
"""
@@ -195,14 +198,14 @@ def _create_lock(self) -> AsyncLock:
195
198
"""Create the appropriate lock based on the current async context."""
196
199
context = detect_async_context ()
197
200
198
- if context == " trio" and has_anyio :
201
+ if context == ' trio' and has_anyio :
199
202
try :
200
203
import anyio
201
204
except Exception : # pragma: no cover
202
205
# Just continue to asyncio if anyio import fails
203
206
return asyncio .Lock () # pragma: no cover
204
207
return anyio .Lock () # pragma: no cover
205
-
208
+
206
209
# For asyncio or unknown contexts
207
210
return asyncio .Lock ()
208
211
@@ -222,6 +225,8 @@ async def _awaitable(self) -> _ValueType:
222
225
if self ._cache is _sentinel :
223
226
self ._cache = await self ._coro
224
227
return self ._cache # type: ignore
228
+
229
+
225
230
# pragma: no cover
226
231
227
232
@@ -258,4 +263,4 @@ def decorator(
258
263
) -> _AwaitableT :
259
264
return ReAwaitable (coro (* args , ** kwargs )) # type: ignore[return-value]
260
265
261
- return decorator
266
+ return decorator
0 commit comments