diff --git a/f1.py b/f1.py index 52e2554..3f62fe7 100644 --- a/f1.py +++ b/f1.py @@ -1,43 +1,58 @@ import torch +import sys def reverse(t): idx = [i for i in range(t.size(0)-1, -1, -1)] idx = torch.LongTensor(idx) it = t.index_select(0, idx) return it def PRFromHistograms(hPos,hNeg): - print("hPos",hPos,"hNeg",hNeg) + #print("hPos",hPos,"hNeg",hNeg) positives=hPos.sum() negatives=hNeg.sum() - print("positives, negatives", positives, negatives) + #print("positives, negatives", positives, negatives) truepositives =reverse(hPos.clone()).long().cumsum(dim=0) falsepositives=reverse(hNeg.clone()).long().cumsum(dim=0) predpositives=torch.add(truepositives,falsepositives) #protect against zero division predpositives[predpositives==0]=1 precision=torch.Tensor.div_(truepositives.float(),predpositives.float()) recall =torch.Tensor.div (truepositives.float(),positives.float()) - precision[precision<=0]=1e-12 - recall [recall <=0]=1e-12 - print("precision,recall",precision,recall) + precision[precision<=0]=sys.float_info.epsilon + recall [recall <=0]=sys.float_info.epsilon + #print("precision,recall",precision,recall) + return precision,recall + +def PRFromDiscrete(nTruePos,nPredPos,nGtPos): + truepositives =nTruePos + falsepositives=nPredPos-nTruePos + positives =nGtPos + predpositives =nPredPos + if predpositives==0: predpositives =1 #protect against zero division + precision=truepositives/float(predpositives) + recall =truepositives/float(positives) + if precision<=0: precision=sys.float_info.epsilon + if recall <=0: recall =sys.float_info.epsilon return precision,recall def PRFromOutGt(outps,targs,nbins=10000): hPos=torch.zeros(nbins) hNeg=torch.zeros(nbins) for o,t in zip(outps,targs): pos=o[lbl==1] neg=t[lbl==0] hPos+=torch.from_numpy(pos.astype(np.float32)).histc(nbins,0,1) hNeg+=torch.from_numpy(neg.astype(np.float32)).histc(nbins,0,1) precision,recall=f1.PRFromHistograms(hPos,hNeg) f1s=f1.F1FromPR(precision,recall) f=f1s.max() return f def F1FromPR(p,r): suminv=torch.pow(p,-1)+torch.pow(r,-1) f1s=torch.pow(suminv,-1).mul(2) - print("f1s",f1s) + #print("f1s",f1s) return f1s + +