-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathserver.py
More file actions
149 lines (125 loc) Β· 4.76 KB
/
server.py
File metadata and controls
149 lines (125 loc) Β· 4.76 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
from model import Net
import torch
import torchaudio
import time
import numpy as np
import json
import os
from utils import glob_audio_files
import gradio as gr
# -----------------------
# Model loading functions
# -----------------------
def load_model(checkpoint_path, config_path):
with open(config_path) as f:
config = json.load(f)
model = Net(**config['model_params'])
model.load_state_dict(torch.load(checkpoint_path, map_location="cpu")['model'])
model.eval()
return model, config['data']['sr']
def load_audio(audio_path, sample_rate):
audio, sr = torchaudio.load(audio_path)
audio = audio.mean(0, keepdim=False) # mono
audio = torchaudio.transforms.Resample(sr, sample_rate)(audio)
return audio
def save_audio(audio, audio_path, sample_rate):
torchaudio.save(audio_path, audio, sample_rate)
# -----------------------
# Inference functions
# -----------------------
def infer(model, audio):
return model(audio.unsqueeze(0).unsqueeze(0)).squeeze(0)
def infer_stream(model, audio, chunk_factor, sr):
L = model.L
chunk_len = model.dec_chunk_size * L * chunk_factor
original_len = len(audio)
if len(audio) % chunk_len != 0:
pad_len = chunk_len - (len(audio) % chunk_len)
audio = torch.nn.functional.pad(audio, (0, pad_len))
audio = torch.cat((audio[L:], torch.zeros(L)))
audio_chunks = torch.split(audio, chunk_len)
# Add lookahead context
new_audio_chunks = []
for i, a in enumerate(audio_chunks):
front_ctx = torch.zeros(L * 2) if i == 0 else audio_chunks[i - 1][-L * 2:]
new_audio_chunks.append(torch.cat([front_ctx, a]))
audio_chunks = new_audio_chunks
outputs = []
times = []
with torch.inference_mode():
enc_buf, dec_buf, out_buf = model.init_buffers(1, torch.device('cpu'))
convnet_pre_ctx = model.convnet_pre.init_ctx_buf(1, torch.device('cpu')) if hasattr(model, 'convnet_pre') else None
for chunk in audio_chunks:
start = time.time()
output, enc_buf, dec_buf, out_buf, convnet_pre_ctx = model(
chunk.unsqueeze(0).unsqueeze(0),
enc_buf, dec_buf, out_buf,
convnet_pre_ctx,
pad=(not model.lookahead)
)
outputs.append(output)
times.append(time.time() - start)
outputs = torch.cat(outputs, dim=2)
avg_time = np.mean(times)
rtf = (chunk_len / sr) / avg_time
e2e_latency = ((2 * L + chunk_len) / sr + avg_time) * 1000
outputs = outputs[:, :, :original_len].squeeze(0)
return outputs, rtf, e2e_latency
def do_infer(model, audio, chunk_factor, sr, stream=True):
with torch.no_grad():
if stream:
outputs, rtf, e2e_latency = infer_stream(model, audio, chunk_factor, sr)
else:
outputs = infer(model, audio)
rtf = None
e2e_latency = None
return outputs, rtf, e2e_latency
# -----------------------
# Gradio app function
# -----------------------
def convert_voice(mic_audio):
sr, samples = mic_audio
# Convert to float32
audio = torch.from_numpy(samples.astype(np.float32))
# Normalize audio to [-1, 1] without clipping
max_val = audio.abs().max()
if max_val > 0:
audio = audio / max_val
outputs, rtf, e2e_latency = do_infer(model, audio, chunk_factor=30, sr=sr, stream=True)
audio_data = outputs.squeeze().numpy()
max_val = np.max(np.abs(audio_data))
if max_val > 0:
audio_data = audio_data / max_val
audio_data = audio_data.astype(np.float32)
# Return audio + RTF + E2E latency
return (sr, audio_data), round(rtf, 2), round(e2e_latency, 2)
# -----------------------
# Load model once
# -----------------------
checkpoint_path = "models/llvc/G_500000.pth"
config_path = "experiments/llvc/config.json"
model, sr = load_model(checkpoint_path, config_path)
description = """
ποΈ **LLVC Real-Time Voice Conversion**
Speak into your microphone, and the model will convert your voice **live** using a lightweight streaming inference pipeline.
**Features:**
- π Real-time voice streaming
- β‘ Efficient chunked inference
- π Displays real-time factor (RTF) and end-to-end latency
*Model automatically normalizes your input audio to prevent clipping.*
"""
iface = gr.Interface(
fn=convert_voice,
inputs=gr.Audio(sources=["microphone"], type="numpy", label="π€ Input Voice"),
outputs=[
gr.Audio(type="numpy", label="π Converted Voice"),
gr.Number(label="β‘ Real-Time Factor (RTF)"),
gr.Number(label="β±οΈ End-to-End Latency (ms)")
],
live=True,
theme=gr.themes.Soft(primary_hue="blue", secondary_hue="pink"),
title="π§ LLVC Streaming Voice Conversion",
description=description,
allow_flagging="never"
)
iface.launch(share=True)