package base.evaluate; import java.util.HashMap; import java.util.List; public class ConfusionMatrix { private HashMap<String, Integer> labelCntMap = new HashMap<>(); private String combineLabel(String trueLabel, String predictLabel) { return trueLabel + "\t" + predictLabel; } private void sysPrint(String value) { System.out.print(value); } public void add(int trueLabel, int predictLabel) { add(String.valueOf(trueLabel), String.valueOf(predictLabel)); } public void add(String trueLabel, String predictLabel) { String tp = combineLabel(trueLabel, predictLabel); labelCntMap.put(tp, labelCntMap.getOrDefault(tp, 0) + 1); } public void printResult(List<String> labels) { sysPrint("\t"); for (String predictLabel: labels) { sysPrint("\t" + predictLabel); } sysPrint("\r\n"); for (String trueLabel: labels) { sysPrint(trueLabel); for (String predictLabel: labels) { String tp = combineLabel(trueLabel, predictLabel); int cnt = labelCntMap.getOrDefault(tp, 0); sysPrint("\t"+cnt); } sysPrint("\r\n"); } for (String label: labels) { double precision = .0; double recall = 0.; int pu = labelCntMap.getOrDefault(combineLabel(label, label), 0); int ru = pu; int pd = 0; int rd = 0; for (String otherLabel: labels) { pd += labelCntMap.getOrDefault(combineLabel(otherLabel, label), 0); rd += labelCntMap.getOrDefault(combineLabel(label, otherLabel), 0); } precision = Double.valueOf(String.format("%.2f", pu * 100 / Double.valueOf(pd))); recall = Double.valueOf(String.format("%.2f", ru * 100 / Double.valueOf(rd))); double f1 = Double.valueOf(String.format("%.2f", 2 * precision * recall / (precision + recall))); sysPrint(label + "\t" + precision + "%\t" + recall + "%\t" + f1 + "%\r\n"); } } }
混淆矩阵java代码
标签:
发表于:2018-06-12
阅读次数:1726
博文推荐
3
OCPC
4
动态规划
10
Apache Arrow
15
幸存者偏差
16
MLE - 最大似然估计
17
批处理和流处理
18
XGBoost
19
Pretext task
20
redis实战