Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- def train_readout(X, y, Xt, yt, PRINTSTUFF, reg_fact=0):
- """
- Train and test a readout. Assumes labels are integers (0., ..., N_classes-1).
- :param X: training data
- :param y: training labels
- :param Xt: test data
- :param yt: test labels
- :param PRINTSTUFF: print or not
- :param reg_fact: regularization constant, default: 0
- :returns: tuple of training and test error
- """
- # add constant component to states for bias
- X = np.hstack((np.ones((len(X), 1)), X))
- Xt = np.hstack((np.ones((len(Xt), 1)), Xt))
- # train
- if reg_fact == 0:
- w = np.linalg.lstsq(X, y)[0]
- else:
- I = np.eye(X.shape[1])
- w = np.linalg.inv(np.dot(X.T, X) + reg_fact*I).dot(X.T).dot(y)
- # compute predictions
- max_label = y.max()
- y_train = np.clip(np.dot(X, w).round(), 0., max_label).astype(int)
- y_test = np.clip(np.dot(Xt, w).round(), 0., max_label).astype(int)
- # compute errors
- train_err = (y_train != y).sum() / len(y)
- test_err = (y_test != yt).sum() / len(yt)
- if PRINTSTUFF:
- print("X")
- print(X)
- print(X.shape)
- print("Xt")
- print(Xt)
- print(Xt.shape)
- print("reg_fact")
- print(reg_fact)
- print("w")
- print(w)
- print(w.shape)
- print("y_train")
- print(y_train)
- print(y_train.shape)
- print("y_test")
- print(y_test)
- print(y_test.shape)
- print("train_err")
- print(train_err)
- print("test_err")
- print(test_err)
- return train_err, test_err
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement