import os import re import sys if sys.platform == "darwin": os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1" import io import json import torchaudio import wave from pathlib import Path print("Starting...") import shutil import time import torch import torch._dynamo torch._dynamo.config.suppress_errors = True torch._dynamo.config.cache_size_limit = 64 torch._dynamo.config.suppress_errors = True torch.set_float32_matmul_precision("high") os.environ["KMP_DUPLICATE_LIB_OK"] = "True" import subprocess import soundfile as sf import ChatTTS import datetime from dotenv import load_dotenv load_dotenv() from flask import ( Flask, request, render_template, jsonify, send_from_directory, send_file, Response, stream_with_context, ) import logging from logging.handlers import RotatingFileHandler from waitress import serve from random import random from modelscope import snapshot_download import numpy as np import threading from uilib.cfg import WEB_ADDRESS, SPEAKER_DIR, LOGS_DIR, WAVS_DIR, MODEL_DIR, ROOT_DIR from uilib import utils, VERSION from ChatTTS.utils import select_device from uilib.utils import is_chinese_os, modelscope_status merge_size = int(os.getenv("merge_size", 10)) env_lang = os.getenv("lang", "") if env_lang == "zh": is_cn = True elif env_lang == "en": is_cn = False else: is_cn = is_chinese_os() if not shutil.which("ffmpeg"): print("请先安装ffmpeg") time.sleep(60) exit() chat = ChatTTS.Chat() device_str = os.getenv("device", "default") if device_str in ["default", "mps"]: device = select_device(min_memory=2047, experimental=True if device_str == "mps" else False) elif device_str == "cuda": device = select_device(min_memory=2047) elif device_str == "cpu": device = torch.device("cpu") chat.load( source="local" if not os.path.exists(MODEL_DIR + "/DVAE_full.pt") else "custom", custom_path=ROOT_DIR, device=device, compile=True if os.getenv("compile", "true").lower() != "false" else False, ) # 配置日志 # 禁用 Werkzeug 默认的日志处理器 log = logging.getLogger("werkzeug") log.handlers[:] = [] log.setLevel(logging.WARNING) app = Flask( __name__, static_folder=ROOT_DIR + "/static", static_url_path="/static", template_folder=ROOT_DIR + "/templates", ) root_log = logging.getLogger() # Flask的根日志记录器 root_log.handlers = [] root_log.setLevel(logging.WARNING) app.logger.setLevel(logging.WARNING) # 创建 RotatingFileHandler 对象,设置写入的文件路径和大小限制 file_handler = RotatingFileHandler( LOGS_DIR + f'/{datetime.datetime.now().strftime("%Y%m%d")}.log', maxBytes=1024 * 1024, backupCount=5, ) # 创建日志的格式 formatter = logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s") # 设置文件处理器的级别和格式 file_handler.setLevel(logging.WARNING) file_handler.setFormatter(formatter) # 将文件处理器添加到日志记录器中 app.logger.addHandler(file_handler) app.jinja_env.globals.update(enumerate=enumerate) @app.route("/static/") def static_files(filename): return send_from_directory(app.config["STATIC_FOLDER"], filename) @app.route("/") def index(): speakers = utils.get_speakers() return render_template( f"index{'' if is_cn else 'en'}.html", weburl=WEB_ADDRESS, speakers=speakers, version=VERSION ) # 根据文本返回tts结果,返回 filename=文件名 url=可下载地址 # 请求端根据需要自行选择使用哪个 # params: # # text:待合成文字 # prompt: # voice:音色 # custom_voice:自定义音色值 # skip_refine: 1=跳过refine_text阶段,0=不跳过 # temperature # top_p # top_k # speed # text_seed # refine_max_new_token # infer_max_new_token # wav audio_queue = [] @app.route("/tts", methods=["GET", "POST"]) def tts(): global audio_queue # 原始字符串 text = request.args.get("text", "").strip() or request.form.get("text", "").strip() prompt = request.args.get("prompt", "").strip() or request.form.get("prompt", "") # 默认值 defaults = { "custom_voice": 0, "voice": "2222", "temperature": 0.3, "top_p": 0.7, "top_k": 20, "skip_refine": 0, "speed": 5, "text_seed": 42, "refine_max_new_token": 384, "infer_max_new_token": 2048, "wav": 0, "is_stream": 0, } # 获取 custom_voice = utils.get_parameter(request, "custom_voice", defaults["custom_voice"], int) voice = ( str(custom_voice) if custom_voice > 0 else utils.get_parameter(request, "voice", defaults["voice"], str) ) temperature = utils.get_parameter(request, "temperature", defaults["temperature"], float) top_p = utils.get_parameter(request, "top_p", defaults["top_p"], float) top_k = utils.get_parameter(request, "top_k", defaults["top_k"], int) skip_refine = utils.get_parameter(request, "skip_refine", defaults["skip_refine"], int) is_stream = utils.get_parameter(request, "is_stream", defaults["is_stream"], int) speed = utils.get_parameter(request, "speed", defaults["speed"], int) text_seed = utils.get_parameter(request, "text_seed", defaults["text_seed"], int) refine_max_new_token = utils.get_parameter( request, "refine_max_new_token", defaults["refine_max_new_token"], int ) infer_max_new_token = utils.get_parameter( request, "infer_max_new_token", defaults["infer_max_new_token"], int ) wav = utils.get_parameter(request, "wav", defaults["wav"], int) app.logger.info(f"[tts]{text=}\n{voice=},{skip_refine=}\n") if not text: return jsonify({"code": 1, "msg": "text params lost"}) # 固定音色 rand_spk = None # voice可能是 {voice}.csv or {voice}.pt or number voice = voice.replace(".csv", ".pt") seed_path = f"{SPEAKER_DIR}/{voice}" print(f"{voice=}") # if voice.endswith('.csv') and os.path.exists(seed_path): # rand_spk=utils.load_speaker(voice) # print(f'当前使用音色 {seed_path=}') # el if voice.endswith(".pt") and os.path.exists(seed_path): # 如果.env中未指定设备,则使用 ChatTTS相同算法找设备,否则使用指定设备 rand_spk = torch.load(seed_path, map_location=device) print(f"当前使用音色 {seed_path=}") # 否则 判断是否存在 {voice}.csv # elif os.path.exists(f'{SPEAKER_DIR}/{voice}.csv'): # rand_spk=utils.load_speaker(voice) # print(f'当前使用音色 {SPEAKER_DIR}/{voice}.csv') if rand_spk is None: print(f"当前使用音色:根据seed={voice}获取随机音色") voice_int = re.findall(r"^(\d+)", voice) if len(voice_int) > 0: voice = int(voice_int[0]) else: voice = 2222 torch.manual_seed(voice) # std, mean = chat.sample_random_speaker rand_spk = chat.sample_random_speaker() # rand_spk = torch.randn(768) * std + mean # 保存音色 torch.save(rand_spk, f"{SPEAKER_DIR}/{voice}.pt") # utils.save_speaker(voice,rand_spk) audio_files = [] start_time = time.time() # 中英按语言分行 text_list = [t.strip() for t in text.split("\n") if t.strip()] new_text = utils.split_text(text_list) if text_seed > 0: torch.manual_seed(text_seed) params_infer_code = ChatTTS.Chat.InferCodeParams( spk_emb=rand_spk, prompt=f"[speed_{speed}]", top_P=top_p, top_K=top_k, temperature=temperature, max_new_token=infer_max_new_token, ) params_refine_text = ChatTTS.Chat.RefineTextParams( prompt=prompt, top_P=top_p, top_K=top_k, temperature=temperature, max_new_token=refine_max_new_token, ) print(f"{prompt=}") # 将少于30个字符的行同其他行拼接 retext = [] short_text = "" for it in new_text: if len(it) < 30: short_text += f"{it} [uv_break] " if len(short_text) > 30: retext.append(short_text) short_text = "" else: retext.append(short_text + it) short_text = "" if len(short_text) > 30 or len(retext) < 1: retext.append(short_text) elif short_text: retext[-1] += f" [uv_break] {short_text}" new_text = retext new_text_list = [new_text[i : i + merge_size] for i in range(0, len(new_text), merge_size)] filename_list = [] audio_time = 0 inter_time = 0 for i, te in enumerate(new_text_list): print(f"{te=}") wavs = chat.infer( te, # use_decoder=False, stream=True if is_stream == 1 else False, skip_refine_text=skip_refine, do_text_normalization=False, do_homophone_replacement=True, params_refine_text=params_refine_text, params_infer_code=params_infer_code, ) end_time = time.time() inference_time = end_time - start_time inference_time_rounded = round(inference_time, 2) inter_time += inference_time_rounded print(f"推理时长: {inference_time_rounded} 秒") for j, w in enumerate(wavs): filename = ( datetime.datetime.now().strftime("%H%M%S_") + f"use{inference_time_rounded}s-seed{voice}-te{temperature}-tp{top_p}-tk{top_k}-textlen{len(text)}-{str(random())[2:7]}" + f"-{i}-{j}.wav" ) filename_list.append(filename) torchaudio.save(WAVS_DIR + "/" + filename, torch.from_numpy(w).unsqueeze(0), 24000) txt_tmp = "\n".join([f"file '{WAVS_DIR}/{it}'" for it in filename_list]) txt_name = f"{time.time()}.txt" with open(f"{WAVS_DIR}/{txt_name}", "w", encoding="utf-8") as f: f.write(txt_tmp) outname = ( datetime.datetime.now().strftime("%H%M%S_") + f"use{inter_time}s-audio{audio_time}s-seed{voice}-te{temperature}-tp{top_p}-tk{top_k}-textlen{len(text)}-{str(random())[2:7]}" + "-merge.wav" ) try: subprocess.run( [ "ffmpeg", "-hide_banner", "-ignore_unknown", "-y", "-f", "concat", "-safe", "0", "-i", f"{WAVS_DIR}/{txt_name}", "-c:a", "copy", WAVS_DIR + "/" + outname, ], stdout=subprocess.PIPE, stderr=subprocess.PIPE, encoding="utf-8", check=True, text=True, creationflags=0 if sys.platform != "win32" else subprocess.CREATE_NO_WINDOW, ) except Exception as e: return jsonify({"code": 1, "msg": str(e)}) audio_path = WAVS_DIR + "/" + outname try: # 使用 soundfile audio_info = sf.info(audio_path) audio_duration = round(audio_info.duration, 2) except Exception as e: print(f"计算音频时长失败: {e}") audio_duration = -1 relative_url = f"/static/wavs/{outname}" audio_files.append( { "filename": audio_path, "url": f"http://{request.host}{relative_url}", "relative_url": relative_url, "inference_time": round(inter_time, 2), "audio_duration": audio_duration, } ) result_dict = {"code": 0, "msg": "ok", "audio_files": audio_files} try: if torch.cuda.is_available(): torch.cuda.empty_cache() except Exception: pass # 兼容pyVideoTrans接口调用 if len(audio_files) == 1: result_dict["filename"] = audio_files[0]["filename"] result_dict["url"] = audio_files[0]["url"] result_dict["relative_url"] = audio_files[0]["relative_url"] if wav > 0: return send_file(audio_files[0]["filename"], mimetype="audio/x-wav") else: return jsonify(result_dict) @app.route("/clear_wavs", methods=["POST"]) def clear_wavs(): dir_path = "static/wavs" # wav音频文件存储目录 success, message = utils.ClearWav(dir_path) if success: return jsonify({"code": 0, "msg": message}) else: return jsonify({"code": 1, "msg": message}) try: host = WEB_ADDRESS.split(":") print(f"Start:{WEB_ADDRESS}") threading.Thread(target=utils.openweb, args=(f"http://{WEB_ADDRESS}",)).start() serve(app, host=host[0], port=int(host[1])) except Exception as e: print(e)