diff --git a/pyproject.toml b/pyproject.toml index ec70681..0811f40 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -38,7 +38,7 @@ Documentation = "https://typer.tiangolo.com/" [tool.flit.metadata.requires-extra] test = [ "shellingham", - "pytest >=4.4.0", + "pytest >=4.4.0,< 5.4", "pytest-cov", "coverage", "pytest-xdist", diff --git a/tests/test_others.py b/tests/test_others.py index 93f3e81..897941a 100644 --- a/tests/test_others.py +++ b/tests/test_others.py @@ -199,3 +199,26 @@ def test_autocompletion_too_many_parameters(): with pytest.raises(click.ClickException) as exc_info: runner.invoke(app, ["--name", "Camila"]) assert exc_info.value.message == "Invalid autocompletion callback parameters: val2" + + +def test_forward_references(): + app = typer.Typer() + + @app.command() + def main(arg1, arg2: int, arg3: "int", arg4: bool = False, arg5: "bool" = False): + typer.echo(f"arg1: {type(arg1)} {arg1}") + typer.echo(f"arg2: {type(arg2)} {arg2}") + typer.echo(f"arg3: {type(arg3)} {arg3}") + typer.echo(f"arg4: {type(arg4)} {arg4}") + typer.echo(f"arg5: {type(arg5)} {arg5}") + + result = runner.invoke(app, ["Hello", "2", "invalid"]) + assert ( + "Error: Invalid value for 'ARG3': invalid is not a valid integer" + in result.stdout + ) + result = runner.invoke(app, ["Hello", "2", "3", "--arg4", "--arg5"]) + assert ( + "arg1: Hello\narg2: 2\narg3: 3\narg4: True\narg5: True\n" + in result.stdout + ) diff --git a/typer/completion.py b/typer/completion.py index 43f1e15..4a0b012 100644 --- a/typer/completion.py +++ b/typer/completion.py @@ -1,4 +1,3 @@ -import inspect import os import re import subprocess @@ -10,7 +9,9 @@ from typing import Any, Optional, Tuple import click import click._bashcomplete +from .models import ParamMeta from .params import Option +from .utils import get_params_from_function try: import shellingham @@ -21,14 +22,16 @@ except ImportError: # pragma: nocover _click_patched = False -def get_completion_inspect_parameters() -> Tuple[inspect.Parameter, inspect.Parameter]: +def get_completion_inspect_parameters() -> Tuple[ParamMeta, ParamMeta]: completion_init() test_disable_detection = os.getenv("_TYPER_COMPLETE_TEST_DISABLE_SHELL_DETECTION") if shellingham and not test_disable_detection: - signature = inspect.signature(_install_completion_placeholder_function) + parameters = get_params_from_function(_install_completion_placeholder_function) else: - signature = inspect.signature(_install_completion_no_auto_placeholder_function) - install_param, show_param = signature.parameters.values() + parameters = get_params_from_function( + _install_completion_no_auto_placeholder_function + ) + install_param, show_param = parameters.values() return install_param, show_param @@ -204,7 +207,7 @@ def install_bash(*, prog_name: str, complete_var: str, shell: str) -> Path: rc_content = rc_path.read_text() completion_init_lines = [f"source {completion_path}"] for line in completion_init_lines: - if line not in rc_content: + if line not in rc_content: # pragma: nocover rc_content += f"\n{line}" rc_content += "\n" rc_path.write_text(rc_content) @@ -231,7 +234,7 @@ def install_zsh(*, prog_name: str, complete_var: str, shell: str) -> Path: "fpath+=~/.zfunc", ] for line in completion_init_lines: - if line not in zshrc_content: + if line not in zshrc_content: # pragma: nocover zshrc_content += f"\n{line}" zshrc_content += "\n" zshrc_path.write_text(zshrc_content) diff --git a/typer/main.py b/typer/main.py index 807ae8e..5e5c90d 100644 --- a/typer/main.py +++ b/typer/main.py @@ -23,9 +23,11 @@ from .models import ( NoneType, OptionInfo, ParameterInfo, + ParamMeta, Required, TyperInfo, ) +from .utils import get_params_from_function def get_install_completion_arguments() -> Tuple[click.Parameter, click.Parameter]: @@ -393,8 +395,8 @@ def get_params_convertors_ctx_param_name_from_function( convertors = {} context_param_name = None if callback: - signature = inspect.signature(callback) - for param_name, param in signature.parameters.items(): + parameters = get_params_from_function(callback) + for param_name, param in parameters.items(): if lenient_issubclass(param.annotation, click.Context): context_param_name = param_name continue @@ -476,9 +478,9 @@ def get_callback( ) -> Optional[Callable]: if not callback: return None - signature = inspect.signature(callback) + parameters = get_params_from_function(callback) use_params: Dict[str, Any] = {} - for param_name, param_sig in signature.parameters.items(): + for param_name in parameters: use_params[param_name] = None for param in params: use_params[param.name] = param.default @@ -591,7 +593,7 @@ def lenient_issubclass( def get_click_param( - param: inspect.Parameter, + param: ParamMeta, ) -> Tuple[Union[click.Argument, click.Option], Any]: # First, find out what will be: # * ParamInfo (ArgumentInfo or OptionInfo) @@ -744,12 +746,12 @@ def get_param_callback( ) -> Optional[Callable]: if not callback: return None - signature = inspect.signature(callback) + parameters = get_params_from_function(callback) ctx_name = None click_param_name = None value_name = None untyped_names: List[str] = [] - for param_name, param_sig in signature.parameters.items(): + for param_name, param_sig in parameters.items(): if lenient_issubclass(param_sig.annotation, click.Context): ctx_name = param_name elif lenient_issubclass(param_sig.annotation, click.Parameter): @@ -792,11 +794,11 @@ def get_param_callback( def get_param_completion(callback: Optional[Callable] = None) -> Optional[Callable]: if not callback: return None - signature = inspect.signature(callback) + parameters = get_params_from_function(callback) ctx_name = None args_name = None incomplete_name = None - unassigned_params = [param for param in signature.parameters.values()] + unassigned_params = [param for param in parameters.values()] for param_sig in unassigned_params[:]: origin = getattr(param_sig.annotation, "__origin__", None) if lenient_issubclass(param_sig.annotation, click.Context): diff --git a/typer/models.py b/typer/models.py index 824979e..a705b71 100644 --- a/typer/models.py +++ b/typer/models.py @@ -1,3 +1,4 @@ +import inspect import io from typing import ( TYPE_CHECKING, @@ -388,3 +389,18 @@ class ArgumentInfo(ParameterInfo): allow_dash=allow_dash, path_type=path_type, ) + + +class ParamMeta: + empty = inspect.Parameter.empty + + def __init__( + self, + *, + name: str, + default: Any = inspect.Parameter.empty, + annotation: Any = inspect.Parameter.empty, + ) -> None: + self.name = name + self.default = default + self.annotation = annotation diff --git a/typer/utils.py b/typer/utils.py new file mode 100644 index 0000000..d5d3b1e --- /dev/null +++ b/typer/utils.py @@ -0,0 +1,18 @@ +import inspect +from typing import Callable, Dict, get_type_hints + +from .models import ParamMeta + + +def get_params_from_function(func: Callable) -> Dict[str, ParamMeta]: + signature = inspect.signature(func) + type_hints = get_type_hints(func) + params = {} + for param in signature.parameters.values(): + annotation = param.annotation + if param.name in type_hints: + annotation = type_hints[param.name] + params[param.name] = ParamMeta( + name=param.name, default=param.default, annotation=annotation + ) + return params