Spaces:
Paused
Paused
| from langchain.llms.base import LLM | |
| from langchain.callbacks.manager import CallbackManagerForLLMRun | |
| from typing import Optional, List, Mapping, Any | |
| import warnings | |
| import torch | |
| from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig | |
| from transformers.models.mistral.modeling_mistral import MistralForCausalLM | |
| from transformers.models.llama.tokenization_llama_fast import LlamaTokenizerFast | |
| from pydantic import Field | |
| class CustomLLMMistral(LLM): | |
| model: MistralForCausalLM = Field(...) | |
| tokenizer: LlamaTokenizerFast = Field(...) | |
| def __init__(self): | |
| model_name = "mistralai/Mistral-7B-Instruct-v0.3" | |
| quantization_config = BitsAndBytesConfig(load_in_4bit=True) | |
| model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.bfloat16, quantization_config=quantization_config, device_map="auto") | |
| tokenizer = AutoTokenizer.from_pretrained(model_name) | |
| super().__init__(model=model, tokenizer=tokenizer) | |
| self.model = model | |
| self.tokenizer = tokenizer | |
| def _llm_type(self) -> str: | |
| return "custom" | |
| def _call(self, prompt: str, stop: Optional[List[str]] = None, | |
| run_manager: Optional[CallbackManagerForLLMRun] = None) -> str: | |
| messages = [ | |
| {"role": "user", "content": prompt}, | |
| ] | |
| encodeds = self.tokenizer.apply_chat_template(messages, return_tensors="pt") | |
| model_inputs = encodeds.to(self.model.device) | |
| generated_ids = self.model.generate(model_inputs, max_new_tokens=512, do_sample=True, pad_token_id=self.tokenizer.eos_token_id, top_k=4, temperature=0.7) | |
| decoded = self.tokenizer.batch_decode(generated_ids) | |
| output = decoded[0].split("[/INST]")[1].replace("</s>", "").strip() | |
| if stop is not None: | |
| for word in stop: | |
| output = output.split(word)[0].strip() | |
| while not output.endswith("```"): | |
| output += "`" | |
| return output | |
| def _identifying_params(self) -> Mapping[str, Any]: | |
| return {"model": self.model} |