logo
Loading...

train_test_split內的超參數代表意義? - Cupoy

請問一下, 在train_test_split( )中的shuffle以及stratify代表什麼意...

train_test_split內的超參數代表意義?

2021/07/29 上午 09:32
訓練/測試集切分的概念
Yaoga
觀看數:942
回答數:2
收藏數:0

請問一下, 在train_test_split( )中的shuffle以及stratify代表什麼意思呢? shuffle可以=True or False ,對於結果有何影響? stratify又可以有怎麼樣的設定呢?對於結果也會有什麼影響嗎? 感謝了!

回答列表

  • 2021/07/29 上午 11:52
    Jeffrey
    贊同數:0
    不贊同數:0
    留言數:0

    以Python的sk-learn為例, shuffle參數: 對原始數據進行隨機抽樣,保證隨機性。 stratify參數: 想要達到分層隨機抽樣的目的。特別是在原始數據中樣本標籤分佈不均衡時非常有用,一些分類問題可能會在目標類的分佈中表現出很大的不平衡:例如,負樣本可能比正樣本多幾倍。在這種情況下,建議使用分層抽樣

  • 2021/07/31 下午 04:35
    Yilin
    贊同數:1
    不贊同數:0
    留言數:0

    在分類的問題中,我們更關心每個類別的資料分佈比例。當測試集的分佈盡可能與訓練相同情況下,模型才更有可能得到更準確的預測。假設我們有三個標籤的類別,這三個類別的分佈分別有 0.4、0.3、0.3。然而我們在切割資料的時候必須確保訓練集與測試集需要有相同的資料比例分佈。 通常我們都使用 Sklearn 的 `train_test_split` 進行資料切割。在此方法中 Sklearn 提供了一個 `stratify` 參數達到分層隨機抽樣的目的。特別是在原始數據中樣本標籤分佈不均衡時非常有用,一些分類問題可能會在目標類的分佈中表現出很大的不平衡:例如,負樣本與正樣本比例懸殊(信用卡倒刷預測、離職員工預測)。以下用紅酒分類預測來進行示範,首先我們不使用 `stratify` 隨機切割資料。 ```py from sklearn.datasets import load_wine from sklearn.model_selection import train_test_split X, y = load_wine(return_X_y=True) # Look at the class weights before splitting pd.Series(y).value_counts(normalize=True) ``` ``` # 全部資料三種類別比例 1 0.398876 0 0.331461 2 0.269663 dtype: float64 ``` ```py # Generate unstratified split X_train, X_test, y_train, y_test = train_test_split(X, y) # Look at the class weights of train set pd.Series(y_train).value_counts(normalize=True) # Look at the class weights of the test set pd.Series(y_test).value_counts(normalize=True) ``` ``` # 訓練集三種類別比例 1 0.390977 0 0.330827 2 0.278195 dtype: float64 # 測試集三種類別比例 1 0.511111 0 0.266667 2 0.222222 dtype: float64 ``` 從上面切出來的訓練集與測試集可以發現三個類別的資料分佈比例都不同。因此我們可以使用 `stratify` 參數再切割一次。 ```py # Generate stratified split X_train, X_test, y_train, y_test = train_test_split(X, y, stratify=y) # Look at the class weights of train set pd.Series(y_train).value_counts(normalize=True) # Look at the class weights of the test set pd.Series(y_test).value_counts(normalize=True) ``` ``` # 訓練集三種類別比例 1 0.400000 0 0.333333 2 0.266667 dtype: float64 # 測試集三種類別比例 1 0.398496 0 0.330827 2 0.270677 dtype: float64 ``` 我們可以發現將 `stratify` 設置為目標 (y) 在訓練和測試集中產生相同的分佈。