Allow configuring pretty errors when creating the Typer instance (#416)

This commit is contained in:
Sebastián Ramírez 2022-07-08 17:20:32 +02:00 committed by GitHub
parent 850060776c
commit 95a5233d9e
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
6 changed files with 166 additions and 18 deletions

View file

@ -0,0 +1,16 @@
import typer
import typer.main
typer.main.rich = None
app = typer.Typer(pretty_errors_short=False)
@app.command()
def main(name: str = "morty"):
print(name + 3)
if __name__ == "__main__":
app()

View file

@ -0,0 +1,12 @@
import typer
app = typer.Typer(pretty_errors_enable=False)
@app.command()
def main(name: str = "morty"):
print(name + 3)
if __name__ == "__main__":
app()

View file

@ -0,0 +1,12 @@
import typer
app = typer.Typer(pretty_errors_short=False)
@app.command()
def main(name: str = "morty"):
print(name + 3)
if __name__ == "__main__":
app()

View file

@ -12,7 +12,28 @@ def test_traceback_rich():
)
assert "return get_command(self)(*args, **kwargs)" not in result.stderr
assert "typer.run(main)" in result.stderr
assert "typer.run(main)" not in result.stderr
assert "print(name + 3)" in result.stderr
# TODO: when deprecating Python 3.6, remove second option
assert (
'TypeError: can only concatenate str (not "int") to str' in result.stderr
or "TypeError: must be str, not int" in result.stderr
)
assert "name = 'morty'" in result.stderr
def test_traceback_rich_pretty_short_disable():
file_path = Path(__file__).parent / "assets/type_error_rich_short_disable.py"
result = subprocess.run(
["coverage", "run", str(file_path)],
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
encoding="utf-8",
)
assert "return get_command(self)(*args, **kwargs)" not in result.stderr
assert "app()" in result.stderr
assert "print(name + 3)" in result.stderr
# TODO: when deprecating Python 3.6, remove second option
@ -42,6 +63,25 @@ def test_traceback_no_rich():
)
def test_traceback_no_rich_short_disable():
file_path = Path(__file__).parent / "assets/type_error_no_rich_short_disable.py"
result = subprocess.run(
["coverage", "run", str(file_path)],
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
encoding="utf-8",
)
assert "return get_command(self)(*args, **kwargs)" not in result.stderr
assert "app()" in result.stderr
assert "print(name + 3)" in result.stderr
# TODO: when deprecating Python 3.6, remove second option
assert (
'TypeError: can only concatenate str (not "int") to str' in result.stderr
or "TypeError: must be str, not int" in result.stderr
)
def test_unmodified_traceback():
file_path = Path(__file__).parent / "assets/type_error_normal_traceback.py"
result = subprocess.run(
@ -62,3 +102,22 @@ def test_unmodified_traceback():
'TypeError: can only concatenate str (not "int") to str' in result.stderr
or "TypeError: must be str, not int" in result.stderr
)
def test_rich_pretty_errors_disable():
file_path = Path(__file__).parent / "assets/type_error_rich_pretty_disable.py"
result = subprocess.run(
["coverage", "run", str(file_path)],
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
encoding="utf-8",
)
assert "return get_command(self)(*args, **kwargs)" in result.stderr
assert "app()" in result.stderr
assert "print(name + 3)" in result.stderr
# TODO: when deprecating Python 3.6, remove second option
assert (
'TypeError: can only concatenate str (not "int") to str' in result.stderr
or "TypeError: must be str, not int" in result.stderr
)

View file

@ -22,6 +22,7 @@ from .models import (
CommandInfo,
Default,
DefaultPlaceholder,
DeveloperExceptionConfig,
FileBinaryRead,
FileBinaryWrite,
FileText,
@ -52,7 +53,10 @@ _typer_developer_exception_attr_name = "__typer_developer_exception__"
def except_hook(
exc_type: Type[BaseException], exc_value: BaseException, tb: TracebackType
) -> None:
if not getattr(exc_value, _typer_developer_exception_attr_name, None):
exception_config: Union[DeveloperExceptionConfig, None] = getattr(
exc_value, _typer_developer_exception_attr_name, None
)
if not exception_config or not exception_config.pretty_errors_enable:
_original_except_hook(exc_type, exc_value, tb)
return
typer_path = os.path.dirname(__file__)
@ -64,7 +68,7 @@ def except_hook(
type(exc),
exc,
exc.__traceback__,
show_locals=True,
show_locals=exception_config.pretty_errors_show_locals,
suppress=supress_internal_dir_names,
)
console_stderr.print(rich_tb)
@ -75,15 +79,16 @@ def except_hook(
if any(
[frame.filename.startswith(path) for path in supress_internal_dir_names]
):
# Hide the line for internal libraries, Typer and Click
stack.append(
traceback.FrameSummary(
filename=frame.filename,
lineno=frame.lineno,
name=frame.name,
line="",
if not exception_config.pretty_errors_short:
# Hide the line for internal libraries, Typer and Click
stack.append(
traceback.FrameSummary(
filename=frame.filename,
lineno=frame.lineno,
name=frame.name,
line="",
)
)
)
else:
stack.append(frame)
# Type ignore ref: https://github.com/python/typeshed/pull/8244
@ -123,8 +128,14 @@ class Typer:
hidden: bool = Default(False),
deprecated: bool = Default(False),
add_completion: bool = True,
pretty_errors_enable: bool = True,
pretty_errors_show_locals: bool = True,
pretty_errors_short: bool = True,
):
self._add_completion = add_completion
self.pretty_errors_enable = pretty_errors_enable
self.pretty_errors_show_locals = pretty_errors_show_locals
self.pretty_errors_short = pretty_errors_short
self.info = TyperInfo(
name=name,
cls=cls,
@ -285,12 +296,23 @@ class Typer:
# but that means the last error shown is the custom exception, not the
# actual error. This trick improves developer experience by showing the
# actual error last.
setattr(e, _typer_developer_exception_attr_name, True)
setattr(
e,
_typer_developer_exception_attr_name,
DeveloperExceptionConfig(
pretty_errors_enable=self.pretty_errors_enable,
pretty_errors_show_locals=self.pretty_errors_show_locals,
pretty_errors_short=self.pretty_errors_short,
),
)
raise e
def get_group(typer_instance: Typer) -> click.Command:
group = get_group_from_info(TyperInfo(typer_instance))
group = get_group_from_info(
TyperInfo(typer_instance),
pretty_errors_short=typer_instance.pretty_errors_short,
)
return group
@ -318,7 +340,9 @@ def get_command(typer_instance: Typer) -> click.Command:
):
single_command.context_settings = typer_instance.info.context_settings
click_command = get_command_from_info(single_command)
click_command = get_command_from_info(
single_command, pretty_errors_short=typer_instance.pretty_errors_short
)
if typer_instance._add_completion:
click_command.params.append(click_install_param)
click_command.params.append(click_show_param)
@ -422,17 +446,23 @@ def solve_typer_info_defaults(typer_info: TyperInfo) -> TyperInfo:
return TyperInfo(**values)
def get_group_from_info(group_info: TyperInfo) -> click.Command:
def get_group_from_info(
group_info: TyperInfo, *, pretty_errors_short: bool
) -> click.Command:
assert (
group_info.typer_instance
), "A Typer instance is needed to generate a Click Group"
commands: Dict[str, click.Command] = {}
for command_info in group_info.typer_instance.registered_commands:
command = get_command_from_info(command_info=command_info)
command = get_command_from_info(
command_info=command_info, pretty_errors_short=pretty_errors_short
)
if command.name:
commands[command.name] = command
for sub_group_info in group_info.typer_instance.registered_groups:
sub_group = get_group_from_info(sub_group_info)
sub_group = get_group_from_info(
sub_group_info, pretty_errors_short=pretty_errors_short
)
if sub_group.name:
commands[sub_group.name] = sub_group
solved_info = solve_typer_info_defaults(group_info)
@ -456,6 +486,7 @@ def get_group_from_info(group_info: TyperInfo) -> click.Command:
params=params,
convertors=convertors,
context_param_name=context_param_name,
pretty_errors_short=pretty_errors_short,
),
params=params, # type: ignore
help=solved_info.help,
@ -492,7 +523,9 @@ def get_params_convertors_ctx_param_name_from_function(
return params, convertors, context_param_name
def get_command_from_info(command_info: CommandInfo) -> click.Command:
def get_command_from_info(
command_info: CommandInfo, *, pretty_errors_short: bool
) -> click.Command:
assert command_info.callback, "A command must have a callback function"
name = command_info.name or get_command_name(command_info.callback.__name__)
use_help = command_info.help
@ -514,6 +547,7 @@ def get_command_from_info(command_info: CommandInfo) -> click.Command:
params=params,
convertors=convertors,
context_param_name=context_param_name,
pretty_errors_short=pretty_errors_short,
),
params=params, # type: ignore
help=use_help,
@ -585,6 +619,7 @@ def get_callback(
params: Sequence[click.Parameter] = [],
convertors: Dict[str, Callable[[str], Any]] = {},
context_param_name: Optional[str] = None,
pretty_errors_short: bool,
) -> Optional[Callable[..., Any]]:
if not callback:
return None
@ -597,6 +632,7 @@ def get_callback(
use_params[param.name] = param.default
def wrapper(**kwargs: Any) -> Any:
_rich_traceback_guard = pretty_errors_short # noqa: F841
for k, v in kwargs.items():
if k in convertors:
use_params[k] = convertors[k](v)

View file

@ -454,3 +454,16 @@ class ParamMeta:
self.name = name
self.default = default
self.annotation = annotation
class DeveloperExceptionConfig:
def __init__(
self,
*,
pretty_errors_enable: bool = True,
pretty_errors_show_locals: bool = True,
pretty_errors_short: bool = True,
) -> None:
self.pretty_errors_enable = pretty_errors_enable
self.pretty_errors_show_locals = pretty_errors_show_locals
self.pretty_errors_short = pretty_errors_short