florianhoenicke commited on
Commit
4d69be8
·
1 Parent(s): 4586ed4

fix: dtype

Browse files
Files changed (1) hide show
  1. blocks_jvlm.py +2 -2
blocks_jvlm.py CHANGED
@@ -944,9 +944,9 @@ class MHSDPA(nn.Module):
944
  q, k, v = qkv.split(self.fused_dims, dim=-1)
945
  else:
946
  assert xk is not None
947
- q = f.linear(xq, self.q_w.weight, self.q_b)
948
  kv_b = torch.cat((self.k_b, self.v_b))
949
- kv = f.linear(xk, self.kv_w.weight, kv_b)
950
  if self.clip_qkv is not None:
951
  q.clamp_(min=-self.clip_qkv, max=self.clip_qkv)
952
  kv.clamp_(min=-self.clip_qkv, max=self.clip_qkv)
 
944
  q, k, v = qkv.split(self.fused_dims, dim=-1)
945
  else:
946
  assert xk is not None
947
+ q = f.linear(xq.to(self.q_w.weight.dtype), self.q_w.weight, self.q_b)
948
  kv_b = torch.cat((self.k_b, self.v_b))
949
+ kv = f.linear(xk.to(self.kv_w.weight.dtype), self.kv_w.weight, kv_b)
950
  if self.clip_qkv is not None:
951
  q.clamp_(min=-self.clip_qkv, max=self.clip_qkv)
952
  kv.clamp_(min=-self.clip_qkv, max=self.clip_qkv)