Use torch.inference_mode() and disable gradient checkpointing

#4
by prathamj31 - opened
Files changed (2) hide show
  1. config.json +4 -1
  2. modeling_zeranker.py +20 -9
config.json CHANGED
@@ -64,5 +64,8 @@
64
  "transformers_version": "4.57.1",
65
  "use_cache": true,
66
  "use_sliding_window": false,
67
- "vocab_size": 151936
 
 
 
68
  }
 
64
  "transformers_version": "4.57.1",
65
  "use_cache": true,
66
  "use_sliding_window": false,
67
+ "vocab_size": 151936,
68
+ "auto_map": {
69
+ "AutoConfig": "modeling_zeranker.ZEConfig"
70
+ }
71
  }
modeling_zeranker.py CHANGED
@@ -24,7 +24,7 @@ from transformers.tokenization_utils_fast import PreTrainedTokenizerFast
24
  # pyright: reportUnknownVariableType=false
25
 
26
  MODEL_PATH = "zeroentropy/zerank-2"
27
- PER_DEVICE_BATCH_SIZE_TOKENS = 15_000
28
  global_device = (
29
  torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
30
  )
@@ -125,13 +125,7 @@ def predict(
125
  raise ValueError("query_documents or sentences must be provided")
126
  query_documents = [[sentence[0], sentence[1]] for sentence in sentences]
127
 
128
- if not hasattr(self, "inner_model"):
129
- self.inner_tokenizer, self.inner_model = load_model(global_device)
130
- self.inner_model.gradient_checkpointing_enable()
131
- self.inner_model.eval()
132
- self.inner_yes_token_id = self.inner_tokenizer.encode(
133
- "Yes", add_special_tokens=False
134
- )[0]
135
 
136
  model = self.inner_model
137
  tokenizer = self.inner_tokenizer
@@ -172,7 +166,8 @@ def predict(
172
  batch_inputs = batch_inputs.to(global_device)
173
 
174
  try:
175
- outputs = model(**batch_inputs, use_cache=False)
 
176
  except torch.OutOfMemoryError:
177
  print(f"GPU OOM! {torch.cuda.memory_reserved()}")
178
  torch.cuda.empty_cache()
@@ -207,6 +202,22 @@ def to_device(self: _CE, new_device: torch.device) -> None:
207
  global_device = new_device
208
 
209
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
210
  _CE.predict = predict
211
 
212
  from transformers import Qwen3Config
 
24
  # pyright: reportUnknownVariableType=false
25
 
26
  MODEL_PATH = "zeroentropy/zerank-2"
27
+ PER_DEVICE_BATCH_SIZE_TOKENS = 10_000
28
  global_device = (
29
  torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
30
  )
 
125
  raise ValueError("query_documents or sentences must be provided")
126
  query_documents = [[sentence[0], sentence[1]] for sentence in sentences]
127
 
128
+
 
 
 
 
 
 
129
 
130
  model = self.inner_model
131
  tokenizer = self.inner_tokenizer
 
166
  batch_inputs = batch_inputs.to(global_device)
167
 
168
  try:
169
+ with torch.inference_mode():
170
+ outputs = model(**batch_inputs, use_cache=False)
171
  except torch.OutOfMemoryError:
172
  print(f"GPU OOM! {torch.cuda.memory_reserved()}")
173
  torch.cuda.empty_cache()
 
202
  global_device = new_device
203
 
204
 
205
+
206
+ original_init = _CE.__init__
207
+
208
+
209
+ def _new_init(self: _CE, *args: Any, **kwargs: Any) -> None:
210
+ original_init(self, *args, **kwargs)
211
+ self.inner_tokenizer, self.inner_model = load_model(global_device)
212
+ self.inner_model.eval()
213
+ self.inner_model.gradient_checkpointing_disable()
214
+ self.inner_yes_token_id = self.inner_tokenizer.encode(
215
+ "Yes", add_special_tokens=False
216
+ )[0]
217
+
218
+
219
+ _CE.__init__ = _new_init
220
+
221
  _CE.predict = predict
222
 
223
  from transformers import Qwen3Config