Spaces:
Runtime error
Runtime error
| import gradio as gr | |
| from gradio import * | |
| from run import * | |
| szse_summary_df = pd.read_csv(os.path.join(main_path ,"data/df1.csv")) | |
| tableqa_ = "数据表问答(编辑数据)" | |
| default_val_dict = { | |
| tableqa_ :{ | |
| "tqa_question": "EPS大于0且周涨跌大于5的平均市值是多少?", | |
| "tqa_header": szse_summary_df.columns.tolist(), | |
| "tqa_rows": szse_summary_df.values.tolist(), | |
| "tqa_data_path": os.path.join(main_path ,"data/df1.csv"), | |
| "tqa_answer": { | |
| "sql_query": "SELECT AVG(col_4) FROM Mem_Table WHERE col_5 > 0 and col_3 > 5", | |
| "cnt_num": 2, | |
| "conclusion": [57.645] | |
| } | |
| } | |
| } | |
| def tableqa_layer(post_data): | |
| question = post_data["question"] | |
| table_rows = post_data["table_rows"] | |
| table_header = post_data["table_header"] | |
| assert all(map(lambda x: type(x) == type(""), [question, table_rows, table_header])) | |
| table_rows = json.loads(table_rows) | |
| table_header = json.loads(table_header) | |
| assert all(map(lambda x: type(x) == type([]), [table_rows, table_header])) | |
| if bool(table_rows) and bool(table_header): | |
| assert len(table_header) == len(table_rows[0]) | |
| df = pd.DataFrame(table_rows, columns = table_header) | |
| conclusion = single_table_pred(question, df) | |
| return conclusion | |
| def run_tableqa(*input): | |
| question, data = input | |
| header = data.columns.tolist() | |
| rows = data.values.tolist() | |
| rows = list(filter(lambda x: any(map(lambda xx: bool(xx), x)), rows)) | |
| assert all(map(lambda x: type(x) == type([]), [header, rows])) | |
| header = json.dumps(header) | |
| rows = json.dumps(rows) | |
| assert all(map(lambda x: type(x) == type(""), [question, header, rows])) | |
| resp = tableqa_layer( | |
| { | |
| "question": question, | |
| "table_header": header, | |
| "table_rows": rows | |
| } | |
| ) | |
| if "cnt_num" in resp: | |
| if hasattr(resp["cnt_num"], "tolist"): | |
| resp["cnt_num"] = resp["cnt_num"].tolist() | |
| if "conclusion" in resp: | |
| if hasattr(resp["conclusion"], "tolist"): | |
| resp["conclusion"] = resp["conclusion"].tolist() | |
| ''' | |
| import pickle as pkl | |
| with open("resp.pkl", "wb") as f: | |
| pkl.dump(resp, f) | |
| print(resp) | |
| ''' | |
| resp = json.loads(json.dumps(resp)) | |
| return resp | |
| demo = gr.Blocks(css=".container { max-width: 800px; margin: auto; }") | |
| with demo: | |
| gr.Markdown("") | |
| gr.Markdown("This _example_ was **drive** from <br/><b><h4>[https://github.com/svjack/tableQA-Chinese](https://github.com/svjack/tableQA-Chinese)</h4></b>\n") | |
| with gr.Tabs(): | |
| #### tableqa | |
| with gr.TabItem("数据表问答(TableQA)"): | |
| with gr.Tabs(): | |
| with gr.TabItem(tableqa_): | |
| tqa_question = gr.Textbox( | |
| default_val_dict[tableqa_]["tqa_question"], | |
| label = "问句:(输入)" | |
| ) | |
| tqa_data = gr.Dataframe( | |
| headers=default_val_dict[tableqa_]["tqa_header"], | |
| value=default_val_dict[tableqa_]["tqa_rows"], | |
| row_count = len(default_val_dict[tableqa_]["tqa_rows"]) + 1 | |
| ) | |
| tqa_answer = JSON( | |
| default_val_dict[tableqa_]["tqa_answer"], | |
| label = "问句:(输出)" | |
| ) | |
| tqa_button = gr.Button("得到答案") | |
| tqa_button.click(run_tableqa, inputs=[ | |
| tqa_question, | |
| tqa_data | |
| ], outputs=tqa_answer) | |
| demo.launch(server_name="0.0.0.0") | |