view scripts/util_dataset.py @ 6:a35bd818d8e9 branch-tests

notebook to test music segments
author Maria Panteli <m.x.panteli@gmail.com>
date Mon, 11 Sep 2017 14:22:17 +0100
parents e50c63cf96be
children
line wrap: on
line source
# -*- coding: utf-8 -*-
"""
Created on Wed Mar 15 23:04:24 2017

@author: mariapanteli
"""

import numpy as np
from sklearn.model_selection import train_test_split


def get_train_val_test_idx(X, Y, seed=None):
    """ Split in train, validation, test sets.
    
    Parameters
    ----------
    X : np.array
        Data or indices.
    Y : np.array
        Class labels for data in X.
    seed: int
        Random seed.
    Returns
    -------
    (X_train, Y_train) : tuple
        Data X and labels y for the train set
    (X_val, Y_val) : tuple
        Data X and labels y for the validation set
    (X_test, Y_test) : tuple
        Data X and labels y for the test set
    
    """
    X_train, X_val_test, Y_train, Y_val_test = train_test_split(X, Y, train_size=0.6, random_state=seed, stratify=Y)
    X_val, X_test, Y_val, Y_test = train_test_split(X_val_test, Y_val_test, train_size=0.5, random_state=seed, stratify=Y_val_test)
    return (X_train, Y_train), (X_val, Y_val), (X_test, Y_test)


def subset_labels(Y, N_min=10, N_max=100, seed=None):
    """ Subset dataset to contain minimum N_min and maximum N_max instances 
        per class. Return indices for this subset. 
    
    Parameters
    ----------
    Y : np.array
        Class labels
    N_min : int
        Minimum instances per class
    N_max : int
        Maximum instances per class
    seed: int
        Random seed.
    
    Returns
    -------
    subset_idx : np.array
        Indices for a subset with classes of size bounded by N_min, N_max
    
    """
    subset_idx = []
    labels = np.unique(Y)
    for label in labels:
        label_idx = np.where(Y==label)[0]
        counts = len(label_idx)
        if counts>=N_max:
            subset_idx.append(np.random.choice(label_idx, N_max, replace=False))
        elif counts>=N_min and counts<N_max:
            subset_idx.append(label_idx)
        else:
            # not enough samples for this class, skip
            continue
    if len(subset_idx)>0:
        subset_idx = np.concatenate(subset_idx, axis=0)
    return subset_idx