diff --git a/src/sentry/api/endpoints/prompts_activity.py b/src/sentry/api/endpoints/prompts_activity.py index 29f21e5f4f2974..793fe71ccd9220 100644 --- a/src/sentry/api/endpoints/prompts_activity.py +++ b/src/sentry/api/endpoints/prompts_activity.py @@ -3,7 +3,6 @@ from django.db import IntegrityError, router, transaction from django.db.models import Q -from django.http import HttpResponse from django.utils import timezone from rest_framework import serializers from rest_framework.request import Request @@ -42,7 +41,7 @@ class PromptsActivityEndpoint(OrganizationEndpoint): "PUT": ApiPublishStatus.UNKNOWN, } - def get(self, request: Request, **kwargs) -> Response: + def get(self, request: Request, organization: Organization, **kwargs) -> Response: """Return feature prompt status if dismissed or in snoozed period""" if not request.user.is_authenticated: @@ -58,14 +57,26 @@ def get(self, request: Request, **kwargs) -> Response: return Response({"detail": "Invalid feature name " + feature}, status=400) required_fields = prompt_config.required_fields(feature) - for field in required_fields: - if field not in request.GET: - return Response({"detail": 'Missing required field "%s"' % field}, status=400) - filters = {k: request.GET.get(k) for k in required_fields} + filters: dict[str, Any] = {} + + # project_id must be provided and belong to the organization + if "project_id" in required_fields: + project_id = request.GET.get("project_id") + if not project_id: + return Response({"detail": 'Missing required field "project_id"'}, status=400) + if not Project.objects.filter( + id=project_id, organization_id=organization.id + ).exists(): + return Response({"detail": "Project not found"}, status=404) + filters["project_id"] = project_id + condition = Q(feature=feature, **filters) conditions = condition if conditions is None else (conditions | condition) - result_qs = PromptsActivity.objects.filter(conditions, user_id=request.user.id) + # Always scope by organization from URL - passed directly to filter() to prevent override + result_qs = PromptsActivity.objects.filter( + conditions, user_id=request.user.id, organization_id=organization.id + ) featuredata = {k.feature: k.data for k in result_qs} if len(features) == 1: result = result_qs.first() @@ -74,7 +85,7 @@ def get(self, request: Request, **kwargs) -> Response: else: return Response({"features": featuredata}) - def put(self, request: Request, **kwargs): + def put(self, request: Request, organization: Organization, **kwargs) -> Response: serializer = PromptsActivitySerializer(data=request.data) if not serializer.is_valid(): return Response(serializer.errors, status=400) @@ -89,26 +100,26 @@ def put(self, request: Request, **kwargs): if any(elem is None for elem in fields.values()): return Response({"detail": "Missing required field"}, status=400) - # if project_id or organization_id in required fields make sure they exist - # if NOT in required fields, insert dummy value so dups aren't recorded + # Validate organization_id is present and matches URL organization + if "organization_id" not in required_fields or str(fields["organization_id"]) != str( + organization.id + ): + return Response({"detail": "Organization missing or mismatched"}, status=400) + # Override with URL organization to prevent IDOR + fields["organization_id"] = organization.id + + # Validate project_id if required, otherwise use dummy value to prevent duplicates if "project_id" in required_fields: - if not Project.objects.filter( - id=fields["project_id"], organization_id=request.organization.id - ).exists(): + project_id = fields["project_id"] + if not project_id: + return Response({"detail": "Invalid project_id"}, status=400) + if not Project.objects.filter(id=project_id, organization_id=organization.id).exists(): return Response( {"detail": "Project does not belong to this organization"}, status=400 ) else: fields["project_id"] = 0 - if "organization_id" in required_fields and str(fields["organization_id"]) == str( - request.organization.id - ): - if not Organization.objects.filter(id=fields["organization_id"]).exists(): - return Response({"detail": "Organization no longer exists"}, status=400) - else: - return Response({"detail": "Organization missing or mismatched"}, status=400) - data: dict[str, Any] = {} now = calendar.timegm(timezone.now().utctimetuple()) if status == "snoozed": @@ -126,4 +137,4 @@ def put(self, request: Request, **kwargs): ) except IntegrityError: pass - return HttpResponse(status=201) + return Response(status=201) diff --git a/tests/sentry/api/endpoints/test_prompts_activity.py b/tests/sentry/api/endpoints/test_prompts_activity.py index 21c0fe0b49a67c..bf788ef9c26380 100644 --- a/tests/sentry/api/endpoints/test_prompts_activity.py +++ b/tests/sentry/api/endpoints/test_prompts_activity.py @@ -75,7 +75,6 @@ def test_batched_invalid_feature(self) -> None: def test_invalid_project(self) -> None: # Invalid project id data = { - "organization_id": self.org.id, "project_id": self.project.id, "feature": "releases", } @@ -98,7 +97,6 @@ def test_invalid_project(self) -> None: def test_dismiss(self) -> None: data = { - "organization_id": self.org.id, "project_id": self.project.id, "feature": "releases", } @@ -135,7 +133,6 @@ def test_dismiss_str_id(self) -> None: assert resp.status_code == 201, resp.content data = { - "organization_id": self.org.id, "project_id": self.project.id, "feature": "releases", } @@ -147,7 +144,6 @@ def test_dismiss_str_id(self) -> None: def test_snooze(self) -> None: data = { - "organization_id": self.org.id, "project_id": self.project.id, "feature": "releases", } @@ -173,7 +169,6 @@ def test_snooze(self) -> None: def test_visible(self) -> None: data = { - "organization_id": self.org.id, "project_id": self.project.id, "feature": "releases", } @@ -199,7 +194,6 @@ def test_visible(self) -> None: def test_visible_after_dismiss(self) -> None: data = { - "organization_id": self.org.id, "project_id": self.project.id, "feature": "releases", } @@ -235,7 +229,6 @@ def test_visible_after_dismiss(self) -> None: def test_batched(self) -> None: data = { - "organization_id": self.org.id, "project_id": self.project.id, "feature": ["releases", "alert_stream"], } @@ -290,3 +283,48 @@ def test_project_from_different_organization(self) -> None: assert resp.status_code == 400 assert resp.data["detail"] == "Project does not belong to this organization" + + def test_idor_get_project_from_different_org(self) -> None: + """Regression test: GET cannot access projects from other organizations (IDOR).""" + other_org = self.create_organization() + other_project = self.create_project(organization=other_org) + + resp = self.client.get( + self.path, + { + "project_id": str(other_project.id), + "feature": "releases", + }, + ) + + # Should return 404 to prevent ID enumeration + assert resp.status_code == 404 + assert resp.data["detail"] == "Project not found" + + def test_get_empty_project_id(self) -> None: + """Test that empty string project_id returns 400 instead of 500.""" + resp = self.client.get( + self.path, + { + "project_id": "", + "feature": "releases", + }, + ) + + assert resp.status_code == 400 + assert resp.data["detail"] == 'Missing required field "project_id"' + + def test_put_empty_project_id(self) -> None: + """Test that empty string project_id in PUT returns 400 instead of 500.""" + resp = self.client.put( + self.path, + { + "organization_id": self.org.id, + "project_id": "", + "feature": "releases", + "status": "dismissed", + }, + ) + + assert resp.status_code == 400 + assert resp.data["detail"] == "Invalid project_id"