Spaces:
Runtime error
Runtime error
| import os | |
| import time | |
| import argparse | |
| from dotenv import load_dotenv | |
| from distutils.util import strtobool | |
| from memory_profiler import memory_usage | |
| from tqdm import tqdm | |
| from llama2_wrapper import LLAMA2_WRAPPER | |
| def run_iteration( | |
| llama2_wrapper, prompt_example, DEFAULT_SYSTEM_PROMPT, DEFAULT_MAX_NEW_TOKENS | |
| ): | |
| def generation(): | |
| generator = llama2_wrapper.run( | |
| prompt_example, | |
| [], | |
| DEFAULT_SYSTEM_PROMPT, | |
| DEFAULT_MAX_NEW_TOKENS, | |
| 1, | |
| 0.95, | |
| 50, | |
| ) | |
| model_response = None | |
| try: | |
| first_model_response = next(generator) | |
| except StopIteration: | |
| pass | |
| for model_response in generator: | |
| pass | |
| return llama2_wrapper.get_token_length(model_response), model_response | |
| tic = time.perf_counter() | |
| mem_usage, (output_token_length, model_response) = memory_usage( | |
| (generation,), max_usage=True, retval=True | |
| ) | |
| toc = time.perf_counter() | |
| generation_time = toc - tic | |
| tokens_per_second = output_token_length / generation_time | |
| return generation_time, tokens_per_second, mem_usage, model_response | |
| def main(): | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument("--iter", type=int, default=5, help="Number of iterations") | |
| parser.add_argument("--model_path", type=str, default="", help="model path") | |
| parser.add_argument( | |
| "--backend_type", | |
| type=str, | |
| default="", | |
| help="Backend options: llama.cpp, gptq, transformers", | |
| ) | |
| parser.add_argument( | |
| "--load_in_8bit", | |
| type=bool, | |
| default=False, | |
| help="Whether to use bitsandbytes 8 bit.", | |
| ) | |
| args = parser.parse_args() | |
| load_dotenv() | |
| DEFAULT_SYSTEM_PROMPT = os.getenv("DEFAULT_SYSTEM_PROMPT", "") | |
| MAX_MAX_NEW_TOKENS = int(os.getenv("MAX_MAX_NEW_TOKENS", 2048)) | |
| DEFAULT_MAX_NEW_TOKENS = int(os.getenv("DEFAULT_MAX_NEW_TOKENS", 1024)) | |
| MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", 4000)) | |
| MODEL_PATH = os.getenv("MODEL_PATH") | |
| assert MODEL_PATH is not None, f"MODEL_PATH is required, got: {MODEL_PATH}" | |
| BACKEND_TYPE = os.getenv("BACKEND_TYPE") | |
| assert BACKEND_TYPE is not None, f"BACKEND_TYPE is required, got: {BACKEND_TYPE}" | |
| LOAD_IN_8BIT = bool(strtobool(os.getenv("LOAD_IN_8BIT", "True"))) | |
| if args.model_path != "": | |
| MODEL_PATH = args.model_path | |
| if args.backend_type != "": | |
| BACKEND_TYPE = args.backend_type | |
| if args.load_in_8bit: | |
| LOAD_IN_8BIT = True | |
| # Initialization | |
| init_tic = time.perf_counter() | |
| llama2_wrapper = LLAMA2_WRAPPER( | |
| model_path=MODEL_PATH, | |
| backend_type=BACKEND_TYPE, | |
| max_tokens=MAX_INPUT_TOKEN_LENGTH, | |
| load_in_8bit=LOAD_IN_8BIT, | |
| # verbose=True, | |
| ) | |
| init_toc = time.perf_counter() | |
| initialization_time = init_toc - init_tic | |
| total_time = 0 | |
| total_tokens_per_second = 0 | |
| total_memory_gen = 0 | |
| prompt_example = ( | |
| "Can you explain briefly to me what is the Python programming language?" | |
| ) | |
| # Cold run | |
| print("Performing cold run...") | |
| run_iteration( | |
| llama2_wrapper, prompt_example, DEFAULT_SYSTEM_PROMPT, DEFAULT_MAX_NEW_TOKENS | |
| ) | |
| # Timed runs | |
| print(f"Performing {args.iter} timed runs...") | |
| for i in tqdm(range(args.iter)): | |
| try: | |
| gen_time, tokens_per_sec, mem_gen, model_response = run_iteration( | |
| llama2_wrapper, | |
| prompt_example, | |
| DEFAULT_SYSTEM_PROMPT, | |
| DEFAULT_MAX_NEW_TOKENS, | |
| ) | |
| total_time += gen_time | |
| total_tokens_per_second += tokens_per_sec | |
| total_memory_gen += mem_gen | |
| except: | |
| break | |
| avg_time = total_time / (i + 1) | |
| avg_tokens_per_second = total_tokens_per_second / (i + 1) | |
| avg_memory_gen = total_memory_gen / (i + 1) | |
| print(f"Last model response: {model_response}") | |
| print(f"Initialization time: {initialization_time:0.4f} seconds.") | |
| print( | |
| f"Average generation time over {(i + 1)} iterations: {avg_time:0.4f} seconds." | |
| ) | |
| print( | |
| f"Average speed over {(i + 1)} iterations: {avg_tokens_per_second:0.4f} tokens/sec." | |
| ) | |
| print(f"Average memory usage during generation: {avg_memory_gen:.2f} MiB") | |
| if __name__ == "__main__": | |
| main() | |