Skip to content

Commit 5a24623

Browse files
committed
handle edge cases for python3
1 parent fddd274 commit 5a24623

File tree

1 file changed

+25
-9
lines changed

1 file changed

+25
-9
lines changed

src/lpg/interfaces/lang/python3.py

+25-9
Original file line numberDiff line numberDiff line change
@@ -4,17 +4,24 @@
44
from .base import BaseLanguageInterface
55

66
FUNCTION_SIGNATURE_PATTERN = re.compile(
7-
r"^def (?P<name>\w+)\((?P<params>[\w\s,=]*)\) -> (?P<returnType>\w+):$",
7+
r"^(class Solution:\n)?\s*def (?P<name>\w+)\((?P<params>[^)]*)\) -> (?P<returnType>[^:]+):$",
88
flags=re.MULTILINE,
99
)
1010

11+
# LeetCode uses Typing types, not Python 3.9+ types
12+
TYPING_IMPORT_TEMPLATE = "from typing import *\n\n"
13+
1114
TEST_FILE_TEMPLATE = """\
15+
from solution import Solution
16+
17+
1218
if __name__ == "__main__":
1319
{params_setup}
14-
result = {name}({params_call})
20+
result = Solution().{name}({params_call})
1521
print("result:", result)
1622
"""
1723

24+
1825
class Python3LanguageInterface(BaseLanguageInterface):
1926
"""Implementation of the Python 3 language project template interface."""
2027

@@ -24,16 +31,25 @@ def write_project_files(self, template: str):
2431
"""Creates the project template for Python 3."""
2532

2633
with open("solution.py", "w", encoding="utf-8") as file:
27-
file.write(template + "\n")
28-
29-
params = self.groups["params"].split(", ") if self.groups["params"] else []
34+
file.write(f"{TYPING_IMPORT_TEMPLATE}\n{template}pass\n")
35+
36+
params = (
37+
[
38+
param
39+
for param in self.groups["params"].split(", ")
40+
if param and param != "self"
41+
]
42+
if self.groups["params"]
43+
else []
44+
)
3045
self.groups["params_setup"] = "\n ".join(
31-
f"{param.split('=')[0]} = 0" for param in params if param
46+
param if "=" in param else f"{param} = None" for param in params
47+
)
48+
self.groups["params_call"] = ", ".join(
49+
param.split("=")[0].split(":")[0].strip() for param in params
3250
)
33-
self.groups["params_call"] = ", ".join(param.split('=')[0] for param in params)
3451

3552
formatted = TEST_FILE_TEMPLATE.format(**self.groups)
3653

3754
with open("test.py", "w", encoding="utf-8") as file:
38-
file.write(formatted)
39-
55+
file.write(f"{TYPING_IMPORT_TEMPLATE}{formatted}")

0 commit comments

Comments
 (0)