Skip to content
This repository was archived by the owner on Mar 13, 2023. It is now read-only.

feat: Add interaction_tree to Client and Extension #640

Merged
merged 1 commit into from
Sep 14, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 22 additions & 1 deletion naff/client/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@
from naff.models.discord.file import UPLOADABLE_TYPE
from naff.models.discord.modal import Modal
from naff.models.naff.active_voice_state import ActiveVoiceState
from naff.models.naff.application_commands import ModalCommand
from naff.models.naff.application_commands import ContextMenu, ModalCommand
from naff.models.naff.auto_defer import AutoDefer
from naff.models.naff.hybrid_commands import _prefixed_from_slash, _base_subcommand_generator
from naff.models.naff.listener import Listener
Expand Down Expand Up @@ -370,6 +370,10 @@ def __init__(
"""A dictionary of registered prefixed commands: `{name: command}`"""
self.interactions: Dict["Snowflake_Type", Dict[str, InteractionCommand]] = {}
"""A dictionary of registered application commands: `{cmd_id: command}`"""
self.interaction_tree: Dict[
"Snowflake_Type", Dict[str, InteractionCommand | Dict[str, InteractionCommand]]
] = {}
"""A dictionary of registered application commands in a tree"""
self._component_callbacks: Dict[str, Callable[..., Coroutine]] = {}
self._modal_callbacks: Dict[str, Callable[..., Coroutine]] = {}
self._interaction_scopes: Dict["Snowflake_Type", "Snowflake_Type"] = {}
Expand Down Expand Up @@ -1103,6 +1107,8 @@ def add_interaction(self, command: InteractionCommand) -> bool:
if command.callback is None:
return False

base, group, sub, *_ = command.resolved_name.split(" ") + [None, None]

for scope in command.scopes:
if scope not in self.interactions:
self.interactions[scope] = {}
Expand All @@ -1115,6 +1121,21 @@ def add_interaction(self, command: InteractionCommand) -> bool:

self.interactions[scope][command.resolved_name] = command

if scope not in self.interaction_tree:
self.interaction_tree[scope] = {}

if group is None or isinstance(command, ContextMenu):
self.interaction_tree[scope][command.resolved_name] = command
elif group is not None:
if base not in self.interaction_tree[scope]:
self.interaction_tree[scope][base] = {}
if sub is None:
self.interaction_tree[scope][base][group] = command
else:
if group not in self.interaction_tree[scope][base]:
self.interaction_tree[scope][base][group] = {}
self.interaction_tree[scope][base][group][sub] = command

return True

def add_hybrid_command(self, command: HybridCommand) -> bool:
Expand Down
27 changes: 24 additions & 3 deletions naff/models/naff/extension.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,17 @@
import asyncio
import inspect
from typing import Awaitable, List, TYPE_CHECKING, Callable, Coroutine, Optional
from typing import Awaitable, Dict, List, TYPE_CHECKING, Callable, Coroutine, Optional

import naff.models.naff as naff
from naff.client.const import logger, MISSING
from naff.client.utils.misc_utils import wrap_partial
from naff.models.naff.tasks import Task
from naff.models.naff import ContextMenu

if TYPE_CHECKING:
from naff.client import Client
from naff.models.naff import AutoDefer, BaseCommand, Listener
from naff.models.discord import Snowflake_Type
from naff.models.naff import AutoDefer, BaseCommand, InteractionCommand, Listener
from naff.models.naff import Context


Expand Down Expand Up @@ -38,6 +40,7 @@ async def some_command(self, context):
extension_checks str: A list of checks to be ran on any command in this extension
extension_prerun List: A list of coroutines to be run before any command in this extension
extension_postrun List: A list of coroutines to be run after any command in this extension
interaction_tree Dict: A dictionary of registered application commands in a tree

"""

Expand All @@ -49,6 +52,7 @@ async def some_command(self, context):
extension_prerun: List
extension_postrun: List
extension_error: Optional[Callable[..., Coroutine]]
interaction_tree: Dict["Snowflake_Type", Dict[str, "InteractionCommand" | Dict[str, "InteractionCommand"]]]
_commands: List
_listeners: List
auto_defer: "AutoDefer"
Expand All @@ -61,6 +65,7 @@ def __new__(cls, bot: "Client", *args, **kwargs) -> "Extension":
new_cls.extension_prerun = []
new_cls.extension_postrun = []
new_cls.extension_error = None
new_cls.interaction_tree = {}
new_cls.auto_defer = MISSING

new_cls.description = kwargs.get("Description", None)
Expand Down Expand Up @@ -89,7 +94,23 @@ def __new__(cls, bot: "Client", *args, **kwargs) -> "Extension":
elif isinstance(val, naff.HybridCommand):
bot.add_hybrid_command(val)
elif isinstance(val, naff.InteractionCommand):
bot.add_interaction(val)
if not bot.add_interaction(val):
continue
base, group, sub, *_ = val.resolved_name.split(" ") + [None, None]
for scope in val.scopes:
if scope not in new_cls.interaction_tree:
new_cls.interaction_tree[scope] = {}
if group is None or isinstance(val, ContextMenu):
new_cls.interaction_tree[scope][val.resolved_name] = val
elif group is not None:
if base not in new_cls.interaction_tree[scope]:
new_cls.interaction_tree[scope][base] = {}
if sub is None:
new_cls.interaction_tree[scope][base][group] = val
else:
if group not in new_cls.interaction_tree[scope][base]:
new_cls.interaction_tree[scope][base][group] = {}
new_cls.interaction_tree[scope][base][group][sub] = val
else:
bot.add_prefixed_command(val)

Expand Down