-
Notifications
You must be signed in to change notification settings - Fork 13
Expand file tree
/
Copy pathconfig.py
More file actions
204 lines (159 loc) · 5.78 KB
/
config.py
File metadata and controls
204 lines (159 loc) · 5.78 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
from pydantic import BaseModel, model_validator
import tomli
from typing import List, Dict, Optional, Any
from utils.logger import logger
from pathlib import Path
import sys
import os
import re
import aiohttp
import asyncio
from dotenv import load_dotenv
load_dotenv(override=True)
# Allowed image generation models
ALLOWED_MODELS = {"flux", "klein", "klein-large", "zimage"}
class BotConfig(BaseModel):
command_prefix: str
bot_id: str
avatar_url: str
commands: Dict[str, str]
emojis: Dict[str, str]
class APIConfig(BaseModel):
api_key: str
models_list_endpoint: str
image_gen_endpoint: str
models_refresh_interval_minutes: int
max_timeout_seconds: int
class ImageGenerationDefaults(BaseModel):
width: int
height: int
safe: bool
cached: bool
nologo: bool
enhance: bool
private: bool
class ImageGenerationValidation(BaseModel):
min_width: int
min_height: int
max_prompt_length: int
max_enhanced_prompt_length: int = 80
class CommandCooldown(BaseModel):
rate: int
seconds: int
per_minute: Optional[int] = None
per_day: Optional[int] = None
class CommandConfig(BaseModel):
cooldown: CommandCooldown
default_width: int
default_height: int
timeout_seconds: Optional[int] = None
max_prompt_length: Optional[int] = None
class UIColors(BaseModel):
success: str
error: str
warning: str
class UIConfig(BaseModel):
bot_invite_url: str
support_server_url: str
github_repo_url: str
api_provider_url: str
bot_creator_avatar: str
colors: UIColors
error_messages: Dict[str, str]
class ResourcesConfig(BaseModel):
waiting_gifs: List[str]
class ImageGenerationConfig(BaseModel):
referer: str
fallback_model: str
defaults: ImageGenerationDefaults
validation: ImageGenerationValidation
class Config(BaseModel):
bot: BotConfig
api: APIConfig
image_generation: ImageGenerationConfig
commands: Dict[str, CommandConfig]
ui: UIConfig
resources: ResourcesConfig
MODELS: List[str] = [] # Initialize with empty list as default
@model_validator(mode="after")
def validate_structure(self):
required_commands = {"pollinate", "multi_pollinate", "random"}
if not all(cmd in self.commands for cmd in required_commands):
missing = required_commands - self.commands.keys()
logger.error(f"Missing required commands: {missing}")
raise ValueError(f"Missing required commands: {missing}")
return self
def resolve_env_variables(data: Any) -> Any:
"""Recursively resolve environment variables in config data"""
if isinstance(data, dict):
return {key: resolve_env_variables(value) for key, value in data.items()}
elif isinstance(data, list):
return [resolve_env_variables(item) for item in data]
elif isinstance(data, str):
# Match ${VARIABLE_NAME} pattern
pattern = r"\$\{([^}]+)\}"
matches = re.findall(pattern, data)
for match in matches:
env_value = os.getenv(match)
if env_value is not None:
data = data.replace(f"${{{match}}}", env_value)
else:
logger.warning(
f"Environment variable '{match}' not found, keeping original value"
)
return data
else:
return data
async def load_config_async(path: str = "config.toml") -> Config:
"""Load and validate config from TOML file asynchronously"""
config_path = Path(path)
if not config_path.exists():
raise FileNotFoundError(f"Config file not found at {config_path}")
# Use asyncio to run file I/O in thread pool to avoid blocking
loop = asyncio.get_event_loop()
def _read_config():
with open(config_path, "rb") as f:
return tomli.load(f)
config_data = await loop.run_in_executor(None, _read_config)
# Resolve environment variables
config_data = resolve_env_variables(config_data)
return Config(**config_data)
def load_config(path: str = "config.toml") -> Config:
"""Load and validate config from TOML file (synchronous fallback for startup)"""
config_path = Path(path)
if not config_path.exists():
raise FileNotFoundError(f"Config file not found at {config_path}")
with open(config_path, "rb") as f:
config_data = tomli.load(f)
# Resolve environment variables
config_data = resolve_env_variables(config_data)
return Config(**config_data)
async def initialize_models_async(config_instance: Config) -> List[str]:
try:
async with aiohttp.ClientSession() as session:
async with session.get(
config_instance.api.models_list_endpoint
) as response:
if response.ok:
models = await response.json()
return [model['name'] if isinstance(model, dict) else model for model in models]
except Exception as e:
logger.error(f"Error pre-initializing models: {e}")
return [config_instance.image_generation.fallback_model]
def initialize_models(config_instance: Config) -> List[str]:
import requests
try:
response = requests.get(config_instance.api.models_list_endpoint)
if response.ok:
models = response.json()
return [model['name'] if isinstance(model, dict) else model for model in models]
except Exception as e:
print(f"Error pre-initializing models: {e}", file=sys.stderr)
return [config_instance.image_generation.fallback_model]
# Load config on import
try:
config: Config = load_config()
# Pre-initialize models list (will be replaced with async version during bot startup)
config.MODELS = initialize_models(config)
except Exception as e:
raise RuntimeError(f"Failed to load config: {e}") from e