File size: 1,240 Bytes
18f8e1a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
from .config import DEFAULT_CHECKPOINT_DIR, VOCOS_ID, DEVICE
from pitchflower.synthesizer import PitchFlowerSynthesizer
from vocos import Vocos
import torch, numpy as np, torchaudio

_synth = None
_vocos = None

def load_models():
    global _synth, _vocos
    if _synth is None:
        _synth = PitchFlowerSynthesizer.from_pretrained(DEFAULT_CHECKPOINT_DIR).to(DEVICE).eval()
    if _vocos is None:
        _vocos = Vocos.from_pretrained(VOCOS_ID).to(DEVICE).eval()
    return _synth, _vocos

@torch.inference_mode()
def reconstruct_from_logf0(filepath, target_logf0_1d, steps, w_scale, target_sr):
    synth, vocos = load_models()
    wav, sr = torchaudio.load(filepath)
    wav = ensure_mono(wav)
    wav = resample_if_needed(wav, sr, target_sr)
    wav = clamp_duration(wav, target_sr)
    wav_dev = wav.to(DEVICE)

    target_b1t = torch.tensor(target_logf0_1d, dtype=torch.float32, device=DEVICE).unsqueeze(0)
    mels = synth.reconstruct_mels(wav_dev, target_b1t, steps=int(steps), w=float(w_scale))
    out = vocos.decode(mels)[0].detach().cpu().numpy()
    out = np.clip(out, -1.0, 1.0)
    return out

# lightweight imports from audio_utils to avoid cycle
from .audio_utils import ensure_mono, resample_if_needed, clamp_duration