python - Cross-validation: finding row indices for a test set that aren't part of a training set -


what need randomly pick (with replacement) 50 rows numpy matrix purposes of training linear separator.

then, need test linear separator using rows did not pick.

for first part, a full data matrix, do:

a_train = a[np.random.randint(a.shape[0],size=50),:] 

but have no effective way find:

a_test = ... 

where a_test contains no rows same a_train. how this?

key problem a n x m matrix, , not 1-dimensional matrix...

you can use np.setdiff1d find row indices not included in training set:

import numpy np  gen = np.random.randomstate(0)  n_total = 1000 n_train = 800  train_idx = gen.choice(n_total, size=n_train) test_idx = np.setdiff1d(np.arange(n_total), train_idx) 

one consequence of sampling replacement number of examples eligible inclusion in test set vary according number of repeated examples in training set:

print(test_idx.size) # 439 

if want ensure size of test set consistent, resample replacement set of indices aren't in training set:

n_test = 200 test_idx2 = gen.choice(test_idx, size=n_test) 

if don't care sampling replacement simpler option generate random permutation of indices, take first n training examples , rest test examples:

idx = gen.permutation(n_total) train_idx, test_idx = idx[:n_train], idx[n_train:] 

or shuffle rows of array in place using np.random.shuffle.


i should point out scikit-learn has various convenience methods partitioning data training , test sets purposes of cross-validation.


Comments

Popular posts from this blog

routing - AngularJS State management ->load multiple states in one page -

python - GRASS parser() error -

Swift game error message -