import numpy as np def balance_weights(y: np.ndarray, bins=10, normalize: bool = True, **kwargs) -> np.ndarray: """ 平衡权重,分布数量越少权重越大 """ bins = int(bins) counts, bin_edges = np.histogram(y, bins=bins) # digitize 不使用 right=True,这样最小值也能落在 bin 0 开始 bin_indices = np.digitize(y, bin_edges[1:-1], right=False) # bin_counts 用 0 到 bins-1 的索引 bin_counts = {i: count for i, count in enumerate(counts)} # 对于每个样本分配权重(加个兜底:出现异常时给个较大默认值) weights = np.array([1.0 / bin_counts.get(b, 1e-6) for b in bin_indices]) if normalize: weights /= np.mean(weights) return weights def south_weight(target: np.ndarray, cap, **kwargs) -> np.ndarray: """ 应付南方点网的奇怪考核 为了不把曲线压太低,这里有所收敛(添加开方处理,不让权重分布过于离散) """ cap = float(cap) weight = 1 / np.sqrt(np.where(target < 0.2 * cap, 0.2 * cap, target)) return weight def standard_weight(target: np.array, **kwargs) -> np.ndarray: """ 标准化权重 """ weight = np.sqrt(np.abs(target - np.mean(target))) / np.std(target) return weight # ------------------------------权重函数注册------------------------------------------------ WEIGHT_REGISTER = { "balance": balance_weights, "south": south_weight, "std": standard_weight }