BranchSBM / networks /utils.py
sophiat44
model upload
5a87d8d
raw
history blame contribute delete
323 Bytes
import sys
sys.path.append("./BranchSBM")
import torch
class flow_model_torch_wrapper(torch.nn.Module):
"""Wraps model to torchdyn compatible format."""
def __init__(self, model):
super().__init__()
self.model = model
def forward(self, t, x, *args, **kwargs):
return self.model(t, x)