wenjin_lee
commited on
Commit
·
d9aaf5e
1
Parent(s):
9ae8623
Use torch.inference_mode() and disable gradient checkpointing
Browse files- config.json +4 -1
- modeling_zeranker.py +11 -3
config.json
CHANGED
|
@@ -64,5 +64,8 @@
|
|
| 64 |
"transformers_version": "4.57.1",
|
| 65 |
"use_cache": true,
|
| 66 |
"use_sliding_window": false,
|
| 67 |
-
"vocab_size": 151936
|
|
|
|
|
|
|
|
|
|
| 68 |
}
|
|
|
|
| 64 |
"transformers_version": "4.57.1",
|
| 65 |
"use_cache": true,
|
| 66 |
"use_sliding_window": false,
|
| 67 |
+
"vocab_size": 151936,
|
| 68 |
+
"auto_map": {
|
| 69 |
+
"AutoConfig": "modeling_zeranker.ZEConfig"
|
| 70 |
+
}
|
| 71 |
}
|
modeling_zeranker.py
CHANGED
|
@@ -20,11 +20,16 @@ from transformers.models.qwen3.modeling_qwen3 import Qwen3ForCausalLM
|
|
| 20 |
from transformers.tokenization_utils_base import BatchEncoding
|
| 21 |
from transformers.tokenization_utils_fast import PreTrainedTokenizerFast
|
| 22 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 23 |
# pyright: reportUnknownMemberType=false
|
| 24 |
# pyright: reportUnknownVariableType=false
|
| 25 |
|
| 26 |
MODEL_PATH = "zeroentropy/zerank-2"
|
| 27 |
-
PER_DEVICE_BATCH_SIZE_TOKENS =
|
| 28 |
global_device = (
|
| 29 |
torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
|
| 30 |
)
|
|
@@ -126,9 +131,11 @@ def predict(
|
|
| 126 |
query_documents = [[sentence[0], sentence[1]] for sentence in sentences]
|
| 127 |
|
| 128 |
if not hasattr(self, "inner_model"):
|
|
|
|
| 129 |
self.inner_tokenizer, self.inner_model = load_model(global_device)
|
| 130 |
-
|
| 131 |
self.inner_model.eval()
|
|
|
|
| 132 |
self.inner_yes_token_id = self.inner_tokenizer.encode(
|
| 133 |
"Yes", add_special_tokens=False
|
| 134 |
)[0]
|
|
@@ -172,7 +179,8 @@ def predict(
|
|
| 172 |
batch_inputs = batch_inputs.to(global_device)
|
| 173 |
|
| 174 |
try:
|
| 175 |
-
|
|
|
|
| 176 |
except torch.OutOfMemoryError:
|
| 177 |
print(f"GPU OOM! {torch.cuda.memory_reserved()}")
|
| 178 |
torch.cuda.empty_cache()
|
|
|
|
| 20 |
from transformers.tokenization_utils_base import BatchEncoding
|
| 21 |
from transformers.tokenization_utils_fast import PreTrainedTokenizerFast
|
| 22 |
|
| 23 |
+
import logging
|
| 24 |
+
|
| 25 |
+
logger = logging.getLogger(__name__)
|
| 26 |
+
print("Running code of HF Model")
|
| 27 |
+
|
| 28 |
# pyright: reportUnknownMemberType=false
|
| 29 |
# pyright: reportUnknownVariableType=false
|
| 30 |
|
| 31 |
MODEL_PATH = "zeroentropy/zerank-2"
|
| 32 |
+
PER_DEVICE_BATCH_SIZE_TOKENS = 10_000
|
| 33 |
global_device = (
|
| 34 |
torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
|
| 35 |
)
|
|
|
|
| 131 |
query_documents = [[sentence[0], sentence[1]] for sentence in sentences]
|
| 132 |
|
| 133 |
if not hasattr(self, "inner_model"):
|
| 134 |
+
logger.info(f"Memory reserved [Within Model File] Before Loading Model: {torch.cuda.memory_reserved()}")
|
| 135 |
self.inner_tokenizer, self.inner_model = load_model(global_device)
|
| 136 |
+
logger.info(f"Memory reserved [Within Model File] After Loading Model: {torch.cuda.memory_reserved()}")
|
| 137 |
self.inner_model.eval()
|
| 138 |
+
self.inner_model.gradient_checkpointing_disable()
|
| 139 |
self.inner_yes_token_id = self.inner_tokenizer.encode(
|
| 140 |
"Yes", add_special_tokens=False
|
| 141 |
)[0]
|
|
|
|
| 179 |
batch_inputs = batch_inputs.to(global_device)
|
| 180 |
|
| 181 |
try:
|
| 182 |
+
with torch.inference_mode():
|
| 183 |
+
outputs = model(**batch_inputs, use_cache=False)
|
| 184 |
except torch.OutOfMemoryError:
|
| 185 |
print(f"GPU OOM! {torch.cuda.memory_reserved()}")
|
| 186 |
torch.cuda.empty_cache()
|