Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .ruff.toml
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
line-length = 100 # ideally I want this to be less than 100 but don't wanna test and change files with longer lines
line-length = 120 # ideally I want this to be less than 100 but don't wanna test and change files with longer lines
target-version = "py313"
lint.select = [
"E", # pycodestyle errors
Expand Down
1 change: 1 addition & 0 deletions AGENTS.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
@CLAUDE.md
27 changes: 3 additions & 24 deletions src/kernelbot/api/api_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,9 +74,7 @@ async def _handle_discord_oauth(code: str, redirect_uri: str) -> tuple[str, str]
user_name = user_json.get("username")

if not user_id or not user_name:
raise HTTPException(
status_code=500, detail="Failed to retrieve user ID or username from Discord."
)
raise HTTPException(status_code=500, detail="Failed to retrieve user ID or username from Discord.")

return user_id, user_name

Expand Down Expand Up @@ -135,16 +133,12 @@ async def _handle_github_oauth(code: str, redirect_uri: str) -> tuple[str, str]:
user_name = user_json.get("login") # GitHub uses 'login' for username

if not user_id or not user_name:
raise HTTPException(
status_code=500, detail="Failed to retrieve user ID or username from GitHub."
)
raise HTTPException(status_code=500, detail="Failed to retrieve user ID or username from GitHub.")

return user_id, user_name


async def _run_submission(
submission: SubmissionRequest, mode: SubmissionMode, backend: KernelBackend
):
async def _run_submission(submission: SubmissionRequest, mode: SubmissionMode, backend: KernelBackend):
try:
req = prepare_submission(submission, backend)
except Exception as e:
Expand Down Expand Up @@ -225,21 +219,6 @@ async def to_submit_info(

try:
with db_context as db:
# Per-user rate limit: max 1 submission per hour on Modal B200 for leaderboard 730
if gpu_type == "B200":
lb_id = db.get_leaderboard_id(leaderboard_name)
if lb_id == 730:
last_submission_time = db.check_user_rate_limit(user_id)
if last_submission_time:
raise HTTPException(
status_code=429,
detail=(
f"Rate limit exceeded. You can submit once per hour. "
f"Last submission: {last_submission_time.isoformat()}. "
f"Consider using the NVIDIA runner instead of Modal for faster iteration."
),
)

leaderboard_item = db.get_leaderboard(leaderboard_name)
gpus = leaderboard_item.get("gpu_types", [])
if gpu_type not in gpus:
Expand Down
82 changes: 53 additions & 29 deletions src/kernelbot/api/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@

app = FastAPI()


def json_serializer(obj):
"""JSON serializer for objects not serializable by default json code"""
if isinstance(obj, (datetime.datetime, datetime.date, datetime.time)):
Expand Down Expand Up @@ -185,9 +186,7 @@ def require_admin(
@app.get("/auth/init")
async def auth_init(provider: str, db_context=Depends(get_db)) -> dict:
if provider not in ["discord", "github"]:
raise HTTPException(
status_code=400, detail="Invalid provider, must be 'discord' or 'github'"
)
raise HTTPException(status_code=400, detail="Invalid provider, must be 'discord' or 'github'")

"""
Initialize authentication flow for the specified provider.
Expand Down Expand Up @@ -230,9 +229,7 @@ async def cli_auth(auth_provider: str, code: str, state: str, db_context=Depends
"""

if auth_provider not in ["discord", "github"]:
raise HTTPException(
status_code=400, detail="Invalid provider, must be 'discord' or 'github'"
)
raise HTTPException(status_code=400, detail="Invalid provider, must be 'discord' or 'github'")

if not code or not state:
raise HTTPException(status_code=400, detail="Missing authorization code or state")
Expand All @@ -252,8 +249,7 @@ async def cli_auth(auth_provider: str, code: str, state: str, db_context=Depends
if not api_base_url:
raise HTTPException(
status_code=500,
detail="Redirect URI base not configured."
"Set HEROKU_APP_DEFAULT_DOMAIN_NAME or POPCORN_API_URL.",
detail="Redirect URI base not configured. Set HEROKU_APP_DEFAULT_DOMAIN_NAME or POPCORN_API_URL.",
)
redirect_uri_base = api_base_url.rstrip("/")
redirect_uri = f"https://{redirect_uri_base}/auth/cli/{auth_provider}"
Expand All @@ -275,7 +271,10 @@ async def cli_auth(auth_provider: str, code: str, state: str, db_context=Depends
raise HTTPException(status_code=500, detail=f"Error during {auth_provider} OAuth flow: {e}") from e

if not user_id or not user_name:
raise HTTPException(status_code=500,detail="Failed to retrieve user ID or username from provider.",)
raise HTTPException(
status_code=500,
detail="Failed to retrieve user ID or username from provider.",
)

try:
with db_context as db:
Expand All @@ -297,6 +296,7 @@ async def cli_auth(auth_provider: str, code: str, state: str, db_context=Depends
"is_reset": is_reset,
}


async def _stream_submission_response(
submission_request: SubmissionRequest,
submission_mode_enum: SubmissionMode,
Expand All @@ -315,18 +315,18 @@ async def _stream_submission_response(

while not task.done():
elapsed_time = time.time() - start_time
yield f"event: status\ndata: {json.dumps({'status': 'processing',
'elapsed_time': round(elapsed_time, 2)},
default=json_serializer)}\n\n"
yield f"event: status\ndata: {
json.dumps({'status': 'processing', 'elapsed_time': round(elapsed_time, 2)}, default=json_serializer)
}\n\n"

try:
await asyncio.wait_for(asyncio.shield(task), timeout=15.0)
except asyncio.TimeoutError:
continue
except asyncio.CancelledError:
yield f"event: error\ndata: {json.dumps(
{'status': 'error', 'detail': 'Submission cancelled'},
default=json_serializer)}\n\n"
yield f"event: error\ndata: {
json.dumps({'status': 'error', 'detail': 'Submission cancelled'}, default=json_serializer)
}\n\n"
return

result, reports = await task
Expand Down Expand Up @@ -360,6 +360,7 @@ async def _stream_submission_response(
except asyncio.CancelledError:
pass


@app.post("/{leaderboard_name}/{gpu_type}/{submission_mode}")
async def run_submission( # noqa: C901
leaderboard_name: str,
Expand Down Expand Up @@ -398,27 +399,28 @@ async def run_submission( # noqa: C901
)
return StreamingResponse(generator, media_type="text/event-stream")


async def enqueue_background_job(
req: ProcessedSubmissionRequest,
mode: SubmissionMode,
backend: KernelBackend,
manager: BackgroundSubmissionManager,
):

# pre-create the submission for api returns
with backend.db as db:
sub_id = db.create_submission(
leaderboard=req.leaderboard,
file_name=req.file_name,
code=req.code,
user_id=req.user_id,
time=datetime.datetime.now(),
time=datetime.datetime.now(datetime.timezone.utc),
user_name=req.user_name,
)
job_id = db.upsert_submission_job_status(sub_id, "initial", None)
# put submission request in queue
await manager.enqueue(req, mode, sub_id)
return sub_id,job_id
return sub_id, job_id


@app.post("/submission/{leaderboard_name}/{gpu_type}/{submission_mode}")
async def run_submission_async(
Expand All @@ -445,15 +447,13 @@ async def run_submission_async(
JSONResponse: A JSON response containing job_id and and submission_id for the client to poll for status.
"""
try:

await simple_rate_limit()
logger.info(f"Received submission request for {leaderboard_name} {gpu_type} {submission_mode}")


# throw error if submission request is invalid
try:
submission_request, submission_mode_enum = await to_submit_info(
user_info, submission_mode, file, leaderboard_name, gpu_type, db_context
user_info, submission_mode, file, leaderboard_name, gpu_type, db_context
)

req = prepare_submission(submission_request, backend_instance)
Expand All @@ -466,13 +466,13 @@ async def run_submission_async(
raise HTTPException(status_code=400, detail="Invalid GPU type")

# put submission request to background manager to run in background
sub_id,job_status_id = await enqueue_background_job(
sub_id, job_status_id = await enqueue_background_job(
req, submission_mode_enum, backend_instance, background_submission_manager
)

return JSONResponse(
status_code=202,
content={"details":{"id": sub_id, "job_status_id": job_status_id}, "status": "accepted"},
content={"details": {"id": sub_id, "job_status_id": job_status_id}, "status": "accepted"},
)
# Preserve FastAPI HTTPException as-is
except HTTPException:
Expand Down Expand Up @@ -536,8 +536,7 @@ async def create_dev_leaderboard(
# GPUs must be specified in task.yml
if not definition.gpus:
raise HTTPException(
status_code=400,
detail="No gpus specified in task.yml. Add 'gpus:' field with list of GPU types."
status_code=400, detail="No gpus specified in task.yml. Add 'gpus:' field with list of GPU types."
)

with db_context as db:
Expand Down Expand Up @@ -629,7 +628,7 @@ async def admin_update_problems(
branch=branch,
force=force,
creator_id=0, # API-created
forum_id=-1, # No Discord forum
forum_id=-1, # No Discord forum
)
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e)) from e
Expand All @@ -643,6 +642,33 @@ async def admin_update_problems(
}


@app.get("/leaderboard/rate-limits/{leaderboard_name}")
async def get_leaderboard_rate_limits(leaderboard_name: str, db_context=Depends(get_db)) -> dict:
with db_context as db:
rate_limits = db.get_leaderboard_rate_limits(leaderboard_name)
return {"status": "ok", "rate_limits": rate_limits}


@app.post("/leaderboard/rate-limits/{leaderboard_name}/{gpu_type}")
async def set_leaderboard_gpu_rate_limit(
leaderboard_name: str,
gpu_type: str,
rate_limit_seconds: int,
_: Annotated[None, Depends(require_admin)],
db_context=Depends(get_db),
) -> dict:
if rate_limit_seconds <= 0:
rate_limit_seconds = None
with db_context as db:
db.set_leaderboard_gpu_rate_limit(leaderboard_name, gpu_type, rate_limit_seconds)
return {
"status": "ok",
"leaderboard_name": leaderboard_name,
"gpu_type": gpu_type,
"rate_limit_seconds": rate_limit_seconds,
}


@app.get("/leaderboards")
async def get_leaderboards(db_context=Depends(get_db)):
"""An endpoint that returns all leaderboards.
Expand Down Expand Up @@ -692,9 +718,7 @@ async def get_submissions(
try:
with db_context as db:
# Add validation for leaderboard and GPU? Might be redundant if DB handles it.
return db.get_leaderboard_submissions(
leaderboard_name, gpu_name, limit=limit, offset=offset
)
return db.get_leaderboard_submissions(leaderboard_name, gpu_name, limit=limit, offset=offset)
except Exception as e:
raise HTTPException(status_code=500, detail=f"Error fetching submissions: {e}") from e

Expand Down
Loading
Loading