Commit
·
4d69be8
1
Parent(s):
4586ed4
fix: dtype
Browse files- blocks_jvlm.py +2 -2
blocks_jvlm.py
CHANGED
|
@@ -944,9 +944,9 @@ class MHSDPA(nn.Module):
|
|
| 944 |
q, k, v = qkv.split(self.fused_dims, dim=-1)
|
| 945 |
else:
|
| 946 |
assert xk is not None
|
| 947 |
-
q = f.linear(xq, self.q_w.weight, self.q_b)
|
| 948 |
kv_b = torch.cat((self.k_b, self.v_b))
|
| 949 |
-
kv = f.linear(xk, self.kv_w.weight, kv_b)
|
| 950 |
if self.clip_qkv is not None:
|
| 951 |
q.clamp_(min=-self.clip_qkv, max=self.clip_qkv)
|
| 952 |
kv.clamp_(min=-self.clip_qkv, max=self.clip_qkv)
|
|
|
|
| 944 |
q, k, v = qkv.split(self.fused_dims, dim=-1)
|
| 945 |
else:
|
| 946 |
assert xk is not None
|
| 947 |
+
q = f.linear(xq.to(self.q_w.weight.dtype), self.q_w.weight, self.q_b)
|
| 948 |
kv_b = torch.cat((self.k_b, self.v_b))
|
| 949 |
+
kv = f.linear(xk.to(self.kv_w.weight.dtype), self.kv_w.weight, kv_b)
|
| 950 |
if self.clip_qkv is not None:
|
| 951 |
q.clamp_(min=-self.clip_qkv, max=self.clip_qkv)
|
| 952 |
kv.clamp_(min=-self.clip_qkv, max=self.clip_qkv)
|