您现在的位置是:首页 >其他 >怎样在最新的sklearn中使用mnist数据集网站首页其他
怎样在最新的sklearn中使用mnist数据集
简介怎样在最新的sklearn中使用mnist数据集
from sklearn.datasets import fetch_openml
mnist = fetch_openml('mnist_784')
In [6]:
X, y =mnist['data'], mnist['target'] X.shape
Out[6]:
(70000, 784)
In [10]:
y.shape X.index
Out[10]:
RangeIndex(start=0, stop=70000, step=1)
In [30]:
%matplotlib inline
import matplotlib
import matplotlib.pyplot as plt
some_digit = X.loc[36011].values
print(type(some_digit))
some_digit_image = some_digit.reshape(28, 28)
plt.imshow(some_digit_image, cmap = matplotlib.cm.binary, interpolation="nearest")
plt.axis("off")
plt.show()
<class 'numpy.ndarray'>

In [32]:
y[36011]
Out[32]:
'5'
In [45]:
X_train, X_test, y_train, y_test = X[:60000].values, X[60000:].values, y[:60000].values, y[60000:].values
In [46]:
import numpy as np shuffle_index = np.random.permutation(60000) X_train, y_train = X_train[shuffle_index], y_train[shuffle_index]
In [53]:
X_train.shape y_train
Out[53]:
['4', '4', '1', '8', '4', ..., '8', '5', '4', '7', '5'] Length: 60000 Categories (10, object): ['0', '1', '2', '3', ..., '6', '7', '8', '9']
In [54]:
y_train_5 = (y_train == '5') y_test_5 = (y_test == '5') y_train_5
Out[54]:
array([False, False, False, ..., False, False, True])
In [56]:
from sklearn.linear_model import SGDClassifier sgd_clf = SGDClassifier(random_state=42) sgd_clf.fit(X_train, y_train_5)
Out[56]:
SGDClassifier
SGDClassifier(random_state=42)
In [61]:
sgd_clf.predict([X.loc[36011], X.loc[36010]])
Out[61]:
array([ True, False])
In [63]:
from sklearn.model_selection import cross_val_score cross_val_score(sgd_clf, X_train, y_train_5, cv=3, scoring="accuracy")
Out[63]:
array([0.96975, 0.96335, 0.9521 ])
In [65]:
from sklearn.base import BaseEstimator
class Never5Classifier(BaseEstimator):
def fit(self, X, y=None):
pass
def predict(self, X):
return np.zeros((len(X), 1), dtype=bool)
In [67]:
never_5_clf = Never5Classifier() cross_val_score(never_5_clf, X_train, y_train_5, cv=3, scoring="accuracy")
Out[67]:
array([0.91025, 0.90975, 0.90895])
In [70]:
y_scores = sgd_clf.decision_function([X.loc[36011], X.loc[36010]]) y_scores
Out[70]:
array([ 3211.94412614, -4886.85783489])
In [72]:
from sklearn.model_selection import cross_val_predict
y_scores = cross_val_predict(sgd_clf, X_train, y_train_5, cv=3,
method="decision_function")
y_scores
Out[72]:
array([-10306.99957365, -6898.31615637, -7299.4630608 , ...,
-5958.06340458, -18084.1859439 , 6322.98986326])
In [74]:
y_scores.shape
Out[74]:
(60000,)
In [76]:
from sklearn.ensemble import RandomForestClassifier
forest_clf = RandomForestClassifier(random_state=42)
y_probas_forest = cross_val_predict(forest_clf, X_train, y_train_5, cv=3,
method="predict_proba")
In [78]:
y_probas_forest[:4]
Out[78]:
array([[0.99, 0.01],
[1. , 0. ],
[1. , 0. ],
[0.96, 0.04]])
风语者!平时喜欢研究各种技术,目前在从事后端开发工作,热爱生活、热爱工作。





U8W/U8W-Mini使用与常见问题解决
QT多线程的5种用法,通过使用线程解决UI主界面的耗时操作代码,防止界面卡死。...
stm32使用HAL库配置串口中断收发数据(保姆级教程)
分享几个国内免费的ChatGPT镜像网址(亲测有效)
Allegro16.6差分等长设置及走线总结