from .config import DEFAULT_CHECKPOINT_DIR, VOCOS_ID, DEVICE from pitchflower.synthesizer import PitchFlowerSynthesizer from vocos import Vocos import torch, numpy as np, torchaudio _synth = None _vocos = None def load_models(): global _synth, _vocos if _synth is None: _synth = PitchFlowerSynthesizer.from_pretrained(DEFAULT_CHECKPOINT_DIR).to(DEVICE).eval() if _vocos is None: _vocos = Vocos.from_pretrained(VOCOS_ID).to(DEVICE).eval() return _synth, _vocos @torch.inference_mode() def reconstruct_from_logf0(filepath, target_logf0_1d, steps, w_scale, target_sr): synth, vocos = load_models() wav, sr = torchaudio.load(filepath) wav = ensure_mono(wav) wav = resample_if_needed(wav, sr, target_sr) wav = clamp_duration(wav, target_sr) wav_dev = wav.to(DEVICE) target_b1t = torch.tensor(target_logf0_1d, dtype=torch.float32, device=DEVICE).unsqueeze(0) mels = synth.reconstruct_mels(wav_dev, target_b1t, steps=int(steps), w=float(w_scale)) out = vocos.decode(mels)[0].detach().cpu().numpy() out = np.clip(out, -1.0, 1.0) return out # lightweight imports from audio_utils to avoid cycle from .audio_utils import ensure_mono, resample_if_needed, clamp_duration