Skip to content

Commit 8b54470

Browse files
tushar-composioDevanshusisodiyaangrybaybladeellipsis-dev[bot]
authored
fix: Validate actions available to an entity in check_connected_account (#1095)
Co-authored-by: Devanshusisodiya <[email protected]> Co-authored-by: Viraj <[email protected]> Co-authored-by: Devanshu Kumar Singh <[email protected]> Co-authored-by: ellipsis-dev[bot] <65095814+ellipsis-dev[bot]@users.noreply.github.com>
1 parent f9da75f commit 8b54470

File tree

3 files changed

+28
-5
lines changed

3 files changed

+28
-5
lines changed

python/composio/client/collections.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@
3030
Trigger,
3131
TriggerType,
3232
)
33-
from composio.constants import PUSHER_CLUSTER, PUSHER_KEY
33+
from composio.constants import DEFAULT_ENTITY_ID, PUSHER_CLUSTER, PUSHER_KEY
3434
from composio.exceptions import (
3535
ErrorFetchingResource,
3636
InvalidParams,
@@ -92,7 +92,7 @@ class ConnectedAccountModel(BaseModel):
9292
connectionParams: AuthConnectionParamsModel
9393

9494
clientUniqueUserId: t.Optional[str] = None
95-
entityId: t.Optional[str] = None
95+
entityId: str = DEFAULT_ENTITY_ID
9696

9797
# Override arbitrary model config.
9898
model_config: ConfigDict = ConfigDict( # type: ignore

python/composio/tools/toolset.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1668,7 +1668,9 @@ def _validate_connection_ids(
16681668
return valid
16691669
raise InvalidConnectedAccount(f"Invalid connected accounts found: {invalid}")
16701670

1671-
def check_connected_account(self, action: ActionType) -> None:
1671+
def check_connected_account(
1672+
self, action: ActionType, entity_id: t.Optional[str] = None
1673+
) -> None:
16721674
"""Check if connected account is required and if required it exists or not."""
16731675
action = Action(action)
16741676
if action.no_auth or action.is_runtime:
@@ -1686,6 +1688,7 @@ def check_connected_account(self, action: ActionType) -> None:
16861688
if action.app not in [
16871689
connection.appUniqueId.upper() # Normalize app names/ids coming from API
16881690
for connection in self._connected_accounts
1691+
if entity_id is None or connection.clientUniqueUserId == entity_id
16891692
]:
16901693
raise ConnectedAccountNotFoundError(
16911694
f"No connected account found for app `{action.app}`; "
@@ -1799,7 +1802,7 @@ def _execute_remote(
17991802
action=action
18001803
)
18011804
if auth is None:
1802-
self.check_connected_account(action=action)
1805+
self.check_connected_account(action=action, entity_id=entity_id)
18031806

18041807
output = self.client.get_entity( # pylint: disable=protected-access
18051808
id=entity_id

python/tests/test_tools/test_toolset.py

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,11 @@
1212
from pydantic import BaseModel, Field
1313

1414
from composio import Action, App, Trigger
15-
from composio.exceptions import ApiKeyNotProvidedError, ComposioSDKError
15+
from composio.exceptions import (
16+
ApiKeyNotProvidedError,
17+
ComposioSDKError,
18+
ConnectedAccountNotFoundError,
19+
)
1620
from composio.tools.base.abs import action_registry, tool_registry
1721
from composio.tools.base.runtime import action as custom_action
1822
from composio.tools.local.filetool.tool import Filetool, FindFile
@@ -289,6 +293,22 @@ def postprocess(response: dict) -> dict:
289293
assert postprocessor_called
290294

291295

296+
def test_entity_id_validation_in_check_connected_accounts() -> None:
297+
"""Test whether check_connected_account raises error with invalid entity_id"""
298+
toolset = ComposioToolSet()
299+
with pytest.raises(
300+
ConnectedAccountNotFoundError,
301+
match=(
302+
"No connected account found for app `GMAIL`; "
303+
"Run `composio add gmail` to fix this"
304+
),
305+
):
306+
toolset.check_connected_account(
307+
action=Action.GMAIL_FETCH_EMAILS,
308+
entity_id="some_very_random_obviously_wrong_entity_id",
309+
)
310+
311+
292312
def test_check_connected_accounts_flag() -> None:
293313
"""Test the `check_connected_accounts` flag on `get_tools()`."""
294314

0 commit comments

Comments
 (0)