Spaces:
Paused
Paused
| import io | |
| import os | |
| import torch | |
| import streamlit as st | |
| from PyPDF2 import PdfReader | |
| from langchain.text_splitter import CharacterTextSplitter | |
| from langchain.chains.question_answering import load_qa_chain | |
| from langchain_huggingface import HuggingFaceEmbeddings, HuggingFacePipeline | |
| from langchain_community.vectorstores import FAISS | |
| from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline | |
| # Global variables are no longer needed, we will use session state | |
| # PDF ํ์ผ ๋ก๋ ๋ฐ ํ ์คํธ ์ถ์ถ | |
| def load_pdf(pdf_file): | |
| pdf_reader = PdfReader(pdf_file) | |
| text = "".join(page.extract_text() for page in pdf_reader.pages) | |
| return text | |
| # ํ ์คํธ๋ฅผ ์ฒญํฌ๋ก ๋ถํ | |
| def split_text(text): | |
| text_splitter = CharacterTextSplitter( | |
| separator="\n", | |
| chunk_size=1000, | |
| chunk_overlap=200, | |
| length_function=len | |
| ) | |
| return text_splitter.split_text(text) | |
| # FAISS ๋ฒกํฐ ์ ์ฅ์ ์์ฑ | |
| def create_knowledge_base(chunks): | |
| model_name = "sentence-transformers/all-mpnet-base-v2" # ์๋ฒ ๋ฉ ๋ชจ๋ธ์ ๋ช ์ | |
| embeddings = HuggingFaceEmbeddings(model_name=model_name) | |
| return FAISS.from_texts(chunks, embeddings) | |
| # Hugging Face ๋ชจ๋ธ ๋ก๋ | |
| def load_model(): | |
| model_name = "google/gemma-2-2b" # Hugging Face ๋ชจ๋ธ ID | |
| access_token = os.getenv("HF_TOKEN") | |
| try: | |
| tokenizer = AutoTokenizer.from_pretrained(model_name, token=access_token, clean_up_tokenization_spaces=False) | |
| model = AutoModelForCausalLM.from_pretrained(model_name, token=access_token) | |
| # ๋๋ฐ์ด์ค ์ค์ | |
| if torch.cuda.is_available(): | |
| device = 0 | |
| else: | |
| device = -1 | |
| # `do_sample`์ True๋ก ์ค์ | |
| return pipeline( | |
| "text-generation", | |
| model=model, | |
| tokenizer=tokenizer, | |
| max_new_tokens=150, | |
| temperature=0.1, | |
| do_sample=True, # ์ด ์ค์ ์ถ๊ฐ | |
| device=device | |
| ) | |
| except Exception as e: | |
| print(f"Error loading model: {e}") | |
| return None | |
| # ๋ชจ๋ธ ์๋ต ์ฒ๋ฆฌ | |
| def get_response_from_model(prompt): | |
| try: | |
| if "knowledge_base" not in st.session_state: | |
| return "No PDF has been uploaded yet." | |
| if "qa_chain" not in st.session_state: | |
| return "QA chain is not initialized." | |
| docs = st.session_state.knowledge_base.similarity_search(prompt) | |
| print("docs:", docs) # ์ด๊น์ง ๋๋๋ฐ | |
| print("prompt:", prompt) # ์ด๊น์ง ๋๋๋ฐ | |
| # Chain์ invoke() ๋ฉ์๋ ์ฌ์ฉ (input_documents๋ก ์ ๋ฌ) | |
| response = st.session_state.qa_chain.invoke({ | |
| "input_documents": docs, | |
| "question": prompt | |
| }) | |
| try: | |
| if "Helpful Answer:" in response: | |
| response = response.split("Helpful Answer:")[1].strip() | |
| except ValueError as e: | |
| print(f"ValueError occurred: {e}") | |
| return f"Error: Invalid response format - {e}" | |
| return response | |
| except Exception as e: | |
| return f"Error: {str(e)}" | |
| # ํ์ด์ง UI | |
| def main(): | |
| st.title("Welcome to GemmaPaperQA") | |
| # PDF ์ ๋ก๋ ์น์ | |
| with st.expander("Upload Your Paper", expanded=True): | |
| paper = st.file_uploader("Upload Here!", type="pdf", label_visibility="hidden") | |
| if paper: | |
| st.write(f"Upload complete! File name: {paper.name}") | |
| # ํ์ผ ํฌ๊ธฐ ํ์ธ | |
| file_size = paper.size # ํ์ผ ํฌ๊ธฐ๋ฅผ ํ์ผ ํฌ์ธํฐ ์ด๋ ์์ด ํ์ธ | |
| if file_size > 10 * 1024 * 1024: # 10MB ์ ํ | |
| st.error("File is too large! Please upload a file smaller than 10MB.") | |
| return | |
| # PDF ํ ์คํธ ๋ฏธ๋ฆฌ๋ณด๊ธฐ | |
| with st.spinner('Processing PDF...'): | |
| try: | |
| paper.seek(0) | |
| contents = paper.read() | |
| pdf_file = io.BytesIO(contents) | |
| text = load_pdf(pdf_file) | |
| if len(text.strip()) == 0: | |
| st.error("The PDF appears to have no extractable text. Please check the file and try again.") | |
| return | |
| st.text_area("Preview of extracted text", text[:1000], height=200) | |
| st.write(f"Total characters extracted: {len(text)}") | |
| if st.button("Create Knowledge Base"): | |
| chunks = split_text(text) | |
| st.session_state.knowledge_base = create_knowledge_base(chunks) | |
| print("knowledge_base:", st.session_state.knowledge_base) | |
| if st.session_state.knowledge_base is None: | |
| st.error("Failed to create knowledge base.") | |
| return | |
| try: | |
| pipe = load_model() | |
| except Exception as e: | |
| st.error(f"Error loading model: {e}") | |
| return | |
| llm = HuggingFacePipeline(pipeline=pipe) | |
| st.session_state.qa_chain = load_qa_chain(llm, chain_type="map_rerank") | |
| st.success("Knowledge base created! You can now ask questions.") | |
| except Exception as e: | |
| st.error(f"Failed to process the PDF: {str(e)}") | |
| # ์ง๋ฌธ-์๋ต ์น์ | |
| if "knowledge_base" in st.session_state and "qa_chain" in st.session_state: | |
| with st.expander("Ask Questions", expanded=True): | |
| prompt = st.text_input("Chat here!") | |
| if prompt: | |
| response = get_response_from_model(prompt) | |
| if response: | |
| st.write(f"**Assistant**: {response}") | |
| # ์ฑ ์คํ | |
| if __name__ == "__main__": | |
| main() | |