Skip to content

stubgen: infer types for class attributes #18978

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 17 additions & 5 deletions mypy/stubgen.py
Original file line number Diff line number Diff line change
Expand Up @@ -476,8 +476,9 @@ def __init__(
analyzed: bool = False,
export_less: bool = False,
include_docstrings: bool = False,
known_modules: list[str] | None = None,
) -> None:
super().__init__(_all_, include_private, export_less, include_docstrings)
super().__init__(_all_, include_private, export_less, include_docstrings, known_modules)
self._decorators: list[str] = []
# Stack of defined variables (per scope).
self._vars: list[list[str]] = [[]]
Expand Down Expand Up @@ -1233,7 +1234,10 @@ def get_init(
return None
self._vars[-1].append(lvalue)
if annotation is not None:
typename = self.print_annotation(annotation)
if isinstance(annotation, UnboundType):
typename = self.print_annotation(annotation)
else:
typename = self.print_annotation(annotation, self.known_modules)
if (
isinstance(annotation, UnboundType)
and not annotation.args
Expand Down Expand Up @@ -1460,7 +1464,14 @@ def visit_assignment_stmt(self, o: AssignmentStmt) -> None:
and isinstance(lvalue.expr, NameExpr)
and lvalue.expr.name == "self"
):
self.results.append((lvalue.name, o.rvalue, o.unanalyzed_type))
# lvalue.node might be populated with an inferred type
if isinstance(lvalue.node, Var) and (
lvalue.node.is_ready or not isinstance(get_proper_type(lvalue.node.type), AnyType)
):
annotation = lvalue.node.type
else:
annotation = o.unanalyzed_type
self.results.append((lvalue.name, o.rvalue, annotation))


def find_self_initializers(fdef: FuncBase) -> list[tuple[str, Expression, Type | None]]:
Expand Down Expand Up @@ -1652,7 +1663,7 @@ def mypy_options(stubgen_options: Options) -> MypyOptions:
options.follow_imports = "skip"
options.incremental = False
options.ignore_errors = True
options.semantic_analysis_only = True
options.semantic_analysis_only = False
options.python_version = stubgen_options.pyversion
options.show_traceback = True
options.transform_source = remove_misplaced_type_comments
Expand Down Expand Up @@ -1729,7 +1740,7 @@ def generate_stub_for_py_module(
) -> None:
"""Use analysed (or just parsed) AST to generate type stub for single file.

If directory for target doesn't exist it will created. Existing stub
If directory for target doesn't exist it will be created. Existing stub
will be overwritten.
"""
if inspect:
Expand All @@ -1748,6 +1759,7 @@ def generate_stub_for_py_module(
else:
gen = ASTStubGenerator(
mod.runtime_all,
known_modules=all_modules,
include_private=include_private,
analyzed=not parse_only,
export_less=export_less,
Expand Down
10 changes: 4 additions & 6 deletions mypy/stubgenc.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,8 +171,8 @@ def generate_stub_for_c_module(

gen = InspectionStubGenerator(
module_name,
known_modules,
doc_dir,
known_modules=known_modules,
doc_dir=doc_dir,
include_private=include_private,
export_less=export_less,
include_docstrings=include_docstrings,
Expand Down Expand Up @@ -240,9 +240,8 @@ def __init__(
else:
self.module = module
self.is_c_module = is_c_module(self.module)
self.known_modules = known_modules
self.resort_members = self.is_c_module
super().__init__(_all_, include_private, export_less, include_docstrings)
super().__init__(_all_, include_private, export_less, include_docstrings, known_modules)
self.module_name = module_name
if self.is_c_module:
# Add additional implicit imports.
Expand Down Expand Up @@ -393,10 +392,9 @@ def strip_or_import(self, type_name: str) -> str:
Arguments:
typ: name of the type
"""
local_modules = ["builtins", self.module_name]
parsed_type = parse_type_comment(type_name, 0, 0, None)[1]
assert parsed_type is not None, type_name
return self.print_annotation(parsed_type, self.known_modules, local_modules)
return self.print_annotation(parsed_type, self.known_modules)

def get_obj_module(self, obj: object) -> str | None:
"""Return module name of the object."""
Expand Down
104 changes: 87 additions & 17 deletions mypy/stubutil.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,18 @@
from mypy.stubdoc import ArgSig, FunctionSig
from mypy.types import (
AnyType,
CallableType,
DeletedType,
ErasedType,
Instance,
NoneType,
Type,
TypedDictType,
TypeList,
TypeStrVisitor,
TypeVarType,
UnboundType,
UninhabitedType,
UnionType,
UnpackType,
)
Expand Down Expand Up @@ -251,6 +258,23 @@ def __init__(
self.known_modules = known_modules
self.local_modules = local_modules or ["builtins"]

def track_imports(self, s: str) -> str | None:
if self.known_modules is not None and "." in s:
# see if this object is from any of the modules that we're currently processing.
# reverse sort so that subpackages come before parents: e.g. "foo.bar" before "foo".
for module_name in self.local_modules + sorted(self.known_modules, reverse=True):
if s.startswith(module_name + "."):
if module_name in self.local_modules:
s = s[len(module_name) + 1 :]
arg_module = module_name
break
else:
arg_module = s[: s.rindex(".")]
if arg_module not in self.local_modules:
self.stubgen.import_tracker.add_import(arg_module, require=True)
return s
return None

def visit_any(self, t: AnyType) -> str:
s = super().visit_any(t)
self.stubgen.import_tracker.require_name(s)
Expand All @@ -267,19 +291,9 @@ def visit_unbound_type(self, t: UnboundType) -> str:
return self.stubgen.add_name("_typeshed.Incomplete")
if fullname in TYPING_BUILTIN_REPLACEMENTS:
s = self.stubgen.add_name(TYPING_BUILTIN_REPLACEMENTS[fullname], require=True)
if self.known_modules is not None and "." in s:
# see if this object is from any of the modules that we're currently processing.
# reverse sort so that subpackages come before parents: e.g. "foo.bar" before "foo".
for module_name in self.local_modules + sorted(self.known_modules, reverse=True):
if s.startswith(module_name + "."):
if module_name in self.local_modules:
s = s[len(module_name) + 1 :]
arg_module = module_name
break
else:
arg_module = s[: s.rindex(".")]
if arg_module not in self.local_modules:
self.stubgen.import_tracker.add_import(arg_module, require=True)

if new_s := self.track_imports(s):
s = new_s
elif s == "NoneType":
# when called without analysis all types are unbound, so this won't hit
# visit_none_type().
Expand All @@ -292,6 +306,9 @@ def visit_unbound_type(self, t: UnboundType) -> str:
s += "[()]"
return s

def typeddict_item_str(self, t: TypedDictType, name: str, typ: str) -> str:
return f"{name!r}: {typ}"

def visit_none_type(self, t: NoneType) -> str:
return "None"

Expand Down Expand Up @@ -322,6 +339,55 @@ def args_str(self, args: Iterable[Type]) -> str:
res.append(arg_str)
return ", ".join(res)

def visit_type_var(self, t: TypeVarType) -> str:
return t.name

def visit_uninhabited_type(self, t: UninhabitedType) -> str:
return self.stubgen.add_name("typing.Any")

def visit_erased_type(self, t: ErasedType) -> str:
return self.stubgen.add_name("typing.Any")

def visit_deleted_type(self, t: DeletedType) -> str:
return self.stubgen.add_name("typing.Any")

def visit_instance(self, t: Instance) -> str:
if t.last_known_value and not t.args:
# Instances with a literal fallback should never be generic. If they are,
# something went wrong so we fall back to showing the full Instance repr.
s = f"{t.last_known_value.accept(self)}"
else:
s = t.type.fullname or t.type.name or self.stubgen.add_name("_typeshed.Incomplete")

s = self.track_imports(s) or s

if t.args:
if t.type.fullname == "builtins.tuple":
assert len(t.args) == 1
s += f"[{self.list_str(t.args)}, ...]"
else:
s += f"[{self.list_str(t.args)}]"
elif t.type.has_type_var_tuple_type and len(t.type.type_vars) == 1:
s += "[()]"

return s

def visit_callable_type(self, t: CallableType) -> str:
from mypy.suggestions import is_tricky_callable

if is_tricky_callable(t):
arg_str = "..."
else:
# Note: for default arguments, we just assume that they
# are required. This isn't right, but neither is the
# other thing, and I suspect this will produce more better
# results than falling back to `...`
args = [typ.accept(self) for typ in t.arg_types]
arg_str = f"[{', '.join(args)}]"

callable = self.stubgen.add_name("typing.Callable")
return f"{callable}[{arg_str}, {t.ret_type.accept(self)}]"


class ClassInfo:
def __init__(
Expand Down Expand Up @@ -454,11 +520,11 @@ class ImportTracker:

def __init__(self) -> None:
# module_for['foo'] has the module name where 'foo' was imported from, or None if
# 'foo' is a module imported directly;
# 'foo' is a module imported directly;
# direct_imports['foo'] is the module path used when the name 'foo' was added to the
# namespace.
# namespace.
# reverse_alias['foo'] is the name that 'foo' had originally when imported with an
# alias; examples
# alias; examples
# 'from pkg import mod' ==> module_for['mod'] == 'pkg'
# 'from pkg import mod as m' ==> module_for['m'] == 'pkg'
# ==> reverse_alias['m'] == 'mod'
Expand Down Expand Up @@ -618,7 +684,9 @@ def __init__(
include_private: bool = False,
export_less: bool = False,
include_docstrings: bool = False,
known_modules: list[str] | None = None,
) -> None:
self.known_modules = known_modules or []
# Best known value of __all__.
self._all_ = _all_
self._include_private = include_private
Expand Down Expand Up @@ -839,7 +907,9 @@ def print_annotation(
known_modules: list[str] | None = None,
local_modules: list[str] | None = None,
) -> str:
printer = AnnotationPrinter(self, known_modules, local_modules)
printer = AnnotationPrinter(
self, known_modules, local_modules or ["builtins", self.module_name]
)
return t.accept(printer)

def is_not_in_all(self, name: str) -> bool:
Expand Down
19 changes: 11 additions & 8 deletions mypy/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -3462,18 +3462,21 @@ def visit_tuple_type(self, t: TupleType, /) -> str:
return f"{tuple_name}[{s}, fallback={t.partial_fallback.accept(self)}]"
return f"{tuple_name}[{s}]"

def typeddict_item_str(self, t: TypedDictType, name: str, typ: str) -> str:
modifier = ""
if name not in t.required_keys:
modifier += "?"
if name in t.readonly_keys:
modifier += "="
return f"{name!r}{modifier}: {typ}"

def visit_typeddict_type(self, t: TypedDictType, /) -> str:
def item_str(name: str, typ: str) -> str:
modifier = ""
if name not in t.required_keys:
modifier += "?"
if name in t.readonly_keys:
modifier += "="
return f"{name!r}{modifier}: {typ}"

s = (
"{"
+ ", ".join(item_str(name, typ.accept(self)) for name, typ in t.items.items())
+ ", ".join(
self.typeddict_item_str(t, name, typ.accept(self)) for name, typ in t.items.items()
)
+ "}"
)
prefix = ""
Expand Down
Loading