Spaces:
Running
Running
Commit
·
efbcc96
1
Parent(s):
b7dc123
added decision tree visualization
Browse files- app.py +103 -23
- visualise.py +328 -0
app.py
CHANGED
|
@@ -6,6 +6,45 @@ import gradio as gr
|
|
| 6 |
import time
|
| 7 |
import smtplib
|
| 8 |
from email.message import EmailMessage
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 9 |
|
| 10 |
# Make your repo importable (expecting a folder named causal-agent at repo root)
|
| 11 |
sys.path.append(str(Path(__file__).parent / "causal-agent"))
|
|
@@ -120,18 +159,33 @@ def _ok_html(text):
|
|
| 120 |
return f"<div style='padding:10px;border:1px solid #2ea043;border-radius:5px;color:#2ea043;background-color:#333333;'>✅ {text}</div>"
|
| 121 |
|
| 122 |
# --- Email support ---
|
| 123 |
-
|
| 124 |
-
|
| 125 |
-
host = os.getenv("SMTP_HOST")
|
| 126 |
-
port = int(os.getenv("SMTP_PORT", "587"))
|
| 127 |
-
user = os.getenv("SMTP_USER")
|
| 128 |
-
pwd = os.getenv("SMTP_PASS")
|
| 129 |
-
from_addr = os.getenv("EMAIL_FROM")
|
| 130 |
|
| 131 |
-
|
| 132 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 133 |
|
| 134 |
try:
|
|
|
|
| 135 |
msg = EmailMessage()
|
| 136 |
msg["From"] = from_addr
|
| 137 |
msg["To"] = recipient
|
|
@@ -142,10 +196,17 @@ def send_email(recipient: str, subject: str, body_text: str, attachment_name: st
|
|
| 142 |
payload = json.dumps(attachment_json, indent=2).encode("utf-8")
|
| 143 |
msg.add_attachment(payload, maintype="application", subtype="json", filename=attachment_name)
|
| 144 |
|
| 145 |
-
|
| 146 |
-
|
| 147 |
-
|
| 148 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 149 |
return ""
|
| 150 |
except Exception as e:
|
| 151 |
return f"Email send failed: {e}"
|
|
@@ -154,35 +215,43 @@ def run_agent(query: str, csv_path: str, dataset_description: str, email: str):
|
|
| 154 |
start = time.time()
|
| 155 |
|
| 156 |
processing_html = _html_panel("🔄 Analysis in Progress...", "<div style='font-size:14px;color:#bbb;'>This may take 1–2 minutes depending on dataset size.</div>")
|
| 157 |
-
yield (processing_html, processing_html, processing_html, {"status": "Processing started..."})
|
| 158 |
|
| 159 |
if not os.getenv("OPENAI_API_KEY"):
|
| 160 |
-
yield (_err_html("Set a Space Secret named OPENAI_API_KEY"), "", "", {})
|
| 161 |
return
|
| 162 |
if not csv_path:
|
| 163 |
-
yield (_warn_html("Please upload a CSV dataset."), "", "", {})
|
| 164 |
return
|
| 165 |
|
| 166 |
try:
|
| 167 |
step_html = _html_panel("📊 Running Causal Analysis...", "<div style='font-size:14px;color:#bbb;'>Analyzing dataset and selecting optimal method…</div>")
|
| 168 |
-
yield (step_html, step_html, step_html, {"status": "Running causal analysis..."})
|
| 169 |
|
| 170 |
result = run_causal_analysis(
|
| 171 |
query=(query or "What is the effect of treatment T on outcome Y controlling for X?").strip(),
|
| 172 |
dataset_path=csv_path,
|
| 173 |
dataset_description=(dataset_description or "").strip(),
|
| 174 |
)
|
| 175 |
-
|
| 176 |
llm_html = _html_panel("🤖 Generating Summary...", "<div style='font-size:14px;color:#bbb;'>Creating human-readable interpretation…</div>")
|
| 177 |
-
yield (llm_html, llm_html, llm_html, {"status": "Generating explanation...", "raw_analysis": result if isinstance(result, dict) else {}})
|
| 178 |
|
| 179 |
except Exception as e:
|
| 180 |
-
yield (_err_html(str(e)), "", "", {})
|
| 181 |
return
|
| 182 |
|
| 183 |
try:
|
| 184 |
payload = _extract_minimal_payload(result if isinstance(result, dict) else {})
|
|
|
|
| 185 |
method = payload.get("method_used", "N/A")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 186 |
|
| 187 |
method_html = _html_panel("Selected Method", f"<p style='margin:0;font-size:16px;'>{method}</p>")
|
| 188 |
|
|
@@ -199,7 +268,7 @@ def run_agent(query: str, csv_path: str, dataset_description: str, email: str):
|
|
| 199 |
explanation_html = _warn_html(f"LLM summary failed: {e}")
|
| 200 |
|
| 201 |
except Exception as e:
|
| 202 |
-
yield (_err_html(f"Failed to parse results: {e}"), "", "", {})
|
| 203 |
return
|
| 204 |
|
| 205 |
# Optional email send (best-effort)
|
|
@@ -225,8 +294,12 @@ def run_agent(query: str, csv_path: str, dataset_description: str, email: str):
|
|
| 225 |
explanation_html += _warn_html(email_err)
|
| 226 |
else:
|
| 227 |
explanation_html += _ok_html(f"Results emailed to {email.strip()}")
|
|
|
|
| 228 |
|
| 229 |
-
|
|
|
|
|
|
|
|
|
|
| 230 |
|
| 231 |
with gr.Blocks() as demo:
|
| 232 |
gr.Markdown("# Causal AI Scientist")
|
|
@@ -310,16 +383,23 @@ with gr.Blocks() as demo:
|
|
| 310 |
with gr.Row():
|
| 311 |
explanation_out = gr.HTML(label="Detailed Explanation")
|
| 312 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 313 |
with gr.Accordion("Raw Results (Advanced)", open=False):
|
| 314 |
raw_results = gr.JSON(label="Complete Analysis Output", show_label=False)
|
| 315 |
|
| 316 |
run_btn.click(
|
| 317 |
fn=run_agent,
|
| 318 |
inputs=[query, csv_file, dataset_description, email],
|
| 319 |
-
outputs=[method_out, effects_out, explanation_out, raw_results],
|
| 320 |
show_progress=True
|
| 321 |
)
|
| 322 |
|
|
|
|
|
|
|
| 323 |
|
| 324 |
|
| 325 |
if __name__ == "__main__":
|
|
|
|
| 6 |
import time
|
| 7 |
import smtplib
|
| 8 |
from email.message import EmailMessage
|
| 9 |
+
from visualise import render_from_json
|
| 10 |
+
from pathlib import Path
|
| 11 |
+
import time
|
| 12 |
+
import os, json, time, tempfile
|
| 13 |
+
from huggingface_hub import HfApi, HfFileSystem, create_repo
|
| 14 |
+
|
| 15 |
+
REPO = "CausalNLP/cais-demo-cache" # dataset repo id
|
| 16 |
+
TOKEN = os.environ["HF_WRITE_TOKEN"] # set as Space secret
|
| 17 |
+
api = HfApi(token=TOKEN)
|
| 18 |
+
fs = HfFileSystem(token=TOKEN)
|
| 19 |
+
|
| 20 |
+
# 1) ensure repo exists
|
| 21 |
+
create_repo(REPO, repo_type="dataset", private=True, exist_ok=True, token=TOKEN)
|
| 22 |
+
|
| 23 |
+
def cache_run(query, payload, artifacts=None):
|
| 24 |
+
ts = time.strftime("%Y-%m-%dT%H:%M:%S")
|
| 25 |
+
row = {"timestamp": ts, "query": query, "payload": payload, "artifacts": artifacts or {}}
|
| 26 |
+
|
| 27 |
+
hub_path = f"datasets/{REPO}/logs.jsonl"
|
| 28 |
+
# 2) download existing (if any), append, and push in one commit
|
| 29 |
+
with tempfile.TemporaryDirectory() as td:
|
| 30 |
+
local = os.path.join(td, "logs.jsonl")
|
| 31 |
+
try:
|
| 32 |
+
with fs.open(hub_path, "rb") as fsrc, open(local, "wb") as fdst:
|
| 33 |
+
fdst.write(fsrc.read())
|
| 34 |
+
except FileNotFoundError:
|
| 35 |
+
open(local, "w").close()
|
| 36 |
+
|
| 37 |
+
with open(local, "a", encoding="utf-8") as f:
|
| 38 |
+
f.write(json.dumps(row, ensure_ascii=False) + "\n")
|
| 39 |
+
|
| 40 |
+
api.upload_file(
|
| 41 |
+
path_or_fileobj=local,
|
| 42 |
+
path_in_repo="logs.jsonl",
|
| 43 |
+
repo_id=REPO,
|
| 44 |
+
repo_type="dataset",
|
| 45 |
+
commit_message=f"append log {ts}"
|
| 46 |
+
)
|
| 47 |
+
|
| 48 |
|
| 49 |
# Make your repo importable (expecting a folder named causal-agent at repo root)
|
| 50 |
sys.path.append(str(Path(__file__).parent / "causal-agent"))
|
|
|
|
| 159 |
return f"<div style='padding:10px;border:1px solid #2ea043;border-radius:5px;color:#2ea043;background-color:#333333;'>✅ {text}</div>"
|
| 160 |
|
| 161 |
# --- Email support ---
|
| 162 |
+
import base64, json, requests
|
| 163 |
+
from email.message import EmailMessage
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 164 |
|
| 165 |
+
def _gmail_access_token() -> str:
|
| 166 |
+
token_url = "https://oauth2.googleapis.com/token"
|
| 167 |
+
data = {
|
| 168 |
+
"client_id": os.getenv("GMAIL_CLIENT_ID"),
|
| 169 |
+
"client_secret": os.getenv("GMAIL_CLIENT_SECRET"),
|
| 170 |
+
"refresh_token": os.getenv("GMAIL_REFRESH_TOKEN"),
|
| 171 |
+
"grant_type": "refresh_token",
|
| 172 |
+
}
|
| 173 |
+
r = requests.post(token_url, data=data, timeout=20)
|
| 174 |
+
r.raise_for_status()
|
| 175 |
+
return r.json()["access_token"]
|
| 176 |
+
|
| 177 |
+
def send_email(recipient: str, subject: str, body_text: str,
|
| 178 |
+
attachment_name: str = None, attachment_json: dict = None) -> str:
|
| 179 |
+
"""
|
| 180 |
+
Sends via Gmail API. Returns '' on success, or an error string.
|
| 181 |
+
"""
|
| 182 |
+
from_addr = os.getenv("EMAIL_FROM")
|
| 183 |
+
if not all([os.getenv("GMAIL_CLIENT_ID"), os.getenv("GMAIL_CLIENT_SECRET"),
|
| 184 |
+
os.getenv("GMAIL_REFRESH_TOKEN"), from_addr]):
|
| 185 |
+
return "Gmail API not configured (set GMAIL_CLIENT_ID, GMAIL_CLIENT_SECRET, GMAIL_REFRESH_TOKEN, EMAIL_FROM)."
|
| 186 |
|
| 187 |
try:
|
| 188 |
+
# Build MIME message
|
| 189 |
msg = EmailMessage()
|
| 190 |
msg["From"] = from_addr
|
| 191 |
msg["To"] = recipient
|
|
|
|
| 196 |
payload = json.dumps(attachment_json, indent=2).encode("utf-8")
|
| 197 |
msg.add_attachment(payload, maintype="application", subtype="json", filename=attachment_name)
|
| 198 |
|
| 199 |
+
# Base64url encode the raw RFC822 message
|
| 200 |
+
raw = base64.urlsafe_b64encode(msg.as_bytes()).decode("utf-8")
|
| 201 |
+
|
| 202 |
+
# Get access token and send
|
| 203 |
+
access_token = _gmail_access_token()
|
| 204 |
+
api_url = "https://gmail.googleapis.com/gmail/v1/users/me/messages/send"
|
| 205 |
+
headers = {"Authorization": f"Bearer {access_token}", "Content-Type": "application/json"}
|
| 206 |
+
r = requests.post(api_url, headers=headers, json={"raw": raw}, timeout=20)
|
| 207 |
+
|
| 208 |
+
if r.status_code >= 400:
|
| 209 |
+
return f"Gmail API error {r.status_code}: {r.text[:300]}"
|
| 210 |
return ""
|
| 211 |
except Exception as e:
|
| 212 |
return f"Email send failed: {e}"
|
|
|
|
| 215 |
start = time.time()
|
| 216 |
|
| 217 |
processing_html = _html_panel("🔄 Analysis in Progress...", "<div style='font-size:14px;color:#bbb;'>This may take 1–2 minutes depending on dataset size.</div>")
|
| 218 |
+
yield (processing_html, processing_html, processing_html, {"status": "Processing started..."}, None, None)
|
| 219 |
|
| 220 |
if not os.getenv("OPENAI_API_KEY"):
|
| 221 |
+
yield (_err_html("Set a Space Secret named OPENAI_API_KEY"), "", "", {}, None, None)
|
| 222 |
return
|
| 223 |
if not csv_path:
|
| 224 |
+
yield (_warn_html("Please upload a CSV dataset."), "", "", {}, None, None)
|
| 225 |
return
|
| 226 |
|
| 227 |
try:
|
| 228 |
step_html = _html_panel("📊 Running Causal Analysis...", "<div style='font-size:14px;color:#bbb;'>Analyzing dataset and selecting optimal method…</div>")
|
| 229 |
+
yield (step_html, step_html, step_html, {"status": "Running causal analysis..."}, None, None)
|
| 230 |
|
| 231 |
result = run_causal_analysis(
|
| 232 |
query=(query or "What is the effect of treatment T on outcome Y controlling for X?").strip(),
|
| 233 |
dataset_path=csv_path,
|
| 234 |
dataset_description=(dataset_description or "").strip(),
|
| 235 |
)
|
| 236 |
+
cache_run(query, result)
|
| 237 |
llm_html = _html_panel("🤖 Generating Summary...", "<div style='font-size:14px;color:#bbb;'>Creating human-readable interpretation…</div>")
|
| 238 |
+
yield (llm_html, llm_html, llm_html, {"status": "Generating explanation...", "raw_analysis": result if isinstance(result, dict) else {}}, None, None)
|
| 239 |
|
| 240 |
except Exception as e:
|
| 241 |
+
yield (_err_html(str(e)), "", "", {}, None, None)
|
| 242 |
return
|
| 243 |
|
| 244 |
try:
|
| 245 |
payload = _extract_minimal_payload(result if isinstance(result, dict) else {})
|
| 246 |
+
|
| 247 |
method = payload.get("method_used", "N/A")
|
| 248 |
+
# --- Decision tree render ---
|
| 249 |
+
artifacts_dir = Path("artifacts")
|
| 250 |
+
artifacts_dir.mkdir(exist_ok=True)
|
| 251 |
+
ts = time.strftime("%Y%m%d-%H%M%S")
|
| 252 |
+
out_stem = str(artifacts_dir / f"decision_tree_{ts}")
|
| 253 |
+
|
| 254 |
+
# This creates: out_stem.dot, out_stem.svg, out_stem.png
|
| 255 |
|
| 256 |
method_html = _html_panel("Selected Method", f"<p style='margin:0;font-size:16px;'>{method}</p>")
|
| 257 |
|
|
|
|
| 268 |
explanation_html = _warn_html(f"LLM summary failed: {e}")
|
| 269 |
|
| 270 |
except Exception as e:
|
| 271 |
+
yield (_err_html(f"Failed to parse results: {e}"), "", "", {}, "", None)
|
| 272 |
return
|
| 273 |
|
| 274 |
# Optional email send (best-effort)
|
|
|
|
| 294 |
explanation_html += _warn_html(email_err)
|
| 295 |
else:
|
| 296 |
explanation_html += _ok_html(f"Results emailed to {email.strip()}")
|
| 297 |
+
render_from_json(result, out_stem)
|
| 298 |
|
| 299 |
+
tree_png = f"{out_stem}.png"
|
| 300 |
+
tree_svg = f"{out_stem}.svg"
|
| 301 |
+
tree_dot = f"{out_stem}.dot"
|
| 302 |
+
yield (method_html, effects_html, explanation_html, result if isinstance(result, dict) else {}, tree_png, [tree_svg, tree_dot, tree_png])
|
| 303 |
|
| 304 |
with gr.Blocks() as demo:
|
| 305 |
gr.Markdown("# Causal AI Scientist")
|
|
|
|
| 383 |
with gr.Row():
|
| 384 |
explanation_out = gr.HTML(label="Detailed Explanation")
|
| 385 |
|
| 386 |
+
with gr.Row():
|
| 387 |
+
tree_img = gr.Image(label="Decision Tree", type="filepath")
|
| 388 |
+
with gr.Row():
|
| 389 |
+
tree_files = gr.Files(label="Download decision tree artifacts (.svg / .dot / .png)")
|
| 390 |
+
|
| 391 |
with gr.Accordion("Raw Results (Advanced)", open=False):
|
| 392 |
raw_results = gr.JSON(label="Complete Analysis Output", show_label=False)
|
| 393 |
|
| 394 |
run_btn.click(
|
| 395 |
fn=run_agent,
|
| 396 |
inputs=[query, csv_file, dataset_description, email],
|
| 397 |
+
outputs=[method_out, effects_out, explanation_out, raw_results, tree_img, tree_files],
|
| 398 |
show_progress=True
|
| 399 |
)
|
| 400 |
|
| 401 |
+
|
| 402 |
+
|
| 403 |
|
| 404 |
|
| 405 |
if __name__ == "__main__":
|
visualise.py
ADDED
|
@@ -0,0 +1,328 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# -*- coding: utf-8 -*-
|
| 2 |
+
"""
|
| 3 |
+
Render a JSON-aware visualization of CAIS's rule-based method selector.
|
| 4 |
+
- Parses a CAIS run payload (dict) and highlights ALL plausible candidates (green).
|
| 5 |
+
- The actually selected method receives a thicker border.
|
| 6 |
+
- The traversed decision path edges are colored.
|
| 7 |
+
|
| 8 |
+
Usage:
|
| 9 |
+
render_from_json(payload_dict, out_stem="artifacts/decision_tree")
|
| 10 |
+
|
| 11 |
+
(Optional) CLI:
|
| 12 |
+
python decision_tree.py payload.json
|
| 13 |
+
"""
|
| 14 |
+
|
| 15 |
+
from graphviz import Digraph
|
| 16 |
+
import json, sys
|
| 17 |
+
from typing import Dict, Any, List, Set, Tuple, Optional
|
| 18 |
+
|
| 19 |
+
from auto_causal.components.decision_tree import (
|
| 20 |
+
DIFF_IN_MEANS, LINEAR_REGRESSION, DIFF_IN_DIFF, REGRESSION_DISCONTINUITY,
|
| 21 |
+
INSTRUMENTAL_VARIABLE, PROPENSITY_SCORE_MATCHING, PROPENSITY_SCORE_WEIGHTING,
|
| 22 |
+
GENERALIZED_PROPENSITY_SCORE, BACKDOOR_ADJUSTMENT, FRONTDOOR_ADJUSTMENT
|
| 23 |
+
)
|
| 24 |
+
|
| 25 |
+
LABEL = {
|
| 26 |
+
DIFF_IN_MEANS: "Diff-in-Means (RCT)",
|
| 27 |
+
LINEAR_REGRESSION: "Linear Regression",
|
| 28 |
+
DIFF_IN_DIFF: "Difference-in-Differences",
|
| 29 |
+
REGRESSION_DISCONTINUITY: "Regression Discontinuity",
|
| 30 |
+
INSTRUMENTAL_VARIABLE: "Instrumental Variables",
|
| 31 |
+
PROPENSITY_SCORE_MATCHING: "PS Matching",
|
| 32 |
+
PROPENSITY_SCORE_WEIGHTING: "PS Weighting",
|
| 33 |
+
GENERALIZED_PROPENSITY_SCORE: "Generalized PS (continuous T)",
|
| 34 |
+
BACKDOOR_ADJUSTMENT: "Backdoor Adjustment",
|
| 35 |
+
FRONTDOOR_ADJUSTMENT: "Frontdoor Adjustment",
|
| 36 |
+
}
|
| 37 |
+
|
| 38 |
+
# -------- Heuristic extractors from payload -------- #
|
| 39 |
+
|
| 40 |
+
def _get(d: Dict, path: List[str], default=None):
|
| 41 |
+
cur = d
|
| 42 |
+
for k in path:
|
| 43 |
+
if not isinstance(cur, dict) or k not in cur:
|
| 44 |
+
return default
|
| 45 |
+
cur = cur[k]
|
| 46 |
+
return cur
|
| 47 |
+
|
| 48 |
+
def extract_signals(p: Dict[str, Any]) -> Dict[str, Any]:
|
| 49 |
+
vars_ = _get(p, ["results", "variables"], {}) or _get(p, ["variables"], {}) or {}
|
| 50 |
+
da = _get(p, ["results", "dataset_analysis"], {}) or _get(p, ["dataset_analysis"], {}) or {}
|
| 51 |
+
|
| 52 |
+
treatment = vars_.get("treatment_variable")
|
| 53 |
+
t_type = vars_.get("treatment_variable_type") # "binary"/"continuous"
|
| 54 |
+
is_rct = bool(vars_.get("is_rct", False))
|
| 55 |
+
|
| 56 |
+
# Temporal / panel
|
| 57 |
+
temporal_detected = bool(da.get("temporal_structure_detected", False))
|
| 58 |
+
time_var = vars_.get("time_variable")
|
| 59 |
+
group_var = vars_.get("group_variable")
|
| 60 |
+
has_temporal = temporal_detected or bool(time_var) or bool(group_var)
|
| 61 |
+
|
| 62 |
+
# RDD
|
| 63 |
+
running_variable = vars_.get("running_variable")
|
| 64 |
+
cutoff_value = vars_.get("cutoff_value")
|
| 65 |
+
rdd_ready = running_variable is not None and cutoff_value is not None
|
| 66 |
+
# (Some detectors raise 'discontinuities_detected', but we still require running var + cutoff.)
|
| 67 |
+
# If you want permissive behavior, flip rdd_ready to also consider da.get("discontinuities_detected").
|
| 68 |
+
|
| 69 |
+
# Instruments
|
| 70 |
+
instrument = vars_.get("instrument_variable")
|
| 71 |
+
pot_instr = da.get("potential_instruments") or []
|
| 72 |
+
# Consider an instrument valid only if it exists and is NOT the treatment itself
|
| 73 |
+
has_valid_instrument = (
|
| 74 |
+
instrument is not None and instrument != treatment
|
| 75 |
+
) or any(pi and pi != treatment for pi in pot_instr)
|
| 76 |
+
|
| 77 |
+
covariates = vars_.get("covariates") or []
|
| 78 |
+
has_covariates = len(covariates) > 0
|
| 79 |
+
|
| 80 |
+
# Frontdoor: only mark if explicitly provided (else too speculative)
|
| 81 |
+
frontdoor_ok = bool(_get(p, ["results", "dataset_analysis", "frontdoor_satisfied"], False))
|
| 82 |
+
|
| 83 |
+
# Overlap: if explicitly known, use it; else unknown → both PS variants remain plausible.
|
| 84 |
+
overlap_assessment = da.get("overlap_assessment")
|
| 85 |
+
strong_overlap = None
|
| 86 |
+
if isinstance(overlap_assessment, dict):
|
| 87 |
+
# accept typical keys like {"strong_overlap": true}
|
| 88 |
+
strong_overlap = overlap_assessment.get("strong_overlap")
|
| 89 |
+
|
| 90 |
+
return dict(
|
| 91 |
+
treatment=treatment,
|
| 92 |
+
t_type=t_type,
|
| 93 |
+
is_rct=is_rct,
|
| 94 |
+
has_temporal=has_temporal,
|
| 95 |
+
rdd_ready=rdd_ready,
|
| 96 |
+
has_valid_instrument=has_valid_instrument,
|
| 97 |
+
has_covariates=has_covariates,
|
| 98 |
+
frontdoor_ok=frontdoor_ok,
|
| 99 |
+
strong_overlap=strong_overlap,
|
| 100 |
+
)
|
| 101 |
+
|
| 102 |
+
# -------- Candidate inference (green leaves) -------- #
|
| 103 |
+
|
| 104 |
+
def infer_candidate_methods(signals: Dict[str, Any]) -> Set[str]:
|
| 105 |
+
cands: Set[str] = set()
|
| 106 |
+
is_rct = signals["is_rct"]
|
| 107 |
+
|
| 108 |
+
# RCT branch: both Diff-in-Means and LR are valid analyses; IV only if a valid instrument exists (e.g., randomized encouragement)
|
| 109 |
+
if is_rct:
|
| 110 |
+
cands.add(DIFF_IN_MEANS)
|
| 111 |
+
if signals["has_covariates"]:
|
| 112 |
+
cands.add(LINEAR_REGRESSION)
|
| 113 |
+
if signals["has_valid_instrument"]:
|
| 114 |
+
cands.add(INSTRUMENTAL_VARIABLE)
|
| 115 |
+
return cands # stop here; the observational tree is not needed
|
| 116 |
+
|
| 117 |
+
# Observational branch
|
| 118 |
+
if signals["has_temporal"]:
|
| 119 |
+
cands.add(DIFF_IN_DIFF)
|
| 120 |
+
if signals["rdd_ready"]:
|
| 121 |
+
cands.add(REGRESSION_DISCONTINUITY)
|
| 122 |
+
if signals["has_valid_instrument"]:
|
| 123 |
+
cands.add(INSTRUMENTAL_VARIABLE)
|
| 124 |
+
if signals["frontdoor_ok"]:
|
| 125 |
+
cands.add(FRONTDOOR_ADJUSTMENT)
|
| 126 |
+
|
| 127 |
+
# Treatment type
|
| 128 |
+
if str(signals["t_type"]).lower() == "continuous":
|
| 129 |
+
cands.add(GENERALIZED_PROPENSITY_SCORE)
|
| 130 |
+
|
| 131 |
+
# Backdoor / PS (need covariates)
|
| 132 |
+
if signals["has_covariates"]:
|
| 133 |
+
# If overlap is known, choose one; if unknown, mark both as plausible.
|
| 134 |
+
if signals["strong_overlap"] is True:
|
| 135 |
+
cands.add(PROPENSITY_SCORE_MATCHING)
|
| 136 |
+
elif signals["strong_overlap"] is False:
|
| 137 |
+
cands.add(PROPENSITY_SCORE_WEIGHTING)
|
| 138 |
+
else:
|
| 139 |
+
cands.add(PROPENSITY_SCORE_MATCHING)
|
| 140 |
+
cands.add(PROPENSITY_SCORE_WEIGHTING)
|
| 141 |
+
cands.add(BACKDOOR_ADJUSTMENT)
|
| 142 |
+
|
| 143 |
+
return cands
|
| 144 |
+
|
| 145 |
+
# -------- Compute the single realized path to the chosen leaf (for edge coloring) -------- #
|
| 146 |
+
|
| 147 |
+
def infer_decision_path(signals: Dict[str, Any], selected_method: Optional[str]) -> List[Tuple[str, str]]:
|
| 148 |
+
path: List[Tuple[str, str]] = []
|
| 149 |
+
# Start → is_rct
|
| 150 |
+
path.append(("start", "is_rct"))
|
| 151 |
+
|
| 152 |
+
if signals["is_rct"]:
|
| 153 |
+
path.append(("is_rct", "has_instr_rct"))
|
| 154 |
+
if signals["has_valid_instrument"]:
|
| 155 |
+
path.append(("has_instr_rct", INSTRUMENTAL_VARIABLE))
|
| 156 |
+
else:
|
| 157 |
+
path.append(("has_instr_rct", "has_cov_rct"))
|
| 158 |
+
if signals["has_covariates"]:
|
| 159 |
+
path.append(("has_cov_rct", LINEAR_REGRESSION))
|
| 160 |
+
else:
|
| 161 |
+
path.append(("has_cov_rct", DIFF_IN_MEANS))
|
| 162 |
+
return path
|
| 163 |
+
|
| 164 |
+
# Observational
|
| 165 |
+
path.append(("is_rct", "has_temporal"))
|
| 166 |
+
if signals["has_temporal"]:
|
| 167 |
+
path.append(("has_temporal", DIFF_IN_DIFF))
|
| 168 |
+
return path
|
| 169 |
+
else:
|
| 170 |
+
path.append(("has_temporal", "has_rv"))
|
| 171 |
+
|
| 172 |
+
if signals["rdd_ready"]:
|
| 173 |
+
path.append(("has_rv", REGRESSION_DISCONTINUITY))
|
| 174 |
+
return path
|
| 175 |
+
else:
|
| 176 |
+
path.append(("has_rv", "has_instr"))
|
| 177 |
+
|
| 178 |
+
if signals["has_valid_instrument"]:
|
| 179 |
+
path.append(("has_instr", INSTRUMENTAL_VARIABLE))
|
| 180 |
+
return path
|
| 181 |
+
else:
|
| 182 |
+
path.append(("has_instr", "frontdoor"))
|
| 183 |
+
|
| 184 |
+
if signals["frontdoor_ok"]:
|
| 185 |
+
path.append(("frontdoor", FRONTDOOR_ADJUSTMENT))
|
| 186 |
+
return path
|
| 187 |
+
else:
|
| 188 |
+
path.append(("frontdoor", "t_cont"))
|
| 189 |
+
|
| 190 |
+
if str(signals["t_type"]).lower() == "continuous":
|
| 191 |
+
path.append(("t_cont", GENERALIZED_PROPENSITY_SCORE))
|
| 192 |
+
return path
|
| 193 |
+
else:
|
| 194 |
+
path.append(("t_cont", "has_cov"))
|
| 195 |
+
|
| 196 |
+
if signals["has_covariates"]:
|
| 197 |
+
path.append(("has_cov", "overlap"))
|
| 198 |
+
# If overlap known, pick the branch; else default to weighting.
|
| 199 |
+
if signals["strong_overlap"] is True:
|
| 200 |
+
path.append(("overlap", PROPENSITY_SCORE_MATCHING))
|
| 201 |
+
else:
|
| 202 |
+
path.append(("overlap", PROPENSITY_SCORE_WEIGHTING))
|
| 203 |
+
else:
|
| 204 |
+
path.append(("has_cov", BACKDOOR_ADJUSTMENT)) # keep original topology; see note in previous message
|
| 205 |
+
return path
|
| 206 |
+
|
| 207 |
+
# -------- Graph building -------- #
|
| 208 |
+
|
| 209 |
+
def build_graph(payload: Dict[str, Any]) -> Digraph:
|
| 210 |
+
g = Digraph("CAISDecisionTree", format="svg")
|
| 211 |
+
g.attr(rankdir="LR", nodesep="0.4", ranksep="0.35", fontsize="11")
|
| 212 |
+
|
| 213 |
+
# Decisions
|
| 214 |
+
g.node("start", "Start", shape="circle")
|
| 215 |
+
g.node("is_rct", "Is RCT?", shape="diamond")
|
| 216 |
+
g.node("has_instr_rct", "Instrument available?", shape="diamond")
|
| 217 |
+
g.node("has_cov_rct", "Covariates observed?", shape="diamond")
|
| 218 |
+
g.node("has_temporal", "Temporal structure?", shape="diamond")
|
| 219 |
+
g.node("has_rv", "Running var & cutoff?", shape="diamond")
|
| 220 |
+
g.node("has_instr", "Instrument available?", shape="diamond")
|
| 221 |
+
g.node("frontdoor", "Frontdoor criterion satisfied?", shape="diamond")
|
| 222 |
+
g.node("has_cov", "Covariates observed?", shape="diamond")
|
| 223 |
+
g.node("overlap", "Strong overlap?\n(overlap ≥ 0.1)", shape="diamond")
|
| 224 |
+
g.node("t_cont", "Treatment continuous?", shape="diamond")
|
| 225 |
+
|
| 226 |
+
# Leaves
|
| 227 |
+
def leaf(name_const, fill=None, bold=False):
|
| 228 |
+
attrs = {"shape": "box", "style": "rounded"}
|
| 229 |
+
if fill:
|
| 230 |
+
attrs.update(style="rounded,filled", fillcolor=fill)
|
| 231 |
+
if bold:
|
| 232 |
+
attrs.update(penwidth="2")
|
| 233 |
+
g.node(name_const, LABEL[name_const], **attrs)
|
| 234 |
+
|
| 235 |
+
# Compute signals, candidates, path
|
| 236 |
+
signals = extract_signals(payload)
|
| 237 |
+
candidates = infer_candidate_methods(signals)
|
| 238 |
+
|
| 239 |
+
selected_method_str = _get(payload, ["results", "results", "method_used"]) \
|
| 240 |
+
or _get(payload, ["results", "method_used"]) \
|
| 241 |
+
or _get(payload, ["method"])
|
| 242 |
+
selected_method = {
|
| 243 |
+
"linear_regression": LINEAR_REGRESSION,
|
| 244 |
+
"diff_in_means": DIFF_IN_MEANS,
|
| 245 |
+
"difference_in_differences": DIFF_IN_DIFF,
|
| 246 |
+
"regression_discontinuity": REGRESSION_DISCONTINUITY,
|
| 247 |
+
"instrumental_variable": INSTRUMENTAL_VARIABLE,
|
| 248 |
+
"propensity_score_matching": PROPENSITY_SCORE_MATCHING,
|
| 249 |
+
"propensity_score_weighting": PROPENSITY_SCORE_WEIGHTING,
|
| 250 |
+
"generalized_propensity_score": GENERALIZED_PROPENSITY_SCORE,
|
| 251 |
+
"backdoor_adjustment": BACKDOOR_ADJUSTMENT,
|
| 252 |
+
"frontdoor_adjustment": FRONTDOOR_ADJUSTMENT,
|
| 253 |
+
}.get(str(selected_method_str or "").lower())
|
| 254 |
+
|
| 255 |
+
# Add leaves with coloring
|
| 256 |
+
for m in [
|
| 257 |
+
DIFF_IN_MEANS, LINEAR_REGRESSION, DIFF_IN_DIFF, REGRESSION_DISCONTINUITY,
|
| 258 |
+
INSTRUMENTAL_VARIABLE, PROPENSITY_SCORE_MATCHING, PROPENSITY_SCORE_WEIGHTING,
|
| 259 |
+
GENERALIZED_PROPENSITY_SCORE, BACKDOOR_ADJUSTMENT, FRONTDOOR_ADJUSTMENT
|
| 260 |
+
]:
|
| 261 |
+
leaf(m,
|
| 262 |
+
fill=("palegreen" if m in candidates else None),
|
| 263 |
+
bold=(m == selected_method))
|
| 264 |
+
|
| 265 |
+
# Edges with optional path highlighting
|
| 266 |
+
path_edges = set(infer_decision_path(signals, selected_method))
|
| 267 |
+
def e(u, v, label=None):
|
| 268 |
+
attrs = {}
|
| 269 |
+
if (u, v) in path_edges:
|
| 270 |
+
attrs.update(color="forestgreen", penwidth="2")
|
| 271 |
+
g.edge(u, v, **({} if label is None else {"label": label}) | attrs)
|
| 272 |
+
|
| 273 |
+
# Topology (unchanged)
|
| 274 |
+
e("start", "is_rct")
|
| 275 |
+
|
| 276 |
+
# RCT branch
|
| 277 |
+
e("is_rct", "has_instr_rct", label="Yes")
|
| 278 |
+
e("has_instr_rct", INSTRUMENTAL_VARIABLE, label="Yes")
|
| 279 |
+
e("has_instr_rct", "has_cov_rct", label="No")
|
| 280 |
+
e("has_cov_rct", LINEAR_REGRESSION, label="Yes")
|
| 281 |
+
e("has_cov_rct", DIFF_IN_MEANS, label="No")
|
| 282 |
+
|
| 283 |
+
# Observational branch
|
| 284 |
+
e("is_rct", "has_temporal", label="No")
|
| 285 |
+
e("has_temporal", DIFF_IN_DIFF, label="Yes")
|
| 286 |
+
e("has_temporal", "has_rv", label="No")
|
| 287 |
+
|
| 288 |
+
e("has_rv", REGRESSION_DISCONTINUITY, label="Yes")
|
| 289 |
+
e("has_rv", "has_instr", label="No")
|
| 290 |
+
|
| 291 |
+
e("has_instr", INSTRUMENTAL_VARIABLE, label="Yes")
|
| 292 |
+
e("has_instr", "frontdoor", label="No")
|
| 293 |
+
|
| 294 |
+
e("frontdoor", FRONTDOOR_ADJUSTMENT, label="Yes")
|
| 295 |
+
e("frontdoor", "t_cont", label="No")
|
| 296 |
+
|
| 297 |
+
e("t_cont", GENERALIZED_PROPENSITY_SCORE, label="Yes")
|
| 298 |
+
e("t_cont", "has_cov", label="No")
|
| 299 |
+
|
| 300 |
+
e("has_cov", "overlap", label="Yes")
|
| 301 |
+
e("has_cov", BACKDOOR_ADJUSTMENT, label="No")
|
| 302 |
+
|
| 303 |
+
e("overlap", PROPENSITY_SCORE_MATCHING, label="Yes")
|
| 304 |
+
e("overlap", PROPENSITY_SCORE_WEIGHTING, label="No")
|
| 305 |
+
|
| 306 |
+
# Optional legend
|
| 307 |
+
g.node("legend", "Legend:\nGreen = plausible candidate(s)\nBold border = method used", shape="note")
|
| 308 |
+
g.edge("legend", "start", style="dashed", arrowhead="none")
|
| 309 |
+
|
| 310 |
+
return g
|
| 311 |
+
|
| 312 |
+
def render_from_json(payload: Dict[str, Any], out_stem: str = "artifacts/decision_tree"):
|
| 313 |
+
g = build_graph(payload)
|
| 314 |
+
g.save(filename=f"{out_stem}.dot")
|
| 315 |
+
g.render(filename=out_stem, cleanup=True) # SVG
|
| 316 |
+
g.format = "png"
|
| 317 |
+
g.render(filename=out_stem, cleanup=True) # PNG
|
| 318 |
+
|
| 319 |
+
def main():
|
| 320 |
+
# if len(sys.argv) >= 2:
|
| 321 |
+
with open('sample_output.json', "r") as f:
|
| 322 |
+
payload = json.load(f)
|
| 323 |
+
# else:
|
| 324 |
+
# payload = json.load()
|
| 325 |
+
render_from_json(payload)
|
| 326 |
+
|
| 327 |
+
if __name__ == "__main__":
|
| 328 |
+
main()
|