altpuppet
commited on
Commit
·
52beecd
1
Parent(s):
0a3e8eb
Fix syntax error in gated_deltaproduct.py and add matplotlib dependency
Browse files
requirements.txt
CHANGED
|
@@ -18,3 +18,4 @@ python-dateutil>=2.8.0
|
|
| 18 |
pytz>=2021.1
|
| 19 |
PyYAML>=5.4.1
|
| 20 |
flash-linear-attention @ git+https://github.com/fla-org/flash-linear-attention@main
|
|
|
|
|
|
| 18 |
pytz>=2021.1
|
| 19 |
PyYAML>=5.4.1
|
| 20 |
flash-linear-attention @ git+https://github.com/fla-org/flash-linear-attention@main
|
| 21 |
+
matplotlib>=3.5.0
|
src/models/gated_deltaproduct/gated_deltaproduct.py
CHANGED
|
@@ -74,11 +74,9 @@ class GatedDeltaProduct(nn.Module):
|
|
| 74 |
# Consistency check: Ensure expand_v produces integer values
|
| 75 |
if not math.isclose(self.num_v_heads * self.head_dim * expand_v, self.value_dim, rel_tol=1e-5):
|
| 76 |
raise ValueError(
|
| 77 |
-
f"expand_v={expand_v} does not produce an integer value when multiplied by key_dim={self.key_dim}. "
|
| 78 |
-
|
| 79 |
-
|
| 80 |
-
"which is invalid for nn.Linear."
|
| 81 |
-
)
|
| 82 |
)
|
| 83 |
if self.num_v_heads > self.num_heads and self.num_v_heads % self.num_heads != 0:
|
| 84 |
raise ValueError(f"num_v_heads={self.num_v_heads} must be divisible by num_heads={self.num_heads}.")
|
|
|
|
| 74 |
# Consistency check: Ensure expand_v produces integer values
|
| 75 |
if not math.isclose(self.num_v_heads * self.head_dim * expand_v, self.value_dim, rel_tol=1e-5):
|
| 76 |
raise ValueError(
|
| 77 |
+
f"expand_v={expand_v} does not produce an integer value when multiplied by key_dim={self.key_dim}. "
|
| 78 |
+
f"Resulting value_dim would be {self.num_v_heads * self.head_dim * expand_v}, "
|
| 79 |
+
"which is invalid for nn.Linear."
|
|
|
|
|
|
|
| 80 |
)
|
| 81 |
if self.num_v_heads > self.num_heads and self.num_v_heads % self.num_heads != 0:
|
| 82 |
raise ValueError(f"num_v_heads={self.num_v_heads} must be divisible by num_heads={self.num_heads}.")
|