weight.py 1.5 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546
  1. import numpy as np
  2. def balance_weights(y: np.ndarray, bins=10, normalize: bool = True, **kwargs) -> np.ndarray:
  3. """
  4. 平衡权重,分布数量越少权重越大
  5. """
  6. bins = int(bins)
  7. counts, bin_edges = np.histogram(y, bins=bins)
  8. # digitize 不使用 right=True,这样最小值也能落在 bin 0 开始
  9. bin_indices = np.digitize(y, bin_edges[1:-1], right=False)
  10. # bin_counts 用 0 到 bins-1 的索引
  11. bin_counts = {i: count for i, count in enumerate(counts)}
  12. # 对于每个样本分配权重(加个兜底:出现异常时给个较大默认值)
  13. weights = np.array([1.0 / bin_counts.get(b, 1e-6) for b in bin_indices])
  14. if normalize:
  15. weights /= np.mean(weights)
  16. return weights
  17. def south_weight(target: np.ndarray, cap, **kwargs) -> np.ndarray:
  18. """
  19. 应付南方点网的奇怪考核
  20. 为了不把曲线压太低,这里有所收敛(添加开方处理,不让权重分布过于离散)
  21. """
  22. cap = float(cap)
  23. weight = 1 / np.sqrt(np.where(target < 0.2 * cap, 0.2 * cap, target))
  24. return weight
  25. def standard_weight(target: np.array, **kwargs) -> np.ndarray:
  26. """
  27. 标准化权重
  28. """
  29. weight = np.sqrt(np.abs(target - np.mean(target))) / np.std(target)
  30. return weight
  31. # ------------------------------权重函数注册------------------------------------------------
  32. WEIGHT_REGISTER = {
  33. "balance": balance_weights,
  34. "south": south_weight,
  35. "std": standard_weight
  36. }