1.寫在前面
微信和米哥加了個好友,看他朋友圈推薦了一個kesci比賽的東西,關於預測A股行業板塊動向的比賽,我平時也有炒股,而且感覺米哥這個人是靠譜的,比賽的數據也是tushare提供的,所以就參加了這個比賽,用了一些比較常見的算法。
比賽的網址
比賽明確說明用新聞信息,NLP方面,但是這次比賽的新聞信息用的是新聞聯播,預測的時間週期是1,2,3天,感覺和新聞聯播可能關係不大(可能我對新聞確實不怎麼關注)。本着玩票的心態參加了這個比賽,沒用新聞數據,只用了成交價,做一些簡單的指標分析。客觀得分還不錯,主觀得分需要kesci專家review代碼,不是我能左右的。
客觀得分應該是前三名(大概是第二)
2.指標選擇
關於指標選擇,本人選擇了比較常見的幾個指標,百度都能查得到。
BOLL:布林線上軌的值 , 布林線中軌的值, 布林線下軌的值,布林線是否穿越下軌(從下向上穿越下軌用1,其它算作0)
KDJ:kdj的值,kd金叉(金叉算作1,其它算作0),kj是否金叉,J是否超賣
MACD:macd的值
CCI:6日cci的值,10日cci的值,6日cci是否超賣,10日cci是否超賣,6日cci是否金叉10日cci
WILLR:6日willr的值,10日willr的值,6日willr是否超賣,10日willr是否超賣,6日willr是否金叉10日willr
RSI:6日rsi的值,10日rsi的值,6日rsi是否超賣,10日rsi是否超賣,6日rsi是否金叉10日rsi
3.算法選擇
關於算法,也選擇了很常見的算法(logistic,svm,randomforest),用着三種算法根據歷史數據進行學習,然後進行預測,如果兩種算法預測上漲,則判定爲上漲,如果兩種算法預測下跌,則判定爲下跌。
algorithms = [
[RandomForestClassifier(random_state=1, n_estimators=100, min_samples_split=4, min_samples_leaf=2),
["boll_upper", "boll_middle", "boll_lower", "k", "d", "j", "macd", "macd_signal", "macd_hist",
"cci6", "rsi6", "willr6", "cci10", "rsi10", "willr10",
'boll_cross', 'kd_cross', 'kj_cross', 'cci_cross', 'rsi_cross', 'willr_cross',
'over_cci6', 'over_cci10', 'over_willr6', 'over_willr10', 'over_j']],
[LogisticRegression(random_state=1, solver='liblinear'),
["boll_upper", "boll_middle", "boll_lower", "k", "d", "j", "macd", "macd_signal", "macd_hist",
"cci6", "rsi6", "willr6", "cci10", "rsi10", "willr10",
'boll_cross', 'kd_cross', 'kj_cross', 'cci_cross', 'rsi_cross', 'willr_cross',
'over_cci6', 'over_cci10', 'over_willr6', 'over_willr10', 'over_j']],
[SVC(C=1.0, kernel='linear', probability=True),
["boll_upper", "boll_middle", "boll_lower", "k", "d", "j", "macd", "macd_signal", "macd_hist",
"cci6", "rsi6", "willr6", "cci10", "rsi10", "willr10",
'boll_cross', 'kd_cross', 'kj_cross', 'cci_cross', 'rsi_cross', 'willr_cross',
'over_cci6', 'over_cci10', 'over_willr6', 'over_willr10', 'over_j']]
]
4.python重要方法
由於機器學習比較耗時,所以使用了併發(pool.map),下面代碼是一個關於pool.map的demo
from multiprocessing.pool import Pool
def task(i):
return [i for i in range(i, i + 5)]
def pool_method():
result_list = list()
pool = Pool()
temp_result_list = pool.map(task, [1, 2, 3, 4, 5])
result_list.extend(temp_result_list)
pool.close()
pool.join()
print(result_list)
pool_method()
5.完整代碼
由於比賽使用的是notebook,所以只能給個代碼的地址了。
代碼地址
6.輸入數據格式
從tushare中就可以下載。下面是下載數據的代碼,下載之前註冊一個key。
import tushare as ts
def download_and_save_sw_job(pro):
"""
下載並保存基本面數據
:param file_name:
:param pro:
:return:
"""
# 獲取股票基本信息
logger.info('下載申萬數據')
for str_datetime in __get_date_range(BASIC_INFO_START_DATE, __get_today_str()):
logger.info(str_datetime)
basic_data_dataframe = down_load_daily_sw_data(pro, str_datetime)
if basic_data_dataframe is None:
logger.info('獲取{0}的sw信息失敗(pro.sw_daily)'.format(str_datetime))
continue
basic_data_dataframe = basic_data_dataframe[basic_data_dataframe['ts_code'].isin(['801010.SI',
'801020.SI',
'801030.SI',
'801040.SI',
'801050.SI',
'801080.SI',
'801110.SI',
'801120.SI',
'801130.SI',
'801140.SI',
'801150.SI',
'801160.SI',
'801170.SI',
'801180.SI',
'801200.SI',
'801210.SI',
'801230.SI',
'801250.SI',
'801260.SI',
'801270.SI',
'801280.SI',
'801300.SI',
'801710.SI',
'801720.SI',
'801730.SI',
'801740.SI',
'801750.SI',
'801760.SI',
'801770.SI',
'801780.SI',
'801790.SI',
'801880.SI',
'801890.SI',
'802600.SI'])]
if os.path.exists(os.path.join(SW_DATA_STORE_FOLDER, 'sw_{0}.csv'.format(str_datetime))):
os.remove(os.path.join(SW_DATA_STORE_FOLDER, 'sw_{0}.csv'.format(str_datetime)))
basic_data_dataframe.to_csv(os.path.join(SW_DATA_STORE_FOLDER, 'sw_{0}.csv'.format(str_datetime)))
combine_dataframe()
def combine_dataframe():
if os.path.exists(os.path.join(SW_DATA_STORE_FOLDER, 'TRAINSET_STOCK.csv')):
os.remove(os.path.join(SW_DATA_STORE_FOLDER, 'TRAINSET_STOCK.csv'))
file_list = os.listdir(SW_DATA_STORE_FOLDER)
base_data_frame = pd.read_csv(os.path.join(SW_DATA_STORE_FOLDER, file_list[0]))
for i in range(1, len(file_list)):
path = os.path.join(SW_DATA_STORE_FOLDER, file_list[i])
temp_data_frame = pd.read_csv(path)
base_data_frame = base_data_frame.append(temp_data_frame, ignore_index=True)
base_data_frame.to_csv(os.path.join(SW_DATA_STORE_FOLDER, 'TRAINSET_STOCK.csv'))
if __name__ == '__main__':
ts.set_token('***')
pro = ts.pro_api()
pro = ts.pro_api()
download_and_save_sw_job(pro) # 下載申萬數據
combine_dataframe()
,Unnamed: 0,ts_code,trade_date,name,open,low,high,close,change,pct_change,vol,amount,pe,pb
0,4,801010.SI,20170405,農林牧漁,3228.07,3227.19,3271.9,3271.9,49.6,1.54,83229.0,997867.0,29.15,3.98
1,13,801020.SI,20170405,採掘,3499.94,3499.94,3549.68,3549.68,64.05,1.84,130225.0,1082993.0,57.27,1.8
2,18,801030.SI,20170405,化工,3339.79,3339.17,3394.7,3394.7,65.47,1.97,327918.0,4424141.0,42.27,2.95
3,25,801040.SI,20170405,鋼鐵,2792.78,2792.76,2822.85,2822.5,80.8,2.95,126113.0,663732.0,54.97,1.71
4,27,801050.SI,20170405,有色金屬,3774.51,3774.51,3872.85,3872.85,106.59,2.83,285608.0,3330624.0,88.06,3.21
5,37,801080.SI,20170405,電子,3224.77,3224.62,3286.47,3286.47,70.31,2.19,227326.0,3245112.0,64.27,3.77
6,48,801110.SI,20170405,家用電器,5743.21,5718.12,5780.01,5764.28,17.12,0.3,64192.0,1136937.0,19.9,3.36