def train_svm_low_shot()

in tools/svm/train_svm_low_shot.py [0:0]


def train_svm_low_shot(opts):
    assert os.path.exists(opts.data_file), "Data file not found. Abort!"
    if not os.path.exists(opts.output_path):
        os.makedirs(opts.output_path)

    features, targets = svm_helper.load_input_data(
        opts.data_file, opts.targets_data_file
    )
    # normalize the features: N x 9216 (example shape)
    features = svm_helper.normalize_features(features)

    # parse the cost values for training the SVM on
    costs_list = svm_helper.parse_cost_list(opts.costs_list)
    logger.info('Training SVM for costs: {}'.format(costs_list))

    # classes for which SVM testing should be done
    num_classes, cls_list = svm_helper.get_low_shot_svm_classes(
        targets, opts.dataset
    )

    for cls in cls_list:
        for cost_idx in range(len(costs_list)):
            cost = costs_list[cost_idx]
            suffix = '_'.join(
                opts.targets_data_file.split('/')[-1].split('.')[0].split('_')[-2:]
            )
            out_file = svm_helper.get_low_shot_output_file(
                opts, cls, cost, suffix
            )
            if os.path.exists(out_file):
                logger.info('SVM model exists: {}'.format(out_file))
            else:
                logger.info('SVM model not found: {}'.format(out_file))
                logger.info('Training model with the cost: {}'.format(cost))
                clf = LinearSVC(
                    C=cost, class_weight={1: 2, -1: 1}, intercept_scaling=1.0,
                    verbose=1, penalty='l2', loss='squared_hinge', tol=0.0001,
                    dual=True, max_iter=2000,
                )
                train_feats, train_cls_labels = svm_helper.get_cls_feats_labels(
                    cls, features, targets, opts.dataset
                )
                num_positives = len(np.where(train_cls_labels == 1)[0])
                num_negatives = len(np.where(train_cls_labels == -1)[0])
                logger.info('cls: {} has +ve: {} -ve: {} ratio: {}'.format(
                    cls, num_positives, num_negatives,
                    float(num_positives) / num_negatives)
                )
                logger.info('features: {} cls_labels: {}'.format(
                    train_feats.shape, train_cls_labels.shape))
                clf.fit(train_feats, train_cls_labels)
                logger.info('Saving SVM model to: {}'.format(out_file))
                with open(out_file, 'wb') as fwrite:
                    pickle.dump(clf, fwrite)
    logger.info('All done!')