Spaces:
Runtime error
Runtime error
| import gradio as gr | |
| import datasets | |
| import faiss | |
| import os | |
| from transformers import pipeline | |
| auth_token = os.environ.get("CLARIN_KNEXT") | |
| sample_text = ( | |
| "Europejscy astronomowie odkryli planetę " | |
| "pozasłoneczną pochodzącą spoza naszej galaktyki, czyli " | |
| "[unused0] Drogi Mlecznej [unused1]. Obserwacji dokonali " | |
| "2,2-metrowym teleskopem MPG/ESO." | |
| ) | |
| textbox = gr.Textbox( | |
| label="Type your query here.", | |
| value=sample_text, lines=10 | |
| ) | |
| def load_index(index_data: str = "clarin-knext/entity-linking-index"): | |
| ds = datasets.load_dataset(index_data, use_auth_token=auth_token)['train'] | |
| index_data = { | |
| idx: (e_id, e_text) for idx, (e_id, e_text) in | |
| enumerate(zip(ds['entities'], ds['texts'])) | |
| } | |
| faiss_index = faiss.read_index("./encoder.faissindex", faiss.IO_FLAG_MMAP) | |
| return index_data, faiss_index | |
| def load_model(model_name: str = "clarin-knext/entity-linking-encoder"): | |
| model = pipeline("feature-extraction", model=model_name, use_auth_token=auth_token) | |
| return model | |
| model = load_model() | |
| index = load_index() | |
| def predict(text: str = sample_text, top_k: int=3): | |
| index_data, faiss_index = index | |
| # takes only the [CLS] embedding (for now) | |
| query = model(text, return_tensors='pt')[0][0].numpy().reshape(1, -1) | |
| scores, indices = faiss_index.search(query, top_k) | |
| scores, indices = scores.tolist(), indices.tolist() | |
| results = "\n".join([ | |
| f"{index_data[result[0]]}: {result[1]}" | |
| for output in zip(indices, scores) | |
| for result in zip(*output) | |
| ]) | |
| return results | |
| demo = gr.Interface(fn=predict, inputs=textbox, outputs="text").launch() |