def create_dataset()

in aiops/ContraLSP/switchstate/switchgenerator.py [0:0]


def create_dataset(count, signal_len):
    dataset = []
    labels = []
    importance_score = []
    states = []
    label_logits = []
    mean, cov = init_distribution_params()
    gp_lengthscale = np.random.uniform(0.2, 0.2, SIG_NUM)
    for num in range(count):
        sig, y, state, importance, y_logits = create_signal(signal_len, gp_params=gp_lengthscale, mean=mean, cov=cov)
        dataset.append(sig)
        labels.append(y)
        importance_score.append(importance.T)
        states.append(state)
        label_logits.append(y_logits)

        if num % 50 == 0:
            print(num, count)
    dataset = np.array(dataset)
    labels = np.array(labels)
    importance_score = np.array(importance_score).transpose(0,2,1)
    states = np.array(states)
    label_logits = np.array(label_logits)
    n_train = int(len(dataset) * 0.8)
    train_data = dataset[:n_train]
    test_data = dataset[n_train:]
    # train_data_n, test_data_n = normalize(train_data, test_data)
    train_data_n = train_data.transpose(0,2,1)
    test_data_n = test_data.transpose(0,2,1)
    if not os.path.exists('simulated_data_l2x'):
        os.mkdir('simulated_data_l2x')
    with open('simulated_data_l2x/state_dataset_x_train.pkl', 'wb') as f:
        pickle.dump(train_data_n, f)
    with open('simulated_data_l2x/state_dataset_x_test.pkl', 'wb') as f:
        pickle.dump(test_data_n, f)
    with open('simulated_data_l2x/state_dataset_y_train.pkl', 'wb') as f:
        pickle.dump(labels[:n_train], f)
    with open('simulated_data_l2x/state_dataset_y_test.pkl', 'wb') as f:
        pickle.dump(labels[n_train:], f)
    with open('simulated_data_l2x/state_dataset_importance_train.pkl', 'wb') as f:
        pickle.dump(importance_score[:n_train], f)
    with open('simulated_data_l2x/state_dataset_importance_test.pkl', 'wb') as f:
        pickle.dump(importance_score[n_train:], f)
    with open('simulated_data_l2x/state_dataset_logits_train.pkl', 'wb') as f:
        pickle.dump(label_logits[:n_train], f)
    with open('simulated_data_l2x/state_dataset_logits_test.pkl', 'wb') as f:
        pickle.dump(label_logits[n_train:], f)
    with open('simulated_data_l2x/state_dataset_states_train.pkl', 'wb') as f:
        pickle.dump(states[:n_train], f)
    with open('simulated_data_l2x/state_dataset_states_test.pkl', 'wb') as f:
        pickle.dump(states[n_train:], f)

    print(train_data_n.shape)
    print(labels.shape)
    print(importance_score.shape)
    print(label_logits.shape)
    print(states.shape)

    return dataset, labels, states, label_logits