import sys sys.path.append("./BranchSBM") import torch class flow_model_torch_wrapper(torch.nn.Module): """Wraps model to torchdyn compatible format.""" def __init__(self, model): super().__init__() self.model = model def forward(self, t, x, *args, **kwargs): return self.model(t, x)