File size: 4,035 Bytes
5de2f8f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
import os
import sys
import copy
import importlib

__dir__ = os.path.dirname(os.path.abspath(__file__))
sys.path.append(os.path.abspath(os.path.join(__dir__, '../..')))

from torch.utils.data import DataLoader, DistributedSampler

# 定义支持的 Dataset 类及其对应的模块路径
DATASET_MODULES = {
    'SimpleDataSet': 'tools.data.simple_dataset',
    'LMDBDataSet': 'tools.data.lmdb_dataset',
    'TextLMDBDataSet': 'tools.data.text_lmdb_dataset',
    'MultiScaleDataSet': 'tools.data.simple_dataset',
    'STRLMDBDataSet': 'tools.data.strlmdb_dataset',
    'LMDBDataSetTest': 'tools.data.lmdb_dataset_test',
    'RatioDataSet': 'tools.data.ratio_dataset',
    'RatioDataSetTest': 'tools.data.ratio_dataset_test',
    'RatioDataSetTVResize': 'tools.data.ratio_dataset_tvresize',
    'RatioDataSetTVResizeTest': 'tools.data.ratio_dataset_tvresize_test'
}

# 定义支持的 Sampler 类及其对应的模块路径
SAMPLER_MODULES = {
    'MultiScaleSampler': 'tools.data.multi_scale_sampler',
    'RatioSampler': 'tools.data.ratio_sampler'
}

__all__ = [
    'build_dataloader',
]


def build_dataloader(config, mode, logger, seed=None, epoch=3, task='rec'):
    config = copy.deepcopy(config)
    mode = mode.capitalize()  # 确保 mode 是首字母大写形式(Train/Eval/Test)

    # 获取 dataset 配置
    dataset_config = config[mode]['dataset']
    module_name = dataset_config['name']

    # 动态导入 dataset 类
    if module_name not in DATASET_MODULES:
        raise ValueError(
            f'Unsupported dataset: {module_name}. Supported datasets: {list(DATASET_MODULES.keys())}'
        )

    dataset_module = importlib.import_module(DATASET_MODULES[module_name])
    dataset_class = getattr(dataset_module, module_name)
    dataset = dataset_class(config, mode, logger, seed, epoch=epoch, task=task)

    # DataLoader 配置
    loader_config = config[mode]['loader']
    batch_size = loader_config['batch_size_per_card']
    drop_last = loader_config['drop_last']
    shuffle = loader_config['shuffle']
    num_workers = loader_config['num_workers']
    pin_memory = loader_config.get('pin_memory', False)

    sampler = None
    batch_sampler = None
    if 'sampler' in config[mode]:
        sampler_config = config[mode]['sampler']
        sampler_name = sampler_config.pop('name')

        if sampler_name not in SAMPLER_MODULES:
            raise ValueError(
                f'Unsupported sampler: {sampler_name}. Supported samplers: {list(SAMPLER_MODULES.keys())}'
            )

        sampler_module = importlib.import_module(SAMPLER_MODULES[sampler_name])
        sampler_class = getattr(sampler_module, sampler_name)
        batch_sampler = sampler_class(dataset, **sampler_config)
    elif config['Global']['distributed'] and mode == 'Train':
        sampler = DistributedSampler(dataset=dataset, shuffle=shuffle)

    if 'collate_fn' in loader_config:
        from . import collate_fn
        collate_fn = getattr(collate_fn, loader_config['collate_fn'])()
    else:
        collate_fn = None

    if batch_sampler is None:
        data_loader = DataLoader(
            dataset=dataset,
            sampler=sampler,
            num_workers=num_workers,
            pin_memory=pin_memory,
            collate_fn=collate_fn,
            batch_size=batch_size,
            drop_last=drop_last,
        )
    else:
        data_loader = DataLoader(
            dataset=dataset,
            batch_sampler=batch_sampler,
            num_workers=num_workers,
            pin_memory=pin_memory,
            collate_fn=collate_fn,
        )

    # 检查数据加载器是否为空
    if len(data_loader) == 0:
        logger.error(
            f'No Images in {mode.lower()} dataloader. Please check:\n'
            '\t1. The images num in the train label_file_list should be >= batch size.\n'
            '\t2. The annotation file and path in the configuration are correct.\n'
            '\t3. The BatchSize is not larger than the number of images.')
        sys.exit()

    return data_loader