From 95a5233d9eaa4a5545f0e148c68335c830bd7a4e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sebasti=C3=A1n=20Ram=C3=ADrez?= Date: Fri, 8 Jul 2022 17:20:32 +0200 Subject: [PATCH] =?UTF-8?q?=E2=9C=A8=20Allow=20configuring=20pretty=20erro?= =?UTF-8?q?rs=20when=20creating=20the=20Typer=20instance=20(#416)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../type_error_no_rich_short_disable.py | 16 +++++ .../assets/type_error_rich_pretty_disable.py | 12 ++++ tests/assets/type_error_rich_short_disable.py | 12 ++++ tests/test_tracebacks.py | 61 +++++++++++++++- typer/main.py | 70 ++++++++++++++----- typer/models.py | 13 ++++ 6 files changed, 166 insertions(+), 18 deletions(-) create mode 100644 tests/assets/type_error_no_rich_short_disable.py create mode 100644 tests/assets/type_error_rich_pretty_disable.py create mode 100644 tests/assets/type_error_rich_short_disable.py diff --git a/tests/assets/type_error_no_rich_short_disable.py b/tests/assets/type_error_no_rich_short_disable.py new file mode 100644 index 0000000..4985fe1 --- /dev/null +++ b/tests/assets/type_error_no_rich_short_disable.py @@ -0,0 +1,16 @@ +import typer +import typer.main + +typer.main.rich = None + + +app = typer.Typer(pretty_errors_short=False) + + +@app.command() +def main(name: str = "morty"): + print(name + 3) + + +if __name__ == "__main__": + app() diff --git a/tests/assets/type_error_rich_pretty_disable.py b/tests/assets/type_error_rich_pretty_disable.py new file mode 100644 index 0000000..dc0f0d6 --- /dev/null +++ b/tests/assets/type_error_rich_pretty_disable.py @@ -0,0 +1,12 @@ +import typer + +app = typer.Typer(pretty_errors_enable=False) + + +@app.command() +def main(name: str = "morty"): + print(name + 3) + + +if __name__ == "__main__": + app() diff --git a/tests/assets/type_error_rich_short_disable.py b/tests/assets/type_error_rich_short_disable.py new file mode 100644 index 0000000..32c233d --- /dev/null +++ b/tests/assets/type_error_rich_short_disable.py @@ -0,0 +1,12 @@ +import typer + +app = typer.Typer(pretty_errors_short=False) + + +@app.command() +def main(name: str = "morty"): + print(name + 3) + + +if __name__ == "__main__": + app() diff --git a/tests/test_tracebacks.py b/tests/test_tracebacks.py index 2f692ae..137639c 100644 --- a/tests/test_tracebacks.py +++ b/tests/test_tracebacks.py @@ -12,7 +12,28 @@ def test_traceback_rich(): ) assert "return get_command(self)(*args, **kwargs)" not in result.stderr - assert "typer.run(main)" in result.stderr + assert "typer.run(main)" not in result.stderr + assert "print(name + 3)" in result.stderr + + # TODO: when deprecating Python 3.6, remove second option + assert ( + 'TypeError: can only concatenate str (not "int") to str' in result.stderr + or "TypeError: must be str, not int" in result.stderr + ) + assert "name = 'morty'" in result.stderr + + +def test_traceback_rich_pretty_short_disable(): + file_path = Path(__file__).parent / "assets/type_error_rich_short_disable.py" + result = subprocess.run( + ["coverage", "run", str(file_path)], + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + encoding="utf-8", + ) + assert "return get_command(self)(*args, **kwargs)" not in result.stderr + + assert "app()" in result.stderr assert "print(name + 3)" in result.stderr # TODO: when deprecating Python 3.6, remove second option @@ -42,6 +63,25 @@ def test_traceback_no_rich(): ) +def test_traceback_no_rich_short_disable(): + file_path = Path(__file__).parent / "assets/type_error_no_rich_short_disable.py" + result = subprocess.run( + ["coverage", "run", str(file_path)], + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + encoding="utf-8", + ) + assert "return get_command(self)(*args, **kwargs)" not in result.stderr + + assert "app()" in result.stderr + assert "print(name + 3)" in result.stderr + # TODO: when deprecating Python 3.6, remove second option + assert ( + 'TypeError: can only concatenate str (not "int") to str' in result.stderr + or "TypeError: must be str, not int" in result.stderr + ) + + def test_unmodified_traceback(): file_path = Path(__file__).parent / "assets/type_error_normal_traceback.py" result = subprocess.run( @@ -62,3 +102,22 @@ def test_unmodified_traceback(): 'TypeError: can only concatenate str (not "int") to str' in result.stderr or "TypeError: must be str, not int" in result.stderr ) + + +def test_rich_pretty_errors_disable(): + file_path = Path(__file__).parent / "assets/type_error_rich_pretty_disable.py" + result = subprocess.run( + ["coverage", "run", str(file_path)], + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + encoding="utf-8", + ) + assert "return get_command(self)(*args, **kwargs)" in result.stderr + + assert "app()" in result.stderr + assert "print(name + 3)" in result.stderr + # TODO: when deprecating Python 3.6, remove second option + assert ( + 'TypeError: can only concatenate str (not "int") to str' in result.stderr + or "TypeError: must be str, not int" in result.stderr + ) diff --git a/typer/main.py b/typer/main.py index 1514f7d..d6cc47d 100644 --- a/typer/main.py +++ b/typer/main.py @@ -22,6 +22,7 @@ from .models import ( CommandInfo, Default, DefaultPlaceholder, + DeveloperExceptionConfig, FileBinaryRead, FileBinaryWrite, FileText, @@ -52,7 +53,10 @@ _typer_developer_exception_attr_name = "__typer_developer_exception__" def except_hook( exc_type: Type[BaseException], exc_value: BaseException, tb: TracebackType ) -> None: - if not getattr(exc_value, _typer_developer_exception_attr_name, None): + exception_config: Union[DeveloperExceptionConfig, None] = getattr( + exc_value, _typer_developer_exception_attr_name, None + ) + if not exception_config or not exception_config.pretty_errors_enable: _original_except_hook(exc_type, exc_value, tb) return typer_path = os.path.dirname(__file__) @@ -64,7 +68,7 @@ def except_hook( type(exc), exc, exc.__traceback__, - show_locals=True, + show_locals=exception_config.pretty_errors_show_locals, suppress=supress_internal_dir_names, ) console_stderr.print(rich_tb) @@ -75,15 +79,16 @@ def except_hook( if any( [frame.filename.startswith(path) for path in supress_internal_dir_names] ): - # Hide the line for internal libraries, Typer and Click - stack.append( - traceback.FrameSummary( - filename=frame.filename, - lineno=frame.lineno, - name=frame.name, - line="", + if not exception_config.pretty_errors_short: + # Hide the line for internal libraries, Typer and Click + stack.append( + traceback.FrameSummary( + filename=frame.filename, + lineno=frame.lineno, + name=frame.name, + line="", + ) ) - ) else: stack.append(frame) # Type ignore ref: https://github.com/python/typeshed/pull/8244 @@ -123,8 +128,14 @@ class Typer: hidden: bool = Default(False), deprecated: bool = Default(False), add_completion: bool = True, + pretty_errors_enable: bool = True, + pretty_errors_show_locals: bool = True, + pretty_errors_short: bool = True, ): self._add_completion = add_completion + self.pretty_errors_enable = pretty_errors_enable + self.pretty_errors_show_locals = pretty_errors_show_locals + self.pretty_errors_short = pretty_errors_short self.info = TyperInfo( name=name, cls=cls, @@ -285,12 +296,23 @@ class Typer: # but that means the last error shown is the custom exception, not the # actual error. This trick improves developer experience by showing the # actual error last. - setattr(e, _typer_developer_exception_attr_name, True) + setattr( + e, + _typer_developer_exception_attr_name, + DeveloperExceptionConfig( + pretty_errors_enable=self.pretty_errors_enable, + pretty_errors_show_locals=self.pretty_errors_show_locals, + pretty_errors_short=self.pretty_errors_short, + ), + ) raise e def get_group(typer_instance: Typer) -> click.Command: - group = get_group_from_info(TyperInfo(typer_instance)) + group = get_group_from_info( + TyperInfo(typer_instance), + pretty_errors_short=typer_instance.pretty_errors_short, + ) return group @@ -318,7 +340,9 @@ def get_command(typer_instance: Typer) -> click.Command: ): single_command.context_settings = typer_instance.info.context_settings - click_command = get_command_from_info(single_command) + click_command = get_command_from_info( + single_command, pretty_errors_short=typer_instance.pretty_errors_short + ) if typer_instance._add_completion: click_command.params.append(click_install_param) click_command.params.append(click_show_param) @@ -422,17 +446,23 @@ def solve_typer_info_defaults(typer_info: TyperInfo) -> TyperInfo: return TyperInfo(**values) -def get_group_from_info(group_info: TyperInfo) -> click.Command: +def get_group_from_info( + group_info: TyperInfo, *, pretty_errors_short: bool +) -> click.Command: assert ( group_info.typer_instance ), "A Typer instance is needed to generate a Click Group" commands: Dict[str, click.Command] = {} for command_info in group_info.typer_instance.registered_commands: - command = get_command_from_info(command_info=command_info) + command = get_command_from_info( + command_info=command_info, pretty_errors_short=pretty_errors_short + ) if command.name: commands[command.name] = command for sub_group_info in group_info.typer_instance.registered_groups: - sub_group = get_group_from_info(sub_group_info) + sub_group = get_group_from_info( + sub_group_info, pretty_errors_short=pretty_errors_short + ) if sub_group.name: commands[sub_group.name] = sub_group solved_info = solve_typer_info_defaults(group_info) @@ -456,6 +486,7 @@ def get_group_from_info(group_info: TyperInfo) -> click.Command: params=params, convertors=convertors, context_param_name=context_param_name, + pretty_errors_short=pretty_errors_short, ), params=params, # type: ignore help=solved_info.help, @@ -492,7 +523,9 @@ def get_params_convertors_ctx_param_name_from_function( return params, convertors, context_param_name -def get_command_from_info(command_info: CommandInfo) -> click.Command: +def get_command_from_info( + command_info: CommandInfo, *, pretty_errors_short: bool +) -> click.Command: assert command_info.callback, "A command must have a callback function" name = command_info.name or get_command_name(command_info.callback.__name__) use_help = command_info.help @@ -514,6 +547,7 @@ def get_command_from_info(command_info: CommandInfo) -> click.Command: params=params, convertors=convertors, context_param_name=context_param_name, + pretty_errors_short=pretty_errors_short, ), params=params, # type: ignore help=use_help, @@ -585,6 +619,7 @@ def get_callback( params: Sequence[click.Parameter] = [], convertors: Dict[str, Callable[[str], Any]] = {}, context_param_name: Optional[str] = None, + pretty_errors_short: bool, ) -> Optional[Callable[..., Any]]: if not callback: return None @@ -597,6 +632,7 @@ def get_callback( use_params[param.name] = param.default def wrapper(**kwargs: Any) -> Any: + _rich_traceback_guard = pretty_errors_short # noqa: F841 for k, v in kwargs.items(): if k in convertors: use_params[k] = convertors[k](v) diff --git a/typer/models.py b/typer/models.py index e1de3a4..96b701f 100644 --- a/typer/models.py +++ b/typer/models.py @@ -454,3 +454,16 @@ class ParamMeta: self.name = name self.default = default self.annotation = annotation + + +class DeveloperExceptionConfig: + def __init__( + self, + *, + pretty_errors_enable: bool = True, + pretty_errors_show_locals: bool = True, + pretty_errors_short: bool = True, + ) -> None: + self.pretty_errors_enable = pretty_errors_enable + self.pretty_errors_show_locals = pretty_errors_show_locals + self.pretty_errors_short = pretty_errors_short