diegotg343's picture
better app version
18f8e1a
from .config import DEFAULT_CHECKPOINT_DIR, VOCOS_ID, DEVICE
from pitchflower.synthesizer import PitchFlowerSynthesizer
from vocos import Vocos
import torch, numpy as np, torchaudio
_synth = None
_vocos = None
def load_models():
global _synth, _vocos
if _synth is None:
_synth = PitchFlowerSynthesizer.from_pretrained(DEFAULT_CHECKPOINT_DIR).to(DEVICE).eval()
if _vocos is None:
_vocos = Vocos.from_pretrained(VOCOS_ID).to(DEVICE).eval()
return _synth, _vocos
@torch.inference_mode()
def reconstruct_from_logf0(filepath, target_logf0_1d, steps, w_scale, target_sr):
synth, vocos = load_models()
wav, sr = torchaudio.load(filepath)
wav = ensure_mono(wav)
wav = resample_if_needed(wav, sr, target_sr)
wav = clamp_duration(wav, target_sr)
wav_dev = wav.to(DEVICE)
target_b1t = torch.tensor(target_logf0_1d, dtype=torch.float32, device=DEVICE).unsqueeze(0)
mels = synth.reconstruct_mels(wav_dev, target_b1t, steps=int(steps), w=float(w_scale))
out = vocos.decode(mels)[0].detach().cpu().numpy()
out = np.clip(out, -1.0, 1.0)
return out
# lightweight imports from audio_utils to avoid cycle
from .audio_utils import ensure_mono, resample_if_needed, clamp_duration