#将csv文件转换为numpy array #源文件是通过‘,’作为分隔,每个值都有,无空值 with open(X_train_fpath) as f: next(f) X_train = np.array([line.strip('\n').split(',')[1:] for line in f], dtype = float) with open(Y_train_fpath) as f: next(f) Y_train = np.array([line.strip('\n').split(',')[1] for line in f], dtype = float) with open(X_test_fpath) as f: next(f) X_test = np.array([line.strip('\n').split(',')[1:] for line in f], dtype = float)
#重复多轮训练 for epoch in range(max_iter): # 每轮开始之前打乱数据 X_train, Y_train = _shuffle(X_train, Y_train) # 最小批次的训练 for idx in range(int(np.floor(train_size / batch_size))): X = X_train[idx*batch_size:(idx+1)*batch_size] Y = Y_train[idx*batch_size:(idx+1)*batch_size]
# 计算梯度 w_grad, b_grad = _gradient(X, Y, w, b) #梯度更新,学习率也随之改变,这个学习率有点简单,直接用学习率除以更新次数的跟 w = w - learning_rate/np.sqrt(step) * w_grad b = b - learning_rate/np.sqrt(step) * b_grad
predictions = _predict(X_test, w, b) with open(output_fpath.format('logistic'), 'w') as f: f.write('id,label\n') for i, label in enumerate(predictions): f.write('{},{}\n'.format(i, label))
# Print out the most significant weights # 找到权重中最大的前十项, ind = np.argsort(np.abs(w))[::-1] with open(X_test_fpath) as f: content = f.readline().strip('\n').split(',') features = np.array(content) for i in ind[0:10]: print(features[i], w[i])
Unemployed full-time 1.1225676433808978
Not in universe -1.0573592082887484
Other Rel <18 never married RP of subfamily -0.912973403518828
Child 18+ ever marr Not in a subfamily -0.8705099085602619
1 0.7950300190669547
Spouse of householder -0.750112419980567
Other Rel <18 ever marr RP of subfamily -0.7491036728017868
Italy -0.7240219980088511
capital losses 0.6611812623046104
id 0.5751429906332644
# Parse csv files to numpy array with open(X_train_fpath) as f: next(f) X_train = np.array([line.strip('\n').split(',')[1:] for line in f], dtype = float) with open(Y_train_fpath) as f: next(f) Y_train = np.array([line.strip('\n').split(',')[1] for line in f], dtype = float) with open(X_test_fpath) as f: next(f) X_test = np.array([line.strip('\n').split(',')[1:] for line in f], dtype = float)
# Normalize training and testing data X_train, X_mean, X_std = _normalize(X_train, train = True) X_test, _, _= _normalize(X_test, train = False, specified_column = None, X_mean = X_mean, X_std = X_std)
# Compute in-class mean #将数据的两个类别分开 X_train_0 = np.array([x for x, y in zip(X_train, Y_train) if y == 0]) X_train_1 = np.array([x for x, y in zip(X_train, Y_train) if y == 1])
for x in X_train_0: cov_0 += np.dot(np.transpose([x - mean_0]), [x - mean_0]) / X_train_0.shape[0] for x in X_train_1: cov_1 += np.dot(np.transpose([x - mean_1]), [x - mean_1]) / X_train_1.shape[0]
# Shared covariance is taken as a weighted average of individual in-class covariance. cov = (cov_0 * X_train_0.shape[0] + cov_1 * X_train_1.shape[0]) / (X_train_0.shape[0] + X_train_1.shape[0])
1 2 3 4 5 6 7 8 9 10 11 12 13 14
# Compute inverse of covariance matrix. # Since covariance matrix may be nearly singular, np.linalg.inv() may give a large numerical error. # Via SVD decomposition, one can get matrix inverse efficiently and accurately. u, s, v = np.linalg.svd(cov, full_matrices=False) inv = np.matmul(v.T * 1 / s, u.T)
# Directly compute weights and bias w = np.dot(inv, mean_0 - mean_1) b = (-0.5) * np.dot(mean_0, np.dot(inv, mean_0)) + 0.5 * np.dot(mean_1, np.dot(inv, mean_1))\ + np.log(float(X_train_0.shape[0]) / X_train_1.shape[0])
# Compute accuracy on training set Y_train_pred = 1 - _predict(X_train, w, b) print('Training accuracy: {}'.format(_accuracy(Y_train_pred, Y_train)))
Training accuracy: 0.873820406959599
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
# Predict testing labels predictions = 1 - _predict(X_test, w, b) with open(output_fpath.format('generative'), 'w') as f: f.write('id,label\n') for i, label in enumerate(predictions): f.write('{},{}\n'.format(i, label))
# Print out the most significant weights ind = np.argsort(np.abs(w))[::-1] with open(X_test_fpath) as f: content = f.readline().strip('\n').split(',') features = np.array(content) for i in ind[0:10]: print(features[i], w[i])
1 2 3 4 5 6 7 8 9 10
Agriculture 7.5625 41-7.5 Retail trade 6.828125 Forestry and fisheries 6.03125 29-6.0 355.265625 34-5.15625 Sales -5.1171875 Construction -5.111328125 37-4.79296875