Spaces:
Runtime error
Runtime error
| import streamlit as st | |
| from predict import PaddleOCR | |
| from pdf2image import convert_from_bytes | |
| import cv2 | |
| import PIL | |
| import numpy as np | |
| import os | |
| import tempfile | |
| import random | |
| import string | |
| from ultralyticsplus import YOLO | |
| import streamlit as st | |
| import numpy as np | |
| import pandas as pd | |
| import matplotlib.pyplot as plt | |
| import matplotlib.patches as patches | |
| import io | |
| import re | |
| from dateutil.parser import parse | |
| import datetime | |
| from file_utils import ( | |
| get_img, | |
| save_excel_file, | |
| concat_csv, | |
| convert_pdf_to_image, | |
| filter_color, | |
| plot, | |
| delete_file, | |
| ) | |
| from process import ( | |
| filter_columns, | |
| extract_text_of_col, | |
| prepare_cols, | |
| process_cols, | |
| finalize_data, | |
| ) | |
| table_model = YOLO("table.pt") | |
| column_model = YOLO("columns.pt") | |
| def remove_dots(string): | |
| # Remove dots from the first and last position of the string | |
| string = string.strip('.') | |
| # Remove the first dot from left to right if there are still more than one dots | |
| if string.count('.') > 1: | |
| string = string.replace(".", "", 1) | |
| return string | |
| def convert_df(df): | |
| return df.to_csv(index=False).encode('utf-8') | |
| def PIL_to_cv(pil_img): | |
| return cv2.cvtColor(np.array(pil_img), cv2.COLOR_RGB2BGR) | |
| def cv_to_PIL(cv_img): | |
| return PIL.Image.fromarray(cv2.cvtColor(cv_img, cv2.COLOR_BGR2RGB)) | |
| def visualize_ocr(pil_img, ocr_result): | |
| plt.imshow(pil_img, interpolation='lanczos') | |
| plt.gcf().set_size_inches(20, 20) | |
| ax = plt.gca() | |
| for idx, result in enumerate(ocr_result): | |
| bbox = result['bbox'] | |
| text = result['text'] | |
| rect = patches.Rectangle(bbox[:2], bbox[2]-bbox[0], bbox[3]-bbox[1], linewidth=2, edgecolor='red', facecolor='none', linestyle='-') | |
| ax.add_patch(rect) | |
| ax.text(bbox[0], bbox[1], text, horizontalalignment='left', verticalalignment='bottom', color='blue', fontsize=7) | |
| plt.xticks([], []) | |
| plt.yticks([], []) | |
| plt.gcf().set_size_inches(10, 10) | |
| plt.axis('off') | |
| img_buf = io.BytesIO() | |
| plt.savefig(img_buf, bbox_inches='tight', dpi=150) | |
| plt.close() | |
| return PIL.Image.open(img_buf) | |
| def filter_columns(columns: np.ndarray): | |
| for idx, col in enumerate(columns): | |
| if idx >= len(columns) - 1: | |
| break | |
| nxt = columns[idx + 1] | |
| threshold = ((col[2] - col[0]) + (nxt[2] - nxt[0])) / 2 | |
| if (col[2] - columns[idx + 1][0]) > threshold * 0.5: | |
| col[1], col[2], col[3] = min(col[1], nxt[1]), nxt[2], max(col[3], nxt[3]) | |
| columns = np.delete(columns, idx + 1, 0) | |
| idx -= 1 | |
| return columns | |
| st.title("Extract Data from Bank Statements") | |
| model = PaddleOCR() | |
| uploaded = st.file_uploader( | |
| "Upload a bank statement pdf file", | |
| type=["png", "jpg", "jpeg", "PNG", "JPG", "JPEG", "pdf", "PDF"], | |
| ) | |
| number = st.number_input('Select Year',value=2023, step=1) | |
| filter = st.checkbox("filter color") | |
| if st.button('Analyze Uploaded File'): | |
| final_csv = pd.DataFrame() | |
| first_flag_dataframe=0 | |
| if uploaded is None: | |
| st.write('Please upload an image') | |
| else: | |
| tabs = st.tabs( | |
| ['Pages','Table Detection', 'Table Structure Recognition', 'Extracted Table(s)'] | |
| ) | |
| print(uploaded.type) | |
| if uploaded.type == "application/pdf": | |
| foldername = tempfile.TemporaryDirectory(dir=os.getcwd()) | |
| filename = uploaded.name.split(".")[0] | |
| pdf_pages=convert_from_bytes(uploaded.read(),500) | |
| for page_enumeration, page in enumerate(pdf_pages, start=1): | |
| with tabs[0]: | |
| st.header('Pages : '+str(page_enumeration)) | |
| st.image(page) | |
| page_img=np.asarray(page) | |
| tables = PaddleOCR.table_model(page_img, conf=0.75) | |
| tabel_datas=tables[0].boxes.data.cpu().numpy() | |
| tables = tables[0].boxes.xyxy.cpu().numpy() | |
| with tabs[1]: | |
| st.header('Table Detection Page :'+str(page_enumeration)) | |
| str_cols = st.columns(4) | |
| str_cols[0].subheader('Table image') | |
| str_cols[1].subheader('Columns') | |
| str_cols[2].subheader('Structure result') | |
| str_cols[3].subheader('Cells result') | |
| results = [] | |
| for table in tables: | |
| try: | |
| tabel_data = np.array( | |
| sorted(tabel_datas, key=lambda x: x[0]), dtype=np.ndarray | |
| ) | |
| tabel_data = filter_columns(tabel_data) | |
| str_cols[0].image(plot(page_img, tabel_data), channels="RGB") | |
| # * crop the table as an image from the original image | |
| sub_img = page_img[ | |
| int(table[1].item()): int(table[3].item()), | |
| int(table[0].item()): int(table[2].item()), | |
| ] | |
| columns_detect = PaddleOCR.column_model(sub_img, conf=0.75) | |
| cols_data = columns_detect[0].boxes.data.cpu().numpy() | |
| # * Sort columns according to the x coordinate | |
| cols_data = np.array( | |
| sorted(cols_data, key=lambda x: x[0]), dtype=np.ndarray | |
| ) | |
| # * merge the duplicated columns | |
| cols_data = filter_columns(cols_data) | |
| str_cols[1].image(plot(sub_img, cols_data), channels="RGB") | |
| except Exception as e: | |
| print(e) | |
| st.warning("No Detection") | |
| try: | |
| #################################################################### | |
| # # columns = cols_data[:, 0:4] | |
| # # #sub_imgs = [] | |
| # # thr = 0 | |
| # # column = columns[0] | |
| # # maxcol1=int(column[1]) | |
| # # maxcol3=int(column[3]) | |
| # # cols = [] | |
| # # for column in columns: | |
| # # if maxcol1 < int(column[1]) : | |
| # # maxcol1=int(column[1]) | |
| # # if maxcol3 < int(column[3]) : | |
| # # maxcol3=int(column[3]) | |
| # # sub_imgs = (sub_img[ maxcol1: maxcol3, : ]) | |
| # # str_cols[2].image(sub_imgs) | |
| # # image = filter_color(sub_imgs) | |
| # # res, threshold,ocr_res = extract_text_of_col(image) | |
| # # vis_ocr_img = visualize_ocr(image, ocr_res) | |
| # # str_cols[3].image(vis_ocr_img) | |
| # # thr += threshold | |
| # # cols.append(prepare_cols(res, threshold * 0.6)) | |
| # # print("cols : ",cols) | |
| # # thr = thr / len(columns) | |
| # # data = process_cols(cols, thr * 0.6) | |
| # # print("data : ",data) | |
| ###################################################################### | |
| columns = cols_data[:, 0:4] | |
| sub_imgs = [] | |
| column = columns[0] | |
| maxcol1=int(column[1]) | |
| maxcol3=int(column[3]) | |
| #for column in columns: | |
| # if maxcol1 < int(column[1]) : | |
| # maxcol1=int(column[1]) | |
| # if maxcol3 < int(column[3]) : | |
| # maxcol3=int(column[3]) | |
| for column in columns: | |
| # * Create list of cropped images for each column | |
| sub_imgs.append(sub_img[maxcol1:maxcol3, int(column[0]): int(column[2])]) | |
| cols = [] | |
| thr = 0 | |
| for image in sub_imgs: | |
| if filter: | |
| # * keep only black color in the image | |
| image = filter_color(image) | |
| # * extract text of each column and get the length threshold | |
| res, threshold, ocr_res = extract_text_of_col(image) | |
| thr += threshold | |
| # * arrange the rows of each column with respect to row length threshold | |
| cols.append(prepare_cols(res, threshold * 0.6)) | |
| thr = thr / len(sub_imgs) | |
| # * append each element in each column to its right place in the dataframe | |
| data = process_cols(cols, thr * 0.6) | |
| # * merge the related rows together | |
| data: pd.DataFrame = finalize_data(data, page_enumeration) | |
| results.append(data) | |
| with tabs[2]: | |
| st.header('Extracted Table(s)') | |
| st.dataframe(data) | |
| print("data : ",data) | |
| print("results : ", results) | |
| if first_flag_dataframe == 0 : | |
| first_flag_dataframe=1 | |
| final_csv=data | |
| else: | |
| final_csv = pd.concat([final_csv,data],ignore_index=True) | |
| csv = convert_df(data) | |
| print(csv) | |
| except: | |
| st.warning("Text Extraction Failed") | |
| continue | |
| with tabs[3]: | |
| st.dataframe(final_csv) | |
| rough_csv= convert_df(final_csv) | |
| st.download_button( | |
| "rough-csv", | |
| rough_csv, | |
| "file.csv", | |
| "text/csv", | |
| key='rough-csv' | |
| ) | |
| final_csv.columns = ['page','Date', 'Transaction_Details', 'Three', 'Deposit','Withdrawal','Balance'] | |
| #final_csv = final_csv.rename(columns={1: 'Date', 2: 'Transaction_Details', 3: 'Three', 4: 'Deposit',5 : 'Withdrawal',6:'Balance'}) | |
| final_csv['Date'] = final_csv['Date'].astype(str) | |
| st.dataframe(final_csv) | |
| final_csv = final_csv[~final_csv['Date'].str.contains('Date')] | |
| final_csv = final_csv[~final_csv['Date'].str.contains('日期')] | |
| final_csv = final_csv[~final_csv['Date'].str.contains('口期')] | |
| final_csv['Date'] = final_csv['Date'].apply(lambda x: re.sub(r'[^a-zA-Z0-9 ]', '', x)) | |
| final_csv['Date'] = final_csv['Date'].apply(lambda x: x + str(number)) | |
| final_csv['Date'] = final_csv['Date'].apply(lambda x:parse(x, fuzzy=True)) | |
| #final_csv['Date']=final_csv['Date'].str.replace(' ', '') | |
| final_csv['*Date'] = pd.to_datetime(final_csv['Date']).dt.strftime('%d-%m-%Y') | |
| final_csv['Withdrawal'] = final_csv['Withdrawal'].astype(str) | |
| final_csv['Withdrawal'] = final_csv['Withdrawal'].str.replace('i', '').str.replace('E', '').str.replace(':', '').str.replace('M', '').str.replace('?', '').str.replace('t', '').str.replace('+', '').str.replace(';', '').str.replace('g', '').str.replace('^', '').str.replace('m', '').str.replace('/', '').str.replace('#', '').str.replace("'", '').str.replace('w', '').str.replace('"', '').str.replace('%', '').str.replace('r', '').str.replace('-', '').str.replace('v', '').str.replace(',', '').str.replace('·', '').str.replace(':', '').str.replace(' ', '').str.replace('*', '').str.replace('~', '').str.replace('V', '') | |
| final_csv['Withdrawal'] = final_csv['Withdrawal'].apply(remove_dots) | |
| final_csv['Withdrawal'] = final_csv['Withdrawal'].astype(float)*-1 | |
| final_csv['Deposit'] = final_csv['Deposit'].astype(str) | |
| final_csv['Deposit'] = final_csv['Deposit'].str.replace('i', '').str.replace('E', '').str.replace(':', '').str.replace('M', '').str.replace('?', '').str.replace('t', '').str.replace('+', '').str.replace(';', '').str.replace('g', '').str.replace('^', '').str.replace('m', '').str.replace('/', '').str.replace('#', '').str.replace("'", '').str.replace('w', '').str.replace('"', '').str.replace('%', '').str.replace('r', '').str.replace('-', '').str.replace('v', '').str.replace(',', '').str.replace('·', '').str.replace(':', '').str.replace(' ', '').str.replace('*', '').str.replace('~', '').str.replace('V', '') | |
| final_csv['Deposit'] = final_csv['Deposit'].apply(remove_dots) | |
| final_csv['Deposit'] = final_csv['Deposit'].astype(float) | |
| final_csv['*Amount'] = final_csv['Withdrawal'].fillna(0) + final_csv['Deposit'].fillna(0) | |
| final_csv = final_csv.drop(['Withdrawal','Deposit'], axis=1) | |
| final_csv['Payee'] = '' | |
| final_csv['Description'] = final_csv['Transaction_Details'] | |
| final_csv.loc[final_csv['Three'].notnull(), 'Description'] += " "+final_csv['Three'] | |
| final_csv = final_csv.drop(['Transaction_Details','Three'], axis=1) | |
| final_csv['Reference'] = '' | |
| final_csv['Check Number'] = '' | |
| df = final_csv[['*Date', '*Amount', 'Payee', 'Description','Reference','Check Number']] | |
| df = df[df['*Amount'] != 0] | |
| csv = convert_df(df) | |
| st.dataframe(df) | |
| st.download_button( | |
| "Press to Download", | |
| csv, | |
| "file.csv", | |
| "text/csv", | |
| key='download-csv' | |
| ) | |
| #success = st.button("Extract", on_click=model, args=[uploaded, filter]) | |