Skip to content
32 changes: 31 additions & 1 deletion courses/serializers/v2/courses.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,8 +113,35 @@ def get_next_run_id(self, instance) -> int | None:

@extend_schema_field(BaseProgramSerializer(many=True, allow_null=True))
def get_programs(self, instance):
"""
Include appropriate programs.

If the org or contract ID is set, include only programs that match. If
neither is specified, filter programs that have "b2b_only" set.
Comment on lines +119 to +120
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
If the org or contract ID is set, include only programs that match. If
neither is specified, filter programs that have "b2b_only" set.
If the org or contract ID is set, include only programs that match. If
neither is specified, filter out programs that have "b2b_only" set.

"""
if self.context.get("include_programs", False):
return BaseProgramSerializer(instance.programs, many=True).data
programs_qs = instance.in_programs

if self.context.get("org_id"):
programs_qs = programs_qs.filter(
program__contract_memberships__contract__organization__pk=self.context.get(
"org_id"
)
)
elif self.context.get("contract_id"):
programs_qs = programs_qs.filter(
program__contract_memberships__contract__pk=self.context.get(
"contract_id"
)
)
else:
programs_qs = programs_qs.filter(program__b2b_only=False)

programs = [
req.program for req in programs_qs.prefetch_related("program").all()
]

return BaseProgramSerializer(programs, many=True).data

return None

Expand Down Expand Up @@ -250,6 +277,9 @@ def get_courseruns(self, instance):
if "contract_id" in self.context:
courseruns = courseruns.filter(b2b_contract_id=self.context["contract_id"])

if "org_id" not in self.context and "contract_id" not in self.context:
courseruns = courseruns.filter(b2b_contract_id=None)

return CourseRunSerializer(courseruns, many=True, read_only=True).data

class Meta:
Expand Down
2 changes: 1 addition & 1 deletion courses/views/v2/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -314,7 +314,7 @@ def get_queryset(self):

return (
Course.objects.select_related("page")
.prefetch_related("departments")
.prefetch_related("departments", "in_programs")
.annotate(count_b2b_courseruns=Count("courseruns__b2b_contract__id"))
.annotate(count_courseruns=Count("courseruns"))
.order_by("title")
Expand Down
137 changes: 137 additions & 0 deletions courses/views/v2/views_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@

from b2b.api import create_contract_run
from b2b.factories import ContractPageFactory, OrganizationPageFactory
from b2b.models import ContractProgramItem
from cms.factories import CoursePageFactory, ProgramPageFactory
from cms.serializers import ProgramPageSerializer
from courses.constants import ENROLL_CHANGE_STATUS_UNENROLLED
Expand Down Expand Up @@ -1622,3 +1623,139 @@ def test_add_verified_program_course_enrollment(
assert resp.json()["enrollment_mode"] == EDX_ENROLLMENT_AUDIT_MODE
else:
assert resp.status_code == status.HTTP_404_NOT_FOUND


@pytest.mark.skip_nplusone_check
@pytest.mark.parametrize(
"with_b2b",
[
True,
False,
],
)
@pytest.mark.parametrize(
"single",
[
True,
False,
],
)
def test_get_courses_b2b_runs(with_b2b, single, user_drf_client):
"""
Test that the courses API returns courses with or without b2b runs.

By default courses should only have runs that aren't B2B runs. There are
other tests that test the result if you've specified an org/etc. so this
doesn't test that.
"""

contract = ContractPageFactory.create() if with_b2b else None

test_course_run = CourseRunFactory.create(b2b_contract=contract)

url = reverse("v2:courses_api-list")
response_raw = user_drf_client.get(
url,
query_params=(
{"readable_id": test_course_run.course.readable_id} if single else {}
),
)
assert response_raw.status_code < 300
response = response_raw.json()["results"]

assert len(response) == 1
returned_course = response[0]

assert returned_course["readable_id"] == test_course_run.course.readable_id

if with_b2b:
assert len(returned_course["courseruns"]) == 0
else:
assert len(returned_course["courseruns"]) == 1
assert (
returned_course["courseruns"][0]["courseware_id"]
== test_course_run.courseware_id
)


@pytest.mark.skip_nplusone_check
@pytest.mark.parametrize(
"with_b2b",
[
True,
False,
],
)
def test_get_courses_b2b_programs(with_b2b, user_drf_client):
"""
Test that the courses API returns courses with or without b2b programs.

By default courses should only list programs that aren't marked as b2b_only.
Again, other tests handle filtering of that list so not testing that here.
"""

program = ProgramFactory.create(b2b_only=with_b2b)

test_course_run = CourseRunFactory.create()
program.add_requirement(test_course_run.course)

url = reverse("v2:courses_api-list")
response_raw = user_drf_client.get(
url,
query_params={"readable_id": test_course_run.course.readable_id},
)
assert response_raw.status_code < 300
response = response_raw.json()["results"]

assert len(response) == 1
returned_course = response[0]

assert returned_course["readable_id"] == test_course_run.course.readable_id

if with_b2b:
assert len(returned_course["programs"]) == 0
else:
assert len(returned_course["programs"]) == 1
assert returned_course["programs"][0]["readable_id"] == program.readable_id


def test_get_courses_with_specified_contract_programs(user, user_drf_client):
"""
Test that specifying a contract when retrieving a course returns only
applicable B2B programs.

This is different than testing for the b2b flag alone - if we have a
contract ID specified, then the programs in the list should only be ones
attached to that contract.
"""

contract = ContractPageFactory.create()
user.b2b_contracts.add(contract)
other_contract = ContractPageFactory.create()

programs = ProgramFactory.create_batch(2)
course_run = CourseRunFactory.create(b2b_contract=contract)
CourseRunFactory.create(b2b_contract=other_contract, course=course_run.course)

for program in programs:
program.add_requirement(course_run.course)
program.save()

ContractProgramItem.objects.create(contract=contract, program=programs[0])

url = reverse("v2:courses_api-list")
response_raw = user_drf_client.get(
url,
query_params={
"readable_id": course_run.course.readable_id,
"contract_id": contract.id,
},
)
assert response_raw.status_code < 300
response = response_raw.json()["results"]

assert len(response[0]["programs"]) > 0

program_ids = [program["id"] for program in response[0]["programs"]]
assert programs[0].id in program_ids
assert programs[1].id not in program_ids