10
10
import numpy as np
11
11
from typing import Callable
12
12
from tqdm import tqdm
13
- from sklearn .linear_model import Ridge , LinearRegression , Lasso
14
13
15
14
16
15
def vwap (price : np .array , volume : np .array , ** kwargs ) -> float :
@@ -45,6 +44,7 @@ def remove_beta_effects(df, **kwargs):
45
44
46
45
:return: DataFrame
47
46
"""
47
+ from sklearn .linear_model import Ridge , LinearRegression , Lasso
48
48
49
49
linear_model = kwargs .get ("linear_model" , "ridge" )
50
50
linear_model_params = kwargs .get ("linear_model_params" , {})
@@ -268,7 +268,7 @@ def cal_yearly_days(dts: list, **kwargs):
268
268
269
269
# 按年重采样并计算每年的交易日数量,取最大值
270
270
yearly_days = dts .resample ('YE' ).size ().max ()
271
- return yearly_days
271
+ return min ( yearly_days , 365 )
272
272
273
273
274
274
def cal_symbols_factor (dfk : pd .DataFrame , factor_function : Callable , ** kwargs ):
@@ -282,44 +282,55 @@ def cal_symbols_factor(dfk: pd.DataFrame, factor_function: Callable, **kwargs):
282
282
- factor_params: dict, 因子计算参数
283
283
- min_klines: int, 最小K线数据量,默认为 300
284
284
- price_type: str, 交易价格类型,默认为 close,可选值为 close 或 next_open
285
+ - strict: bool, 是否严格模式,默认为 True, 严格模式下,计算因子出错会抛出异常
285
286
286
287
:return: dff, pd.DataFrame, 计算后的因子数据
287
288
"""
288
289
logger = kwargs .get ("logger" , loguru .logger )
289
290
min_klines = kwargs .get ("min_klines" , 300 )
290
291
factor_params = kwargs .get ("factor_params" , {})
291
292
price_type = kwargs .get ("price_type" , "close" )
293
+ strict = kwargs .get ("strict" , True )
292
294
293
295
symbols = dfk ["symbol" ].unique ().tolist ()
294
296
factor_name = factor_function .__name__
295
297
298
+ def __one_symbol (symbol ):
299
+ df = dfk [(dfk ["symbol" ] == symbol )].copy ()
300
+ df = df .sort_values ("dt" , ascending = True ).reset_index (drop = True )
301
+ if len (df ) < min_klines :
302
+ logger .warning (f"{ symbol } 数据量过小,跳过;仅有 { len (df )} 条数据,需要 { min_klines } 条数据" )
303
+ return None
304
+
305
+ df = factor_function (df , ** factor_params )
306
+ if price_type == 'next_open' :
307
+ df ["price" ] = df ["open" ].shift (- 1 ).fillna (df ["close" ])
308
+ elif price_type == 'close' :
309
+ df ["price" ] = df ["close" ]
310
+ else :
311
+ raise ValueError ("price_type 参数错误, 可选值为 close 或 next_open" )
312
+
313
+ df ["n1b" ] = (df ["price" ].shift (- 1 ) / df ["price" ] - 1 ).fillna (0 )
314
+ factor = [x for x in df .columns if x .startswith ("F#" )][0 ]
315
+
316
+ # df[factor] = df[factor].replace([np.inf, -np.inf], np.nan).ffill().fillna(0)
317
+ # factor 中不能有 inf 和 -inf 值,也不能有 nan 值
318
+ assert df [factor ].isna ().sum () == 0 , f"{ symbol } { factor } 存在 nan 值"
319
+ assert df [factor ].isin ([np .inf , - np .inf ]).sum () == 0 , f"{ symbol } { factor } 存在 inf 值"
320
+ assert df [factor ].var () != 0 and not np .isnan (df [factor ].var ()), f"{ symbol } { factor } var is 0 or nan"
321
+ return df
322
+
296
323
rows = []
297
- for symbol in tqdm (symbols , desc = f"{ factor_name } 因子计算" ):
298
- try :
299
- df = dfk [(dfk ["symbol" ] == symbol )].copy ()
300
- df = df .sort_values ("dt" , ascending = True ).reset_index (drop = True )
301
- if len (df ) < min_klines :
302
- logger .warning (f"{ symbol } 数据量过小,跳过;仅有 { len (df )} 条数据,需要 { min_klines } 条数据" )
324
+ for _symbol in tqdm (symbols , desc = f"{ factor_name } 因子计算" ):
325
+ if strict :
326
+ dfx = __one_symbol (_symbol )
327
+ else :
328
+ try :
329
+ dfx = __one_symbol (_symbol )
330
+ except Exception as e :
331
+ logger .error (f"{ factor_name } - { _symbol } - 计算因子出错:{ e } " )
303
332
continue
304
-
305
- df = factor_function (df , ** factor_params )
306
- if price_type == 'next_open' :
307
- df ["price" ] = df ["open" ].shift (- 1 ).fillna (df ["close" ])
308
- elif price_type == 'close' :
309
- df ["price" ] = df ["close" ]
310
- else :
311
- raise ValueError ("price_type 参数错误, 可选值为 close 或 next_open" )
312
-
313
- df ["n1b" ] = (df ["price" ].shift (- 1 ) / df ["price" ] - 1 ).fillna (0 )
314
-
315
- factor = [x for x in df .columns if x .startswith ("F#" )][0 ]
316
- df [factor ] = df [factor ].replace ([np .inf , - np .inf ], np .nan ).ffill ().fillna (0 )
317
- if df [factor ].var () == 0 or np .isnan (df [factor ].var ()):
318
- logger .warning (f"{ symbol } { factor } var is 0 or nan" )
319
- else :
320
- rows .append (df .copy ())
321
- except Exception as e :
322
- logger .error (f"{ factor_name } - { symbol } - 计算因子出错:{ e } " )
333
+ rows .append (dfx )
323
334
324
335
dff = pd .concat (rows , ignore_index = True )
325
336
return dff
0 commit comments