Use torch.inference_mode() and disable gradient checkpointing
#4
by
prathamj31
- opened
- config.json +4 -1
- modeling_zeranker.py +20 -9
config.json
CHANGED
|
@@ -64,5 +64,8 @@
|
|
| 64 |
"transformers_version": "4.57.1",
|
| 65 |
"use_cache": true,
|
| 66 |
"use_sliding_window": false,
|
| 67 |
-
"vocab_size": 151936
|
|
|
|
|
|
|
|
|
|
| 68 |
}
|
|
|
|
| 64 |
"transformers_version": "4.57.1",
|
| 65 |
"use_cache": true,
|
| 66 |
"use_sliding_window": false,
|
| 67 |
+
"vocab_size": 151936,
|
| 68 |
+
"auto_map": {
|
| 69 |
+
"AutoConfig": "modeling_zeranker.ZEConfig"
|
| 70 |
+
}
|
| 71 |
}
|
modeling_zeranker.py
CHANGED
|
@@ -24,7 +24,7 @@ from transformers.tokenization_utils_fast import PreTrainedTokenizerFast
|
|
| 24 |
# pyright: reportUnknownVariableType=false
|
| 25 |
|
| 26 |
MODEL_PATH = "zeroentropy/zerank-2"
|
| 27 |
-
PER_DEVICE_BATCH_SIZE_TOKENS =
|
| 28 |
global_device = (
|
| 29 |
torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
|
| 30 |
)
|
|
@@ -125,13 +125,7 @@ def predict(
|
|
| 125 |
raise ValueError("query_documents or sentences must be provided")
|
| 126 |
query_documents = [[sentence[0], sentence[1]] for sentence in sentences]
|
| 127 |
|
| 128 |
-
|
| 129 |
-
self.inner_tokenizer, self.inner_model = load_model(global_device)
|
| 130 |
-
self.inner_model.gradient_checkpointing_enable()
|
| 131 |
-
self.inner_model.eval()
|
| 132 |
-
self.inner_yes_token_id = self.inner_tokenizer.encode(
|
| 133 |
-
"Yes", add_special_tokens=False
|
| 134 |
-
)[0]
|
| 135 |
|
| 136 |
model = self.inner_model
|
| 137 |
tokenizer = self.inner_tokenizer
|
|
@@ -172,7 +166,8 @@ def predict(
|
|
| 172 |
batch_inputs = batch_inputs.to(global_device)
|
| 173 |
|
| 174 |
try:
|
| 175 |
-
|
|
|
|
| 176 |
except torch.OutOfMemoryError:
|
| 177 |
print(f"GPU OOM! {torch.cuda.memory_reserved()}")
|
| 178 |
torch.cuda.empty_cache()
|
|
@@ -207,6 +202,22 @@ def to_device(self: _CE, new_device: torch.device) -> None:
|
|
| 207 |
global_device = new_device
|
| 208 |
|
| 209 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 210 |
_CE.predict = predict
|
| 211 |
|
| 212 |
from transformers import Qwen3Config
|
|
|
|
| 24 |
# pyright: reportUnknownVariableType=false
|
| 25 |
|
| 26 |
MODEL_PATH = "zeroentropy/zerank-2"
|
| 27 |
+
PER_DEVICE_BATCH_SIZE_TOKENS = 10_000
|
| 28 |
global_device = (
|
| 29 |
torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
|
| 30 |
)
|
|
|
|
| 125 |
raise ValueError("query_documents or sentences must be provided")
|
| 126 |
query_documents = [[sentence[0], sentence[1]] for sentence in sentences]
|
| 127 |
|
| 128 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 129 |
|
| 130 |
model = self.inner_model
|
| 131 |
tokenizer = self.inner_tokenizer
|
|
|
|
| 166 |
batch_inputs = batch_inputs.to(global_device)
|
| 167 |
|
| 168 |
try:
|
| 169 |
+
with torch.inference_mode():
|
| 170 |
+
outputs = model(**batch_inputs, use_cache=False)
|
| 171 |
except torch.OutOfMemoryError:
|
| 172 |
print(f"GPU OOM! {torch.cuda.memory_reserved()}")
|
| 173 |
torch.cuda.empty_cache()
|
|
|
|
| 202 |
global_device = new_device
|
| 203 |
|
| 204 |
|
| 205 |
+
|
| 206 |
+
original_init = _CE.__init__
|
| 207 |
+
|
| 208 |
+
|
| 209 |
+
def _new_init(self: _CE, *args: Any, **kwargs: Any) -> None:
|
| 210 |
+
original_init(self, *args, **kwargs)
|
| 211 |
+
self.inner_tokenizer, self.inner_model = load_model(global_device)
|
| 212 |
+
self.inner_model.eval()
|
| 213 |
+
self.inner_model.gradient_checkpointing_disable()
|
| 214 |
+
self.inner_yes_token_id = self.inner_tokenizer.encode(
|
| 215 |
+
"Yes", add_special_tokens=False
|
| 216 |
+
)[0]
|
| 217 |
+
|
| 218 |
+
|
| 219 |
+
_CE.__init__ = _new_init
|
| 220 |
+
|
| 221 |
_CE.predict = predict
|
| 222 |
|
| 223 |
from transformers import Qwen3Config
|