Fix unbound local error when pad_tokens=False
Browse files- flux_emphasis.py +9 -6
flux_emphasis.py
CHANGED
|
@@ -203,24 +203,27 @@ def group_tokens_and_weights(
|
|
| 203 |
, weights = token_weight_list
|
| 204 |
)
|
| 205 |
"""
|
|
|
|
|
|
|
|
|
|
| 206 |
max_len = max_length - 2 if max_length < 77 else max_length
|
| 207 |
# this will be a 2d list
|
| 208 |
new_token_ids = []
|
| 209 |
new_weights = []
|
| 210 |
while len(token_ids) >= max_len:
|
| 211 |
# get the first 75 tokens
|
| 212 |
-
|
| 213 |
-
|
| 214 |
|
| 215 |
# extract token ids and weights
|
| 216 |
|
| 217 |
if pad_tokens:
|
| 218 |
if bos is not None:
|
| 219 |
-
temp_77_token_ids = [bos] +
|
| 220 |
-
temp_77_weights = [1.0] +
|
| 221 |
else:
|
| 222 |
-
temp_77_token_ids =
|
| 223 |
-
temp_77_weights =
|
| 224 |
|
| 225 |
# add 77 token and weights chunk to the holder list
|
| 226 |
new_token_ids.append(temp_77_token_ids)
|
|
|
|
| 203 |
, weights = token_weight_list
|
| 204 |
)
|
| 205 |
"""
|
| 206 |
+
# TODO: Possibly need to fix this, since this doesn't seem correct.
|
| 207 |
+
# Ignoring for now since I don't know what the consequences might be
|
| 208 |
+
# if changed to <= instead of <.
|
| 209 |
max_len = max_length - 2 if max_length < 77 else max_length
|
| 210 |
# this will be a 2d list
|
| 211 |
new_token_ids = []
|
| 212 |
new_weights = []
|
| 213 |
while len(token_ids) >= max_len:
|
| 214 |
# get the first 75 tokens
|
| 215 |
+
temp_77_token_ids = [token_ids.pop(0) for _ in range(max_len)]
|
| 216 |
+
temp_77_weights = [weights.pop(0) for _ in range(max_len)]
|
| 217 |
|
| 218 |
# extract token ids and weights
|
| 219 |
|
| 220 |
if pad_tokens:
|
| 221 |
if bos is not None:
|
| 222 |
+
temp_77_token_ids = [bos] + temp_77_token_ids + [eos]
|
| 223 |
+
temp_77_weights = [1.0] + temp_77_weights + [1.0]
|
| 224 |
else:
|
| 225 |
+
temp_77_token_ids = temp_77_token_ids + [eos]
|
| 226 |
+
temp_77_weights = temp_77_weights + [1.0]
|
| 227 |
|
| 228 |
# add 77 token and weights chunk to the holder list
|
| 229 |
new_token_ids.append(temp_77_token_ids)
|