File size: 1,185 Bytes
c28dddb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
import os, sys
sys.path.append(os.path.join(os.path.dirname(__file__), '..', '..'))
import torch
from my_utils.misc import dump_config
from lightning.pytorch.callbacks.callback import Callback
from lightning.pytorch.utilities.rank_zero import rank_zero_only

class ConfigSnapshotCallback(Callback):
    def __init__(self, config):
        super().__init__()
        self.config = config
    
    def setup(self, trainer, pl_module, stage) -> None:
        self.savedir = os.path.join(pl_module.hparams.exp_dir, 'config')
    
    @rank_zero_only
    def save_config_snapshot(self):
        os.makedirs(self.savedir, exist_ok=True)
        dump_config(os.path.join(self.savedir, 'parsed.yaml'), self.config)

    def on_fit_start(self, trainer, pl_module):
        self.save_config_snapshot()


class GPUCacheCleanCallback(Callback):
    def on_train_batch_start(self, *args, **kwargs):
        torch.cuda.empty_cache()

    def on_validation_batch_start(self, *args, **kwargs):
        torch.cuda.empty_cache()

    def on_test_batch_start(self, *args, **kwargs):
        torch.cuda.empty_cache()

    def on_predict_batch_start(self, *args, **kwargs):
        torch.cuda.empty_cache()