mirror of
https://github.com/jianchang512/ChatTTS-ui.git
synced 2026-01-12 00:06:33 +08:00
407 lines
12 KiB
Python
407 lines
12 KiB
Python
import os
|
||
import re
|
||
import sys
|
||
|
||
if sys.platform == "darwin":
|
||
os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"
|
||
import io
|
||
import json
|
||
import torchaudio
|
||
import wave
|
||
from pathlib import Path
|
||
|
||
print("Starting...")
|
||
import shutil
|
||
import time
|
||
|
||
import torch
|
||
import torch._dynamo
|
||
|
||
torch._dynamo.config.suppress_errors = True
|
||
torch._dynamo.config.cache_size_limit = 64
|
||
torch._dynamo.config.suppress_errors = True
|
||
torch.set_float32_matmul_precision("high")
|
||
os.environ["KMP_DUPLICATE_LIB_OK"] = "True"
|
||
import subprocess
|
||
import soundfile as sf
|
||
import ChatTTS
|
||
import datetime
|
||
from dotenv import load_dotenv
|
||
|
||
load_dotenv()
|
||
from flask import (
|
||
Flask,
|
||
request,
|
||
render_template,
|
||
jsonify,
|
||
send_from_directory,
|
||
send_file,
|
||
Response,
|
||
stream_with_context,
|
||
)
|
||
import logging
|
||
from logging.handlers import RotatingFileHandler
|
||
from waitress import serve
|
||
from random import random
|
||
from modelscope import snapshot_download
|
||
import numpy as np
|
||
import threading
|
||
from uilib.cfg import WEB_ADDRESS, SPEAKER_DIR, LOGS_DIR, WAVS_DIR, MODEL_DIR, ROOT_DIR
|
||
from uilib import utils, VERSION
|
||
from ChatTTS.utils import select_device
|
||
from uilib.utils import is_chinese_os, modelscope_status
|
||
|
||
merge_size = int(os.getenv("merge_size", 10))
|
||
env_lang = os.getenv("lang", "")
|
||
if env_lang == "zh":
|
||
is_cn = True
|
||
elif env_lang == "en":
|
||
is_cn = False
|
||
else:
|
||
is_cn = is_chinese_os()
|
||
|
||
if not shutil.which("ffmpeg"):
|
||
print("请先安装ffmpeg")
|
||
time.sleep(60)
|
||
exit()
|
||
|
||
|
||
chat = ChatTTS.Chat()
|
||
device_str = os.getenv("device", "default")
|
||
|
||
if device_str in ["default", "mps"]:
|
||
device = select_device(min_memory=2047, experimental=True if device_str == "mps" else False)
|
||
elif device_str == "cuda":
|
||
device = select_device(min_memory=2047)
|
||
elif device_str == "cpu":
|
||
device = torch.device("cpu")
|
||
|
||
|
||
chat.load(
|
||
source="local" if not os.path.exists(MODEL_DIR + "/DVAE_full.pt") else "custom",
|
||
custom_path=ROOT_DIR,
|
||
device=device,
|
||
compile=True if os.getenv("compile", "true").lower() != "false" else False,
|
||
)
|
||
|
||
|
||
# 配置日志
|
||
# 禁用 Werkzeug 默认的日志处理器
|
||
log = logging.getLogger("werkzeug")
|
||
log.handlers[:] = []
|
||
log.setLevel(logging.WARNING)
|
||
|
||
app = Flask(
|
||
__name__,
|
||
static_folder=ROOT_DIR + "/static",
|
||
static_url_path="/static",
|
||
template_folder=ROOT_DIR + "/templates",
|
||
)
|
||
|
||
root_log = logging.getLogger() # Flask的根日志记录器
|
||
root_log.handlers = []
|
||
root_log.setLevel(logging.WARNING)
|
||
app.logger.setLevel(logging.WARNING)
|
||
# 创建 RotatingFileHandler 对象,设置写入的文件路径和大小限制
|
||
file_handler = RotatingFileHandler(
|
||
LOGS_DIR + f'/{datetime.datetime.now().strftime("%Y%m%d")}.log',
|
||
maxBytes=1024 * 1024,
|
||
backupCount=5,
|
||
)
|
||
# 创建日志的格式
|
||
formatter = logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s")
|
||
# 设置文件处理器的级别和格式
|
||
file_handler.setLevel(logging.WARNING)
|
||
file_handler.setFormatter(formatter)
|
||
# 将文件处理器添加到日志记录器中
|
||
app.logger.addHandler(file_handler)
|
||
app.jinja_env.globals.update(enumerate=enumerate)
|
||
|
||
|
||
@app.route("/static/<path:filename>")
|
||
def static_files(filename):
|
||
return send_from_directory(app.config["STATIC_FOLDER"], filename)
|
||
|
||
|
||
@app.route("/")
|
||
def index():
|
||
speakers = utils.get_speakers()
|
||
return render_template(
|
||
f"index{'' if is_cn else 'en'}.html", weburl=WEB_ADDRESS, speakers=speakers, version=VERSION
|
||
)
|
||
|
||
|
||
# 根据文本返回tts结果,返回 filename=文件名 url=可下载地址
|
||
# 请求端根据需要自行选择使用哪个
|
||
# params:
|
||
#
|
||
# text:待合成文字
|
||
# prompt:
|
||
# voice:音色
|
||
# custom_voice:自定义音色值
|
||
# skip_refine: 1=跳过refine_text阶段,0=不跳过
|
||
# temperature
|
||
# top_p
|
||
# top_k
|
||
# speed
|
||
# text_seed
|
||
# refine_max_new_token
|
||
# infer_max_new_token
|
||
# wav
|
||
|
||
audio_queue = []
|
||
|
||
|
||
@app.route("/tts", methods=["GET", "POST"])
|
||
def tts():
|
||
global audio_queue
|
||
# 原始字符串
|
||
text = request.args.get("text", "").strip() or request.form.get("text", "").strip()
|
||
prompt = request.args.get("prompt", "").strip() or request.form.get("prompt", "")
|
||
|
||
# 默认值
|
||
defaults = {
|
||
"custom_voice": 0,
|
||
"voice": "2222",
|
||
"temperature": 0.3,
|
||
"top_p": 0.7,
|
||
"top_k": 20,
|
||
"skip_refine": 0,
|
||
"speed": 5,
|
||
"text_seed": 42,
|
||
"refine_max_new_token": 384,
|
||
"infer_max_new_token": 2048,
|
||
"wav": 0,
|
||
"is_stream": 0,
|
||
}
|
||
|
||
# 获取
|
||
custom_voice = utils.get_parameter(request, "custom_voice", defaults["custom_voice"], int)
|
||
voice = (
|
||
str(custom_voice)
|
||
if custom_voice > 0
|
||
else utils.get_parameter(request, "voice", defaults["voice"], str)
|
||
)
|
||
temperature = utils.get_parameter(request, "temperature", defaults["temperature"], float)
|
||
top_p = utils.get_parameter(request, "top_p", defaults["top_p"], float)
|
||
top_k = utils.get_parameter(request, "top_k", defaults["top_k"], int)
|
||
skip_refine = utils.get_parameter(request, "skip_refine", defaults["skip_refine"], int)
|
||
is_stream = utils.get_parameter(request, "is_stream", defaults["is_stream"], int)
|
||
speed = utils.get_parameter(request, "speed", defaults["speed"], int)
|
||
text_seed = utils.get_parameter(request, "text_seed", defaults["text_seed"], int)
|
||
refine_max_new_token = utils.get_parameter(
|
||
request, "refine_max_new_token", defaults["refine_max_new_token"], int
|
||
)
|
||
infer_max_new_token = utils.get_parameter(
|
||
request, "infer_max_new_token", defaults["infer_max_new_token"], int
|
||
)
|
||
wav = utils.get_parameter(request, "wav", defaults["wav"], int)
|
||
|
||
app.logger.info(f"[tts]{text=}\n{voice=},{skip_refine=}\n")
|
||
if not text:
|
||
return jsonify({"code": 1, "msg": "text params lost"})
|
||
# 固定音色
|
||
rand_spk = None
|
||
# voice可能是 {voice}.csv or {voice}.pt or number
|
||
voice = voice.replace(".csv", ".pt")
|
||
seed_path = f"{SPEAKER_DIR}/{voice}"
|
||
print(f"{voice=}")
|
||
# if voice.endswith('.csv') and os.path.exists(seed_path):
|
||
# rand_spk=utils.load_speaker(voice)
|
||
# print(f'当前使用音色 {seed_path=}')
|
||
# el
|
||
|
||
if voice.endswith(".pt") and os.path.exists(seed_path):
|
||
# 如果.env中未指定设备,则使用 ChatTTS相同算法找设备,否则使用指定设备
|
||
rand_spk = torch.load(seed_path, map_location=device)
|
||
print(f"当前使用音色 {seed_path=}")
|
||
# 否则 判断是否存在 {voice}.csv
|
||
# elif os.path.exists(f'{SPEAKER_DIR}/{voice}.csv'):
|
||
# rand_spk=utils.load_speaker(voice)
|
||
# print(f'当前使用音色 {SPEAKER_DIR}/{voice}.csv')
|
||
|
||
if rand_spk is None:
|
||
print(f"当前使用音色:根据seed={voice}获取随机音色")
|
||
voice_int = re.findall(r"^(\d+)", voice)
|
||
if len(voice_int) > 0:
|
||
voice = int(voice_int[0])
|
||
else:
|
||
voice = 2222
|
||
torch.manual_seed(voice)
|
||
# std, mean = chat.sample_random_speaker
|
||
rand_spk = chat.sample_random_speaker()
|
||
# rand_spk = torch.randn(768) * std + mean
|
||
# 保存音色
|
||
torch.save(rand_spk, f"{SPEAKER_DIR}/{voice}.pt")
|
||
# utils.save_speaker(voice,rand_spk)
|
||
|
||
audio_files = []
|
||
|
||
start_time = time.time()
|
||
|
||
# 中英按语言分行
|
||
text_list = [t.strip() for t in text.split("\n") if t.strip()]
|
||
new_text = utils.split_text(text_list)
|
||
if text_seed > 0:
|
||
torch.manual_seed(text_seed)
|
||
|
||
params_infer_code = ChatTTS.Chat.InferCodeParams(
|
||
spk_emb=rand_spk,
|
||
prompt=f"[speed_{speed}]",
|
||
top_P=top_p,
|
||
top_K=top_k,
|
||
temperature=temperature,
|
||
max_new_token=infer_max_new_token,
|
||
)
|
||
params_refine_text = ChatTTS.Chat.RefineTextParams(
|
||
prompt=prompt,
|
||
top_P=top_p,
|
||
top_K=top_k,
|
||
temperature=temperature,
|
||
max_new_token=refine_max_new_token,
|
||
)
|
||
print(f"{prompt=}")
|
||
# 将少于30个字符的行同其他行拼接
|
||
retext = []
|
||
short_text = ""
|
||
for it in new_text:
|
||
if len(it) < 30:
|
||
short_text += f"{it} [uv_break] "
|
||
if len(short_text) > 30:
|
||
retext.append(short_text)
|
||
short_text = ""
|
||
else:
|
||
retext.append(short_text + it)
|
||
short_text = ""
|
||
if len(short_text) > 30 or len(retext) < 1:
|
||
retext.append(short_text)
|
||
elif short_text:
|
||
retext[-1] += f" [uv_break] {short_text}"
|
||
|
||
new_text = retext
|
||
|
||
new_text_list = [new_text[i : i + merge_size] for i in range(0, len(new_text), merge_size)]
|
||
filename_list = []
|
||
|
||
audio_time = 0
|
||
inter_time = 0
|
||
|
||
for i, te in enumerate(new_text_list):
|
||
print(f"{te=}")
|
||
wavs = chat.infer(
|
||
te,
|
||
# use_decoder=False,
|
||
stream=True if is_stream == 1 else False,
|
||
skip_refine_text=skip_refine,
|
||
do_text_normalization=False,
|
||
do_homophone_replacement=True,
|
||
params_refine_text=params_refine_text,
|
||
params_infer_code=params_infer_code,
|
||
)
|
||
|
||
end_time = time.time()
|
||
inference_time = end_time - start_time
|
||
inference_time_rounded = round(inference_time, 2)
|
||
inter_time += inference_time_rounded
|
||
print(f"推理时长: {inference_time_rounded} 秒")
|
||
|
||
for j, w in enumerate(wavs):
|
||
filename = (
|
||
datetime.datetime.now().strftime("%H%M%S_")
|
||
+ f"use{inference_time_rounded}s-seed{voice}-te{temperature}-tp{top_p}-tk{top_k}-textlen{len(text)}-{str(random())[2:7]}"
|
||
+ f"-{i}-{j}.wav"
|
||
)
|
||
filename_list.append(filename)
|
||
torchaudio.save(WAVS_DIR + "/" + filename, torch.from_numpy(w).unsqueeze(0), 24000)
|
||
|
||
txt_tmp = "\n".join([f"file '{WAVS_DIR}/{it}'" for it in filename_list])
|
||
txt_name = f"{time.time()}.txt"
|
||
with open(f"{WAVS_DIR}/{txt_name}", "w", encoding="utf-8") as f:
|
||
f.write(txt_tmp)
|
||
outname = (
|
||
datetime.datetime.now().strftime("%H%M%S_")
|
||
+ f"use{inter_time}s-audio{audio_time}s-seed{voice}-te{temperature}-tp{top_p}-tk{top_k}-textlen{len(text)}-{str(random())[2:7]}"
|
||
+ "-merge.wav"
|
||
)
|
||
try:
|
||
subprocess.run(
|
||
[
|
||
"ffmpeg",
|
||
"-hide_banner",
|
||
"-ignore_unknown",
|
||
"-y",
|
||
"-f",
|
||
"concat",
|
||
"-safe",
|
||
"0",
|
||
"-i",
|
||
f"{WAVS_DIR}/{txt_name}",
|
||
"-c:a",
|
||
"copy",
|
||
WAVS_DIR + "/" + outname,
|
||
],
|
||
stdout=subprocess.PIPE,
|
||
stderr=subprocess.PIPE,
|
||
encoding="utf-8",
|
||
check=True,
|
||
text=True,
|
||
creationflags=0 if sys.platform != "win32" else subprocess.CREATE_NO_WINDOW,
|
||
)
|
||
except Exception as e:
|
||
return jsonify({"code": 1, "msg": str(e)})
|
||
|
||
audio_path = WAVS_DIR + "/" + outname
|
||
try:
|
||
# 使用 soundfile
|
||
audio_info = sf.info(audio_path)
|
||
audio_duration = round(audio_info.duration, 2)
|
||
except Exception as e:
|
||
print(f"计算音频时长失败: {e}")
|
||
audio_duration = -1
|
||
|
||
relative_url = f"/static/wavs/{outname}"
|
||
audio_files.append(
|
||
{
|
||
"filename": audio_path,
|
||
"url": f"http://{request.host}{relative_url}",
|
||
"relative_url": relative_url,
|
||
"inference_time": round(inter_time, 2),
|
||
"audio_duration": audio_duration,
|
||
}
|
||
)
|
||
result_dict = {"code": 0, "msg": "ok", "audio_files": audio_files}
|
||
try:
|
||
if torch.cuda.is_available():
|
||
torch.cuda.empty_cache()
|
||
except Exception:
|
||
pass
|
||
# 兼容pyVideoTrans接口调用
|
||
if len(audio_files) == 1:
|
||
result_dict["filename"] = audio_files[0]["filename"]
|
||
result_dict["url"] = audio_files[0]["url"]
|
||
result_dict["relative_url"] = audio_files[0]["relative_url"]
|
||
|
||
if wav > 0:
|
||
return send_file(audio_files[0]["filename"], mimetype="audio/x-wav")
|
||
else:
|
||
return jsonify(result_dict)
|
||
|
||
|
||
@app.route("/clear_wavs", methods=["POST"])
|
||
def clear_wavs():
|
||
dir_path = "static/wavs" # wav音频文件存储目录
|
||
success, message = utils.ClearWav(dir_path)
|
||
if success:
|
||
return jsonify({"code": 0, "msg": message})
|
||
else:
|
||
return jsonify({"code": 1, "msg": message})
|
||
|
||
|
||
try:
|
||
host = WEB_ADDRESS.split(":")
|
||
print(f"Start:{WEB_ADDRESS}")
|
||
threading.Thread(target=utils.openweb, args=(f"http://{WEB_ADDRESS}",)).start()
|
||
serve(app, host=host[0], port=int(host[1]))
|
||
except Exception as e:
|
||
print(e)
|