12345678910111213141516171819202122232425262728293031323334353637383940414243444546 |
- import numpy as np
- def balance_weights(y: np.ndarray, bins=10, normalize: bool = True, **kwargs) -> np.ndarray:
- """
- 平衡权重,分布数量越少权重越大
- """
- bins = int(bins)
- counts, bin_edges = np.histogram(y, bins=bins)
- # digitize 不使用 right=True,这样最小值也能落在 bin 0 开始
- bin_indices = np.digitize(y, bin_edges[1:-1], right=False)
- # bin_counts 用 0 到 bins-1 的索引
- bin_counts = {i: count for i, count in enumerate(counts)}
- # 对于每个样本分配权重(加个兜底:出现异常时给个较大默认值)
- weights = np.array([1.0 / bin_counts.get(b, 1e-6) for b in bin_indices])
- if normalize:
- weights /= np.mean(weights)
- return weights
- def south_weight(target: np.ndarray, cap, **kwargs) -> np.ndarray:
- """
- 应付南方点网的奇怪考核
- 为了不把曲线压太低,这里有所收敛(添加开方处理,不让权重分布过于离散)
- """
- cap = float(cap)
- weight = 1 / np.sqrt(np.where(target < 0.2 * cap, 0.2 * cap, target))
- return weight
- def standard_weight(target: np.array, **kwargs) -> np.ndarray:
- """
- 标准化权重
- """
- weight = np.sqrt(np.abs(target - np.mean(target))) / np.std(target)
- return weight
- # ------------------------------权重函数注册------------------------------------------------
- WEIGHT_REGISTER = {
- "balance": balance_weights,
- "south": south_weight,
- "std": standard_weight
- }
|