Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- import numpy as np
- import pandas as pd
- from sklearn.linear_model import SGDClassifier
- from sklearn.grid_search import GridSearchCV
- from sklearn.cross_validation import cross_val_score
- import sys
- # check for valid sys.argv (must be 2 files)
- if len(sys.argv) < 3:
- sys.stdout.write("You must set train and test files as arguments!")
- sys.exit(0)
- # read train data
- with open(sys.argv[1], 'r') as train_file:
- train_df = pd.read_csv(train_file, sep='\t')
- features_cols = ['x' + str(i).zfill(3) for i in range(30)]
- X_train = np.array(train_df[features_cols]) #features
- y_train = np.array(train_df['y'])
- # build a model
- n_samples = len(train_df.index)
- n_iter = np.ceil((10**6)/n_samples)
- clf = SGDClassifier(loss='log', penalty='none', fit_intercept=False, n_iter=n_iter)
- #clf.fit(X, Y)
- # TODO:
- # Параметры подбирай через GridSearch, scoring function — f1_weigted,
- # проверку делай через cross_val_score с cv=10.
- param_grid = {
- #'alpha': [0.001, 0.0001, 0.00001, 0.000001],
- }
- grid = GridSearchCV(clf, param_grid, cv=10, scoring='f1_weighted')
- print(grid)
- grid.fit(X_train, y_train)
- best_params = grid.best_estimator_.get_params()
- print('best params:', best_params)
- # read test data
- with open(sys.argv[2], 'r') as test_file:
- test_df = pd.read_csv(test_file, sep='\t')
- X_test = test_df[features_cols]
- y_test = np.array(test_df['y'])
- predict_Y = grid.predict(X_test)
- print('test predicted:', predict_Y)
- print('test real:', y_test)
- error = np.mean(predict_Y != y_test)
- print('test error:', error)
- predict_train_Y = grid.predict(X_train)
- print('train predicted:', predict_train_Y)
- print('train real:', y_train)
- train_error = np.mean(predict_train_Y != y_train)
- print('train error:', train_error)
- clf = SGDClassifier(**best_params)
- scores = cross_val_score(clf, X_train, y_train, cv=10)
- print('Best %s: %0.3f (+/- %0.2f)' % \
- ('f1_weighted', scores.mean(), scores.std() / 2))
Advertisement
Add Comment
Please, Sign In to add comment