diff --git a/typer/completion.py b/typer/completion.py new file mode 100644 index 0000000..4823137 --- /dev/null +++ b/typer/completion.py @@ -0,0 +1,46 @@ +from typing import Any + +import click +import click_completion +import click_completion.core + +from .params import Option + +click_completion.init() + + +def install_callback(ctx: click.Context, param: click.Parameter, value: Any) -> Any: + if not value or ctx.resilient_parsing: + return value + shell, path = click_completion.core.install() + click.echo(f"{shell} completion installed in {path}") + exit(0) + + +def show_callback(ctx: click.Context, param: click.Parameter, value: Any) -> Any: + if not value or ctx.resilient_parsing: + return value + click.echo(click_completion.core.get_code()) + exit(0) + + +# Create a fake command function to extract the completion parameters +def _install_completion_placeholder_function( + install_completion: bool = Option( + None, + "--install-completion", + is_flag=True, + callback=install_callback, + expose_value=False, + help="Install completion for the current shell.", + ), + show_completion: bool = Option( + None, + "--show-completion", + is_flag=True, + callback=show_callback, + expose_value=False, + help="Show completion for the current shell, to copy it or customize the installation.", + ), +) -> Any: + pass diff --git a/typer/main.py b/typer/main.py index 4d970f7..8ae0efa 100644 --- a/typer/main.py +++ b/typer/main.py @@ -25,6 +25,20 @@ from .models import ( TyperInfo, ) +try: + import click_completion + from .completion import _install_completion_placeholder_function +except ImportError: + click_completion = None + + +def get_install_completion_arguments() -> Tuple[click.Parameter, click.Parameter]: + signature = inspect.signature(_install_completion_placeholder_function) + install_param, show_param = signature.parameters.values() + click_install_param, _ = get_click_param(install_param) + click_show_param, _ = get_click_param(show_param) + return click_install_param, click_show_param + class Typer: def __init__( @@ -47,7 +61,9 @@ class Typer: add_help_option: bool = Default(True), hidden: bool = Default(False), deprecated: bool = Default(False), + add_completion: bool = True, ): + self._add_completion = add_completion self.info = TyperInfo( name=name, cls=cls, @@ -204,14 +220,27 @@ def get_group(typer_instance: Typer) -> click.Command: def get_command(typer_instance: Typer) -> click.Command: + if typer_instance._add_completion and click_completion: + click_completion.init() + click_install_param, click_show_param = get_install_completion_arguments() if ( typer_instance.registered_callback or typer_instance.registered_groups or len(typer_instance.registered_commands) > 1 ): - return get_group(typer_instance) + # Create a Group + click_command = get_group(typer_instance) + if typer_instance._add_completion and click_completion: + click_command.params.append(click_install_param) + click_command.params.append(click_show_param) + return click_command elif len(typer_instance.registered_commands) == 1: - return get_command_from_info(typer_instance.registered_commands[0]) + # Create a single Command + click_command = get_command_from_info(typer_instance.registered_commands[0]) + if typer_instance._add_completion and click_completion: + click_command.params.append(click_install_param) + click_command.params.append(click_show_param) + return click_command assert False, "Could not get a command for this Typer instance" @@ -419,7 +448,7 @@ def get_click_type( min_ = int(parameter_info.min) if parameter_info.max is not None: max_ = int(parameter_info.max) - return click.IntRange(min=min_, max=max_, clamp=parameter_info.clamp,) + return click.IntRange(min=min_, max=max_, clamp=parameter_info.clamp) else: return click.INT elif annotation == float: