wenjin_lee commited on
Commit
d9aaf5e
·
1 Parent(s): 9ae8623

Use torch.inference_mode() and disable gradient checkpointing

Browse files
Files changed (2) hide show
  1. config.json +4 -1
  2. modeling_zeranker.py +11 -3
config.json CHANGED
@@ -64,5 +64,8 @@
64
  "transformers_version": "4.57.1",
65
  "use_cache": true,
66
  "use_sliding_window": false,
67
- "vocab_size": 151936
 
 
 
68
  }
 
64
  "transformers_version": "4.57.1",
65
  "use_cache": true,
66
  "use_sliding_window": false,
67
+ "vocab_size": 151936,
68
+ "auto_map": {
69
+ "AutoConfig": "modeling_zeranker.ZEConfig"
70
+ }
71
  }
modeling_zeranker.py CHANGED
@@ -20,11 +20,16 @@ from transformers.models.qwen3.modeling_qwen3 import Qwen3ForCausalLM
20
  from transformers.tokenization_utils_base import BatchEncoding
21
  from transformers.tokenization_utils_fast import PreTrainedTokenizerFast
22
 
 
 
 
 
 
23
  # pyright: reportUnknownMemberType=false
24
  # pyright: reportUnknownVariableType=false
25
 
26
  MODEL_PATH = "zeroentropy/zerank-2"
27
- PER_DEVICE_BATCH_SIZE_TOKENS = 15_000
28
  global_device = (
29
  torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
30
  )
@@ -126,9 +131,11 @@ def predict(
126
  query_documents = [[sentence[0], sentence[1]] for sentence in sentences]
127
 
128
  if not hasattr(self, "inner_model"):
 
129
  self.inner_tokenizer, self.inner_model = load_model(global_device)
130
- self.inner_model.gradient_checkpointing_enable()
131
  self.inner_model.eval()
 
132
  self.inner_yes_token_id = self.inner_tokenizer.encode(
133
  "Yes", add_special_tokens=False
134
  )[0]
@@ -172,7 +179,8 @@ def predict(
172
  batch_inputs = batch_inputs.to(global_device)
173
 
174
  try:
175
- outputs = model(**batch_inputs, use_cache=False)
 
176
  except torch.OutOfMemoryError:
177
  print(f"GPU OOM! {torch.cuda.memory_reserved()}")
178
  torch.cuda.empty_cache()
 
20
  from transformers.tokenization_utils_base import BatchEncoding
21
  from transformers.tokenization_utils_fast import PreTrainedTokenizerFast
22
 
23
+ import logging
24
+
25
+ logger = logging.getLogger(__name__)
26
+ print("Running code of HF Model")
27
+
28
  # pyright: reportUnknownMemberType=false
29
  # pyright: reportUnknownVariableType=false
30
 
31
  MODEL_PATH = "zeroentropy/zerank-2"
32
+ PER_DEVICE_BATCH_SIZE_TOKENS = 10_000
33
  global_device = (
34
  torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
35
  )
 
131
  query_documents = [[sentence[0], sentence[1]] for sentence in sentences]
132
 
133
  if not hasattr(self, "inner_model"):
134
+ logger.info(f"Memory reserved [Within Model File] Before Loading Model: {torch.cuda.memory_reserved()}")
135
  self.inner_tokenizer, self.inner_model = load_model(global_device)
136
+ logger.info(f"Memory reserved [Within Model File] After Loading Model: {torch.cuda.memory_reserved()}")
137
  self.inner_model.eval()
138
+ self.inner_model.gradient_checkpointing_disable()
139
  self.inner_yes_token_id = self.inner_tokenizer.encode(
140
  "Yes", add_special_tokens=False
141
  )[0]
 
179
  batch_inputs = batch_inputs.to(global_device)
180
 
181
  try:
182
+ with torch.inference_mode():
183
+ outputs = model(**batch_inputs, use_cache=False)
184
  except torch.OutOfMemoryError:
185
  print(f"GPU OOM! {torch.cuda.memory_reserved()}")
186
  torch.cuda.empty_cache()