weight.py 1.4 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344
  1. import numpy as np
  2. def balance_weights(y: np.ndarray, bins: int = 10, normalize: bool = True, **kwargs) -> np.ndarray:
  3. """
  4. 平衡权重,分布数量越少权重越大
  5. """
  6. counts, bin_edges = np.histogram(y, bins=bins)
  7. # digitize 不使用 right=True,这样最小值也能落在 bin 0 开始
  8. bin_indices = np.digitize(y, bin_edges[1:-1], right=False)
  9. # bin_counts 用 0 到 bins-1 的索引
  10. bin_counts = {i: count for i, count in enumerate(counts)}
  11. # 对于每个样本分配权重(加个兜底:出现异常时给个较大默认值)
  12. weights = np.array([1.0 / bin_counts.get(b, 1e-6) for b in bin_indices])
  13. if normalize:
  14. weights /= np.mean(weights)
  15. return weights
  16. def south_weight(target: np.ndarray, cap, **kwargs) -> np.ndarray:
  17. """
  18. 应付南方点网的奇怪考核
  19. 为了不把曲线压太低,这里有所收敛(添加开方处理,不让权重分布过于离散)
  20. """
  21. weight = 1 / np.sqrt(np.where(target < 0.2 * cap, 0.2 * cap, target))
  22. return weight
  23. def standard_weight(target: np.array, **kwargs) -> np.ndarray:
  24. """
  25. 标准化权重
  26. """
  27. weight = np.sqrt(np.abs(target - np.mean(target))) / np.std(target)
  28. return weight
  29. # ------------------------------权重函数注册------------------------------------------------
  30. WEIGHT_REGISTER = {
  31. "balance": balance_weights,
  32. "south": south_weight,
  33. "std": standard_weight
  34. }