Manoj Kumar
commited on
Commit
·
7c39f2c
1
Parent(s):
2621d33
updated code
Browse files
README.md
CHANGED
|
@@ -5,7 +5,7 @@ colorFrom: red
|
|
| 5 |
colorTo: red
|
| 6 |
sdk: gradio
|
| 7 |
sdk_version: 5.11.0
|
| 8 |
-
app_file:
|
| 9 |
pinned: false
|
| 10 |
python: 3.9
|
| 11 |
---
|
|
|
|
| 5 |
colorTo: red
|
| 6 |
sdk: gradio
|
| 7 |
sdk_version: 5.11.0
|
| 8 |
+
app_file: t5.py
|
| 9 |
pinned: false
|
| 10 |
python: 3.9
|
| 11 |
---
|
db.py
ADDED
|
@@ -0,0 +1,55 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
|
| 2 |
+
|
| 3 |
+
# Example schema
|
| 4 |
+
schema = {
|
| 5 |
+
"products": {
|
| 6 |
+
"columns": ["product_id", "name", "price", "category_id"],
|
| 7 |
+
"relations": "category_id -> categories.id",
|
| 8 |
+
},
|
| 9 |
+
"categories": {
|
| 10 |
+
"columns": ["id", "category_name"],
|
| 11 |
+
"relations": None,
|
| 12 |
+
},
|
| 13 |
+
"orders": {
|
| 14 |
+
"columns": ["order_id", "customer_name", "product_id", "order_date"],
|
| 15 |
+
"relations": "product_id -> products.product_id",
|
| 16 |
+
},
|
| 17 |
+
}
|
| 18 |
+
|
| 19 |
+
# Step 1: Generate context dynamically from schema
|
| 20 |
+
def generate_context(schema):
|
| 21 |
+
context_lines = []
|
| 22 |
+
for table, details in schema.items():
|
| 23 |
+
# List table columns
|
| 24 |
+
columns = ", ".join(details["columns"])
|
| 25 |
+
context_lines.append(f"The {table} table has the following columns: {columns}.")
|
| 26 |
+
|
| 27 |
+
# Add relationships if present
|
| 28 |
+
if details["relations"]:
|
| 29 |
+
context_lines.append(f"The {table} table has the following relationship: {details['relations']}.")
|
| 30 |
+
|
| 31 |
+
return "\n".join(context_lines)
|
| 32 |
+
|
| 33 |
+
# Generate schema context
|
| 34 |
+
schema_context = generate_context(schema)
|
| 35 |
+
|
| 36 |
+
# Step 2: Load the T5-base-text-to-sql model
|
| 37 |
+
model_name = "mrm8488/t5-base-finetuned-wikiSQL" # A model fine-tuned for SQL generation
|
| 38 |
+
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
| 39 |
+
model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
|
| 40 |
+
|
| 41 |
+
# Step 3: Define a natural language query
|
| 42 |
+
user_query = "List all orders where the product price is greater than 50."
|
| 43 |
+
|
| 44 |
+
# Prepare the input for the model
|
| 45 |
+
# Adjust the prompt to focus on SQL generation
|
| 46 |
+
input_text = f"Convert the following question into an SQL query:\nSchema:\n{schema_context}\n\nQuestion:\n{user_query}"
|
| 47 |
+
inputs = tokenizer.encode(input_text, return_tensors="pt")
|
| 48 |
+
|
| 49 |
+
# Step 4: Generate SQL query
|
| 50 |
+
outputs = model.generate(inputs, max_length=128, num_beams=4, early_stopping=True)
|
| 51 |
+
generated_sql = tokenizer.decode(outputs[0], skip_special_tokens=True)
|
| 52 |
+
|
| 53 |
+
# Step 5: Display the result
|
| 54 |
+
print("User Query:", user_query)
|
| 55 |
+
print("Generated SQL Query:", generated_sql)
|
t5.py
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from transformers import AutoModelWithLMHead, AutoTokenizer
|
| 2 |
+
|
| 3 |
+
tokenizer = AutoTokenizer.from_pretrained("mrm8488/t5-base-finetuned-wikiSQL")
|
| 4 |
+
model = AutoModelWithLMHead.from_pretrained("mrm8488/t5-base-finetuned-wikiSQL")
|
| 5 |
+
|
| 6 |
+
def get_sql(query):
|
| 7 |
+
input_text = "translate English to SQL: %s </s>" % query
|
| 8 |
+
features = tokenizer([input_text], return_tensors='pt')
|
| 9 |
+
|
| 10 |
+
output = model.generate(input_ids=features['input_ids'],
|
| 11 |
+
attention_mask=features['attention_mask'])
|
| 12 |
+
|
| 13 |
+
return tokenizer.decode(output[0])
|
| 14 |
+
|
| 15 |
+
query = "How many models were finetuned using BERT as base model?"
|
| 16 |
+
|
| 17 |
+
res = get_sql(query)
|
| 18 |
+
|
| 19 |
+
print(res)
|
| 20 |
+
|
| 21 |
+
# output: 'SELECT COUNT Model fine tuned FROM table WHERE Base model = BERT'
|