Skip to content

Commit cca1079

Browse files
authored
Get invitations (#1962)
2 parents 688aea7 + 6380eda commit cca1079

File tree

3 files changed

+346
-7
lines changed

3 files changed

+346
-7
lines changed

libs/labelbox/src/labelbox/schema/invite.py

Lines changed: 141 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,15 @@
1+
from typing import TYPE_CHECKING
12
from dataclasses import dataclass
23

34
from labelbox.orm.db_object import DbObject
45
from labelbox.orm.model import Field
56
from labelbox.schema.role import ProjectRole, format_role
67

8+
from labelbox.pagination import PaginatedCollection
9+
10+
if TYPE_CHECKING:
11+
from labelbox import Client
12+
713

814
@dataclass
915
class InviteLimit:
@@ -31,10 +37,138 @@ def __init__(self, client, invite_response):
3137
project_roles = invite_response.pop("projectInvites", [])
3238
super().__init__(client, invite_response)
3339

34-
self.project_roles = [
35-
ProjectRole(
36-
project=client.get_project(r["projectId"]),
37-
role=client.get_roles()[format_role(r["projectRoleName"])],
38-
)
39-
for r in project_roles
40-
]
40+
self.project_roles = []
41+
42+
# If a project is deleted then it doesn't show up in the invite
43+
for pr in project_roles:
44+
try:
45+
project = client.get_project(pr["projectId"])
46+
if project: # Check if project exists
47+
self.project_roles.append(
48+
ProjectRole(
49+
project=project,
50+
role=client.get_roles()[
51+
format_role(pr["projectRoleName"])
52+
],
53+
)
54+
)
55+
except Exception:
56+
# Skip this project role if the project is no longer available
57+
continue
58+
59+
def cancel(self) -> bool:
60+
"""
61+
Cancels this invite.
62+
63+
This will prevent the invited user from accepting the invitation.
64+
65+
Returns:
66+
bool: True if the invite was successfully canceled, False otherwise.
67+
"""
68+
69+
# Case of a newly invited user
70+
if self.uid == "invited":
71+
return False
72+
73+
query_str = """
74+
mutation CancelInvitePyApi($where: WhereUniqueIdInput!) {
75+
cancelInvite(where: $where) {
76+
id
77+
}
78+
}"""
79+
result = self.client.execute(
80+
query_str, {"where": {"id": self.uid}}, experimental=True
81+
)
82+
return (
83+
result is not None
84+
and "cancelInvite" in result
85+
and result.get("cancelInvite") is not None
86+
)
87+
88+
@staticmethod
89+
def get_project_invites(
90+
client: "Client", project_id: str
91+
) -> PaginatedCollection:
92+
"""
93+
Retrieves all invites for a specific project.
94+
95+
Args:
96+
client (Client): The Labelbox client instance.
97+
project_id (str): The ID of the project to get invites for.
98+
99+
Returns:
100+
PaginatedCollection: A collection of Invite objects for the specified project.
101+
"""
102+
query = """query GetProjectInvitationsPyApi(
103+
$from: ID
104+
$first: PageSize
105+
$projectId: ID!
106+
) {
107+
project(where: { id: $projectId }) {
108+
id
109+
invites(from: $from, first: $first) {
110+
nodes {
111+
id
112+
createdAt
113+
organizationRoleName
114+
inviteeEmail
115+
projectInvites {
116+
id
117+
projectRoleName
118+
projectId
119+
}
120+
}
121+
nextCursor
122+
}
123+
}
124+
}"""
125+
126+
invites = PaginatedCollection(
127+
client,
128+
query,
129+
{"projectId": project_id, "search": ""},
130+
["project", "invites", "nodes"],
131+
Invite,
132+
cursor_path=["project", "invites", "nextCursor"],
133+
)
134+
return invites
135+
136+
@staticmethod
137+
def get_invites(client: "Client") -> PaginatedCollection:
138+
"""
139+
Retrieves all invites for the organization.
140+
141+
Args:
142+
client (Client): The Labelbox client instance.
143+
144+
Returns:
145+
PaginatedCollection: A collection of Invite objects for the organization.
146+
"""
147+
query_str = """query GetOrgInvitationsPyApi($from: ID, $first: PageSize) {
148+
organization {
149+
id
150+
invites(from: $from, first: $first) {
151+
nodes {
152+
id
153+
createdAt
154+
organizationRoleName
155+
inviteeEmail
156+
projectInvites {
157+
id
158+
projectRoleName
159+
projectId
160+
}
161+
}
162+
nextCursor
163+
}
164+
}
165+
}"""
166+
invites = PaginatedCollection(
167+
client,
168+
query_str,
169+
{},
170+
["organization", "invites", "nodes"],
171+
Invite,
172+
cursor_path=["organization", "invites", "nextCursor"],
173+
)
174+
return invites

libs/labelbox/src/labelbox/schema/organization.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from labelbox.orm.model import Field, Relationship
88
from labelbox.schema.invite import InviteLimit
99
from labelbox.schema.resource_tag import ResourceTag
10+
from labelbox.pagination import PaginatedCollection
1011

1112
if TYPE_CHECKING:
1213
from labelbox import (
@@ -243,3 +244,24 @@ def get_default_iam_integration(self) -> Optional["IAMIntegration"]:
243244
return (
244245
None if not len(default_integration) else default_integration.pop()
245246
)
247+
248+
def get_invites(self) -> PaginatedCollection:
249+
"""
250+
Retrieves all invites for this organization.
251+
252+
Returns:
253+
PaginatedCollection: A collection of Invite objects for the organization.
254+
"""
255+
return Entity.Invite.get_invites(self.client)
256+
257+
def get_project_invites(self, project_id: str) -> PaginatedCollection:
258+
"""
259+
Retrieves all invites for a specific project in this organization.
260+
261+
Args:
262+
project_id (str): The ID of the project to get invites for.
263+
264+
Returns:
265+
PaginatedCollection: A collection of Invite objects for the specified project.
266+
"""
267+
return Entity.Invite.get_project_invites(self.client, project_id)
Lines changed: 183 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,183 @@
1+
import pytest
2+
from faker import Faker
3+
from labelbox.schema.media_type import MediaType
4+
from labelbox import ProjectRole
5+
6+
faker = Faker()
7+
8+
9+
@pytest.fixture
10+
def dummy_email():
11+
"""Generate a random dummy email for testing"""
12+
return f"none+{faker.uuid4()}@labelbox.com"
13+
14+
15+
@pytest.fixture(scope="module")
16+
def test_project(client):
17+
"""Create a temporary project for testing"""
18+
project = client.create_project(
19+
name=f"test-project-{faker.uuid4()}", media_type=MediaType.Image
20+
)
21+
yield project
22+
23+
# Occurs after the test is finished based on scope
24+
project.delete()
25+
26+
27+
@pytest.fixture
28+
def org_invite(client, dummy_email):
29+
"""Create an organization-level invite"""
30+
role = client.get_roles()["LABELER"]
31+
organization = client.get_organization()
32+
invite = organization.invite_user(dummy_email, role)
33+
34+
yield invite
35+
36+
if invite.uid:
37+
invite.cancel()
38+
39+
40+
@pytest.fixture
41+
def project_invite(client, test_project, dummy_email):
42+
"""Create a project-level invite"""
43+
roles = client.get_roles()
44+
project_role = ProjectRole(project=test_project, role=roles["LABELER"])
45+
organization = client.get_organization()
46+
47+
invite = organization.invite_user(
48+
dummy_email, roles["NONE"], project_roles=[project_role]
49+
)
50+
51+
yield invite
52+
53+
# Cleanup: Use invite.cancel() instead of organization.cancel_invite()
54+
if invite.uid:
55+
invite.cancel()
56+
57+
58+
def test_get_organization_invites(client, org_invite):
59+
"""Test retrieving all organization invites"""
60+
61+
organization = client.get_organization()
62+
invites = organization.get_invites()
63+
invite_list = [invite for invite in invites]
64+
assert len(invite_list) > 0
65+
66+
# Verify our test invite is in the list
67+
invite_emails = [invite.email for invite in invite_list]
68+
assert org_invite.email in invite_emails
69+
70+
71+
def test_get_project_invites(client, test_project, project_invite):
72+
"""Test retrieving project-specific invites"""
73+
74+
organization = client.get_organization()
75+
project_invites = organization.get_project_invites(test_project.uid)
76+
invite_list = [invite for invite in project_invites]
77+
assert len(invite_list) > 0
78+
79+
# Verify our test invite is in the list
80+
invite_emails = [invite.email for invite in invite_list]
81+
assert project_invite.email in invite_emails
82+
83+
# Verify project role assignment
84+
found_invite = next(
85+
invite for invite in invite_list if invite.email == project_invite.email
86+
)
87+
assert len(found_invite.project_roles) == 1
88+
assert found_invite.project_roles[0].project.uid == test_project.uid
89+
90+
91+
def test_cancel_invite(client, dummy_email):
92+
"""Test canceling an invite"""
93+
# Create a new invite
94+
role = client.get_roles()["LABELER"]
95+
organization = client.get_organization()
96+
organization.invite_user(dummy_email, role)
97+
98+
# Find the actual invite by email
99+
invites = organization.get_invites()
100+
found_invite = next(
101+
(invite for invite in invites if invite.email == dummy_email), None
102+
)
103+
assert found_invite is not None, f"Invite for {dummy_email} not found"
104+
105+
# Cancel the invite using the found invite object
106+
result = found_invite.cancel()
107+
assert result is True
108+
109+
# Verify the invite is no longer in the organization's invites
110+
invites = organization.get_invites()
111+
invite_emails = [i.email for i in invites]
112+
assert dummy_email not in invite_emails
113+
114+
115+
def test_cancel_project_invite(client, test_project, dummy_email):
116+
"""Test canceling a project invite"""
117+
# Create a project invite
118+
roles = client.get_roles()
119+
project_role = ProjectRole(project=test_project, role=roles["LABELER"])
120+
organization = client.get_organization()
121+
122+
organization.invite_user(
123+
dummy_email, roles["NONE"], project_roles=[project_role]
124+
)
125+
126+
# Find the actual invite by email
127+
invites = organization.get_invites()
128+
found_invite = next(
129+
(invite for invite in invites if invite.email == dummy_email), None
130+
)
131+
assert found_invite is not None, f"Invite for {dummy_email} not found"
132+
133+
# Cancel the invite using the found invite object
134+
result = found_invite.cancel()
135+
assert result is True
136+
137+
# Verify the invite is no longer in the project's invites
138+
project_invites = organization.get_project_invites(test_project.uid)
139+
invite_emails = [i.email for i in project_invites]
140+
assert dummy_email not in invite_emails
141+
142+
143+
def test_project_invite_after_project_deletion(client, dummy_email):
144+
"""Test that project invites are properly filtered when a project is deleted"""
145+
# Create two test projects
146+
project1 = client.create_project(
147+
name=f"test-project1-{faker.uuid4()}", media_type=MediaType.Image
148+
)
149+
project2 = client.create_project(
150+
name=f"test-project2-{faker.uuid4()}", media_type=MediaType.Image
151+
)
152+
153+
# Create project roles
154+
roles = client.get_roles()
155+
project_role1 = ProjectRole(project=project1, role=roles["LABELER"])
156+
project_role2 = ProjectRole(project=project2, role=roles["LABELER"])
157+
158+
# Invite user to both projects
159+
organization = client.get_organization()
160+
organization.invite_user(
161+
dummy_email, roles["NONE"], project_roles=[project_role1, project_role2]
162+
)
163+
164+
# Delete one project
165+
project1.delete()
166+
167+
# Find the invite and verify project roles
168+
invites = organization.get_invites()
169+
found_invite = next(
170+
(invite for invite in invites if invite.email == dummy_email), None
171+
)
172+
assert found_invite is not None, f"Invite for {dummy_email} not found"
173+
174+
# Verify only one project role remains
175+
assert (
176+
len(found_invite.project_roles) == 1
177+
), "Expected only one project role"
178+
assert found_invite.project_roles[0].project.uid == project2.uid
179+
180+
# Cleanup
181+
project2.delete()
182+
if found_invite.uid:
183+
found_invite.cancel()

0 commit comments

Comments
 (0)