Spaces:
No application file
No application file
Upload 25 files
Browse files- .gitattributes +1 -0
- BUILD +44 -0
- Dockerfile +30 -0
- WORKSPACE +154 -0
- alphabet.txt +57 -0
- app (1).py +63 -0
- bazelisk-linux-amd64 +3 -0
- build_ext.sh +3 -0
- extract_model.py +5 -0
- gitignore +1 -0
- inference.py +91 -0
- install_espeak_ng.sh +12 -0
- mynumbers.py +73 -0
- packages.txt +7 -0
- pooch.py +10 -0
- requirements (1).txt +12 -0
- tacotron.py +451 -0
- tacotron.toml +32 -0
- tacotrons_ljs_24k_v1_0300000.ckpt +3 -0
- text.py +92 -0
- utils.py +74 -0
- wavegru.py +300 -0
- wavegru.yaml +14 -0
- wavegru_cpp.py +42 -0
- wavegru_mod.cc +150 -0
- wavegru_vocoder_1024_v4_1320000.ckpt +3 -0
.gitattributes
CHANGED
|
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 36 |
+
bazelisk-linux-amd64 filter=lfs diff=lfs merge=lfs -text
|
BUILD
ADDED
|
@@ -0,0 +1,44 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# [internal] load cc_fuzz_target.bzl
|
| 2 |
+
# [internal] load cc_proto_library.bzl
|
| 3 |
+
# [internal] load android_cc_test:def.bzl
|
| 4 |
+
|
| 5 |
+
load("@pybind11_bazel//:build_defs.bzl", "pybind_extension")
|
| 6 |
+
|
| 7 |
+
package(default_visibility = [":__subpackages__"])
|
| 8 |
+
|
| 9 |
+
licenses(["notice"])
|
| 10 |
+
|
| 11 |
+
# To run all cc_tests in this directory:
|
| 12 |
+
# bazel test //:all
|
| 13 |
+
|
| 14 |
+
# [internal] Command to run dsp_util_android_test.
|
| 15 |
+
|
| 16 |
+
# [internal] Command to run lyra_integration_android_test.
|
| 17 |
+
|
| 18 |
+
exports_files(
|
| 19 |
+
srcs = [
|
| 20 |
+
"wavegru_mod.cc",
|
| 21 |
+
],
|
| 22 |
+
)
|
| 23 |
+
|
| 24 |
+
pybind_extension(
|
| 25 |
+
name = "wavegru_mod", # This name is not actually created!
|
| 26 |
+
srcs = ["wavegru_mod.cc"],
|
| 27 |
+
deps = [
|
| 28 |
+
"//sparse_matmul",
|
| 29 |
+
],
|
| 30 |
+
)
|
| 31 |
+
|
| 32 |
+
py_library(
|
| 33 |
+
name = "wavegru_mod",
|
| 34 |
+
data = [":wavegru_mod.so"],
|
| 35 |
+
)
|
| 36 |
+
|
| 37 |
+
py_binary(
|
| 38 |
+
name = "wavegru",
|
| 39 |
+
srcs = ["wavegru.py"],
|
| 40 |
+
deps = [
|
| 41 |
+
":wavegru_mod"
|
| 42 |
+
],
|
| 43 |
+
)
|
| 44 |
+
|
Dockerfile
ADDED
|
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# read the doc: https://huggingface.co/docs/hub/spaces-sdks-docker
|
| 2 |
+
# you will also find guides on how best to write your Dockerfile
|
| 3 |
+
|
| 4 |
+
FROM python:3.11
|
| 5 |
+
|
| 6 |
+
RUN apt update; apt install libsndfile1-dev make autoconf automake libtool gcc pkg-config -y
|
| 7 |
+
|
| 8 |
+
WORKDIR /code
|
| 9 |
+
|
| 10 |
+
COPY ./requirements.txt /code/requirements.txt
|
| 11 |
+
|
| 12 |
+
RUN pip install --no-cache-dir --upgrade -r /code/requirements.txt
|
| 13 |
+
|
| 14 |
+
# Set up a new user named "user" with user ID 1000
|
| 15 |
+
RUN useradd -m -u 1000 user
|
| 16 |
+
|
| 17 |
+
# Switch to the "user" user
|
| 18 |
+
USER user
|
| 19 |
+
|
| 20 |
+
# Set home to the user's home directory
|
| 21 |
+
ENV HOME=/home/user \
|
| 22 |
+
PATH=/home/user/.local/bin:$PATH
|
| 23 |
+
|
| 24 |
+
# Set the working directory to the user's home directory
|
| 25 |
+
WORKDIR $HOME/app
|
| 26 |
+
|
| 27 |
+
# Copy the current directory contents into the container at $HOME/app setting the owner to the user
|
| 28 |
+
COPY --chown=user . $HOME/app
|
| 29 |
+
|
| 30 |
+
CMD ["python", "app.py"]
|
WORKSPACE
ADDED
|
@@ -0,0 +1,154 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
########################
|
| 2 |
+
# Platform Independent #
|
| 3 |
+
########################
|
| 4 |
+
|
| 5 |
+
load("@bazel_tools//tools/build_defs/repo:git.bzl", "git_repository", "new_git_repository")
|
| 6 |
+
load("@bazel_tools//tools/build_defs/repo:http.bzl", "http_archive")
|
| 7 |
+
|
| 8 |
+
# GoogleTest/GoogleMock framework.
|
| 9 |
+
git_repository(
|
| 10 |
+
name = "com_google_googletest",
|
| 11 |
+
remote = "https://github.com/google/googletest.git",
|
| 12 |
+
tag = "release-1.10.0",
|
| 13 |
+
)
|
| 14 |
+
|
| 15 |
+
# Google benchmark.
|
| 16 |
+
http_archive(
|
| 17 |
+
name = "com_github_google_benchmark",
|
| 18 |
+
urls = ["https://github.com/google/benchmark/archive/bf585a2789e30585b4e3ce6baf11ef2750b54677.zip"], # 2020-11-26T11:14:03Z
|
| 19 |
+
strip_prefix = "benchmark-bf585a2789e30585b4e3ce6baf11ef2750b54677",
|
| 20 |
+
sha256 = "2a778d821997df7d8646c9c59b8edb9a573a6e04c534c01892a40aa524a7b68c",
|
| 21 |
+
)
|
| 22 |
+
|
| 23 |
+
# proto_library, cc_proto_library, and java_proto_library rules implicitly
|
| 24 |
+
# depend on @com_google_protobuf for protoc and proto runtimes.
|
| 25 |
+
# This statement defines the @com_google_protobuf repo.
|
| 26 |
+
git_repository(
|
| 27 |
+
name = "com_google_protobuf",
|
| 28 |
+
remote = "https://github.com/protocolbuffers/protobuf.git",
|
| 29 |
+
tag = "v3.15.4",
|
| 30 |
+
)
|
| 31 |
+
|
| 32 |
+
load("@com_google_protobuf//:protobuf_deps.bzl", "protobuf_deps")
|
| 33 |
+
protobuf_deps()
|
| 34 |
+
|
| 35 |
+
# Google Abseil Libs
|
| 36 |
+
git_repository(
|
| 37 |
+
name = "com_google_absl",
|
| 38 |
+
remote = "https://github.com/abseil/abseil-cpp.git",
|
| 39 |
+
branch = "lts_2020_09_23",
|
| 40 |
+
)
|
| 41 |
+
|
| 42 |
+
# Filesystem
|
| 43 |
+
# The new_* prefix is used because it is not a bazel project and there is
|
| 44 |
+
# no BUILD file in that repo.
|
| 45 |
+
FILESYSTEM_BUILD = """
|
| 46 |
+
cc_library(
|
| 47 |
+
name = "filesystem",
|
| 48 |
+
hdrs = glob(["include/ghc/*"]),
|
| 49 |
+
visibility = ["//visibility:public"],
|
| 50 |
+
)
|
| 51 |
+
"""
|
| 52 |
+
|
| 53 |
+
new_git_repository(
|
| 54 |
+
name = "gulrak_filesystem",
|
| 55 |
+
remote = "https://github.com/gulrak/filesystem.git",
|
| 56 |
+
tag = "v1.3.6",
|
| 57 |
+
build_file_content = FILESYSTEM_BUILD
|
| 58 |
+
)
|
| 59 |
+
|
| 60 |
+
# Audio DSP
|
| 61 |
+
git_repository(
|
| 62 |
+
name = "com_google_audio_dsp",
|
| 63 |
+
remote = "https://github.com/google/multichannel-audio-tools.git",
|
| 64 |
+
# There are no tags for this repo, we are synced to bleeding edge.
|
| 65 |
+
branch = "master",
|
| 66 |
+
repo_mapping = {
|
| 67 |
+
"@com_github_glog_glog" : "@com_google_glog"
|
| 68 |
+
}
|
| 69 |
+
)
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
http_archive(
|
| 73 |
+
name = "pybind11_bazel",
|
| 74 |
+
strip_prefix = "pybind11_bazel-72cbbf1fbc830e487e3012862b7b720001b70672",
|
| 75 |
+
urls = ["https://github.com/pybind/pybind11_bazel/archive/72cbbf1fbc830e487e3012862b7b720001b70672.zip"],
|
| 76 |
+
)
|
| 77 |
+
# We still require the pybind library.
|
| 78 |
+
http_archive(
|
| 79 |
+
name = "pybind11",
|
| 80 |
+
build_file = "@pybind11_bazel//:pybind11.BUILD",
|
| 81 |
+
strip_prefix = "pybind11-2.9.0",
|
| 82 |
+
urls = ["https://github.com/pybind/pybind11/archive/v2.9.0.tar.gz"],
|
| 83 |
+
)
|
| 84 |
+
load("@pybind11_bazel//:python_configure.bzl", "python_configure")
|
| 85 |
+
python_configure(name = "local_config_python")
|
| 86 |
+
|
| 87 |
+
|
| 88 |
+
|
| 89 |
+
# Transitive dependencies of Audio DSP.
|
| 90 |
+
http_archive(
|
| 91 |
+
name = "eigen_archive",
|
| 92 |
+
build_file = "eigen.BUILD",
|
| 93 |
+
sha256 = "f3d69ac773ecaf3602cb940040390d4e71a501bb145ca9e01ce5464cf6d4eb68",
|
| 94 |
+
strip_prefix = "eigen-eigen-049af2f56331",
|
| 95 |
+
urls = [
|
| 96 |
+
"http://mirror.tensorflow.org/bitbucket.org/eigen/eigen/get/049af2f56331.tar.gz",
|
| 97 |
+
"https://bitbucket.org/eigen/eigen/get/049af2f56331.tar.gz",
|
| 98 |
+
],
|
| 99 |
+
)
|
| 100 |
+
|
| 101 |
+
http_archive(
|
| 102 |
+
name = "fft2d",
|
| 103 |
+
build_file = "fft2d.BUILD",
|
| 104 |
+
sha256 = "ada7e99087c4ed477bfdf11413f2ba8db8a840ba9bbf8ac94f4f3972e2a7cec9",
|
| 105 |
+
urls = [
|
| 106 |
+
"http://www.kurims.kyoto-u.ac.jp/~ooura/fft2d.tgz",
|
| 107 |
+
],
|
| 108 |
+
)
|
| 109 |
+
|
| 110 |
+
# Google logging
|
| 111 |
+
git_repository(
|
| 112 |
+
name = "com_google_glog",
|
| 113 |
+
remote = "https://github.com/google/glog.git",
|
| 114 |
+
tag = "v0.5.0"
|
| 115 |
+
)
|
| 116 |
+
# Dependency for glog
|
| 117 |
+
git_repository(
|
| 118 |
+
name = "com_github_gflags_gflags",
|
| 119 |
+
remote = "https://github.com/mchinen/gflags.git",
|
| 120 |
+
branch = "android_linking_fix"
|
| 121 |
+
)
|
| 122 |
+
|
| 123 |
+
# Bazel/build rules
|
| 124 |
+
|
| 125 |
+
http_archive(
|
| 126 |
+
name = "bazel_skylib",
|
| 127 |
+
urls = [
|
| 128 |
+
"https://mirror.bazel.build/github.com/bazelbuild/bazel-skylib/releases/download/1.0.2/bazel-skylib-1.0.2.tar.gz",
|
| 129 |
+
"https://github.com/bazelbuild/bazel-skylib/releases/download/1.0.2/bazel-skylib-1.0.2.tar.gz",
|
| 130 |
+
],
|
| 131 |
+
sha256 = "97e70364e9249702246c0e9444bccdc4b847bed1eb03c5a3ece4f83dfe6abc44",
|
| 132 |
+
)
|
| 133 |
+
load("@bazel_skylib//:workspace.bzl", "bazel_skylib_workspace")
|
| 134 |
+
bazel_skylib_workspace()
|
| 135 |
+
|
| 136 |
+
http_archive(
|
| 137 |
+
name = "rules_android",
|
| 138 |
+
sha256 = "cd06d15dd8bb59926e4d65f9003bfc20f9da4b2519985c27e190cddc8b7a7806",
|
| 139 |
+
strip_prefix = "rules_android-0.1.1",
|
| 140 |
+
urls = ["https://github.com/bazelbuild/rules_android/archive/v0.1.1.zip"],
|
| 141 |
+
)
|
| 142 |
+
|
| 143 |
+
# Google Maven Repository
|
| 144 |
+
GMAVEN_TAG = "20180625-1"
|
| 145 |
+
|
| 146 |
+
http_archive(
|
| 147 |
+
name = "gmaven_rules",
|
| 148 |
+
strip_prefix = "gmaven_rules-%s" % GMAVEN_TAG,
|
| 149 |
+
url = "https://github.com/bazelbuild/gmaven_rules/archive/%s.tar.gz" % GMAVEN_TAG,
|
| 150 |
+
)
|
| 151 |
+
|
| 152 |
+
load("@gmaven_rules//:gmaven.bzl", "gmaven_rules")
|
| 153 |
+
|
| 154 |
+
gmaven_rules()
|
alphabet.txt
ADDED
|
@@ -0,0 +1,57 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
_
|
| 2 |
+
■
|
| 3 |
+
|
| 4 |
+
!
|
| 5 |
+
"
|
| 6 |
+
,
|
| 7 |
+
.
|
| 8 |
+
:
|
| 9 |
+
;
|
| 10 |
+
?
|
| 11 |
+
a
|
| 12 |
+
b
|
| 13 |
+
d
|
| 14 |
+
e
|
| 15 |
+
f
|
| 16 |
+
h
|
| 17 |
+
i
|
| 18 |
+
j
|
| 19 |
+
k
|
| 20 |
+
l
|
| 21 |
+
m
|
| 22 |
+
n
|
| 23 |
+
o
|
| 24 |
+
p
|
| 25 |
+
r
|
| 26 |
+
s
|
| 27 |
+
t
|
| 28 |
+
u
|
| 29 |
+
v
|
| 30 |
+
w
|
| 31 |
+
x
|
| 32 |
+
z
|
| 33 |
+
æ
|
| 34 |
+
ð
|
| 35 |
+
ŋ
|
| 36 |
+
ɐ
|
| 37 |
+
ɑ
|
| 38 |
+
ɔ
|
| 39 |
+
ə
|
| 40 |
+
ɚ
|
| 41 |
+
ɛ
|
| 42 |
+
ɜ
|
| 43 |
+
ɡ
|
| 44 |
+
ɪ
|
| 45 |
+
ɹ
|
| 46 |
+
ɾ
|
| 47 |
+
ʃ
|
| 48 |
+
ʊ
|
| 49 |
+
ʌ
|
| 50 |
+
ʒ
|
| 51 |
+
ʔ
|
| 52 |
+
ˈ
|
| 53 |
+
ˌ
|
| 54 |
+
ː
|
| 55 |
+
̩
|
| 56 |
+
θ
|
| 57 |
+
ᵻ
|
app (1).py
ADDED
|
@@ -0,0 +1,63 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
## build wavegru-cpp
|
| 2 |
+
# import os
|
| 3 |
+
# os.system("./bazelisk-linux-amd64 clean --expunge")
|
| 4 |
+
# os.system("./bazelisk-linux-amd64 build wavegru_mod -c opt --copt=-march=native")
|
| 5 |
+
|
| 6 |
+
# install espeak
|
| 7 |
+
import os
|
| 8 |
+
|
| 9 |
+
if not os.path.isfile("./wavegru_mod.so"):
|
| 10 |
+
os.system("bash ./build_ext.sh")
|
| 11 |
+
|
| 12 |
+
if not os.path.isdir("./espeak"):
|
| 13 |
+
os.system("bash ./install_espeak_ng.sh")
|
| 14 |
+
|
| 15 |
+
import gradio as gr
|
| 16 |
+
from inference import load_tacotron_model, load_wavegru_net, mel_to_wav, text_to_mel
|
| 17 |
+
from wavegru_cpp import extract_weight_mask, load_wavegru_cpp
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
alphabet, tacotron_net, tacotron_config = load_tacotron_model(
|
| 21 |
+
"./alphabet.txt", "./tacotron.toml", "./tacotrons_ljs_24k_v1_0300000.ckpt"
|
| 22 |
+
)
|
| 23 |
+
|
| 24 |
+
wavegru_config, wavegru_net = load_wavegru_net(
|
| 25 |
+
"./wavegru.yaml", "./wavegru_vocoder_1024_v4_1320000.ckpt"
|
| 26 |
+
)
|
| 27 |
+
|
| 28 |
+
wave_cpp_weight_mask = extract_weight_mask(wavegru_net)
|
| 29 |
+
wavecpp = load_wavegru_cpp(
|
| 30 |
+
wave_cpp_weight_mask, wavegru_config["upsample_factors"][-1]
|
| 31 |
+
)
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
def speak(text):
|
| 35 |
+
mel = text_to_mel(tacotron_net, text, alphabet, tacotron_config)
|
| 36 |
+
y = mel_to_wav(wavegru_net, wavecpp, mel, wavegru_config)
|
| 37 |
+
return 24_000, y
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
title = "WaveGRU-TTS"
|
| 41 |
+
description = "WaveGRU text-to-speech demo."
|
| 42 |
+
|
| 43 |
+
gr.Interface(
|
| 44 |
+
fn=speak,
|
| 45 |
+
inputs="text",
|
| 46 |
+
examples=[
|
| 47 |
+
"This is a test!",
|
| 48 |
+
"President Trump met with other leaders at the Group of 20 conference.",
|
| 49 |
+
"The buses aren't the problem, they actually provide a solution.",
|
| 50 |
+
"Generative adversarial network or variational auto-encoder.",
|
| 51 |
+
"Basilar membrane and otolaryngology are not auto-correlations.",
|
| 52 |
+
"There are several variations on the full gated unit, with gating done using the previous hidden state and the bias in various combinations, and a simplified form called minimal gated unit.",
|
| 53 |
+
"October arrived, spreading a damp chill over the grounds and into the castle. Madam Pomfrey, the nurse, was kept busy by a sudden spate of colds among the staff and students.",
|
| 54 |
+
"Artificial intelligence is intelligence demonstrated by machines, as opposed to natural intelligence displayed by animals including humans.",
|
| 55 |
+
'Uncle Vernon entered the kitchen as Harry was turning over the bacon. "Comb your hair!" he barked, by way of a morning greeting. About once a week, Uncle Vernon looked over the top of his newspaper and shouted that Harry needed a haircut. Harry must have had more haircuts than the rest of the boys in his class put together, but it made no difference, his hair simply grew that way - all over the place.',
|
| 56 |
+
],
|
| 57 |
+
outputs="audio",
|
| 58 |
+
title=title,
|
| 59 |
+
description=description,
|
| 60 |
+
theme="default",
|
| 61 |
+
allow_screenshot=False,
|
| 62 |
+
allow_flagging="never",
|
| 63 |
+
).launch(enable_queue=True)
|
bazelisk-linux-amd64
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:231ec5ca8115e94c75a1f4fbada1a062b48822ca04f21f26e4cb1cd8973cd458
|
| 3 |
+
size 5152768
|
build_ext.sh
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
chmod +x ./bazelisk-linux-amd64
|
| 2 |
+
USE_BAZEL_VERSION=5.0.0 ./bazelisk-linux-amd64 build wavegru_mod -c opt --copt=-march=native
|
| 3 |
+
cp -f bazel-bin/wavegru_mod.so .
|
extract_model.py
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import pickle
|
| 2 |
+
|
| 3 |
+
dic = pickle.load(open("./tacotrons_ljs_24k_v1_0300000.ckpt", "rb"))
|
| 4 |
+
del dic["optim_state_dict"]
|
| 5 |
+
pickle.dump(dic, open("./tacotrons_ljs_24k_v1_0300000.ckpt", "wb"))
|
gitignore
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
.venv
|
inference.py
ADDED
|
@@ -0,0 +1,91 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
|
| 3 |
+
import jax
|
| 4 |
+
import jax.numpy as jnp
|
| 5 |
+
import librosa
|
| 6 |
+
import numpy as np
|
| 7 |
+
import pax
|
| 8 |
+
|
| 9 |
+
from text import english_cleaners
|
| 10 |
+
from utils import (
|
| 11 |
+
create_tacotron_model,
|
| 12 |
+
load_tacotron_ckpt,
|
| 13 |
+
load_tacotron_config,
|
| 14 |
+
load_wavegru_ckpt,
|
| 15 |
+
load_wavegru_config,
|
| 16 |
+
)
|
| 17 |
+
from wavegru import WaveGRU
|
| 18 |
+
|
| 19 |
+
os.environ["PHONEMIZER_ESPEAK_LIBRARY"] = "./espeak/usr/lib/libespeak-ng.so.1.1.51"
|
| 20 |
+
from phonemizer.backend import EspeakBackend
|
| 21 |
+
|
| 22 |
+
backend = EspeakBackend("en-us", preserve_punctuation=True, with_stress=True)
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
def load_tacotron_model(alphabet_file, config_file, model_file):
|
| 26 |
+
"""load tacotron model to memory"""
|
| 27 |
+
with open(alphabet_file, "r", encoding="utf-8") as f:
|
| 28 |
+
alphabet = f.read().split("\n")
|
| 29 |
+
|
| 30 |
+
config = load_tacotron_config(config_file)
|
| 31 |
+
net = create_tacotron_model(config)
|
| 32 |
+
_, net, _ = load_tacotron_ckpt(net, None, model_file)
|
| 33 |
+
net = net.eval()
|
| 34 |
+
net = jax.device_put(net)
|
| 35 |
+
return alphabet, net, config
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
tacotron_inference_fn = pax.pure(lambda net, text: net.inference(text, max_len=2400))
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
def text_to_mel(net, text, alphabet, config):
|
| 42 |
+
"""convert text to mel spectrogram"""
|
| 43 |
+
text = english_cleaners(text)
|
| 44 |
+
text = backend.phonemize([text], strip=True)[0]
|
| 45 |
+
text = text + config["END_CHARACTER"]
|
| 46 |
+
text = text + config["PAD"] * (100 - (len(text) % 100))
|
| 47 |
+
tokens = []
|
| 48 |
+
for c in text:
|
| 49 |
+
if c in alphabet:
|
| 50 |
+
tokens.append(alphabet.index(c))
|
| 51 |
+
tokens = jnp.array(tokens, dtype=jnp.int32)
|
| 52 |
+
mel = tacotron_inference_fn(net, tokens[None])
|
| 53 |
+
return mel
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
def load_wavegru_net(config_file, model_file):
|
| 57 |
+
"""load wavegru to memory"""
|
| 58 |
+
config = load_wavegru_config(config_file)
|
| 59 |
+
net = WaveGRU(
|
| 60 |
+
mel_dim=config["mel_dim"],
|
| 61 |
+
rnn_dim=config["rnn_dim"],
|
| 62 |
+
upsample_factors=config["upsample_factors"],
|
| 63 |
+
has_linear_output=True,
|
| 64 |
+
)
|
| 65 |
+
_, net, _ = load_wavegru_ckpt(net, None, model_file)
|
| 66 |
+
net = net.eval()
|
| 67 |
+
net = jax.device_put(net)
|
| 68 |
+
return config, net
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
wavegru_inference = pax.pure(lambda net, mel: net.inference(mel, no_gru=True))
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
def mel_to_wav(net, netcpp, mel, config):
|
| 75 |
+
"""convert mel to wav"""
|
| 76 |
+
if len(mel.shape) == 2:
|
| 77 |
+
mel = mel[None]
|
| 78 |
+
pad = config["num_pad_frames"] // 2 + 2
|
| 79 |
+
mel = np.pad(mel, [(0, 0), (pad, pad), (0, 0)], mode="edge")
|
| 80 |
+
ft = wavegru_inference(net, mel)
|
| 81 |
+
ft = jax.device_get(ft[0])
|
| 82 |
+
wav = netcpp.inference(ft, 1.0)
|
| 83 |
+
wav = np.array(wav)
|
| 84 |
+
wav = librosa.mu_expand(wav - 127, mu=255)
|
| 85 |
+
wav = librosa.effects.deemphasis(wav, coef=0.86)
|
| 86 |
+
wav = wav * 2.0
|
| 87 |
+
wav = wav / max(1.0, np.max(np.abs(wav)))
|
| 88 |
+
wav = wav * 2**15
|
| 89 |
+
wav = np.clip(wav, a_min=-(2**15), a_max=(2**15) - 1)
|
| 90 |
+
wav = wav.astype(np.int16)
|
| 91 |
+
return wav
|
install_espeak_ng.sh
ADDED
|
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
(
|
| 2 |
+
rm -rf espeak
|
| 3 |
+
mkdir -p espeak
|
| 4 |
+
cd espeak
|
| 5 |
+
wget https://github.com/espeak-ng/espeak-ng/archive/refs/tags/1.51.zip
|
| 6 |
+
unzip -qq 1.51.zip
|
| 7 |
+
cd espeak-ng-1.51
|
| 8 |
+
./autogen.sh
|
| 9 |
+
./configure --prefix=`pwd`/../usr
|
| 10 |
+
make
|
| 11 |
+
make install
|
| 12 |
+
)
|
mynumbers.py
ADDED
|
@@ -0,0 +1,73 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
""" from https://github.com/keithito/tacotron """
|
| 2 |
+
|
| 3 |
+
import inflect
|
| 4 |
+
import re
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
_inflect = inflect.engine()
|
| 8 |
+
_comma_number_re = re.compile(r"([0-9][0-9\,]+[0-9])")
|
| 9 |
+
_decimal_number_re = re.compile(r"([0-9]+\.[0-9]+)")
|
| 10 |
+
_pounds_re = re.compile(r"£([0-9\,]*[0-9]+)")
|
| 11 |
+
_dollars_re = re.compile(r"\$([0-9\.\,]*[0-9]+)")
|
| 12 |
+
_ordinal_re = re.compile(r"[0-9]+(st|nd|rd|th)")
|
| 13 |
+
_number_re = re.compile(r"[0-9]+")
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
def _remove_commas(m):
|
| 17 |
+
return m.group(1).replace(",", "")
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
def _expand_decimal_point(m):
|
| 21 |
+
return m.group(1).replace(".", " point ")
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
def _expand_dollars(m):
|
| 25 |
+
match = m.group(1)
|
| 26 |
+
parts = match.split(".")
|
| 27 |
+
if len(parts) > 2:
|
| 28 |
+
return match + " dollars" # Unexpected format
|
| 29 |
+
dollars = int(parts[0]) if parts[0] else 0
|
| 30 |
+
cents = int(parts[1]) if len(parts) > 1 and parts[1] else 0
|
| 31 |
+
if dollars and cents:
|
| 32 |
+
dollar_unit = "dollar" if dollars == 1 else "dollars"
|
| 33 |
+
cent_unit = "cent" if cents == 1 else "cents"
|
| 34 |
+
return "%s %s, %s %s" % (dollars, dollar_unit, cents, cent_unit)
|
| 35 |
+
elif dollars:
|
| 36 |
+
dollar_unit = "dollar" if dollars == 1 else "dollars"
|
| 37 |
+
return "%s %s" % (dollars, dollar_unit)
|
| 38 |
+
elif cents:
|
| 39 |
+
cent_unit = "cent" if cents == 1 else "cents"
|
| 40 |
+
return "%s %s" % (cents, cent_unit)
|
| 41 |
+
else:
|
| 42 |
+
return "zero dollars"
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
def _expand_ordinal(m):
|
| 46 |
+
return _inflect.number_to_words(m.group(0))
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
def _expand_number(m):
|
| 50 |
+
num = int(m.group(0))
|
| 51 |
+
if num > 1000 and num < 3000:
|
| 52 |
+
if num == 2000:
|
| 53 |
+
return "two thousand"
|
| 54 |
+
elif num > 2000 and num < 2010:
|
| 55 |
+
return "two thousand " + _inflect.number_to_words(num % 100)
|
| 56 |
+
elif num % 100 == 0:
|
| 57 |
+
return _inflect.number_to_words(num // 100) + " hundred"
|
| 58 |
+
else:
|
| 59 |
+
return _inflect.number_to_words(
|
| 60 |
+
num, andword="", zero="oh", group=2
|
| 61 |
+
).replace(", ", " ")
|
| 62 |
+
else:
|
| 63 |
+
return _inflect.number_to_words(num, andword="")
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
def normalize_numbers(text):
|
| 67 |
+
text = re.sub(_comma_number_re, _remove_commas, text)
|
| 68 |
+
text = re.sub(_pounds_re, r"\1 pounds", text)
|
| 69 |
+
text = re.sub(_dollars_re, _expand_dollars, text)
|
| 70 |
+
text = re.sub(_decimal_number_re, _expand_decimal_point, text)
|
| 71 |
+
text = re.sub(_ordinal_re, _expand_ordinal, text)
|
| 72 |
+
text = re.sub(_number_re, _expand_number, text)
|
| 73 |
+
return text
|
packages.txt
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
libsndfile1-dev
|
| 2 |
+
make
|
| 3 |
+
autoconf
|
| 4 |
+
automake
|
| 5 |
+
libtool
|
| 6 |
+
gcc
|
| 7 |
+
pkg-config
|
pooch.py
ADDED
|
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
def os_cache(x):
|
| 2 |
+
return x
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
def create(*args, **kwargs):
|
| 6 |
+
class T:
|
| 7 |
+
def load_registry(self, *args, **kwargs):
|
| 8 |
+
return None
|
| 9 |
+
|
| 10 |
+
return T()
|
requirements (1).txt
ADDED
|
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
inflect
|
| 2 |
+
jax
|
| 3 |
+
jaxlib
|
| 4 |
+
jinja2
|
| 5 |
+
librosa
|
| 6 |
+
numpy
|
| 7 |
+
pax3
|
| 8 |
+
pyyaml
|
| 9 |
+
toml
|
| 10 |
+
unidecode
|
| 11 |
+
phonemizer
|
| 12 |
+
gradio==3.42.0
|
tacotron.py
ADDED
|
@@ -0,0 +1,451 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Tacotron + stepwise monotonic attention
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
import jax
|
| 6 |
+
import jax.numpy as jnp
|
| 7 |
+
import pax
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
def conv_block(in_ft, out_ft, kernel_size, activation_fn, use_dropout):
|
| 11 |
+
"""
|
| 12 |
+
Conv >> LayerNorm >> activation >> Dropout
|
| 13 |
+
"""
|
| 14 |
+
f = pax.Sequential(
|
| 15 |
+
pax.Conv1D(in_ft, out_ft, kernel_size, with_bias=False),
|
| 16 |
+
pax.LayerNorm(out_ft, -1, True, True),
|
| 17 |
+
)
|
| 18 |
+
if activation_fn is not None:
|
| 19 |
+
f >>= activation_fn
|
| 20 |
+
if use_dropout:
|
| 21 |
+
f >>= pax.Dropout(0.5)
|
| 22 |
+
return f
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
class HighwayBlock(pax.Module):
|
| 26 |
+
"""
|
| 27 |
+
Highway block
|
| 28 |
+
"""
|
| 29 |
+
|
| 30 |
+
def __init__(self, dim: int) -> None:
|
| 31 |
+
super().__init__()
|
| 32 |
+
self.dim = dim
|
| 33 |
+
self.fc = pax.Linear(dim, 2 * dim)
|
| 34 |
+
|
| 35 |
+
def __call__(self, x: jnp.ndarray) -> jnp.ndarray:
|
| 36 |
+
t, h = jnp.split(self.fc(x), 2, axis=-1)
|
| 37 |
+
t = jax.nn.sigmoid(t - 1.0) # bias toward keeping x
|
| 38 |
+
h = jax.nn.relu(h)
|
| 39 |
+
x = x * (1.0 - t) + h * t
|
| 40 |
+
return x
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
class BiGRU(pax.Module):
|
| 44 |
+
"""
|
| 45 |
+
Bidirectional GRU
|
| 46 |
+
"""
|
| 47 |
+
|
| 48 |
+
def __init__(self, dim):
|
| 49 |
+
super().__init__()
|
| 50 |
+
|
| 51 |
+
self.rnn_fwd = pax.GRU(dim, dim)
|
| 52 |
+
self.rnn_bwd = pax.GRU(dim, dim)
|
| 53 |
+
|
| 54 |
+
def __call__(self, x, reset_masks):
|
| 55 |
+
N = x.shape[0]
|
| 56 |
+
x_fwd = x
|
| 57 |
+
x_bwd = jnp.flip(x, axis=1)
|
| 58 |
+
x_fwd_states = self.rnn_fwd.initial_state(N)
|
| 59 |
+
x_bwd_states = self.rnn_bwd.initial_state(N)
|
| 60 |
+
x_fwd_states, x_fwd = pax.scan(
|
| 61 |
+
self.rnn_fwd, x_fwd_states, x_fwd, time_major=False
|
| 62 |
+
)
|
| 63 |
+
|
| 64 |
+
reset_masks = jnp.flip(reset_masks, axis=1)
|
| 65 |
+
x_bwd_states0 = x_bwd_states
|
| 66 |
+
|
| 67 |
+
def rnn_reset_core(prev, inputs):
|
| 68 |
+
x, reset_mask = inputs
|
| 69 |
+
|
| 70 |
+
def reset_state(x0, xt):
|
| 71 |
+
return jnp.where(reset_mask, x0, xt)
|
| 72 |
+
|
| 73 |
+
state, _ = self.rnn_bwd(prev, x)
|
| 74 |
+
state = jax.tree_map(reset_state, x_bwd_states0, state)
|
| 75 |
+
return state, state.hidden
|
| 76 |
+
|
| 77 |
+
x_bwd_states, x_bwd = pax.scan(
|
| 78 |
+
rnn_reset_core, x_bwd_states, (x_bwd, reset_masks), time_major=False
|
| 79 |
+
)
|
| 80 |
+
x_bwd = jnp.flip(x_bwd, axis=1)
|
| 81 |
+
x = jnp.concatenate((x_fwd, x_bwd), axis=-1)
|
| 82 |
+
return x
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
class CBHG(pax.Module):
|
| 86 |
+
"""
|
| 87 |
+
Conv Bank >> Highway net >> GRU
|
| 88 |
+
"""
|
| 89 |
+
|
| 90 |
+
def __init__(self, dim):
|
| 91 |
+
super().__init__()
|
| 92 |
+
self.convs = [conv_block(dim, dim, i, jax.nn.relu, False) for i in range(1, 17)]
|
| 93 |
+
self.conv_projection_1 = conv_block(16 * dim, dim, 3, jax.nn.relu, False)
|
| 94 |
+
self.conv_projection_2 = conv_block(dim, dim, 3, None, False)
|
| 95 |
+
|
| 96 |
+
self.highway = pax.Sequential(
|
| 97 |
+
HighwayBlock(dim), HighwayBlock(dim), HighwayBlock(dim), HighwayBlock(dim)
|
| 98 |
+
)
|
| 99 |
+
self.rnn = BiGRU(dim)
|
| 100 |
+
|
| 101 |
+
def __call__(self, x, x_mask):
|
| 102 |
+
conv_input = x * x_mask
|
| 103 |
+
fts = [f(conv_input) for f in self.convs]
|
| 104 |
+
residual = jnp.concatenate(fts, axis=-1)
|
| 105 |
+
residual = pax.max_pool(residual, 2, 1, "SAME", -1)
|
| 106 |
+
residual = self.conv_projection_1(residual * x_mask)
|
| 107 |
+
residual = self.conv_projection_2(residual * x_mask)
|
| 108 |
+
x = x + residual
|
| 109 |
+
x = self.highway(x)
|
| 110 |
+
x = self.rnn(x * x_mask, reset_masks=1 - x_mask)
|
| 111 |
+
return x * x_mask
|
| 112 |
+
|
| 113 |
+
|
| 114 |
+
class PreNet(pax.Module):
|
| 115 |
+
"""
|
| 116 |
+
Linear >> relu >> dropout >> Linear >> relu >> dropout
|
| 117 |
+
"""
|
| 118 |
+
|
| 119 |
+
def __init__(self, input_dim, hidden_dim, output_dim, always_dropout=True):
|
| 120 |
+
super().__init__()
|
| 121 |
+
self.fc1 = pax.Linear(input_dim, hidden_dim)
|
| 122 |
+
self.fc2 = pax.Linear(hidden_dim, output_dim)
|
| 123 |
+
self.rng_seq = pax.RngSeq()
|
| 124 |
+
self.always_dropout = always_dropout
|
| 125 |
+
|
| 126 |
+
def __call__(self, x, k1=None, k2=None):
|
| 127 |
+
x = self.fc1(x)
|
| 128 |
+
x = jax.nn.relu(x)
|
| 129 |
+
if self.always_dropout or self.training:
|
| 130 |
+
if k1 is None:
|
| 131 |
+
k1 = self.rng_seq.next_rng_key()
|
| 132 |
+
x = pax.dropout(k1, 0.5, x)
|
| 133 |
+
x = self.fc2(x)
|
| 134 |
+
x = jax.nn.relu(x)
|
| 135 |
+
if self.always_dropout or self.training:
|
| 136 |
+
if k2 is None:
|
| 137 |
+
k2 = self.rng_seq.next_rng_key()
|
| 138 |
+
x = pax.dropout(k2, 0.5, x)
|
| 139 |
+
return x
|
| 140 |
+
|
| 141 |
+
|
| 142 |
+
class Tacotron(pax.Module):
|
| 143 |
+
"""
|
| 144 |
+
Tacotron TTS model.
|
| 145 |
+
|
| 146 |
+
It uses stepwise monotonic attention for robust attention.
|
| 147 |
+
"""
|
| 148 |
+
|
| 149 |
+
def __init__(
|
| 150 |
+
self,
|
| 151 |
+
mel_dim: int,
|
| 152 |
+
attn_bias,
|
| 153 |
+
rr,
|
| 154 |
+
max_rr,
|
| 155 |
+
mel_min,
|
| 156 |
+
sigmoid_noise,
|
| 157 |
+
pad_token,
|
| 158 |
+
prenet_dim,
|
| 159 |
+
attn_hidden_dim,
|
| 160 |
+
attn_rnn_dim,
|
| 161 |
+
rnn_dim,
|
| 162 |
+
postnet_dim,
|
| 163 |
+
text_dim,
|
| 164 |
+
):
|
| 165 |
+
"""
|
| 166 |
+
New Tacotron model
|
| 167 |
+
|
| 168 |
+
Args:
|
| 169 |
+
mel_dim (int): dimension of log mel-spectrogram features.
|
| 170 |
+
attn_bias (float): control how "slow" the attention will
|
| 171 |
+
move forward at initialization.
|
| 172 |
+
rr (int): the reduction factor.
|
| 173 |
+
Number of predicted frame at each time step. Default is 2.
|
| 174 |
+
max_rr (int): max value of rr.
|
| 175 |
+
mel_min (float): the minimum value of mel features.
|
| 176 |
+
The <go> frame is filled by `log(mel_min)` values.
|
| 177 |
+
sigmoid_noise (float): the variance of gaussian noise added
|
| 178 |
+
to attention scores in training.
|
| 179 |
+
pad_token (int): the pad value at the end of text sequences.
|
| 180 |
+
prenet_dim (int): dimension of prenet output.
|
| 181 |
+
attn_hidden_dim (int): dimension of attention hidden vectors.
|
| 182 |
+
attn_rnn_dim (int): number of cells in the attention RNN.
|
| 183 |
+
rnn_dim (int): number of cells in the decoder RNNs.
|
| 184 |
+
postnet_dim (int): number of features in the postnet convolutions.
|
| 185 |
+
text_dim (int): dimension of text embedding vectors.
|
| 186 |
+
"""
|
| 187 |
+
super().__init__()
|
| 188 |
+
self.text_dim = text_dim
|
| 189 |
+
assert rr <= max_rr
|
| 190 |
+
self.rr = rr
|
| 191 |
+
self.max_rr = max_rr
|
| 192 |
+
self.mel_dim = mel_dim
|
| 193 |
+
self.mel_min = mel_min
|
| 194 |
+
self.sigmoid_noise = sigmoid_noise
|
| 195 |
+
self.pad_token = pad_token
|
| 196 |
+
self.prenet_dim = prenet_dim
|
| 197 |
+
|
| 198 |
+
# encoder submodules
|
| 199 |
+
self.encoder_embed = pax.Embed(256, text_dim)
|
| 200 |
+
self.encoder_pre_net = PreNet(text_dim, 256, prenet_dim, always_dropout=True)
|
| 201 |
+
self.encoder_cbhg = CBHG(prenet_dim)
|
| 202 |
+
|
| 203 |
+
# random key generator
|
| 204 |
+
self.rng_seq = pax.RngSeq()
|
| 205 |
+
|
| 206 |
+
# pre-net
|
| 207 |
+
self.decoder_pre_net = PreNet(mel_dim, 256, prenet_dim, always_dropout=True)
|
| 208 |
+
|
| 209 |
+
# decoder submodules
|
| 210 |
+
self.attn_rnn = pax.LSTM(prenet_dim + prenet_dim * 2, attn_rnn_dim)
|
| 211 |
+
self.text_key_fc = pax.Linear(prenet_dim * 2, attn_hidden_dim, with_bias=True)
|
| 212 |
+
self.attn_query_fc = pax.Linear(attn_rnn_dim, attn_hidden_dim, with_bias=False)
|
| 213 |
+
|
| 214 |
+
self.attn_V = pax.Linear(attn_hidden_dim, 1, with_bias=False)
|
| 215 |
+
self.attn_V_weight_norm = jnp.array(1.0 / jnp.sqrt(attn_hidden_dim))
|
| 216 |
+
self.attn_V_bias = jnp.array(attn_bias)
|
| 217 |
+
self.attn_log = jnp.zeros((1,))
|
| 218 |
+
self.decoder_input = pax.Linear(attn_rnn_dim + 2 * prenet_dim, rnn_dim)
|
| 219 |
+
self.decoder_rnn1 = pax.LSTM(rnn_dim, rnn_dim)
|
| 220 |
+
self.decoder_rnn2 = pax.LSTM(rnn_dim, rnn_dim)
|
| 221 |
+
# mel + end-of-sequence token
|
| 222 |
+
self.output_fc = pax.Linear(rnn_dim, (mel_dim + 1) * max_rr, with_bias=True)
|
| 223 |
+
|
| 224 |
+
# post-net
|
| 225 |
+
self.post_net = pax.Sequential(
|
| 226 |
+
conv_block(mel_dim, postnet_dim, 5, jax.nn.tanh, True),
|
| 227 |
+
conv_block(postnet_dim, postnet_dim, 5, jax.nn.tanh, True),
|
| 228 |
+
conv_block(postnet_dim, postnet_dim, 5, jax.nn.tanh, True),
|
| 229 |
+
conv_block(postnet_dim, postnet_dim, 5, jax.nn.tanh, True),
|
| 230 |
+
conv_block(postnet_dim, mel_dim, 5, None, True),
|
| 231 |
+
)
|
| 232 |
+
|
| 233 |
+
parameters = pax.parameters_method("attn_V_weight_norm", "attn_V_bias")
|
| 234 |
+
|
| 235 |
+
def encode_text(self, text: jnp.ndarray) -> jnp.ndarray:
|
| 236 |
+
"""
|
| 237 |
+
Encode text to a sequence of real vectors
|
| 238 |
+
"""
|
| 239 |
+
N, L = text.shape
|
| 240 |
+
text_mask = (text != self.pad_token)[..., None]
|
| 241 |
+
x = self.encoder_embed(text)
|
| 242 |
+
x = self.encoder_pre_net(x)
|
| 243 |
+
x = self.encoder_cbhg(x, text_mask)
|
| 244 |
+
return x
|
| 245 |
+
|
| 246 |
+
def go_frame(self, batch_size: int) -> jnp.ndarray:
|
| 247 |
+
"""
|
| 248 |
+
return the go frame
|
| 249 |
+
"""
|
| 250 |
+
return jnp.ones((batch_size, self.mel_dim)) * jnp.log(self.mel_min)
|
| 251 |
+
|
| 252 |
+
def decoder_initial_state(self, N: int, L: int):
|
| 253 |
+
"""
|
| 254 |
+
setup decoder initial state
|
| 255 |
+
"""
|
| 256 |
+
attn_context = jnp.zeros((N, self.prenet_dim * 2))
|
| 257 |
+
attn_pr = jax.nn.one_hot(
|
| 258 |
+
jnp.zeros((N,), dtype=jnp.int32), num_classes=L, axis=-1
|
| 259 |
+
)
|
| 260 |
+
|
| 261 |
+
attn_state = (self.attn_rnn.initial_state(N), attn_context, attn_pr)
|
| 262 |
+
decoder_rnn_states = (
|
| 263 |
+
self.decoder_rnn1.initial_state(N),
|
| 264 |
+
self.decoder_rnn2.initial_state(N),
|
| 265 |
+
)
|
| 266 |
+
return attn_state, decoder_rnn_states
|
| 267 |
+
|
| 268 |
+
def monotonic_attention(self, prev_state, inputs, envs):
|
| 269 |
+
"""
|
| 270 |
+
Stepwise monotonic attention
|
| 271 |
+
"""
|
| 272 |
+
attn_rnn_state, attn_context, prev_attn_pr = prev_state
|
| 273 |
+
x, attn_rng_key = inputs
|
| 274 |
+
text, text_key = envs
|
| 275 |
+
attn_rnn_input = jnp.concatenate((x, attn_context), axis=-1)
|
| 276 |
+
attn_rnn_state, attn_rnn_output = self.attn_rnn(attn_rnn_state, attn_rnn_input)
|
| 277 |
+
attn_query_input = attn_rnn_output
|
| 278 |
+
attn_query = self.attn_query_fc(attn_query_input)
|
| 279 |
+
attn_hidden = jnp.tanh(attn_query[:, None, :] + text_key)
|
| 280 |
+
score = self.attn_V(attn_hidden)
|
| 281 |
+
score = jnp.squeeze(score, axis=-1)
|
| 282 |
+
weight_norm = jnp.linalg.norm(self.attn_V.weight)
|
| 283 |
+
score = score * (self.attn_V_weight_norm / weight_norm)
|
| 284 |
+
score = score + self.attn_V_bias
|
| 285 |
+
noise = jax.random.normal(attn_rng_key, score.shape) * self.sigmoid_noise
|
| 286 |
+
pr_stay = jax.nn.sigmoid(score + noise)
|
| 287 |
+
pr_move = 1.0 - pr_stay
|
| 288 |
+
pr_new_location = pr_move * prev_attn_pr
|
| 289 |
+
pr_new_location = jnp.pad(
|
| 290 |
+
pr_new_location[:, :-1], ((0, 0), (1, 0)), constant_values=0
|
| 291 |
+
)
|
| 292 |
+
attn_pr = pr_stay * prev_attn_pr + pr_new_location
|
| 293 |
+
attn_context = jnp.einsum("NL,NLD->ND", attn_pr, text)
|
| 294 |
+
new_state = (attn_rnn_state, attn_context, attn_pr)
|
| 295 |
+
return new_state, attn_rnn_output
|
| 296 |
+
|
| 297 |
+
def zoneout_lstm(self, lstm_core, rng_key, zoneout_pr=0.1):
|
| 298 |
+
"""
|
| 299 |
+
Return a zoneout lstm core.
|
| 300 |
+
|
| 301 |
+
It will zoneout the new hidden states and keep the new cell states unchanged.
|
| 302 |
+
"""
|
| 303 |
+
|
| 304 |
+
def core(state, x):
|
| 305 |
+
new_state, _ = lstm_core(state, x)
|
| 306 |
+
h_old = state.hidden
|
| 307 |
+
h_new = new_state.hidden
|
| 308 |
+
mask = jax.random.bernoulli(rng_key, zoneout_pr, h_old.shape)
|
| 309 |
+
h_new = h_old * mask + h_new * (1.0 - mask)
|
| 310 |
+
return pax.LSTMState(h_new, new_state.cell), h_new
|
| 311 |
+
|
| 312 |
+
return core
|
| 313 |
+
|
| 314 |
+
def decoder_step(
|
| 315 |
+
self,
|
| 316 |
+
attn_state,
|
| 317 |
+
decoder_rnn_states,
|
| 318 |
+
rng_key,
|
| 319 |
+
mel,
|
| 320 |
+
text,
|
| 321 |
+
text_key,
|
| 322 |
+
call_pre_net=False,
|
| 323 |
+
):
|
| 324 |
+
"""
|
| 325 |
+
One decoder step
|
| 326 |
+
"""
|
| 327 |
+
if call_pre_net:
|
| 328 |
+
k1, k2, zk1, zk2, rng_key, rng_key_next = jax.random.split(rng_key, 6)
|
| 329 |
+
mel = self.decoder_pre_net(mel, k1, k2)
|
| 330 |
+
else:
|
| 331 |
+
zk1, zk2, rng_key, rng_key_next = jax.random.split(rng_key, 4)
|
| 332 |
+
attn_inputs = (mel, rng_key)
|
| 333 |
+
attn_envs = (text, text_key)
|
| 334 |
+
attn_state, attn_rnn_output = self.monotonic_attention(
|
| 335 |
+
attn_state, attn_inputs, attn_envs
|
| 336 |
+
)
|
| 337 |
+
(_, attn_context, attn_pr) = attn_state
|
| 338 |
+
(decoder_rnn_state1, decoder_rnn_state2) = decoder_rnn_states
|
| 339 |
+
decoder_rnn1_input = jnp.concatenate((attn_rnn_output, attn_context), axis=-1)
|
| 340 |
+
decoder_rnn1_input = self.decoder_input(decoder_rnn1_input)
|
| 341 |
+
decoder_rnn1 = self.zoneout_lstm(self.decoder_rnn1, zk1)
|
| 342 |
+
decoder_rnn_state1, decoder_rnn_output1 = decoder_rnn1(
|
| 343 |
+
decoder_rnn_state1, decoder_rnn1_input
|
| 344 |
+
)
|
| 345 |
+
decoder_rnn2_input = decoder_rnn1_input + decoder_rnn_output1
|
| 346 |
+
decoder_rnn2 = self.zoneout_lstm(self.decoder_rnn2, zk2)
|
| 347 |
+
decoder_rnn_state2, decoder_rnn_output2 = decoder_rnn2(
|
| 348 |
+
decoder_rnn_state2, decoder_rnn2_input
|
| 349 |
+
)
|
| 350 |
+
x = decoder_rnn1_input + decoder_rnn_output1 + decoder_rnn_output2
|
| 351 |
+
decoder_rnn_states = (decoder_rnn_state1, decoder_rnn_state2)
|
| 352 |
+
return attn_state, decoder_rnn_states, rng_key_next, x, attn_pr[0]
|
| 353 |
+
|
| 354 |
+
@jax.jit
|
| 355 |
+
def inference_step(
|
| 356 |
+
self, attn_state, decoder_rnn_states, rng_key, mel, text, text_key
|
| 357 |
+
):
|
| 358 |
+
"""one inference step"""
|
| 359 |
+
attn_state, decoder_rnn_states, rng_key, x, _ = self.decoder_step(
|
| 360 |
+
attn_state,
|
| 361 |
+
decoder_rnn_states,
|
| 362 |
+
rng_key,
|
| 363 |
+
mel,
|
| 364 |
+
text,
|
| 365 |
+
text_key,
|
| 366 |
+
call_pre_net=True,
|
| 367 |
+
)
|
| 368 |
+
x = self.output_fc(x)
|
| 369 |
+
N, D2 = x.shape
|
| 370 |
+
x = jnp.reshape(x, (N, self.max_rr, D2 // self.max_rr))
|
| 371 |
+
x = x[:, : self.rr, :]
|
| 372 |
+
x = jnp.reshape(x, (N, self.rr, -1))
|
| 373 |
+
mel = x[..., :-1]
|
| 374 |
+
eos_logit = x[..., -1]
|
| 375 |
+
eos_pr = jax.nn.sigmoid(eos_logit[0, -1])
|
| 376 |
+
eos_pr = jnp.where(eos_pr < 0.1, 0.0, eos_pr)
|
| 377 |
+
rng_key, eos_rng_key = jax.random.split(rng_key)
|
| 378 |
+
eos = jax.random.bernoulli(eos_rng_key, p=eos_pr)
|
| 379 |
+
return attn_state, decoder_rnn_states, rng_key, (mel, eos)
|
| 380 |
+
|
| 381 |
+
def inference(self, text, seed=42, max_len=1000):
|
| 382 |
+
"""
|
| 383 |
+
text to mel
|
| 384 |
+
"""
|
| 385 |
+
text = self.encode_text(text)
|
| 386 |
+
text_key = self.text_key_fc(text)
|
| 387 |
+
N, L, D = text.shape
|
| 388 |
+
assert N == 1
|
| 389 |
+
mel = self.go_frame(N)
|
| 390 |
+
|
| 391 |
+
attn_state, decoder_rnn_states = self.decoder_initial_state(N, L)
|
| 392 |
+
rng_key = jax.random.PRNGKey(seed)
|
| 393 |
+
mels = []
|
| 394 |
+
count = 0
|
| 395 |
+
while True:
|
| 396 |
+
count = count + 1
|
| 397 |
+
attn_state, decoder_rnn_states, rng_key, (mel, eos) = self.inference_step(
|
| 398 |
+
attn_state, decoder_rnn_states, rng_key, mel, text, text_key
|
| 399 |
+
)
|
| 400 |
+
mels.append(mel)
|
| 401 |
+
if eos.item() or count > max_len:
|
| 402 |
+
break
|
| 403 |
+
|
| 404 |
+
mel = mel[:, -1, :]
|
| 405 |
+
|
| 406 |
+
mels = jnp.concatenate(mels, axis=1)
|
| 407 |
+
mel = mel + self.post_net(mel)
|
| 408 |
+
return mels
|
| 409 |
+
|
| 410 |
+
def decode(self, mel, text):
|
| 411 |
+
"""
|
| 412 |
+
Attention mechanism + Decoder
|
| 413 |
+
"""
|
| 414 |
+
text_key = self.text_key_fc(text)
|
| 415 |
+
|
| 416 |
+
def scan_fn(prev_states, inputs):
|
| 417 |
+
attn_state, decoder_rnn_states = prev_states
|
| 418 |
+
x, rng_key = inputs
|
| 419 |
+
attn_state, decoder_rnn_states, _, output, attn_pr = self.decoder_step(
|
| 420 |
+
attn_state, decoder_rnn_states, rng_key, x, text, text_key
|
| 421 |
+
)
|
| 422 |
+
states = (attn_state, decoder_rnn_states)
|
| 423 |
+
return states, (output, attn_pr)
|
| 424 |
+
|
| 425 |
+
N, L, D = text.shape
|
| 426 |
+
decoder_states = self.decoder_initial_state(N, L)
|
| 427 |
+
rng_keys = self.rng_seq.next_rng_key(mel.shape[1])
|
| 428 |
+
rng_keys = jnp.stack(rng_keys, axis=1)
|
| 429 |
+
decoder_states, (x, attn_log) = pax.scan(
|
| 430 |
+
scan_fn,
|
| 431 |
+
decoder_states,
|
| 432 |
+
(mel, rng_keys),
|
| 433 |
+
time_major=False,
|
| 434 |
+
)
|
| 435 |
+
self.attn_log = attn_log
|
| 436 |
+
del decoder_states
|
| 437 |
+
x = self.output_fc(x)
|
| 438 |
+
|
| 439 |
+
N, T2, D2 = x.shape
|
| 440 |
+
x = jnp.reshape(x, (N, T2, self.max_rr, D2 // self.max_rr))
|
| 441 |
+
x = x[:, :, : self.rr, :]
|
| 442 |
+
x = jnp.reshape(x, (N, T2 * self.rr, -1))
|
| 443 |
+
mel = x[..., :-1]
|
| 444 |
+
eos = x[..., -1]
|
| 445 |
+
return mel, eos
|
| 446 |
+
|
| 447 |
+
def __call__(self, mel: jnp.ndarray, text: jnp.ndarray):
|
| 448 |
+
text = self.encode_text(text)
|
| 449 |
+
mel = self.decoder_pre_net(mel)
|
| 450 |
+
mel, eos = self.decode(mel, text)
|
| 451 |
+
return mel, mel + self.post_net(mel), eos
|
tacotron.toml
ADDED
|
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
[tacotron]
|
| 2 |
+
|
| 3 |
+
# training
|
| 4 |
+
BATCH_SIZE = 64
|
| 5 |
+
LR=1024e-6 # learning rate
|
| 6 |
+
MODEL_PREFIX = "mono_tts_cbhg_small"
|
| 7 |
+
LOG_DIR = "./logs"
|
| 8 |
+
CKPT_DIR = "./ckpts"
|
| 9 |
+
USE_MP = false # use mixed-precision training
|
| 10 |
+
|
| 11 |
+
# data
|
| 12 |
+
TF_DATA_DIR = "./tf_data" # tensorflow data directory
|
| 13 |
+
TF_GTA_DATA_DIR = "./tf_gta_data" # tf gta data directory
|
| 14 |
+
SAMPLE_RATE = 24000 # convert to this sample rate if needed
|
| 15 |
+
MEL_DIM = 80 # the dimension of melspectrogram features
|
| 16 |
+
MEL_MIN = 1e-5
|
| 17 |
+
PAD = "_" # padding character
|
| 18 |
+
PAD_TOKEN = 0
|
| 19 |
+
END_CHARACTER = "■" # to signal the end of the transcript
|
| 20 |
+
TEST_DATA_SIZE = 1024
|
| 21 |
+
|
| 22 |
+
# model
|
| 23 |
+
RR = 1 # reduction factor
|
| 24 |
+
MAX_RR=2
|
| 25 |
+
ATTN_BIAS = 0.0 # control how slow the attention moves forward
|
| 26 |
+
SIGMOID_NOISE = 2.0
|
| 27 |
+
PRENET_DIM = 128
|
| 28 |
+
TEXT_DIM = 256
|
| 29 |
+
RNN_DIM = 512
|
| 30 |
+
ATTN_RNN_DIM = 256
|
| 31 |
+
ATTN_HIDDEN_DIM = 128
|
| 32 |
+
POSTNET_DIM = 512
|
tacotrons_ljs_24k_v1_0300000.ckpt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:73368d52c6c519682b73fa9676fe2eaed1712aa559026034ab6a36b2bfd8f8c0
|
| 3 |
+
size 53561547
|
text.py
ADDED
|
@@ -0,0 +1,92 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
""" from https://github.com/keithito/tacotron """
|
| 2 |
+
|
| 3 |
+
"""
|
| 4 |
+
Cleaners are transformations that run over the input text at both training and eval time.
|
| 5 |
+
|
| 6 |
+
Cleaners can be selected by passing a comma-delimited list of cleaner names as the "cleaners"
|
| 7 |
+
hyperparameter. Some cleaners are English-specific. You'll typically want to use:
|
| 8 |
+
1. "english_cleaners" for English text
|
| 9 |
+
2. "transliteration_cleaners" for non-English text that can be transliterated to ASCII using
|
| 10 |
+
the Unidecode library (https://pypi.python.org/pypi/Unidecode)
|
| 11 |
+
3. "basic_cleaners" if you do not want to transliterate (in this case, you should also update
|
| 12 |
+
the symbols in symbols.py to match your data).
|
| 13 |
+
"""
|
| 14 |
+
|
| 15 |
+
import re
|
| 16 |
+
from mynumbers import normalize_numbers
|
| 17 |
+
from unidecode import unidecode
|
| 18 |
+
|
| 19 |
+
# Regular expression matching whitespace:
|
| 20 |
+
_whitespace_re = re.compile(r"\s+")
|
| 21 |
+
|
| 22 |
+
# List of (regular expression, replacement) pairs for abbreviations:
|
| 23 |
+
_abbreviations = [
|
| 24 |
+
(re.compile("\\b%s\\." % x[0], re.IGNORECASE), x[1])
|
| 25 |
+
for x in [
|
| 26 |
+
("mrs", "misess"),
|
| 27 |
+
("mr", "mister"),
|
| 28 |
+
("dr", "doctor"),
|
| 29 |
+
("st", "saint"),
|
| 30 |
+
("co", "company"),
|
| 31 |
+
("jr", "junior"),
|
| 32 |
+
("maj", "major"),
|
| 33 |
+
("gen", "general"),
|
| 34 |
+
("drs", "doctors"),
|
| 35 |
+
("rev", "reverend"),
|
| 36 |
+
("lt", "lieutenant"),
|
| 37 |
+
("hon", "honorable"),
|
| 38 |
+
("sgt", "sergeant"),
|
| 39 |
+
("capt", "captain"),
|
| 40 |
+
("esq", "esquire"),
|
| 41 |
+
("ltd", "limited"),
|
| 42 |
+
("col", "colonel"),
|
| 43 |
+
("ft", "fort"),
|
| 44 |
+
]
|
| 45 |
+
]
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
def expand_abbreviations(text):
|
| 49 |
+
for regex, replacement in _abbreviations:
|
| 50 |
+
text = re.sub(regex, replacement, text)
|
| 51 |
+
return text
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
def expand_numbers(text):
|
| 55 |
+
return normalize_numbers(text)
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
def lowercase(text):
|
| 59 |
+
return text.lower()
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
def collapse_whitespace(text):
|
| 63 |
+
return re.sub(_whitespace_re, " ", text)
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
def convert_to_ascii(text):
|
| 67 |
+
return unidecode(text)
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
def basic_cleaners(text):
|
| 71 |
+
"""Basic pipeline that lowercases and collapses whitespace without transliteration."""
|
| 72 |
+
text = lowercase(text)
|
| 73 |
+
text = collapse_whitespace(text)
|
| 74 |
+
return text
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
def transliteration_cleaners(text):
|
| 78 |
+
"""Pipeline for non-English text that transliterates to ASCII."""
|
| 79 |
+
text = convert_to_ascii(text)
|
| 80 |
+
text = lowercase(text)
|
| 81 |
+
text = collapse_whitespace(text)
|
| 82 |
+
return text
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
def english_cleaners(text):
|
| 86 |
+
"""Pipeline for English text, including number and abbreviation expansion."""
|
| 87 |
+
text = convert_to_ascii(text)
|
| 88 |
+
text = lowercase(text)
|
| 89 |
+
text = expand_numbers(text)
|
| 90 |
+
text = expand_abbreviations(text)
|
| 91 |
+
text = collapse_whitespace(text)
|
| 92 |
+
return text
|
utils.py
ADDED
|
@@ -0,0 +1,74 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Utility functions
|
| 3 |
+
"""
|
| 4 |
+
import pickle
|
| 5 |
+
from pathlib import Path
|
| 6 |
+
|
| 7 |
+
import pax
|
| 8 |
+
import toml
|
| 9 |
+
import yaml
|
| 10 |
+
|
| 11 |
+
from tacotron import Tacotron
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
def load_tacotron_config(config_file=Path("tacotron.toml")):
|
| 15 |
+
"""
|
| 16 |
+
Load the project configurations
|
| 17 |
+
"""
|
| 18 |
+
return toml.load(config_file)["tacotron"]
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
def load_tacotron_ckpt(net: pax.Module, optim: pax.Module, path):
|
| 22 |
+
"""
|
| 23 |
+
load checkpoint from disk
|
| 24 |
+
"""
|
| 25 |
+
with open(path, "rb") as f:
|
| 26 |
+
dic = pickle.load(f)
|
| 27 |
+
if net is not None:
|
| 28 |
+
net = net.load_state_dict(dic["model_state_dict"])
|
| 29 |
+
if optim is not None:
|
| 30 |
+
optim = optim.load_state_dict(dic["optim_state_dict"])
|
| 31 |
+
return dic["step"], net, optim
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
def create_tacotron_model(config):
|
| 35 |
+
"""
|
| 36 |
+
return a random initialized Tacotron model
|
| 37 |
+
"""
|
| 38 |
+
return Tacotron(
|
| 39 |
+
mel_dim=config["MEL_DIM"],
|
| 40 |
+
attn_bias=config["ATTN_BIAS"],
|
| 41 |
+
rr=config["RR"],
|
| 42 |
+
max_rr=config["MAX_RR"],
|
| 43 |
+
mel_min=config["MEL_MIN"],
|
| 44 |
+
sigmoid_noise=config["SIGMOID_NOISE"],
|
| 45 |
+
pad_token=config["PAD_TOKEN"],
|
| 46 |
+
prenet_dim=config["PRENET_DIM"],
|
| 47 |
+
attn_hidden_dim=config["ATTN_HIDDEN_DIM"],
|
| 48 |
+
attn_rnn_dim=config["ATTN_RNN_DIM"],
|
| 49 |
+
rnn_dim=config["RNN_DIM"],
|
| 50 |
+
postnet_dim=config["POSTNET_DIM"],
|
| 51 |
+
text_dim=config["TEXT_DIM"],
|
| 52 |
+
)
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
def load_wavegru_config(config_file):
|
| 56 |
+
"""
|
| 57 |
+
Load project configurations
|
| 58 |
+
"""
|
| 59 |
+
with open(config_file, "r", encoding="utf-8") as f:
|
| 60 |
+
return yaml.safe_load(f)
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
def load_wavegru_ckpt(net, optim, ckpt_file):
|
| 64 |
+
"""
|
| 65 |
+
load training checkpoint from file
|
| 66 |
+
"""
|
| 67 |
+
with open(ckpt_file, "rb") as f:
|
| 68 |
+
dic = pickle.load(f)
|
| 69 |
+
|
| 70 |
+
if net is not None:
|
| 71 |
+
net = net.load_state_dict(dic["net_state_dict"])
|
| 72 |
+
if optim is not None:
|
| 73 |
+
optim = optim.load_state_dict(dic["optim_state_dict"])
|
| 74 |
+
return dic["step"], net, optim
|
wavegru.py
ADDED
|
@@ -0,0 +1,300 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
WaveGRU model: melspectrogram => mu-law encoded waveform
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
from typing import Tuple
|
| 6 |
+
|
| 7 |
+
import jax
|
| 8 |
+
import jax.numpy as jnp
|
| 9 |
+
import pax
|
| 10 |
+
from pax import GRUState
|
| 11 |
+
from tqdm.cli import tqdm
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
class ReLU(pax.Module):
|
| 15 |
+
def __call__(self, x):
|
| 16 |
+
return jax.nn.relu(x)
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
def dilated_residual_conv_block(dim, kernel, stride, dilation):
|
| 20 |
+
"""
|
| 21 |
+
Use dilated convs to enlarge the receptive field
|
| 22 |
+
"""
|
| 23 |
+
return pax.Sequential(
|
| 24 |
+
pax.Conv1D(dim, dim, kernel, stride, dilation, "VALID", with_bias=False),
|
| 25 |
+
pax.LayerNorm(dim, -1, True, True),
|
| 26 |
+
ReLU(),
|
| 27 |
+
pax.Conv1D(dim, dim, 1, 1, 1, "VALID", with_bias=False),
|
| 28 |
+
pax.LayerNorm(dim, -1, True, True),
|
| 29 |
+
ReLU(),
|
| 30 |
+
)
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
def tile_1d(x, factor):
|
| 34 |
+
"""
|
| 35 |
+
Tile tensor of shape N, L, D into N, L*factor, D
|
| 36 |
+
"""
|
| 37 |
+
N, L, D = x.shape
|
| 38 |
+
x = x[:, :, None, :]
|
| 39 |
+
x = jnp.tile(x, (1, 1, factor, 1))
|
| 40 |
+
x = jnp.reshape(x, (N, L * factor, D))
|
| 41 |
+
return x
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
def up_block(in_dim, out_dim, factor, relu=True):
|
| 45 |
+
"""
|
| 46 |
+
Tile >> Conv >> BatchNorm >> ReLU
|
| 47 |
+
"""
|
| 48 |
+
f = pax.Sequential(
|
| 49 |
+
lambda x: tile_1d(x, factor),
|
| 50 |
+
pax.Conv1D(
|
| 51 |
+
in_dim, out_dim, 2 * factor, stride=1, padding="VALID", with_bias=False
|
| 52 |
+
),
|
| 53 |
+
pax.LayerNorm(out_dim, -1, True, True),
|
| 54 |
+
)
|
| 55 |
+
if relu:
|
| 56 |
+
f >>= ReLU()
|
| 57 |
+
return f
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
class Upsample(pax.Module):
|
| 61 |
+
"""
|
| 62 |
+
Upsample melspectrogram to match raw audio sample rate.
|
| 63 |
+
"""
|
| 64 |
+
|
| 65 |
+
def __init__(
|
| 66 |
+
self, input_dim, hidden_dim, rnn_dim, upsample_factors, has_linear_output=False
|
| 67 |
+
):
|
| 68 |
+
super().__init__()
|
| 69 |
+
self.input_conv = pax.Sequential(
|
| 70 |
+
pax.Conv1D(input_dim, hidden_dim, 1, with_bias=False),
|
| 71 |
+
pax.LayerNorm(hidden_dim, -1, True, True),
|
| 72 |
+
)
|
| 73 |
+
self.upsample_factors = upsample_factors
|
| 74 |
+
self.dilated_convs = [
|
| 75 |
+
dilated_residual_conv_block(hidden_dim, 3, 1, 2**i) for i in range(5)
|
| 76 |
+
]
|
| 77 |
+
self.up_factors = upsample_factors[:-1]
|
| 78 |
+
self.up_blocks = [
|
| 79 |
+
up_block(hidden_dim, hidden_dim, x) for x in self.up_factors[:-1]
|
| 80 |
+
]
|
| 81 |
+
self.up_blocks.append(
|
| 82 |
+
up_block(
|
| 83 |
+
hidden_dim,
|
| 84 |
+
hidden_dim if has_linear_output else 3 * rnn_dim,
|
| 85 |
+
self.up_factors[-1],
|
| 86 |
+
relu=False,
|
| 87 |
+
)
|
| 88 |
+
)
|
| 89 |
+
if has_linear_output:
|
| 90 |
+
self.x2zrh_fc = pax.Linear(hidden_dim, rnn_dim * 3)
|
| 91 |
+
self.has_linear_output = has_linear_output
|
| 92 |
+
|
| 93 |
+
self.final_tile = upsample_factors[-1]
|
| 94 |
+
|
| 95 |
+
def __call__(self, x, no_repeat=False):
|
| 96 |
+
x = self.input_conv(x)
|
| 97 |
+
for residual in self.dilated_convs:
|
| 98 |
+
y = residual(x)
|
| 99 |
+
pad = (x.shape[1] - y.shape[1]) // 2
|
| 100 |
+
x = x[:, pad:-pad, :] + y
|
| 101 |
+
|
| 102 |
+
for f in self.up_blocks:
|
| 103 |
+
x = f(x)
|
| 104 |
+
|
| 105 |
+
if self.has_linear_output:
|
| 106 |
+
x = self.x2zrh_fc(x)
|
| 107 |
+
|
| 108 |
+
if no_repeat:
|
| 109 |
+
return x
|
| 110 |
+
x = tile_1d(x, self.final_tile)
|
| 111 |
+
return x
|
| 112 |
+
|
| 113 |
+
|
| 114 |
+
class GRU(pax.Module):
|
| 115 |
+
"""
|
| 116 |
+
A customized GRU module.
|
| 117 |
+
"""
|
| 118 |
+
|
| 119 |
+
input_dim: int
|
| 120 |
+
hidden_dim: int
|
| 121 |
+
|
| 122 |
+
def __init__(self, hidden_dim: int):
|
| 123 |
+
super().__init__()
|
| 124 |
+
self.hidden_dim = hidden_dim
|
| 125 |
+
self.h_zrh_fc = pax.Linear(
|
| 126 |
+
hidden_dim,
|
| 127 |
+
hidden_dim * 3,
|
| 128 |
+
w_init=jax.nn.initializers.variance_scaling(
|
| 129 |
+
1, "fan_out", "truncated_normal"
|
| 130 |
+
),
|
| 131 |
+
)
|
| 132 |
+
|
| 133 |
+
def initial_state(self, batch_size: int) -> GRUState:
|
| 134 |
+
"""Create an all zeros initial state."""
|
| 135 |
+
return GRUState(jnp.zeros((batch_size, self.hidden_dim), dtype=jnp.float32))
|
| 136 |
+
|
| 137 |
+
def __call__(self, state: GRUState, x) -> Tuple[GRUState, jnp.ndarray]:
|
| 138 |
+
hidden = state.hidden
|
| 139 |
+
x_zrh = x
|
| 140 |
+
h_zrh = self.h_zrh_fc(hidden)
|
| 141 |
+
x_zr, x_h = jnp.split(x_zrh, [2 * self.hidden_dim], axis=-1)
|
| 142 |
+
h_zr, h_h = jnp.split(h_zrh, [2 * self.hidden_dim], axis=-1)
|
| 143 |
+
|
| 144 |
+
zr = x_zr + h_zr
|
| 145 |
+
zr = jax.nn.sigmoid(zr)
|
| 146 |
+
z, r = jnp.split(zr, 2, axis=-1)
|
| 147 |
+
|
| 148 |
+
h_hat = x_h + r * h_h
|
| 149 |
+
h_hat = jnp.tanh(h_hat)
|
| 150 |
+
|
| 151 |
+
h = (1 - z) * hidden + z * h_hat
|
| 152 |
+
return GRUState(h), h
|
| 153 |
+
|
| 154 |
+
|
| 155 |
+
class Pruner(pax.Module):
|
| 156 |
+
"""
|
| 157 |
+
Base class for pruners
|
| 158 |
+
"""
|
| 159 |
+
|
| 160 |
+
def compute_sparsity(self, step):
|
| 161 |
+
t = jnp.power(1 - (step * 1.0 - 1_000) / 200_000, 3)
|
| 162 |
+
z = 0.95 * jnp.clip(1.0 - t, a_min=0, a_max=1)
|
| 163 |
+
return z
|
| 164 |
+
|
| 165 |
+
def prune(self, step, weights):
|
| 166 |
+
"""
|
| 167 |
+
Return a mask
|
| 168 |
+
"""
|
| 169 |
+
z = self.compute_sparsity(step)
|
| 170 |
+
x = weights
|
| 171 |
+
H, W = x.shape
|
| 172 |
+
x = x.reshape(H // 4, 4, W // 4, 4)
|
| 173 |
+
x = jnp.abs(x)
|
| 174 |
+
x = jnp.sum(x, axis=(1, 3), keepdims=True)
|
| 175 |
+
q = jnp.quantile(jnp.reshape(x, (-1,)), z)
|
| 176 |
+
x = x >= q
|
| 177 |
+
x = jnp.tile(x, (1, 4, 1, 4))
|
| 178 |
+
x = jnp.reshape(x, (H, W))
|
| 179 |
+
return x
|
| 180 |
+
|
| 181 |
+
|
| 182 |
+
class GRUPruner(Pruner):
|
| 183 |
+
def __init__(self, gru):
|
| 184 |
+
super().__init__()
|
| 185 |
+
self.h_zrh_fc_mask = jnp.ones_like(gru.h_zrh_fc.weight) == 1
|
| 186 |
+
|
| 187 |
+
def __call__(self, gru: pax.GRU):
|
| 188 |
+
"""
|
| 189 |
+
Apply mask after an optimization step
|
| 190 |
+
"""
|
| 191 |
+
zrh_masked_weights = jnp.where(self.h_zrh_fc_mask, gru.h_zrh_fc.weight, 0)
|
| 192 |
+
gru = gru.replace_node(gru.h_zrh_fc.weight, zrh_masked_weights)
|
| 193 |
+
return gru
|
| 194 |
+
|
| 195 |
+
def update_mask(self, step, gru: pax.GRU):
|
| 196 |
+
"""
|
| 197 |
+
Update internal masks
|
| 198 |
+
"""
|
| 199 |
+
z_weight, r_weight, h_weight = jnp.split(gru.h_zrh_fc.weight, 3, axis=1)
|
| 200 |
+
z_mask = self.prune(step, z_weight)
|
| 201 |
+
r_mask = self.prune(step, r_weight)
|
| 202 |
+
h_mask = self.prune(step, h_weight)
|
| 203 |
+
self.h_zrh_fc_mask *= jnp.concatenate((z_mask, r_mask, h_mask), axis=1)
|
| 204 |
+
|
| 205 |
+
|
| 206 |
+
class LinearPruner(Pruner):
|
| 207 |
+
def __init__(self, linear):
|
| 208 |
+
super().__init__()
|
| 209 |
+
self.mask = jnp.ones_like(linear.weight) == 1
|
| 210 |
+
|
| 211 |
+
def __call__(self, linear: pax.Linear):
|
| 212 |
+
"""
|
| 213 |
+
Apply mask after an optimization step
|
| 214 |
+
"""
|
| 215 |
+
return linear.replace(weight=jnp.where(self.mask, linear.weight, 0))
|
| 216 |
+
|
| 217 |
+
def update_mask(self, step, linear: pax.Linear):
|
| 218 |
+
"""
|
| 219 |
+
Update internal masks
|
| 220 |
+
"""
|
| 221 |
+
self.mask *= self.prune(step, linear.weight)
|
| 222 |
+
|
| 223 |
+
|
| 224 |
+
class WaveGRU(pax.Module):
|
| 225 |
+
"""
|
| 226 |
+
WaveGRU vocoder model.
|
| 227 |
+
"""
|
| 228 |
+
|
| 229 |
+
def __init__(
|
| 230 |
+
self,
|
| 231 |
+
mel_dim=80,
|
| 232 |
+
rnn_dim=1024,
|
| 233 |
+
upsample_factors=(5, 3, 20),
|
| 234 |
+
has_linear_output=False,
|
| 235 |
+
):
|
| 236 |
+
super().__init__()
|
| 237 |
+
self.embed = pax.Embed(256, 3 * rnn_dim)
|
| 238 |
+
self.upsample = Upsample(
|
| 239 |
+
input_dim=mel_dim,
|
| 240 |
+
hidden_dim=512,
|
| 241 |
+
rnn_dim=rnn_dim,
|
| 242 |
+
upsample_factors=upsample_factors,
|
| 243 |
+
has_linear_output=has_linear_output,
|
| 244 |
+
)
|
| 245 |
+
self.rnn = GRU(rnn_dim)
|
| 246 |
+
self.o1 = pax.Linear(rnn_dim, rnn_dim)
|
| 247 |
+
self.o2 = pax.Linear(rnn_dim, 256)
|
| 248 |
+
self.gru_pruner = GRUPruner(self.rnn)
|
| 249 |
+
self.o1_pruner = LinearPruner(self.o1)
|
| 250 |
+
self.o2_pruner = LinearPruner(self.o2)
|
| 251 |
+
|
| 252 |
+
def output(self, x):
|
| 253 |
+
x = self.o1(x)
|
| 254 |
+
x = jax.nn.relu(x)
|
| 255 |
+
x = self.o2(x)
|
| 256 |
+
return x
|
| 257 |
+
|
| 258 |
+
def inference(self, mel, no_gru=False, seed=42):
|
| 259 |
+
"""
|
| 260 |
+
generate waveform form melspectrogram
|
| 261 |
+
"""
|
| 262 |
+
|
| 263 |
+
@jax.jit
|
| 264 |
+
def step(rnn_state, mel, rng_key, x):
|
| 265 |
+
x = self.embed(x)
|
| 266 |
+
x = x + mel
|
| 267 |
+
rnn_state, x = self.rnn(rnn_state, x)
|
| 268 |
+
x = self.output(x)
|
| 269 |
+
rng_key, next_rng_key = jax.random.split(rng_key, 2)
|
| 270 |
+
x = jax.random.categorical(rng_key, x, axis=-1)
|
| 271 |
+
return rnn_state, next_rng_key, x
|
| 272 |
+
|
| 273 |
+
y = self.upsample(mel, no_repeat=no_gru)
|
| 274 |
+
if no_gru:
|
| 275 |
+
return y
|
| 276 |
+
x = jnp.array([127], dtype=jnp.int32)
|
| 277 |
+
rnn_state = self.rnn.initial_state(1)
|
| 278 |
+
output = []
|
| 279 |
+
rng_key = jax.random.PRNGKey(seed)
|
| 280 |
+
for i in tqdm(range(y.shape[1])):
|
| 281 |
+
rnn_state, rng_key, x = step(rnn_state, y[:, i], rng_key, x)
|
| 282 |
+
output.append(x)
|
| 283 |
+
x = jnp.concatenate(output, axis=0)
|
| 284 |
+
return x
|
| 285 |
+
|
| 286 |
+
def __call__(self, mel, x):
|
| 287 |
+
x = self.embed(x)
|
| 288 |
+
y = self.upsample(mel)
|
| 289 |
+
pad_left = (x.shape[1] - y.shape[1]) // 2
|
| 290 |
+
pad_right = x.shape[1] - y.shape[1] - pad_left
|
| 291 |
+
x = x[:, pad_left:-pad_right]
|
| 292 |
+
x = x + y
|
| 293 |
+
_, x = pax.scan(
|
| 294 |
+
self.rnn,
|
| 295 |
+
self.rnn.initial_state(x.shape[0]),
|
| 296 |
+
x,
|
| 297 |
+
time_major=False,
|
| 298 |
+
)
|
| 299 |
+
x = self.output(x)
|
| 300 |
+
return x
|
wavegru.yaml
ADDED
|
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
## dsp
|
| 2 |
+
sample_rate : 24000
|
| 3 |
+
window_length: 50.0 # ms
|
| 4 |
+
hop_length: 12.5 # ms
|
| 5 |
+
mel_min: 1.0e-5 ## need .0 to make it a float
|
| 6 |
+
mel_dim: 80
|
| 7 |
+
n_fft: 2048
|
| 8 |
+
|
| 9 |
+
## wavegru
|
| 10 |
+
embed_dim: 32
|
| 11 |
+
rnn_dim: 1024
|
| 12 |
+
frames_per_sequence: 67
|
| 13 |
+
num_pad_frames: 62
|
| 14 |
+
upsample_factors: [5, 3, 20]
|
wavegru_cpp.py
ADDED
|
@@ -0,0 +1,42 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
from wavegru_mod import WaveGRU
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
def extract_weight_mask(net):
|
| 6 |
+
data = {}
|
| 7 |
+
data["embed_weight"] = net.embed.weight
|
| 8 |
+
data["gru_h_zrh_weight"] = net.rnn.h_zrh_fc.weight
|
| 9 |
+
data["gru_h_zrh_mask"] = net.gru_pruner.h_zrh_fc_mask
|
| 10 |
+
data["gru_h_zrh_bias"] = net.rnn.h_zrh_fc.bias
|
| 11 |
+
|
| 12 |
+
data["o1_weight"] = net.o1.weight
|
| 13 |
+
data["o1_mask"] = net.o1_pruner.mask
|
| 14 |
+
data["o1_bias"] = net.o1.bias
|
| 15 |
+
data["o2_weight"] = net.o2.weight
|
| 16 |
+
data["o2_mask"] = net.o2_pruner.mask
|
| 17 |
+
data["o2_bias"] = net.o2.bias
|
| 18 |
+
return data
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
def load_wavegru_cpp(data, repeat_factor):
|
| 22 |
+
"""load wavegru weight to cpp object"""
|
| 23 |
+
embed = data["embed_weight"]
|
| 24 |
+
rnn_dim = data["gru_h_zrh_bias"].shape[0] // 3
|
| 25 |
+
net = WaveGRU(rnn_dim, repeat_factor)
|
| 26 |
+
net.load_embed(embed)
|
| 27 |
+
|
| 28 |
+
m = np.ascontiguousarray(data["gru_h_zrh_weight"].T)
|
| 29 |
+
mask = np.ascontiguousarray(data["gru_h_zrh_mask"].T)
|
| 30 |
+
b = data["gru_h_zrh_bias"]
|
| 31 |
+
|
| 32 |
+
o1 = np.ascontiguousarray(data["o1_weight"].T)
|
| 33 |
+
masko1 = np.ascontiguousarray(data["o1_mask"].T)
|
| 34 |
+
o1b = data["o1_bias"]
|
| 35 |
+
|
| 36 |
+
o2 = np.ascontiguousarray(data["o2_weight"].T)
|
| 37 |
+
masko2 = np.ascontiguousarray(data["o2_mask"].T)
|
| 38 |
+
o2b = data["o2_bias"]
|
| 39 |
+
|
| 40 |
+
net.load_weights(m, mask, b, o1, masko1, o1b, o2, masko2, o2b)
|
| 41 |
+
|
| 42 |
+
return net
|
wavegru_mod.cc
ADDED
|
@@ -0,0 +1,150 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/*
|
| 2 |
+
WaveGRU:
|
| 3 |
+
> Embed > GRU > O1 > O2 > Sampling > ...
|
| 4 |
+
*/
|
| 5 |
+
|
| 6 |
+
#include <pybind11/numpy.h>
|
| 7 |
+
#include <pybind11/pybind11.h>
|
| 8 |
+
#include <pybind11/stl.h>
|
| 9 |
+
|
| 10 |
+
#include <iostream>
|
| 11 |
+
#include <random>
|
| 12 |
+
#include <vector>
|
| 13 |
+
|
| 14 |
+
#include "sparse_matmul/sparse_matmul.h"
|
| 15 |
+
namespace py = pybind11;
|
| 16 |
+
using namespace std;
|
| 17 |
+
|
| 18 |
+
using fvec = std::vector<float>;
|
| 19 |
+
using ivec = std::vector<int>;
|
| 20 |
+
using fndarray = py::array_t<float>;
|
| 21 |
+
using indarray = py::array_t<int>;
|
| 22 |
+
using mat = csrblocksparse::CsrBlockSparseMatrix<float, float, int16_t>;
|
| 23 |
+
using vec = csrblocksparse::CacheAlignedVector<float>;
|
| 24 |
+
using masked_mat = csrblocksparse::MaskedSparseMatrix<float>;
|
| 25 |
+
|
| 26 |
+
mat create_mat(int h, int w) {
|
| 27 |
+
auto m = masked_mat(w, h, 0.90, 4, 4, 0.0, true);
|
| 28 |
+
auto a = mat(m);
|
| 29 |
+
return a;
|
| 30 |
+
}
|
| 31 |
+
|
| 32 |
+
struct WaveGRU {
|
| 33 |
+
int hidden_dim;
|
| 34 |
+
int repeat_factor;
|
| 35 |
+
mat m;
|
| 36 |
+
vec b;
|
| 37 |
+
vec z, r, hh, zrh;
|
| 38 |
+
vec fco1, fco2;
|
| 39 |
+
vec o1b, o2b;
|
| 40 |
+
vec t;
|
| 41 |
+
vec h;
|
| 42 |
+
vec logits;
|
| 43 |
+
mat o1, o2;
|
| 44 |
+
std::vector<vec> embed;
|
| 45 |
+
|
| 46 |
+
WaveGRU(int hidden_dim, int repeat_factor)
|
| 47 |
+
: hidden_dim(hidden_dim),
|
| 48 |
+
repeat_factor(repeat_factor),
|
| 49 |
+
b(3*hidden_dim),
|
| 50 |
+
t(3*hidden_dim),
|
| 51 |
+
zrh(3*hidden_dim),
|
| 52 |
+
z(hidden_dim),
|
| 53 |
+
r(hidden_dim),
|
| 54 |
+
hh(hidden_dim),
|
| 55 |
+
fco1(hidden_dim),
|
| 56 |
+
fco2(256),
|
| 57 |
+
h(hidden_dim),
|
| 58 |
+
o1b(hidden_dim),
|
| 59 |
+
o2b(256),
|
| 60 |
+
logits(256) {
|
| 61 |
+
m = create_mat(hidden_dim, 3*hidden_dim);
|
| 62 |
+
o1 = create_mat(hidden_dim, hidden_dim);
|
| 63 |
+
o2 = create_mat(hidden_dim, 256);
|
| 64 |
+
embed = std::vector<vec>();
|
| 65 |
+
for (int i = 0; i < 256; i++) {
|
| 66 |
+
embed.emplace_back(hidden_dim * 3);
|
| 67 |
+
embed[i].FillRandom();
|
| 68 |
+
}
|
| 69 |
+
}
|
| 70 |
+
|
| 71 |
+
void load_embed(fndarray embed_weights) {
|
| 72 |
+
auto a_embed = embed_weights.unchecked<2>();
|
| 73 |
+
for (int i = 0; i < 256; i++) {
|
| 74 |
+
for (int j = 0; j < hidden_dim * 3; j++) embed[i][j] = a_embed(i, j);
|
| 75 |
+
}
|
| 76 |
+
}
|
| 77 |
+
|
| 78 |
+
mat load_linear(vec& bias, fndarray w, indarray mask, fndarray b) {
|
| 79 |
+
auto w_ptr = static_cast<float*>(w.request().ptr);
|
| 80 |
+
auto mask_ptr = static_cast<int*>(mask.request().ptr);
|
| 81 |
+
auto rb = b.unchecked<1>();
|
| 82 |
+
// load bias, scale by 1/4
|
| 83 |
+
for (int i = 0; i < rb.shape(0); i++) bias[i] = rb(i) / 4;
|
| 84 |
+
// load weights
|
| 85 |
+
masked_mat mm(w.shape(0), w.shape(1), mask_ptr, w_ptr);
|
| 86 |
+
mat mmm(mm);
|
| 87 |
+
return mmm;
|
| 88 |
+
}
|
| 89 |
+
|
| 90 |
+
void load_weights(fndarray m, indarray m_mask, fndarray b,
|
| 91 |
+
fndarray o1, indarray o1_mask,
|
| 92 |
+
fndarray o1b, fndarray o2,
|
| 93 |
+
indarray o2_mask, fndarray o2b) {
|
| 94 |
+
this->m = load_linear(this->b, m, m_mask, b);
|
| 95 |
+
this->o1 = load_linear(this->o1b, o1, o1_mask, o1b);
|
| 96 |
+
this->o2 = load_linear(this->o2b, o2, o2_mask, o2b);
|
| 97 |
+
}
|
| 98 |
+
|
| 99 |
+
std::vector<int> inference(fndarray ft, float temperature) {
|
| 100 |
+
auto rft = ft.unchecked<2>();
|
| 101 |
+
int value = 127;
|
| 102 |
+
std::vector<int> signal(rft.shape(0) * repeat_factor);
|
| 103 |
+
h.FillZero();
|
| 104 |
+
for (int index = 0; index < signal.size(); index++) {
|
| 105 |
+
m.SpMM_bias(h, b, &zrh, false);
|
| 106 |
+
|
| 107 |
+
for (int i = 0; i < 3 * hidden_dim; i++) t[i] = embed[value][i] + rft(index / repeat_factor, i);
|
| 108 |
+
for (int i = 0; i < hidden_dim; i++) {
|
| 109 |
+
z[i] = zrh[i] + t[i];
|
| 110 |
+
r[i] = zrh[hidden_dim + i] + t[hidden_dim + i];
|
| 111 |
+
}
|
| 112 |
+
|
| 113 |
+
z.Sigmoid();
|
| 114 |
+
r.Sigmoid();
|
| 115 |
+
|
| 116 |
+
for (int i = 0; i < hidden_dim; i++) {
|
| 117 |
+
hh[i] = zrh[hidden_dim * 2 + i] * r[i] + t[hidden_dim * 2 + i];
|
| 118 |
+
}
|
| 119 |
+
hh.Tanh();
|
| 120 |
+
for (int i = 0; i < hidden_dim; i++) {
|
| 121 |
+
h[i] = (1. - z[i]) * h[i] + z[i] * hh[i];
|
| 122 |
+
}
|
| 123 |
+
o1.SpMM_bias(h, o1b, &fco1, true);
|
| 124 |
+
o2.SpMM_bias(fco1, o2b, &fco2, false);
|
| 125 |
+
// auto max_logit = fco2[0];
|
| 126 |
+
// for (int i = 1; i <= 255; ++i) {
|
| 127 |
+
// max_logit = max(max_logit, fco2[i]);
|
| 128 |
+
// }
|
| 129 |
+
// float total = 0.0;
|
| 130 |
+
// for (int i = 0; i <= 255; ++i) {
|
| 131 |
+
// logits[i] = csrblocksparse::fast_exp(fco2[i] - max_logit);
|
| 132 |
+
// total += logits[i];
|
| 133 |
+
// }
|
| 134 |
+
// for (int i = 0; i <= 255; ++i) {
|
| 135 |
+
// if (logits[i] < total / 1024.0) fco2[i] = -1e9;
|
| 136 |
+
// }
|
| 137 |
+
value = fco2.Sample(temperature);
|
| 138 |
+
signal[index] = value;
|
| 139 |
+
}
|
| 140 |
+
return signal;
|
| 141 |
+
}
|
| 142 |
+
};
|
| 143 |
+
|
| 144 |
+
PYBIND11_MODULE(wavegru_mod, m) {
|
| 145 |
+
py::class_<WaveGRU>(m, "WaveGRU")
|
| 146 |
+
.def(py::init<int, int>())
|
| 147 |
+
.def("load_embed", &WaveGRU::load_embed)
|
| 148 |
+
.def("load_weights", &WaveGRU::load_weights)
|
| 149 |
+
.def("inference", &WaveGRU::inference);
|
| 150 |
+
}
|
wavegru_vocoder_1024_v4_1320000.ckpt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:052a90bd510607f5cd2a6e6ce9d1ae4138db25cb69cc8504c98f2d33eac13375
|
| 3 |
+
size 69717674
|