✨ Allow configuring pretty errors when creating the Typer instance (#416)
This commit is contained in:
parent
850060776c
commit
95a5233d9e
6 changed files with 166 additions and 18 deletions
16
tests/assets/type_error_no_rich_short_disable.py
Normal file
16
tests/assets/type_error_no_rich_short_disable.py
Normal file
|
@ -0,0 +1,16 @@
|
|||
import typer
|
||||
import typer.main
|
||||
|
||||
typer.main.rich = None
|
||||
|
||||
|
||||
app = typer.Typer(pretty_errors_short=False)
|
||||
|
||||
|
||||
@app.command()
|
||||
def main(name: str = "morty"):
|
||||
print(name + 3)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
app()
|
12
tests/assets/type_error_rich_pretty_disable.py
Normal file
12
tests/assets/type_error_rich_pretty_disable.py
Normal file
|
@ -0,0 +1,12 @@
|
|||
import typer
|
||||
|
||||
app = typer.Typer(pretty_errors_enable=False)
|
||||
|
||||
|
||||
@app.command()
|
||||
def main(name: str = "morty"):
|
||||
print(name + 3)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
app()
|
12
tests/assets/type_error_rich_short_disable.py
Normal file
12
tests/assets/type_error_rich_short_disable.py
Normal file
|
@ -0,0 +1,12 @@
|
|||
import typer
|
||||
|
||||
app = typer.Typer(pretty_errors_short=False)
|
||||
|
||||
|
||||
@app.command()
|
||||
def main(name: str = "morty"):
|
||||
print(name + 3)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
app()
|
|
@ -12,7 +12,28 @@ def test_traceback_rich():
|
|||
)
|
||||
assert "return get_command(self)(*args, **kwargs)" not in result.stderr
|
||||
|
||||
assert "typer.run(main)" in result.stderr
|
||||
assert "typer.run(main)" not in result.stderr
|
||||
assert "print(name + 3)" in result.stderr
|
||||
|
||||
# TODO: when deprecating Python 3.6, remove second option
|
||||
assert (
|
||||
'TypeError: can only concatenate str (not "int") to str' in result.stderr
|
||||
or "TypeError: must be str, not int" in result.stderr
|
||||
)
|
||||
assert "name = 'morty'" in result.stderr
|
||||
|
||||
|
||||
def test_traceback_rich_pretty_short_disable():
|
||||
file_path = Path(__file__).parent / "assets/type_error_rich_short_disable.py"
|
||||
result = subprocess.run(
|
||||
["coverage", "run", str(file_path)],
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.PIPE,
|
||||
encoding="utf-8",
|
||||
)
|
||||
assert "return get_command(self)(*args, **kwargs)" not in result.stderr
|
||||
|
||||
assert "app()" in result.stderr
|
||||
assert "print(name + 3)" in result.stderr
|
||||
|
||||
# TODO: when deprecating Python 3.6, remove second option
|
||||
|
@ -42,6 +63,25 @@ def test_traceback_no_rich():
|
|||
)
|
||||
|
||||
|
||||
def test_traceback_no_rich_short_disable():
|
||||
file_path = Path(__file__).parent / "assets/type_error_no_rich_short_disable.py"
|
||||
result = subprocess.run(
|
||||
["coverage", "run", str(file_path)],
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.PIPE,
|
||||
encoding="utf-8",
|
||||
)
|
||||
assert "return get_command(self)(*args, **kwargs)" not in result.stderr
|
||||
|
||||
assert "app()" in result.stderr
|
||||
assert "print(name + 3)" in result.stderr
|
||||
# TODO: when deprecating Python 3.6, remove second option
|
||||
assert (
|
||||
'TypeError: can only concatenate str (not "int") to str' in result.stderr
|
||||
or "TypeError: must be str, not int" in result.stderr
|
||||
)
|
||||
|
||||
|
||||
def test_unmodified_traceback():
|
||||
file_path = Path(__file__).parent / "assets/type_error_normal_traceback.py"
|
||||
result = subprocess.run(
|
||||
|
@ -62,3 +102,22 @@ def test_unmodified_traceback():
|
|||
'TypeError: can only concatenate str (not "int") to str' in result.stderr
|
||||
or "TypeError: must be str, not int" in result.stderr
|
||||
)
|
||||
|
||||
|
||||
def test_rich_pretty_errors_disable():
|
||||
file_path = Path(__file__).parent / "assets/type_error_rich_pretty_disable.py"
|
||||
result = subprocess.run(
|
||||
["coverage", "run", str(file_path)],
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.PIPE,
|
||||
encoding="utf-8",
|
||||
)
|
||||
assert "return get_command(self)(*args, **kwargs)" in result.stderr
|
||||
|
||||
assert "app()" in result.stderr
|
||||
assert "print(name + 3)" in result.stderr
|
||||
# TODO: when deprecating Python 3.6, remove second option
|
||||
assert (
|
||||
'TypeError: can only concatenate str (not "int") to str' in result.stderr
|
||||
or "TypeError: must be str, not int" in result.stderr
|
||||
)
|
||||
|
|
|
@ -22,6 +22,7 @@ from .models import (
|
|||
CommandInfo,
|
||||
Default,
|
||||
DefaultPlaceholder,
|
||||
DeveloperExceptionConfig,
|
||||
FileBinaryRead,
|
||||
FileBinaryWrite,
|
||||
FileText,
|
||||
|
@ -52,7 +53,10 @@ _typer_developer_exception_attr_name = "__typer_developer_exception__"
|
|||
def except_hook(
|
||||
exc_type: Type[BaseException], exc_value: BaseException, tb: TracebackType
|
||||
) -> None:
|
||||
if not getattr(exc_value, _typer_developer_exception_attr_name, None):
|
||||
exception_config: Union[DeveloperExceptionConfig, None] = getattr(
|
||||
exc_value, _typer_developer_exception_attr_name, None
|
||||
)
|
||||
if not exception_config or not exception_config.pretty_errors_enable:
|
||||
_original_except_hook(exc_type, exc_value, tb)
|
||||
return
|
||||
typer_path = os.path.dirname(__file__)
|
||||
|
@ -64,7 +68,7 @@ def except_hook(
|
|||
type(exc),
|
||||
exc,
|
||||
exc.__traceback__,
|
||||
show_locals=True,
|
||||
show_locals=exception_config.pretty_errors_show_locals,
|
||||
suppress=supress_internal_dir_names,
|
||||
)
|
||||
console_stderr.print(rich_tb)
|
||||
|
@ -75,15 +79,16 @@ def except_hook(
|
|||
if any(
|
||||
[frame.filename.startswith(path) for path in supress_internal_dir_names]
|
||||
):
|
||||
# Hide the line for internal libraries, Typer and Click
|
||||
stack.append(
|
||||
traceback.FrameSummary(
|
||||
filename=frame.filename,
|
||||
lineno=frame.lineno,
|
||||
name=frame.name,
|
||||
line="",
|
||||
if not exception_config.pretty_errors_short:
|
||||
# Hide the line for internal libraries, Typer and Click
|
||||
stack.append(
|
||||
traceback.FrameSummary(
|
||||
filename=frame.filename,
|
||||
lineno=frame.lineno,
|
||||
name=frame.name,
|
||||
line="",
|
||||
)
|
||||
)
|
||||
)
|
||||
else:
|
||||
stack.append(frame)
|
||||
# Type ignore ref: https://github.com/python/typeshed/pull/8244
|
||||
|
@ -123,8 +128,14 @@ class Typer:
|
|||
hidden: bool = Default(False),
|
||||
deprecated: bool = Default(False),
|
||||
add_completion: bool = True,
|
||||
pretty_errors_enable: bool = True,
|
||||
pretty_errors_show_locals: bool = True,
|
||||
pretty_errors_short: bool = True,
|
||||
):
|
||||
self._add_completion = add_completion
|
||||
self.pretty_errors_enable = pretty_errors_enable
|
||||
self.pretty_errors_show_locals = pretty_errors_show_locals
|
||||
self.pretty_errors_short = pretty_errors_short
|
||||
self.info = TyperInfo(
|
||||
name=name,
|
||||
cls=cls,
|
||||
|
@ -285,12 +296,23 @@ class Typer:
|
|||
# but that means the last error shown is the custom exception, not the
|
||||
# actual error. This trick improves developer experience by showing the
|
||||
# actual error last.
|
||||
setattr(e, _typer_developer_exception_attr_name, True)
|
||||
setattr(
|
||||
e,
|
||||
_typer_developer_exception_attr_name,
|
||||
DeveloperExceptionConfig(
|
||||
pretty_errors_enable=self.pretty_errors_enable,
|
||||
pretty_errors_show_locals=self.pretty_errors_show_locals,
|
||||
pretty_errors_short=self.pretty_errors_short,
|
||||
),
|
||||
)
|
||||
raise e
|
||||
|
||||
|
||||
def get_group(typer_instance: Typer) -> click.Command:
|
||||
group = get_group_from_info(TyperInfo(typer_instance))
|
||||
group = get_group_from_info(
|
||||
TyperInfo(typer_instance),
|
||||
pretty_errors_short=typer_instance.pretty_errors_short,
|
||||
)
|
||||
return group
|
||||
|
||||
|
||||
|
@ -318,7 +340,9 @@ def get_command(typer_instance: Typer) -> click.Command:
|
|||
):
|
||||
single_command.context_settings = typer_instance.info.context_settings
|
||||
|
||||
click_command = get_command_from_info(single_command)
|
||||
click_command = get_command_from_info(
|
||||
single_command, pretty_errors_short=typer_instance.pretty_errors_short
|
||||
)
|
||||
if typer_instance._add_completion:
|
||||
click_command.params.append(click_install_param)
|
||||
click_command.params.append(click_show_param)
|
||||
|
@ -422,17 +446,23 @@ def solve_typer_info_defaults(typer_info: TyperInfo) -> TyperInfo:
|
|||
return TyperInfo(**values)
|
||||
|
||||
|
||||
def get_group_from_info(group_info: TyperInfo) -> click.Command:
|
||||
def get_group_from_info(
|
||||
group_info: TyperInfo, *, pretty_errors_short: bool
|
||||
) -> click.Command:
|
||||
assert (
|
||||
group_info.typer_instance
|
||||
), "A Typer instance is needed to generate a Click Group"
|
||||
commands: Dict[str, click.Command] = {}
|
||||
for command_info in group_info.typer_instance.registered_commands:
|
||||
command = get_command_from_info(command_info=command_info)
|
||||
command = get_command_from_info(
|
||||
command_info=command_info, pretty_errors_short=pretty_errors_short
|
||||
)
|
||||
if command.name:
|
||||
commands[command.name] = command
|
||||
for sub_group_info in group_info.typer_instance.registered_groups:
|
||||
sub_group = get_group_from_info(sub_group_info)
|
||||
sub_group = get_group_from_info(
|
||||
sub_group_info, pretty_errors_short=pretty_errors_short
|
||||
)
|
||||
if sub_group.name:
|
||||
commands[sub_group.name] = sub_group
|
||||
solved_info = solve_typer_info_defaults(group_info)
|
||||
|
@ -456,6 +486,7 @@ def get_group_from_info(group_info: TyperInfo) -> click.Command:
|
|||
params=params,
|
||||
convertors=convertors,
|
||||
context_param_name=context_param_name,
|
||||
pretty_errors_short=pretty_errors_short,
|
||||
),
|
||||
params=params, # type: ignore
|
||||
help=solved_info.help,
|
||||
|
@ -492,7 +523,9 @@ def get_params_convertors_ctx_param_name_from_function(
|
|||
return params, convertors, context_param_name
|
||||
|
||||
|
||||
def get_command_from_info(command_info: CommandInfo) -> click.Command:
|
||||
def get_command_from_info(
|
||||
command_info: CommandInfo, *, pretty_errors_short: bool
|
||||
) -> click.Command:
|
||||
assert command_info.callback, "A command must have a callback function"
|
||||
name = command_info.name or get_command_name(command_info.callback.__name__)
|
||||
use_help = command_info.help
|
||||
|
@ -514,6 +547,7 @@ def get_command_from_info(command_info: CommandInfo) -> click.Command:
|
|||
params=params,
|
||||
convertors=convertors,
|
||||
context_param_name=context_param_name,
|
||||
pretty_errors_short=pretty_errors_short,
|
||||
),
|
||||
params=params, # type: ignore
|
||||
help=use_help,
|
||||
|
@ -585,6 +619,7 @@ def get_callback(
|
|||
params: Sequence[click.Parameter] = [],
|
||||
convertors: Dict[str, Callable[[str], Any]] = {},
|
||||
context_param_name: Optional[str] = None,
|
||||
pretty_errors_short: bool,
|
||||
) -> Optional[Callable[..., Any]]:
|
||||
if not callback:
|
||||
return None
|
||||
|
@ -597,6 +632,7 @@ def get_callback(
|
|||
use_params[param.name] = param.default
|
||||
|
||||
def wrapper(**kwargs: Any) -> Any:
|
||||
_rich_traceback_guard = pretty_errors_short # noqa: F841
|
||||
for k, v in kwargs.items():
|
||||
if k in convertors:
|
||||
use_params[k] = convertors[k](v)
|
||||
|
|
|
@ -454,3 +454,16 @@ class ParamMeta:
|
|||
self.name = name
|
||||
self.default = default
|
||||
self.annotation = annotation
|
||||
|
||||
|
||||
class DeveloperExceptionConfig:
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
pretty_errors_enable: bool = True,
|
||||
pretty_errors_show_locals: bool = True,
|
||||
pretty_errors_short: bool = True,
|
||||
) -> None:
|
||||
self.pretty_errors_enable = pretty_errors_enable
|
||||
self.pretty_errors_show_locals = pretty_errors_show_locals
|
||||
self.pretty_errors_short = pretty_errors_short
|
||||
|
|
Loading…
Reference in a new issue