Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,8 @@ LocalAI
models/*
test-models/
test-dir/
tests/e2e-aio/backends
tests/e2e-aio/models

release/

Expand Down
5 changes: 3 additions & 2 deletions backend/go/whisper/gowhisper.go
Original file line number Diff line number Diff line change
Expand Up @@ -130,8 +130,9 @@ func (w *Whisper) AudioTranscription(opts *pb.TranscriptRequest) (pb.TranscriptR
segments := []*pb.TranscriptSegment{}
text := ""
for i := range int(segsLen) {
s := CppGetSegmentStart(i)
t := CppGetSegmentEnd(i)
// segment start/end conversion factor taken from https://github.com/ggml-org/whisper.cpp/blob/master/examples/cli/cli.cpp#L895
s := CppGetSegmentStart(i) * (10000000)
t := CppGetSegmentEnd(i) * (10000000)
txt := strings.Clone(CppGetSegmentText(i))
tokens := make([]int32, CppNTokens(i))

Expand Down
7 changes: 4 additions & 3 deletions backend/python/faster-whisper/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def LoadModel(self, request, context):
device = "mps"
try:
print("Preparing models, please wait", file=sys.stderr)
self.model = WhisperModel(request.Model, device=device, compute_type="float16")
self.model = WhisperModel(request.Model, device=device, compute_type="default")
except Exception as err:
return backend_pb2.Result(success=False, message=f"Unexpected {err=}, {type(err)=}")
# Implement your logic here for the LoadModel service
Expand All @@ -55,11 +55,12 @@ def AudioTranscription(self, request, context):
id = 0
for segment in segments:
print("[%.2fs -> %.2fs] %s" % (segment.start, segment.end, segment.text))
resultSegments.append(backend_pb2.TranscriptSegment(id=id, start=segment.start, end=segment.end, text=segment.text))
resultSegments.append(backend_pb2.TranscriptSegment(id=id, start=int(segment.start)*1e9, end=int(segment.end)*1e9, text=segment.text))
text += segment.text
id += 1
id += 1
except Exception as err:
print(f"Unexpected {err=}, {type(err)=}", file=sys.stderr)
raise err

return backend_pb2.TranscriptResult(segments=resultSegments, text=text)

Expand Down
3 changes: 1 addition & 2 deletions core/backend/transcript.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,7 @@ import (
"github.com/mudler/LocalAI/pkg/model"
)

func ModelTranscription(audio, language string, translate bool, diarize bool, prompt string, ml *model.ModelLoader, modelConfig config.ModelConfig, appConfig *config.ApplicationConfig) (*schema.TranscriptionResult, error) {

func ModelTranscription(audio, language string, translate, diarize bool, prompt string, ml *model.ModelLoader, modelConfig config.ModelConfig, appConfig *config.ApplicationConfig) (*schema.TranscriptionResult, error) {
if modelConfig.Backend == "" {
modelConfig.Backend = model.WhisperBackend
}
Expand Down
58 changes: 47 additions & 11 deletions core/cli/transcript.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,32 +2,42 @@ package cli

import (
"context"
"encoding/json"
"errors"
"fmt"
"strings"

"github.com/mudler/LocalAI/core/backend"
cliContext "github.com/mudler/LocalAI/core/cli/context"
"github.com/mudler/LocalAI/core/config"
"github.com/mudler/LocalAI/core/gallery"
"github.com/mudler/LocalAI/core/schema"
"github.com/mudler/LocalAI/pkg/format"
"github.com/mudler/LocalAI/pkg/model"
"github.com/mudler/LocalAI/pkg/system"
"github.com/mudler/xlog"
)

type TranscriptCMD struct {
Filename string `arg:""`
Filename string `arg:"" name:"file" help:"Audio file to transcribe" type:"path"`

Backend string `short:"b" default:"whisper" help:"Backend to run the transcription model"`
Model string `short:"m" required:"" help:"Model name to run the TTS"`
Language string `short:"l" help:"Language of the audio file"`
Translate bool `short:"c" help:"Translate the transcription to english"`
Diarize bool `short:"d" help:"Mark speaker turns"`
Threads int `short:"t" default:"1" help:"Number of threads used for parallel computation"`
ModelsPath string `env:"LOCALAI_MODELS_PATH,MODELS_PATH" type:"path" default:"${basepath}/models" help:"Path containing models used for inferencing" group:"storage"`
Prompt string `short:"p" help:"Previous transcribed text or words that hint at what the model should expect"`
Backend string `short:"b" default:"whisper" help:"Backend to run the transcription model"`
Model string `short:"m" required:"" help:"Model name to run the TTS"`
Language string `short:"l" help:"Language of the audio file"`
Translate bool `short:"c" help:"Translate the transcription to English"`
Diarize bool `short:"d" help:"Mark speaker turns"`
Threads int `short:"t" default:"1" help:"Number of threads used for parallel computation"`
BackendsPath string `env:"LOCALAI_BACKENDS_PATH,BACKENDS_PATH" type:"path" default:"${basepath}/backends" help:"Path containing backends used for inferencing" group:"storage"`
ModelsPath string `env:"LOCALAI_MODELS_PATH,MODELS_PATH" type:"path" default:"${basepath}/models" help:"Path containing models used for inferencing" group:"storage"`
BackendGalleries string `env:"LOCALAI_BACKEND_GALLERIES,BACKEND_GALLERIES" help:"JSON list of backend galleries" group:"backends" default:"${backends}"`
Prompt string `short:"p" help:"Previous transcribed text or words that hint at what the model should expect"`
ResponseFormat schema.TranscriptionResponseFormatType `short:"f" default:"" help:"Response format for Whisper models, can be one of (txt, lrc, srt, vtt, json, json_verbose)"`
PrettyPrint bool `help:"Used with response_format json or json_verbose for pretty printing"`
}

func (t *TranscriptCMD) Run(ctx *cliContext.Context) error {
systemState, err := system.GetSystemState(
system.WithBackendPath(t.BackendsPath),
system.WithModelPath(t.ModelsPath),
)
if err != nil {
Expand All @@ -40,6 +50,11 @@ func (t *TranscriptCMD) Run(ctx *cliContext.Context) error {

cl := config.NewModelConfigLoader(t.ModelsPath)
ml := model.NewModelLoader(systemState)

if err := gallery.RegisterBackends(systemState, ml); err != nil {
xlog.Error("error registering external backends", "error", err)
}

if err := cl.LoadModelConfigsFromPath(t.ModelsPath); err != nil {
return err
}
Expand All @@ -62,8 +77,29 @@ func (t *TranscriptCMD) Run(ctx *cliContext.Context) error {
if err != nil {
return err
}
for _, segment := range tr.Segments {
fmt.Println(segment.Start.String(), "-", segment.Text)

switch t.ResponseFormat {
case schema.TranscriptionResponseFormatLrc, schema.TranscriptionResponseFormatSrt, schema.TranscriptionResponseFormatVtt, schema.TranscriptionResponseFormatText:
fmt.Println(format.TranscriptionResponse(tr, t.ResponseFormat))
case schema.TranscriptionResponseFormatJson:
tr.Segments = nil
fallthrough
case schema.TranscriptionResponseFormatJsonVerbose:
var mtr []byte
var err error
if t.PrettyPrint {
mtr, err = json.MarshalIndent(tr, "", " ")
} else {
mtr, err = json.Marshal(tr)
}
if err != nil {
return err
}
fmt.Println(string(mtr))
default:
for _, segment := range tr.Segments {
fmt.Println(segment.Start.String(), "-", strings.TrimSpace(segment.Text))
}
}
return nil
}
17 changes: 15 additions & 2 deletions core/http/endpoints/openai/transcription.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package openai

import (
"errors"
"io"
"net/http"
"os"
Expand All @@ -12,6 +13,7 @@ import (
"github.com/mudler/LocalAI/core/config"
"github.com/mudler/LocalAI/core/http/middleware"
"github.com/mudler/LocalAI/core/schema"
"github.com/mudler/LocalAI/pkg/format"
model "github.com/mudler/LocalAI/pkg/model"

"github.com/mudler/xlog"
Expand All @@ -38,6 +40,7 @@ func TranscriptEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, app

diarize := c.FormValue("diarize") != "false"
prompt := c.FormValue("prompt")
responseFormat := schema.TranscriptionResponseFormatType(c.FormValue("response_format"))

// retrieve the file data from the request
file, err := c.FormFile("file")
Expand Down Expand Up @@ -76,7 +79,17 @@ func TranscriptEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, app
}

xlog.Debug("Transcribed", "transcription", tr)
// TODO: handle different outputs here
return c.JSON(http.StatusOK, tr)

switch responseFormat {
case schema.TranscriptionResponseFormatLrc, schema.TranscriptionResponseFormatText, schema.TranscriptionResponseFormatSrt, schema.TranscriptionResponseFormatVtt:
return c.String(http.StatusOK, format.TranscriptionResponse(tr, responseFormat))
case schema.TranscriptionResponseFormatJson:
tr.Segments = nil
fallthrough
case schema.TranscriptionResponseFormatJsonVerbose, "": // maintain backwards compatibility
return c.JSON(http.StatusOK, tr)
default:
return errors.New("invalid response_format")
}
}
}
11 changes: 11 additions & 0 deletions core/schema/openai.go
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,17 @@ type ImageGenerationResponseFormat string

type ChatCompletionResponseFormatType string

type TranscriptionResponseFormatType string

const (
TranscriptionResponseFormatText = TranscriptionResponseFormatType("txt")
TranscriptionResponseFormatSrt = TranscriptionResponseFormatType("srt")
TranscriptionResponseFormatVtt = TranscriptionResponseFormatType("vtt")
TranscriptionResponseFormatLrc = TranscriptionResponseFormatType("lrc")
TranscriptionResponseFormatJson = TranscriptionResponseFormatType("json")
TranscriptionResponseFormatJsonVerbose = TranscriptionResponseFormatType("json_verbose")
)

type ChatCompletionResponseFormat struct {
Type ChatCompletionResponseFormatType `json:"type,omitempty"`
}
Expand Down
2 changes: 1 addition & 1 deletion core/schema/transcription.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,6 @@ type TranscriptionSegment struct {
}

type TranscriptionResult struct {
Segments []TranscriptionSegment `json:"segments"`
Segments []TranscriptionSegment `json:"segments,omitempty"`
Text string `json:"text"`
}
4 changes: 0 additions & 4 deletions core/startup/model_preload.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,6 @@ import (
"github.com/mudler/xlog"
)

const (
YAML_EXTENSION = ".yaml"
)

// InstallModels will preload models from the given list of URLs and galleries
// It will download the model if it is not already present in the model path
// It will also try to resolve if the model is an embedded model YAML configuration
Expand Down
Loading
Loading