Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 45 additions & 0 deletions .github/workflows/format.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
name: Check Format

on:
push:
branches: [ main ]
pull_request:
branches: [ main ]
types: [ opened, synchronize, reopened, ready_for_review ]

jobs:
check-format:
if: github.event_name != 'pull_request' || github.event.pull_request.draft == false
runs-on: ubuntu-latest

steps:
- name: Checkout code
uses: actions/checkout@v4

- name: Set up Python
id: setup-python
uses: actions/setup-python@v5
with:
python-version: '3.9'

- name: Install Poetry
uses: snok/install-poetry@v1
with:
virtualenvs-create: true
virtualenvs-in-project: true
virtualenvs-path: .venv
installer-parallel: true

- name: Load cached venv
id: cached-poetry-deps
uses: actions/cache@v4
with:
path: .venv
key: venv-${{ runner.os }}-${{ steps.setup-python.outputs.python-version }}-${{ hashFiles('**/poetry.lock', '**/pyproject.toml') }}

- name: Install dependencies
if: steps.cached-poetry-deps.outputs.cache-hit != 'true'
run: make setup-dev

- name: Check formatting (black, isort)
run: make check-format
5 changes: 4 additions & 1 deletion check.mk
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,9 @@ CODE_PATHS = grasp tests tasks
# Define paths for JSON files
JSON_PATHS = $(shell find grasp -name "*.json")

# Minimum acceptable pylint score (0.0 - 10.0). CI will fail if below this.
PYLINT_FAIL_UNDER ?= 8.0

########################################################################################################################
# LINT
########################################################################################################################
Expand All @@ -32,7 +35,7 @@ lint-flake8: ## Run flake8 in a controlled environment
.PHONY: lint-pylint-local
lint-pylint-local: ## Analyze the code with pylint using poetry
@echo "Analyzing code with pylint"
poetry run pylint --jobs 0 $(CODE_PATHS)
poetry run pylint --jobs 0 --fail-under=$(PYLINT_FAIL_UNDER) $(CODE_PATHS)

.PHONY: lint-pylint
lint-pylint: ## Run pylint in a controlled environment
Expand Down
64 changes: 26 additions & 38 deletions grasp/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,30 +6,30 @@
"""

import logging
from typing import Union, Dict, Any
from typing import Any, Dict, Union

from .workflow import Workflow, create_graph
from .configuration import ConfigLoader, load_config
from .exceptions import (
GraSPError,
ValidationError,
ExecutionError,
ConfigurationError,
NodeError,
DataError,
ExecutionError,
GraSPError,
ModelError,
NodeError,
TimeoutError,
ValidationError,
)
from .models import ModelConfigBuilder
from .workflow import Workflow, create_graph

try:
from .core.base_task_executor import BaseTaskExecutor, DefaultTaskExecutor
from .core.judge_task_executor import JudgeQualityTaskExecutor
from .core.resumable_execution import ResumableExecutionManager
from .core.dataset.dataset_processor import DatasetProcessor
from .core.graph.graph_config import GraphConfig
from .core.graph.grasp_state import GraspState
from .core.graph.grasp_message import GraspMessage
from .core.graph.grasp_state import GraspState
from .core.judge_task_executor import JudgeQualityTaskExecutor
from .core.resumable_execution import ResumableExecutionManager

CORE_AVAILABLE = True
except ImportError as e:
Expand All @@ -39,11 +39,11 @@
try:
from .core.dataset.dataset_config import (
DataSourceConfig,
OutputConfig,
DataSourceType,
OutputConfig,
OutputType,
TransformConfig,
ShardConfig,
TransformConfig,
)
from .core.dataset.file_handler import FileHandler
from .core.dataset.huggingface_handler import HuggingFaceHandler
Expand All @@ -54,9 +54,9 @@

# Node modules
try:
from .core.graph.nodes.base_node import BaseNode, NodeType, NodeState
from .core.graph.nodes.llm_node import LLMNode as CoreLLMNode
from .core.graph.nodes.agent_node import AgentNode as CoreAgentNode
from .core.graph.nodes.base_node import BaseNode, NodeState, NodeType
from .core.graph.nodes.llm_node import LLMNode as CoreLLMNode
from .core.graph.nodes.multi_llm_node import MultiLLMNode as CoreMultiLLMNode
from .core.graph.nodes.weighted_sampler_node import (
WeightedSamplerNode as CoreWeightedSamplerNode,
Expand All @@ -69,10 +69,10 @@
# Model factory modules
try:
from .core.models.model_factory import ModelFactory
from .core.models.structured_output.schemas_factory import SimpleResponse
from .core.models.structured_output.structured_output_config import (
StructuredOutputConfig,
)
from .core.models.structured_output.schemas_factory import SimpleResponse

MODELS_AVAILABLE = True
except ImportError:
Expand All @@ -81,12 +81,12 @@
# Utility modules
try:
from . import utils
from .utils import constants
from .logger.logger_config import (
logger,
set_external_logger,
reset_to_internal_logger,
set_external_logger,
)
from .utils import constants

UTILS_AVAILABLE = True
except ImportError:
Expand All @@ -95,12 +95,12 @@
# Import node builders
try:
from .nodes import (
LLMNodeBuilder,
AgentNodeBuilder,
MultiLLMNodeBuilder,
LambdaNodeBuilder,
WeightedSamplerNodeBuilder,
LLMNodeBuilder,
MultiLLMNodeBuilder,
SubgraphNodeBuilder,
WeightedSamplerNodeBuilder,
)

NODE_BUILDERS_AVAILABLE = True
Expand All @@ -110,10 +110,10 @@
# Import data utilities
try:
from .data import (
DataSource,
DataSink,
DataSourceFactory,
DataSinkFactory,
DataSource,
DataSourceFactory,
from_file,
from_huggingface,
to_file,
Expand Down Expand Up @@ -161,12 +161,7 @@ def quick_multi_llm(
models: Dict[str, str], prompt: str, data_source: str, output: str = "output.json"
):
"""Quick multi-LLM workflow creation."""
return (
Workflow("quick_multi_llm")
.source(data_source)
.multi_llm(models, prompt)
.sink(output)
)
return Workflow("quick_multi_llm").source(data_source).multi_llm(models, prompt).sink(output)


def execute_task(task_name: str, **kwargs):
Expand Down Expand Up @@ -203,17 +198,14 @@ def create_chat_workflow(name: str, conversation_type: str = "multiturn") -> Wor
return workflow


def create_structured_schema(
fields: Dict[str, str], name: str = "CustomSchema"
) -> Dict[str, Any]:
def create_structured_schema(fields: Dict[str, str], name: str = "CustomSchema") -> Dict[str, Any]:
"""Create structured output schema configuration."""
return {
"enabled": True,
"schema": {
"name": name,
"fields": {
field_name: {"type": field_type}
for field_name, field_type in fields.items()
field_name: {"type": field_type} for field_name, field_type in fields.items()
},
},
}
Expand All @@ -224,9 +216,7 @@ def pydantic_schema(model_class: str) -> Dict[str, Any]:
return {"enabled": True, "schema": model_class}


def create_processor_config(
processor: Union[str, callable], **params
) -> Dict[str, Any]:
def create_processor_config(processor: Union[str, callable], **params) -> Dict[str, Any]:
"""Create processor configuration."""
if callable(processor):
processor_path = f"{processor.__module__}.{processor.__name__}"
Expand All @@ -240,9 +230,7 @@ def create_processor_config(
return config


def create_transformation_config(
transform: Union[str, callable], **params
) -> Dict[str, Any]:
def create_transformation_config(transform: Union[str, callable], **params) -> Dict[str, Any]:
"""Create data transformation configuration."""
if callable(transform):
transform_path = f"{transform.__module__}.{transform.__name__}"
Expand Down
17 changes: 7 additions & 10 deletions grasp/configuration/loader.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
import os
import yaml
from typing import Union, Any
from pathlib import Path
from typing import Any, Union

import yaml

try:
from grasp.utils import utils
from grasp.core.graph.graph_config import GraphConfig
from grasp.core.dataset.dataset_config import DataSourceConfig, OutputConfig
from grasp.core.graph.graph_config import GraphConfig
from grasp.utils import utils

UTILS_AVAILABLE = True
except ImportError:
Expand All @@ -24,15 +25,11 @@ def load(self, config_path: Union[str, Path, dict[str, Any]]) -> dict[str, Any]:
config_path = Path(config_path)
if not config_path.exists():
if UTILS_AVAILABLE:
task_config_path = utils.get_file_in_task_dir(
config_path.stem, "graph_config.yaml"
)
task_config_path = utils.get_file_in_task_dir(config_path.stem, "graph_config.yaml")
if os.path.exists(task_config_path):
config_path = Path(task_config_path)
else:
raise FileNotFoundError(
f"Configuration file not found: {config_path}"
)
raise FileNotFoundError(f"Configuration file not found: {config_path}")
else:
raise FileNotFoundError(f"Configuration file not found: {config_path}")

Expand Down
Loading