Quantcast
Channel: Active questions tagged python - Stack Overflow
Viewing all articles
Browse latest Browse all 23160

Cross-Validation Visualization Mulfunctions

$
0
0

I am inspired by the scikit-learn's cross-validation visuailization guide to visuailize the distribution of training and testing indicies in each CV split:

cmap_data = plt.cm.Pairedcmap_cv = plt.cm.coolwarmdef plot_cv_indices(cv, X, y, group, ax, n_splits, lw=10):"""Create a sample plot for indices of a cross-validation object."""    # Generate the training/testing visualizations for each CV split    for ii, (tr, tt) in enumerate(cv.split(X=X, y=y, groups=group)):        # Fill in indices with the training/test groups        indices = np.array([np.nan] * len(X))        indices[tt] = 1        indices[tr] = 0        # Visualize the results        ax.scatter(            range(len(indices)),            [ii + 0.5] * len(indices),            c=indices,            marker="_",            lw=lw,            cmap=cmap_cv,            vmin=-0.2,            vmax=1.2,        )    # Plot the data classes and groups at the end    ax.scatter(        range(len(X)), [ii + 1.5] * len(X), c=y, marker="_", lw=lw, cmap=cmap_data    )    ax.scatter(        range(len(X)), [ii + 2.5] * len(X), c=group, marker="_", lw=lw, cmap=cmap_data    )    # Formatting    yticklabels = list(range(n_splits)) + ["class", "group"]    ax.set(        yticks=np.arange(n_splits + 2) + 0.5,        yticklabels=yticklabels,        xlabel="Sample index",        ylabel="CV iteration",        ylim=[n_splits + 2.2, -0.2],        xlim=[0, 100],    )    ax.set_title("{}".format(type(cv).__name__), fontsize=15)    plt.show()
from sklearn.datasets import make_classificationfrom sklearn.model_selection import TimeSeriesSplit, KFoldfig, ax = plt.subplots(figsize=(12, 5))X, y = make_classification(    n_samples=1000,    n_features=10,    n_informative=3,    n_redundant=0,    n_repeated=0,    n_classes=2,    random_state=42,    shuffle=False,)plot_cv_indices(    TimeSeriesSplit(n_splits=5, gap=10),    X=X,    y=y,    group=None,    ax=ax,    n_splits=5,)

The above code gives me:enter image description hereI am looking for:enter image description here

The idea of the expected plot is that the graph successfully visuailizes the gaps between the training and testing sets in each split. In addition, I ran a test case on the normal KFold as well the plot_cv_indices function does not seem to operate properly even then.


Viewing all articles
Browse latest Browse all 23160

Trending Articles



<script src="https://jsc.adskeeper.com/r/s/rssing.com.1596347.js" async> </script>