Add support for forward references (#93)

*  Add support for forward references

*  Add forward references test

* 🐛 Fix coverage in completion

* 🐛 Fix testing with Pytest and Pytest-sugar

* 📌 Pin Pytest
This commit is contained in:
Sebastián Ramírez 2020-04-26 15:42:13 +02:00 committed by GitHub
parent deda63ab16
commit a93e6b2f86
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
6 changed files with 79 additions and 17 deletions

View file

@ -38,7 +38,7 @@ Documentation = "https://typer.tiangolo.com/"
[tool.flit.metadata.requires-extra]
test = [
"shellingham",
"pytest >=4.4.0",
"pytest >=4.4.0,< 5.4",
"pytest-cov",
"coverage",
"pytest-xdist",

View file

@ -199,3 +199,26 @@ def test_autocompletion_too_many_parameters():
with pytest.raises(click.ClickException) as exc_info:
runner.invoke(app, ["--name", "Camila"])
assert exc_info.value.message == "Invalid autocompletion callback parameters: val2"
def test_forward_references():
app = typer.Typer()
@app.command()
def main(arg1, arg2: int, arg3: "int", arg4: bool = False, arg5: "bool" = False):
typer.echo(f"arg1: {type(arg1)} {arg1}")
typer.echo(f"arg2: {type(arg2)} {arg2}")
typer.echo(f"arg3: {type(arg3)} {arg3}")
typer.echo(f"arg4: {type(arg4)} {arg4}")
typer.echo(f"arg5: {type(arg5)} {arg5}")
result = runner.invoke(app, ["Hello", "2", "invalid"])
assert (
"Error: Invalid value for 'ARG3': invalid is not a valid integer"
in result.stdout
)
result = runner.invoke(app, ["Hello", "2", "3", "--arg4", "--arg5"])
assert (
"arg1: <class 'str'> Hello\narg2: <class 'int'> 2\narg3: <class 'int'> 3\narg4: <class 'bool'> True\narg5: <class 'bool'> True\n"
in result.stdout
)

View file

@ -1,4 +1,3 @@
import inspect
import os
import re
import subprocess
@ -10,7 +9,9 @@ from typing import Any, Optional, Tuple
import click
import click._bashcomplete
from .models import ParamMeta
from .params import Option
from .utils import get_params_from_function
try:
import shellingham
@ -21,14 +22,16 @@ except ImportError: # pragma: nocover
_click_patched = False
def get_completion_inspect_parameters() -> Tuple[inspect.Parameter, inspect.Parameter]:
def get_completion_inspect_parameters() -> Tuple[ParamMeta, ParamMeta]:
completion_init()
test_disable_detection = os.getenv("_TYPER_COMPLETE_TEST_DISABLE_SHELL_DETECTION")
if shellingham and not test_disable_detection:
signature = inspect.signature(_install_completion_placeholder_function)
parameters = get_params_from_function(_install_completion_placeholder_function)
else:
signature = inspect.signature(_install_completion_no_auto_placeholder_function)
install_param, show_param = signature.parameters.values()
parameters = get_params_from_function(
_install_completion_no_auto_placeholder_function
)
install_param, show_param = parameters.values()
return install_param, show_param
@ -204,7 +207,7 @@ def install_bash(*, prog_name: str, complete_var: str, shell: str) -> Path:
rc_content = rc_path.read_text()
completion_init_lines = [f"source {completion_path}"]
for line in completion_init_lines:
if line not in rc_content:
if line not in rc_content: # pragma: nocover
rc_content += f"\n{line}"
rc_content += "\n"
rc_path.write_text(rc_content)
@ -231,7 +234,7 @@ def install_zsh(*, prog_name: str, complete_var: str, shell: str) -> Path:
"fpath+=~/.zfunc",
]
for line in completion_init_lines:
if line not in zshrc_content:
if line not in zshrc_content: # pragma: nocover
zshrc_content += f"\n{line}"
zshrc_content += "\n"
zshrc_path.write_text(zshrc_content)

View file

@ -23,9 +23,11 @@ from .models import (
NoneType,
OptionInfo,
ParameterInfo,
ParamMeta,
Required,
TyperInfo,
)
from .utils import get_params_from_function
def get_install_completion_arguments() -> Tuple[click.Parameter, click.Parameter]:
@ -393,8 +395,8 @@ def get_params_convertors_ctx_param_name_from_function(
convertors = {}
context_param_name = None
if callback:
signature = inspect.signature(callback)
for param_name, param in signature.parameters.items():
parameters = get_params_from_function(callback)
for param_name, param in parameters.items():
if lenient_issubclass(param.annotation, click.Context):
context_param_name = param_name
continue
@ -476,9 +478,9 @@ def get_callback(
) -> Optional[Callable]:
if not callback:
return None
signature = inspect.signature(callback)
parameters = get_params_from_function(callback)
use_params: Dict[str, Any] = {}
for param_name, param_sig in signature.parameters.items():
for param_name in parameters:
use_params[param_name] = None
for param in params:
use_params[param.name] = param.default
@ -591,7 +593,7 @@ def lenient_issubclass(
def get_click_param(
param: inspect.Parameter,
param: ParamMeta,
) -> Tuple[Union[click.Argument, click.Option], Any]:
# First, find out what will be:
# * ParamInfo (ArgumentInfo or OptionInfo)
@ -744,12 +746,12 @@ def get_param_callback(
) -> Optional[Callable]:
if not callback:
return None
signature = inspect.signature(callback)
parameters = get_params_from_function(callback)
ctx_name = None
click_param_name = None
value_name = None
untyped_names: List[str] = []
for param_name, param_sig in signature.parameters.items():
for param_name, param_sig in parameters.items():
if lenient_issubclass(param_sig.annotation, click.Context):
ctx_name = param_name
elif lenient_issubclass(param_sig.annotation, click.Parameter):
@ -792,11 +794,11 @@ def get_param_callback(
def get_param_completion(callback: Optional[Callable] = None) -> Optional[Callable]:
if not callback:
return None
signature = inspect.signature(callback)
parameters = get_params_from_function(callback)
ctx_name = None
args_name = None
incomplete_name = None
unassigned_params = [param for param in signature.parameters.values()]
unassigned_params = [param for param in parameters.values()]
for param_sig in unassigned_params[:]:
origin = getattr(param_sig.annotation, "__origin__", None)
if lenient_issubclass(param_sig.annotation, click.Context):

View file

@ -1,3 +1,4 @@
import inspect
import io
from typing import (
TYPE_CHECKING,
@ -388,3 +389,18 @@ class ArgumentInfo(ParameterInfo):
allow_dash=allow_dash,
path_type=path_type,
)
class ParamMeta:
empty = inspect.Parameter.empty
def __init__(
self,
*,
name: str,
default: Any = inspect.Parameter.empty,
annotation: Any = inspect.Parameter.empty,
) -> None:
self.name = name
self.default = default
self.annotation = annotation

18
typer/utils.py Normal file
View file

@ -0,0 +1,18 @@
import inspect
from typing import Callable, Dict, get_type_hints
from .models import ParamMeta
def get_params_from_function(func: Callable) -> Dict[str, ParamMeta]:
signature = inspect.signature(func)
type_hints = get_type_hints(func)
params = {}
for param in signature.parameters.values():
annotation = param.annotation
if param.name in type_hints:
annotation = type_hints[param.name]
params[param.name] = ParamMeta(
name=param.name, default=param.default, annotation=annotation
)
return params