| import importlib | |
| from .optim import get_optimizer, get_scheduler | |
| def get_trainer(name: str): | |
| """ | |
| Return our trainer class | |
| """ | |
| try: | |
| module = importlib.import_module(f'src.trainer.{name}') | |
| except ModuleNotFoundError as e: | |
| print(e) | |
| print('-> Using default trainer') | |
| module = importlib.import_module('src.trainer.default') | |
| return getattr(module, 'OurTrainer') | |