Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
129 changes: 128 additions & 1 deletion cmd/scenario/draft_generation/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package main
import (
"context"
"encoding/json"
"errors"
"fmt"
"os"
"path/filepath"
Expand Down Expand Up @@ -42,6 +43,13 @@ func main() {
baseURL := envFirst("http://127.0.0.1:8081/v1", "LLM_BASE_URL", "LLAMACPP_BASE_URL")
model := envFirst("gemma4:31b", "DRAFT_LLM_MODEL", "LLM_MODEL", "LLAMACPP_MODEL")
verifyModel := envFirst("gemma4:latest", "VERIFY_LLM_MODEL", "LLM_MODEL", "LLAMACPP_MODEL")
failureContext := failureAttemptContext{
LLMBaseURL: baseURL,
LLMModel: model,
VerifyModel: verifyModel,
PersonaID: brief.PersonaID,
OutputFormatID: brief.OutputFormatID,
}
minStyleScore := envFloat("SCENARIO_MIN_STYLE_SCORE", 80)
minDraftRunes := envInt("SCENARIO_MIN_DRAFT_RUNES", 2400)
maxAttempts := envInt("DRAFT_MAX_ATTEMPTS", 2)
Expand Down Expand Up @@ -87,13 +95,24 @@ func main() {
}
elapsed := time.Since(started)
cancel()
metrics := attemptRuntimeMetrics{
ElapsedSeconds: elapsed.Seconds(),
TimeoutSeconds: timeout.Seconds(),
Streaming: streamDraft,
FirstChunkMs: finalFirstChunkMs(firstChunk),
Chunks: chunkCount,
}
if err != nil {
fatalf("generate draft attempt %d: %v", attempt, err)
attempts := generationAttemptsFromError(err)
artifacts := writeRawAttemptArtifacts(outputDir, attempt, attempts)
failurePath := writeFailureAttempt(outputDir, attempt, err, metrics, failureContext, artifacts)
fatalf("generate draft attempt %d: %v (failure=%s)", attempt, err, failurePath)
}
finalElapsed = elapsed
finalFirstChunk = firstChunk
finalChunks = chunkCount
finalAttempt = attempt
writeRawAttemptArtifacts(outputDir, attempt, result.Attempts)
writeFile(filepath.Join(outputDir, fmt.Sprintf("draft_attempt_%d.md", attempt)), result.Draft.Markdown()+"\n")
writeJSON(filepath.Join(outputDir, fmt.Sprintf("evaluation_attempt_%d.json", attempt)), result.Evaluation)
writeJSON(filepath.Join(outputDir, fmt.Sprintf("verification_attempt_%d.json", attempt)), result.Verification)
Expand Down Expand Up @@ -139,6 +158,114 @@ func main() {
}
}

type attemptRuntimeMetrics struct {
ElapsedSeconds float64 `json:"elapsed_seconds"`
TimeoutSeconds float64 `json:"timeout_seconds"`
Streaming bool `json:"streaming"`
FirstChunkMs int64 `json:"first_chunk_ms,omitempty"`
Chunks int `json:"chunks,omitempty"`
}

type rawAttemptArtifact struct {
GenerationAttempt int `json:"generation_attempt"`
Kind string `json:"kind"`
Path string `json:"path"`
ValidationError string `json:"validation_error,omitempty"`
}

type failureAttemptContext struct {
LLMBaseURL string `json:"llm_base_url"`
LLMModel string `json:"llm_model"`
VerifyModel string `json:"verify_model"`
PersonaID string `json:"persona_id"`
OutputFormatID string `json:"output_format_id"`
}

type failureAttemptReport struct {
Attempt int `json:"attempt"`
Error string `json:"error"`
ValidationError string `json:"validation_error,omitempty"`
RuntimeMetrics attemptRuntimeMetrics `json:"runtime_metrics"`
Context failureAttemptContext `json:"context"`
RawOutputs []rawAttemptArtifact `json:"raw_outputs"`
}

func generationAttemptsFromError(err error) []draftapp.GenerationAttempt {
var unusable *draftapp.UnusableDraftError
if errors.As(err, &unusable) {
return unusable.Attempts
}
return nil
}

func writeRawAttemptArtifacts(outputDir string, scenarioAttempt int, attempts []draftapp.GenerationAttempt) []rawAttemptArtifact {
artifacts := make([]rawAttemptArtifact, 0, len(attempts))
for _, attempt := range attempts {
if strings.TrimSpace(attempt.RawOutput) == "" {
continue
}
kind := sanitizeArtifactPart(attempt.Kind)
if kind == "" {
kind = "generation"
}
index := attempt.Index
if index <= 0 {
index = len(artifacts) + 1
}
path := filepath.Join(outputDir, fmt.Sprintf("raw_attempt_%d_generation_%d_%s.txt", scenarioAttempt, index, kind))
writeFile(path, strings.TrimRight(attempt.RawOutput, "\n")+"\n")
artifacts = append(artifacts, rawAttemptArtifact{
GenerationAttempt: index,
Kind: attempt.Kind,
Path: path,
ValidationError: attempt.ValidationError,
})
}
return artifacts
}

func writeFailureAttempt(outputDir string, attempt int, err error, metrics attemptRuntimeMetrics, context failureAttemptContext, artifacts []rawAttemptArtifact) string {
report := failureAttemptReport{
Attempt: attempt,
Error: err.Error(),
ValidationError: validationErrorFromGenerateError(err),
RuntimeMetrics: metrics,
Context: context,
RawOutputs: artifacts,
}
path := filepath.Join(outputDir, fmt.Sprintf("failure_attempt_%d.json", attempt))
writeJSON(path, report)
return path
}

func validationErrorFromGenerateError(err error) string {
var unusable *draftapp.UnusableDraftError
if errors.As(err, &unusable) && unusable.Err != nil {
return unusable.Err.Error()
}
return ""
}

func finalFirstChunkMs(firstChunk time.Duration) int64 {
if firstChunk <= 0 {
return 0
}
return firstChunk.Milliseconds()
}

func sanitizeArtifactPart(value string) string {
value = strings.TrimSpace(strings.ToLower(value))
var builder strings.Builder
for _, r := range value {
if (r >= 'a' && r <= 'z') || (r >= '0' && r <= '9') || r == '-' || r == '_' {
builder.WriteRune(r)
continue
}
builder.WriteByte('_')
}
return strings.Trim(builder.String(), "_")
}

func readJSON(path string, out any) {
encoded, err := os.ReadFile(path)
if err != nil {
Expand Down
93 changes: 93 additions & 0 deletions cmd/scenario/draft_generation/main_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
package main

import (
"encoding/json"
"errors"
"os"
"path/filepath"
"testing"

draftapp "github.com/teradakousuke/note_maker/internal/application/draft"
)

func TestWriteFailureAttemptPreservesRawOutputsAndRuntimeMetrics(t *testing.T) {
outputDir := t.TempDir()
generateErr := &draftapp.UnusableDraftError{
FormatID: "zenn_article",
Err: errors.New("zenn article must use :::message, not Qiita :::note"),
Attempts: []draftapp.GenerationAttempt{
{
Index: 1,
Kind: "initial",
RawOutput: "---\ntitle: \"T\"\nemoji: \"📝\"\ntype: \"tech\"\ntopics: [\"go\"]\npublished: false\n---\n\n:::note info\nwrong\n:::",
ValidationError: "zenn article must use :::message, not Qiita :::note",
},
{
Index: 2,
Kind: "format_repair",
RawOutput: "---\ntitle: \"T\"\nemoji: \"📝\"\ntype: \"tech\"\ntopics: [\"go\"]\npublished: false\n---\n\n:::note warn\nstill wrong\n:::",
ValidationError: "zenn article must use :::message, not Qiita :::note",
},
},
}
metrics := attemptRuntimeMetrics{
ElapsedSeconds: 1.25,
TimeoutSeconds: 30,
Streaming: true,
FirstChunkMs: 120,
Chunks: 3,
}
context := failureAttemptContext{
LLMBaseURL: "http://evo-x2.tailb30e58.ts.net/v1",
LLMModel: "gemma4:31b",
VerifyModel: "gemma4:latest",
PersonaID: "cloudia",
OutputFormatID: "zenn_article",
}

artifacts := writeRawAttemptArtifacts(outputDir, 2, generationAttemptsFromError(generateErr))
failurePath := writeFailureAttempt(outputDir, 2, generateErr, metrics, context, artifacts)

if len(artifacts) != 2 {
t.Fatalf("artifacts = %#v, want 2", artifacts)
}
for _, artifact := range artifacts {
content, err := os.ReadFile(artifact.Path)
if err != nil {
t.Fatalf("read raw artifact %s: %v", artifact.Path, err)
}
if len(content) == 0 {
t.Fatalf("raw artifact %s was empty", artifact.Path)
}
}
if _, err := os.Stat(filepath.Join(outputDir, "raw_attempt_2_generation_1_initial.txt")); err != nil {
t.Fatalf("missing initial raw artifact: %v", err)
}
if _, err := os.Stat(filepath.Join(outputDir, "raw_attempt_2_generation_2_format_repair.txt")); err != nil {
t.Fatalf("missing repair raw artifact: %v", err)
}

encoded, err := os.ReadFile(failurePath)
if err != nil {
t.Fatalf("read failure artifact: %v", err)
}
var report failureAttemptReport
if err := json.Unmarshal(encoded, &report); err != nil {
t.Fatalf("decode failure artifact: %v", err)
}
if report.Attempt != 2 {
t.Fatalf("attempt = %d, want 2", report.Attempt)
}
if report.ValidationError != "zenn article must use :::message, not Qiita :::note" {
t.Fatalf("validation error = %q", report.ValidationError)
}
if report.RuntimeMetrics.ElapsedSeconds != 1.25 || report.RuntimeMetrics.FirstChunkMs != 120 || report.RuntimeMetrics.Chunks != 3 {
t.Fatalf("runtime metrics not preserved: %#v", report.RuntimeMetrics)
}
if report.Context.LLMBaseURL != context.LLMBaseURL || report.Context.LLMModel != context.LLMModel || report.Context.OutputFormatID != context.OutputFormatID {
t.Fatalf("runtime context not preserved: %#v", report.Context)
}
if len(report.RawOutputs) != 2 || report.RawOutputs[0].ValidationError == "" {
t.Fatalf("raw output metadata not preserved: %#v", report.RawOutputs)
}
}
Loading