Spaces:
Sleeping
Sleeping
| import os | |
| import gradio as gr | |
| import matplotlib.pyplot as plt | |
| import numpy as np | |
| import osail_utils | |
| import pandas as pd | |
| import skimage | |
| from mediffusion import DiffusionModule | |
| import monai as mn | |
| import torch | |
| # Loading the model for inference | |
| model = DiffusionModule("./diffusion_configs.yaml") | |
| model.load_ckpt("./data/model.ckpt") | |
| model.cuda().half() | |
| model.eval(); | |
| # Loading a baseline noise for making predictions | |
| seed = 3407 | |
| np.random.seed(seed) | |
| torch.random.manual_seed(seed) | |
| torch.cuda.manual_seed(seed) | |
| torch.cuda.manual_seed_all(seed) | |
| torch.backends.cudnn.deterministic = True | |
| BASELINE_NOISE = torch.randn(1, 1, 256, 256).cuda().half() | |
| # Model helper functions | |
| def create_ds(img_paths): | |
| if type(img_paths) == str: | |
| img_paths = [img_paths] | |
| data_list = [{"img": img_path} for img_path in img_paths] | |
| # Get the transforms | |
| Ts_list = [ | |
| osail_utils.io.LoadImageD(keys=["img"], transpose=True, normalize=True), | |
| mn.transforms.EnsureChannelFirstD( | |
| keys=["img"], channel_dim="no_channel" | |
| ), | |
| mn.transforms.ResizeD( | |
| keys=["img"], | |
| spatial_size=(256, 256), | |
| mode=["bicubic"], | |
| ), | |
| mn.transforms.ScaleIntensityD(keys=["img"], minv=0, maxv=1), | |
| mn.transforms.ToTensorD(keys=["img"], track_meta=None), | |
| mn.transforms.SelectItemsD(keys=["img"]), | |
| ] | |
| return mn.data.Dataset(data_list, transform=mn.transforms.Compose(Ts_list)) | |
| def make_predictions(img_path, angles=None, cls_batch=None, rotate_to_standard=False, sampler="DDIM100"): | |
| global model | |
| global BASELINE_NOISE | |
| # Create the image dataset | |
| if cls_batch is not None: | |
| ds = create_ds([img_path]*len(cls_batch)) | |
| else: | |
| ds = create_ds(img_path) | |
| dl = mn.data.DataLoader(ds, batch_size=len(ds), num_workers=0 if len(ds)==1 else 4, shuffle=False) | |
| input_batch = next(iter(dl)) | |
| original_imgs = input_batch["img"].detach().cpu().numpy() | |
| # Create the classifier condition if not provided | |
| if cls_batch is None: | |
| fp = torch.zeros(768) | |
| if rotate_to_standard or angles is None: | |
| angles = [1000, 1000, 1000] | |
| cls_value = torch.tensor([2, *angles, *fp]) | |
| else: | |
| cls_value = torch.tensor([1, *angles, *fp]) | |
| cls_batch = cls_value.unsqueeze(0).repeat(input_batch["img"].shape[0], 1).cuda().half() | |
| # Generate noise | |
| noise = BASELINE_NOISE.repeat(input_batch["img"].shape[0], 1, 1, 1) | |
| model_kwargs = { | |
| "cls": cls_batch, | |
| "concat": input_batch["img"].cuda().half(), | |
| } | |
| # Make predictions | |
| preds = model.predict( | |
| noise, model_kwargs=model_kwargs, classifier_cond_scale=4, inference_protocol=sampler | |
| ) | |
| adjusted_preds = list() | |
| for pred, original_img in zip(preds, original_imgs): | |
| adjusted_pred = pred.detach().cpu().numpy().squeeze() | |
| original_img = original_img.squeeze() | |
| adjusted_pred = skimage.exposure.match_histograms(adjusted_pred, original_img) | |
| adjusted_preds.append(adjusted_pred) | |
| return adjusted_preds | |
| # Gradio helper functions | |
| current_img = None | |
| live_preds = None | |
| def rotate_btn_fn(img_path, xt, yt, zt, add_bone_cmap=False): | |
| global current_img | |
| angles = [float(xt), float(yt), float(zt)] | |
| out_img = make_predictions(img_path, angles)[0] | |
| if not add_bone_cmap: | |
| print(out_img.shape) | |
| return out_img | |
| cmap = plt.get_cmap('bone') | |
| out_img = cmap(out_img) | |
| out_img = (out_img[..., :3] * 255).astype(np.uint8) | |
| current_img = out_img | |
| return out_img | |
| def rotate_to_standard_btn_fn(img_path, add_bone_cmap=False): | |
| global current_img | |
| out_img = make_predictions(img_path, rotate_to_standard=True)[0] | |
| if not add_bone_cmap: | |
| return out_img | |
| cmap = plt.get_cmap('bone') | |
| out_img = cmap(out_img) | |
| out_img = (out_img[..., :3] * 255).astype(np.uint8) | |
| current_img = out_img | |
| return out_img | |
| def use_current_btn_fn(input_img): | |
| return input_img | |
| def make_live_btn_fn(img_path, axis, add_bone_cmap=False): | |
| global live_preds | |
| base_angles = list(range(-20, 21, 1)) | |
| base_angles = [float(i) for i in base_angles] | |
| if axis.lower() == "axis x": | |
| all_angles = [[i, 0, 0] for i in base_angles] | |
| elif axis.lower() == "axis y": | |
| all_angles = [[0, i, 0] for i in base_angles] | |
| elif axis.lower() == "axis z": | |
| all_angles = [[0, 0, i] for i in base_angles] | |
| fp = torch.zeros(768) | |
| cls_batch = torch.tensor([[1, *angles, *fp] for angles in all_angles]) | |
| live_preds = make_predictions(img_path, cls_batch=cls_batch) | |
| live_preds = {angle: live_preds[i] for i, angle in enumerate(base_angles)} | |
| return img_path | |
| def rotate_live_img_fn(angle, add_bone_cmap=False): | |
| global live_img | |
| global live_preds | |
| if live_img is not None: | |
| if angle == 0: | |
| return live_img | |
| return live_preds[float(angle)] | |
| css_style = "./style.css" | |
| callback = gr.CSVLogger() | |
| with gr.Blocks(css=css_style) as app: | |
| gr.HTML("VCNet: A Deep Learning Solution for Roating RadioGraphs in 3D Space", elem_classes="title") | |
| gr.HTML("Developed by the Orthopedics Surgery Artificial Intelligence Lab (OSAIL)", elem_classes="note") | |
| gr.HTML("Note: This is a proof-of-concept demo of an AI tool that is not yet finalized. Please interpret with care!", elem_classes="note") | |
| with gr.TabItem("Single Rotation"): | |
| with gr.Row(): | |
| input_img = gr.Image(type='filepath', label='Input image', sources='upload', interactive=False, elem_classes='imgs') | |
| output_img = gr.Image(type='pil', label='Output image', interactive=False, elem_classes='imgs') | |
| with gr.Row(): | |
| gr.Examples( | |
| examples = [os.path.join("./data/examples", f) for f in os.listdir("./data/examples") if "xr" in f], | |
| inputs = [input_img], | |
| label = "Xray Examples", | |
| elem_id='examples' | |
| ) | |
| gr.Examples( | |
| examples = [os.path.join("./data/examples", f) for f in os.listdir("./data/examples") if "drr" in f], | |
| inputs = [input_img], | |
| label = "DRR Examples", | |
| elem_id='examples' | |
| ) | |
| with gr.Row(): | |
| gr.Markdown('Please select an example image, choose your rotation angles, and press Rotate!', elem_classes='text') | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| xt = gr.Slider(label='Rotation angle in x axis:', elem_classes='angle', value=0, minimum=-20, maximum=20, step=1) | |
| with gr.Column(scale=1): | |
| yt = gr.Slider(label='Rotation angle in y axis:', elem_classes='angle', value=0, minimum=-20, maximum=20, step=1) | |
| with gr.Column(scale=1): | |
| zt = gr.Slider(label='Rotation angle in z axis:', elem_classes='angle', value=0, minimum=-20, maximum=20, step=1) | |
| with gr.Row(): | |
| rotate_btn = gr.Button("Rotate!", elem_classes='rotate_button') | |
| with gr.Row(): | |
| rotate_to_standard_btn = gr.Button("Rotate to standard view!", elem_classes='rotate_to_standard_button') | |
| with gr.Row(): | |
| use_current_btn = gr.Button("Use the current output as the new input!", elem_classes='use_current_button') | |
| rotate_btn.click(fn=rotate_btn_fn, inputs=[input_img, xt, yt, zt], outputs=output_img) | |
| rotate_to_standard_btn.click(fn=rotate_to_standard_btn_fn, inputs=[input_img], outputs=output_img) | |
| use_current_btn.click(fn=use_current_btn_fn, inputs=[output_img], outputs=input_img) | |
| with gr.TabItem("Live Rotation"): | |
| with gr.Row(): | |
| live_img = gr.Image(type='filepath', label='Live Image', sources='upload', interactive=False, elem_classes='imgs') | |
| with gr.Row(): | |
| gr.Examples( | |
| examples = [os.path.join("./data/examples", f) for f in os.listdir("./data/examples") if "xr" in f], | |
| inputs = [live_img], | |
| label = "Xray Examples", | |
| elem_id='examples' | |
| ) | |
| gr.Examples( | |
| examples = [os.path.join("./data/examples", f) for f in os.listdir("./data/examples") if "drr" in f], | |
| inputs = [live_img], | |
| label = "DRR Examples", | |
| elem_id='examples' | |
| ) | |
| with gr.Row(): | |
| gr.Markdown('Please select an example image, an axis, and then press Make Live!', elem_classes='text') | |
| with gr.Row(): | |
| axis = gr.Dropdown(choices=['Axis X', 'Axis Y', 'Axis Z'], show_label=False, elem_classes='angle', value='Axis X') | |
| live_btn = gr.Button("Make Live!", elem_classes='make_live_button') | |
| with gr.Row(): | |
| gr.Markdown('You can now rotate the radiograph in your selected axis using the scaler.', elem_classes='text') | |
| with gr.Row(): | |
| slider = gr.Slider(show_label=False, minimum=-20, maximum=20, step=1, value=0, elem_classes='slider', interactive=True) | |
| live_btn.click(fn=make_live_btn_fn, inputs=[live_img, axis], outputs=live_img) | |
| slider.change(fn=rotate_live_img_fn, inputs=[slider], outputs=live_img) | |
| try: | |
| app.close() | |
| gr.close_all() | |
| except: | |
| pass | |
| demo = app.launch( | |
| max_threads=4, | |
| share=True, | |
| inline=False, | |
| show_api=False, | |
| show_error=True, | |
| server_port=1902, | |
| server_name="0.0.0.0", | |
| ) |