import java.io.*; import java.util.*; /***************************************************************** CN2 2/20/98 9:35PM *****************************************************************/ class CN2 { Vector rules; Vector heap; Vector classes; Vector attributes; Vector domains; Vector examples; int[] globalDistribution; Node bestNode; String initialMajorityClass; int numberOfAttributes; int numberOfClasses; int numberOfExamples; static double threshold = 0.0; static int heapSize = 5; static int starSize = 5; static final int useless = 0; static final int betterThanDefault = 1; static final int worseThanDefault = -1; void assess(Node node) { int status = evaluate(node); if(status != useless) { if((node.rank > bestNode.rank) && significant(node) && (status == betterThanDefault)) { bestNode = node; } if(potentialRank(node) > bestNode.rank) { tossOntoHeap(node); } } } MultiArray[] classify(Vector examples) { MultiArray[] tmp1 = new MultiArray[numberOfClasses]; numberOfExamples = 0; for (int i = 0; i < classes.size(); i++) { Vector tmp2 = new Vector(); String s = (String)classes.elementAt(i); for(int j = 0; j < examples.size(); j++) { Vector v = (Vector)((Vector)examples.elementAt(j)).clone(); if(((String)v.lastElement()).equals(s) ) { v.removeElementAt(v.size() - 1); tmp2.addElement(toIntArray(v)); numberOfExamples++; } } MultiArray tmp3 = new MultiArray(tmp2); tmp1[i] = tmp3; } return tmp1; } int[] toIntArray(Vector v) { int[] tmp = new int[v.size()]; for(int i = 0; i < v.size(); i++) { tmp[i] = ((Integer)v.elementAt(i)).intValue(); } return tmp; } /** * This method was created by a SmartGuide. * @return java.util.Vector * @param v java.util.Vector */ CoverElement[] copyCover(CoverElement[] arr) { CoverElement[] tmp = new CoverElement[arr.length]; for(int i = 0; i < arr.length; i++) { CoverElement c1 = arr[i]; CoverElement c2 = new CoverElement(); if(c1.instantiated) { c2.setValue(c1.value); } tmp[i] = c2; } return tmp; } /** * This method was created by a SmartGuide. * @return java.util.Vector * @param v java.util.Vector */ Vector copyVector(Vector v, int start, int finish) { Vector tmp = new Vector(); for (int i = start; i < finish; i++) { tmp.addElement(v.elementAt(i)); } return tmp; } boolean empty(MultiArray[] qqs) { for(int i = 0; i < qqs.length; i++) { if(! qqs[i].isEmpty()) { return false; } } return true; } int evaluate(Node node) { int[] accumulators = new int[3]; String sclass = majorityClass(node.qqs,accumulators); int index = accumulators[0], length = accumulators[1], total = accumulators[2]; if(total > 0) { node.rank = (1.0 + length) / (total + classes.size()); if ( ((float)length/(float)total) > ((float)globalDistribution[index] / numberOfExamples)) { return betterThanDefault; } else { return worseThanDefault; } } else { return useless; } } MultiArray[] filter(MultiArray[] qqs) { MultiArray[] tmp1 = new MultiArray[qqs.length]; for(int i = 0; i < qqs.length; i++) { int size = qqs[i].size; MultiArray tmp2 = new MultiArray(size); for(int j = 0; j < size; j++) { int[] example = qqs[i].arr[j]; if(!bestNode.covers(example)) { tmp2.add(example); } } tmp1[i] = tmp2; } return tmp1; } /* * According to my Lisp profiling, most of the execution time gets spent * in this routine. */ boolean filterExamples(Node lower, Node node, int index, int val) { MultiArray[] tmp1 = new MultiArray[node.qqs.length]; int newCount = 0, oldCount = 0; for(int i = 0; i < node.qqs.length; i++) { MultiArray tmp2 = new MultiArray(node.qqs[i].size); for(int j = 0; j < node.qqs[i].size; j++) { oldCount++; int[] ex = node.qqs[i].arr[j]; if(ex[index] == val) { tmp2.add(ex); newCount++; } } tmp1[i] = tmp2; } lower.qqs = tmp1; return (newCount > 0) && (oldCount > newCount); } Rule getBestRule(Vector star) { while(!star.isEmpty()) { for(int i = 0; i < star.size(); i++) { specialise((Node)star.elementAt(i)); } star = heapToStar(); } Rule rule = new Rule(bestNode.cover); if(bestNode.rank > -1.0) { int[] dummy = new int[3]; rule.ruleClass = majorityClass(bestNode.qqs,dummy); rule.distribution = new int[bestNode.qqs.length]; for(int i = 0; i < bestNode.qqs.length; i++) { rule.distribution[i] = bestNode.qqs[i].size; } } else { rule.ruleClass = initialMajorityClass; rule.distribution = globalDistribution; } return rule; } Vector heapToStar() { heap = sort(heap); int size = heap.size(); if(size > starSize) { Vector tmp = copyVector(heap,0,starSize); heap = copyVector(heap,starSize,heap.size()); return tmp; } else { Vector tmp = heap; heap = new Vector(); return tmp; } } Vector initialStar(MultiArray[] qqs) { Node node = new Node(numberOfAttributes,qqs); evaluate(node); Vector v = new Vector(); v.addElement(node); return v; } void load(String atts, String exs) { attributes = new Vector(); domains = new Vector(); classes = new Vector(); examples = new Vector(); try { DataInputStream is = new DataInputStream(new FileInputStream(atts)); String line = is.readLine(); while(line != null) { if (line.length() > 0) { StringTokenizer t = new StringTokenizer(line); String first = t.nextToken(); if(first.charAt(first.length() - 1) == ':') { attributes.addElement(first.substring(0,first.length() - 1)); Vector tmp = new Vector(); while(t.hasMoreTokens()) { tmp.addElement(new Integer(t.nextToken())); } domains.addElement(toIntArray(tmp)); } } line = is.readLine(); } is.close(); is = new DataInputStream(new FileInputStream(exs)); Hashtable map = new Hashtable(); line = is.readLine(); while(line != null) { if(line.length() > 0 && line.charAt(line.length() - 1) == ';') { StringTokenizer t = new StringTokenizer(line); Vector tmp = new Vector(); while(t.hasMoreTokens()) { String s = t.nextToken(); if(s.charAt(s.length()-1) == ';') { String ss = s.substring(0,s.length() - 1); map.put(ss,""); tmp.addElement(ss); } else { tmp.addElement(new Integer(s)); } } examples.addElement(tmp); } line = is.readLine(); } is.close(); for (Enumeration e = map.keys() ; e.hasMoreElements() ;) { classes.addElement(e.nextElement()); } numberOfClasses = classes.size(); numberOfAttributes = attributes.size(); } catch(IOException e) { System.out.println("Error: " + e); System.exit(1); } } // end method load public static void main (String args[]) { String atts, exs; if(args.length == 3) { atts = args[1]; exs = args[2]; } else { atts = "soya.att"; exs = "soya.exs"; } CN2 cn2 = new CN2(); System.out.println("Starting load.."); cn2.load(atts,exs); System.out.println("Load finished."); long begin = System.currentTimeMillis(); cn2.run(); long end = System.currentTimeMillis(); cn2.printRules(); long elapsed = end - begin; System.out.println("Elapsed time: " + elapsed + " milliseconds."); // try { // System.in.read(); // } // catch(IOException err) { // System.out.println("oops"); // } } // finish main String majorityClass(MultiArray[] qqs, int[] accumulators) { int counter = 0; for(int i=0; i<3; i++) { accumulators[i] = 0; } for(int i = 0; i < qqs.length; i++) { int len = qqs[i].size; if(len > accumulators[1]) { accumulators[1] = len; accumulators[0] = counter; } accumulators[2] += len; counter++; } return (String)classes.elementAt(accumulators[0]); } boolean nonTrivial(Rule rule) { boolean trivial = true; for(Enumeration e = rule.cover.elements(); e.hasMoreElements();) { if(((CoverElement)e.nextElement()).instantiated) { trivial = false; break; } } return !trivial; } double potentialRank(Node node) { int[] accumulators = new int[3]; String sclass = majorityClass(node.qqs,accumulators); int index = accumulators[0], length = accumulators[1]; return (length + 1.0) / (length + classes.size()); } void printRule(Rule rule) { if(nonTrivial(rule)) { System.out.print("If "); int i = 0; int j = 0; for(Enumeration e = rule.cover.elements(); e.hasMoreElements(); i++) { CoverElement elt = (CoverElement)e.nextElement(); if(elt.instantiated) { if(j++ > 0) { System.out.print(" and "); } System.out.print((String)attributes.elementAt(i) + " = " + elt.value); } } System.out.println(" then item is " + rule.ruleClass); } else { System.out.println(": item is " + rule.ruleClass + "."); } } void printRules() { System.out.println("Printing rules.."); for(Enumeration e = rules.elements(); e.hasMoreElements();) { Rule rule = (Rule)e.nextElement(); printRule(rule); if(nonTrivial(rule)) { System.out.println("else"); } } } void recordGlobalDist(MultiArray[] qqs) { globalDistribution = new int[qqs.length]; numberOfExamples = 0; for(int i = 0; i < qqs.length; i++) { int size = qqs[i].size; globalDistribution[i] = size; numberOfExamples += size; } } void run() { heap = new Vector(heapSize); rules = new Vector(); for(MultiArray[] qqs = classify(examples); !empty(qqs); qqs = filter(qqs)) { bestNode = new Node(numberOfAttributes); bestNode.qqs = qqs; recordGlobalDist(qqs); int[] dummy = new int[3]; initialMajorityClass = majorityClass(qqs,dummy); Rule rule = getBestRule(initialStar(qqs)); // printRule(rule); // System.out.println("Number of examples: " + numberOfExamples); rules.addElement(rule); } } boolean significant(Node node) { int numEx = 0; for(int i = 0; i < node.qqs.length; i++) { numEx += node.qqs[i].size; } double ratio = (double)numEx / (double)numberOfExamples; double distance = 0.0; for(int i = 0; i < node.qqs.length; i++) { int size1 = node.qqs[i].size; int size2 = globalDistribution[i]; if(size1 > 0 && size2 > 0) { double dsize1 = (double)size1; distance += dsize1 * Math.log(dsize1 / (size2 * ratio)); } } double significance = 2 * distance; return significance > threshold; } /** * This is a REALLY bad sort, but it doesn't get called often enough to make * any real difference to the execution time. * @param v java.util.Vector */ Vector sort(Vector v) { Vector tmp = new Vector(); int size = v.size(); for(int i = 0; i < size; i++) { double smallest = -2.0; int index = -1; for(int j = 0; j < v.size(); j++) { double rank = ((Node)v.elementAt(j)).rank; if(smallest == -2.0 || rank < smallest) { smallest = rank; index = j; } } tmp.addElement(v.elementAt(index)); v.removeElementAt(index); } return tmp; } void specialise(Node node) { for(int i = 0; i < domains.size(); i++) { int[] domain = (int[])domains.elementAt(i); if(!node.cover[i].instantiated){ specialiseDomain(domain,node,i); } } } void specialiseDomain(int[] domain, Node node, int index) { for(int i = 0; i < domain.length; i++) { int val = domain[i]; Node lower = new Node(); lower.cover = copyCover(node.cover); lower.cover[index].setValue(val); if(filterExamples(lower,node,index,val)) { assess(lower); } } } void tossOntoHeap(Node node) { boolean found = false; for(int i = 0; i < heap.size(); i++) { Node n = (Node)heap.elementAt(i); if(n.rank == node.rank) { boolean equal = true; for(int j = 0; j < node.cover.length; j++) { CoverElement c1 = node.cover[j]; CoverElement c2 = n.cover[j]; if(c1.instantiated) { if (!c2.instantiated || c2.value != c1.value) { equal = false; break; } } else { if(c2.instantiated) { equal = false; break; } } } if(equal) { found = true; break; } } } if(!found) { if(heap.size() < heapSize) { heap.addElement(node); } else { heap = sort(heap); if(node.rank > ((Node)heap.firstElement()).rank) { heap.setElementAt(node,0); } } } } } // finish class CN2 class CoverElement { public boolean instantiated; public int value; CoverElement() { instantiated = false; } void setValue(int i) { instantiated = true; value = i; } } class Node { double rank; CoverElement[] cover; MultiArray[] qqs; Node() { rank = -1.0; } Node(int n) { rank = -1.0; cover = new CoverElement[n]; for(int i=0; i rank) return -1; if(orank < rank) return 1; return 0; } boolean covers(int[] arr) { boolean covered = true; for(int i = 0; i < cover.length; i++) { CoverElement elt = cover[i]; if(elt.instantiated && elt.value != arr[i] ) { covered = false; break; } } return covered; } } // finish class Node class Rule { String ruleClass; Vector cover; int[] distribution; Rule(CoverElement[] cov) { cover = new Vector(cov.length); for(int i = 0; i < cov.length; i++) { cover.addElement(cov[i]); } } } // finish class Rule class MultiArray { public MultiArray(int maxlength) { arr = new int[maxlength][]; size = 0; } public MultiArray(Vector v) { size = v.size(); arr = new int[size][]; for(int i = 0; i < size; i++) { arr[i] = (int[])v.elementAt(i); } } public void add(int[] i) { arr[size++] = i; } public boolean isEmpty() { return size == 0; } public int [][] arr; public int size; }