every blog every motto: You can do more than you think.
https://blog.csdn.net/weixin_39190382?type=blog
0. 前言
torch.quantile 分位数计算方法
1. 正文
简单讲一句话,计算分位数。
啥玩意是分位数,
简单讲就是,把数据从小到大排序,然后取百分位置的值,就是第几个分位数。
比如,第0.5分位数,就是取排序后中间位置的值。(也是我们的中位数)
第0.25分位数,就是取排序后的前25%位置的值。
第0.75分位数,就是取排序后的后25%位置的值。
第0.1分位数,就是取排序后前10%位置的值。
第0.9分位数,就是取排序后的后10%位置的值。
2. 代码
2.1 案例一
我们先看一个简单的例子
import torch
import torch.nn.functional as F
x = torch.tensor([1,2,3,4,5])
quantile = F.quantile(x,0.5)print(quantile)
输出:
tensor(3.)
2.2 案例二
x = torch.tensor([2.0,3,4,5,7])
表格:
索引01234值2.03457
# 计算 0.5 分位数(即中位数)
median = torch.quantile(x,0.5)
median
上面一个5个数,索引是[0,4],那么,
0.5 * 4 = 2,所以,取索引为2的数,即4.0
输出:
tensor(4.)
再进一步,
median = torch.quantile(x,0.9)
同样的原理,先计算索引,
0.9 * 4 = 3.6,
这次的索引不是整数了,而是小数,所以要找到这个索引左右两侧的值a和b。
索引3.6位于索引3和索引4之间,索引3和4对应的值为5和7。
所以计算:
5 + (7-5)*0.6 = 6.2
输出:
tensor(6.2000)
如果索引不是整数,找到该索引左右两侧的值,进行计算。
如果索引不是整数,找到该索引左右两侧的值,进行计算。
如果索引不是整数,找到该索引左右两侧的值,进行计算。
上面为什么乘以0.6呢,
因为索引3.6,比索引3大0.6,取小数部分,具体可以参考下图。
这是对应的linear插值法。默认 选项
此外还有
lower
、
higher
、
nearest
和
midpoint
方法。
lower: 索引3.6向下取整,即索引3对应的数
higher: 索引3.6向上取整,即索引4对应的数
midpoint: 索引3.6,取索引3和索引4对应数值的平均值。
nearest: 索引3.6最接近的数,即索引4对应的数
3. 小结
torch.quantile函数用于计算张量的分位数,可以指定不同的插值方法来计算非整数索引处的值。默认情况下,使用线性插值法,但也可以选择其他插值方法,如lower、higher、nearest和midpoint。
找到对应的索引,如果索引为整数可以直接获取其值,该值就是分位数!!!
如果索引为小数,找到该索引左右两侧的值,进行计算。
参考
版权归原作者 胡侃有料 所有, 如有侵权,请联系我们删除。