def get_loader_segment()

in aiops/ContraAD/data_factory/data_loader.py [0:0]


def get_loader_segment(index, data_path, batch_size, win_size=100, step=100, mode='train', dataset='KDD'):
    step = 1
    if (dataset == 'SMD'):
        dataset = SMDSegLoader(data_path, win_size, step, mode)
    elif (dataset == 'MSL'):
        dataset = MSLSegLoader(data_path, win_size, step, mode)
    elif (dataset == 'SMAP'):
        dataset = SMAPSegLoader(data_path, win_size, step, mode)
    elif (dataset == 'PSM'):
        dataset = PSMSegLoader(data_path, win_size, step, mode)
    elif (dataset =='SWAT'):
        dataset = SWATSegLoader(data_path,win_size,step,mode)
    elif (dataset == 'UCR'):
        dataset = UCRSegLoader(index, data_path, win_size, step, mode)
    elif (dataset == 'UCR_AUG'):
        dataset = UCRAUGSegLoader(index, data_path, win_size, step, mode)
    elif (dataset == 'NIPS_TS_Water'):
        dataset = NIPS_TS_WaterSegLoader(data_path, win_size, step, mode)
    elif (dataset == 'NIPS_TS_Swan'):
        dataset = NIPS_TS_SwanSegLoader(data_path, win_size, step, mode)
    elif (dataset == 'NIPS_TS_CCard'):
        dataset = NIPS_TS_CCardSegLoader(data_path, win_size, step, mode)
    elif (dataset == 'SMD_Ori'):
        dataset = SMD_OriSegLoader(index, data_path, win_size, step, mode)
    
    shuffle = False
    if mode == 'train':
        shuffle = True

    data_loader = DataLoader(dataset=dataset,
                             batch_size=batch_size,
                             shuffle=shuffle,
                             num_workers=8,
                             drop_last=False,
                             worker_init_fn=seed_worker,
                             generator=g,
                             )
    return data_loader