Daniel@0: # Compute the weight vector of linear SVM based on the model file Daniel@0: # Original Perl Author: Thorsten Joachims (thorsten@joachims.org) Daniel@0: # Python Version: Ori Cohen (orioric@gmail.com) Daniel@0: # Call: python svm2weights.py svm_model Daniel@0: Daniel@0: import sys Daniel@0: from operator import itemgetter Daniel@0: Daniel@0: try: Daniel@0: import psyco Daniel@0: psyco.full() Daniel@0: except ImportError: Daniel@0: print 'Psyco not installed, the program will just run slower' Daniel@0: Daniel@0: def sortbyvalue(d,reverse=True): Daniel@0: ''' proposed in PEP 265, using the itemgetter this function sorts a dictionary''' Daniel@0: return sorted(d.iteritems(), key=itemgetter(1), reverse=True) Daniel@0: Daniel@0: def sortbykey(d,reverse=True): Daniel@0: ''' proposed in PEP 265, using the itemgetter this function sorts a dictionary''' Daniel@0: return sorted(d.iteritems(), key=itemgetter(0), reverse=False) Daniel@0: Daniel@0: def get_file(): Daniel@0: """ Daniel@0: Tries to extract a filename from the command line. If none is present, it Daniel@0: assumes file to be svm_model (default svmLight output). If the file Daniel@0: exists, it returns it, otherwise it prints an error message and ends Daniel@0: execution. Daniel@0: """ Daniel@0: # Get the name of the data file and load it into Daniel@0: if len(sys.argv) < 2: Daniel@0: # assume file to be svm_model (default svmLight output) Daniel@0: print "Assuming file as svm_model" Daniel@0: filename = 'svm_model' Daniel@0: #filename = sys.stdin.readline().strip() Daniel@0: else: Daniel@0: filename = sys.argv[1] Daniel@0: Daniel@0: Daniel@0: try: Daniel@0: f = open(filename, "r") Daniel@0: except IOError: Daniel@0: print "Error: The file '%s' was not found on this system." % filename Daniel@0: sys.exit(0) Daniel@0: Daniel@0: return f Daniel@0: Daniel@0: Daniel@0: Daniel@0: Daniel@0: if __name__ == "__main__": Daniel@0: f = get_file() Daniel@0: i=0 Daniel@0: lines = f.readlines() Daniel@0: printOutput = True Daniel@0: w = {} Daniel@0: for line in lines: Daniel@0: if i>10: Daniel@0: features = line[:line.find('#')-1] Daniel@0: comments = line[line.find('#'):] Daniel@0: alpha = features[:features.find(' ')] Daniel@0: feat = features[features.find(' ')+1:] Daniel@0: for p in feat.split(' '): # Changed the code here. Daniel@0: a,v = p.split(':') Daniel@0: if not (int(a) in w): Daniel@0: w[int(a)] = 0 Daniel@0: for p in feat.split(' '): Daniel@0: a,v = p.split(':') Daniel@0: w[int(a)] +=float(alpha)*float(v) Daniel@0: elif i==1: Daniel@0: if line.find('0')==-1: Daniel@0: print 'Not linear Kernel!\n' Daniel@0: printOutput = False Daniel@0: break Daniel@0: elif i==10: Daniel@0: if line.find('threshold b')==-1: Daniel@0: print "Parsing error!\n" Daniel@0: printOutput = False Daniel@0: break Daniel@0: Daniel@0: i+=1 Daniel@0: f.close() Daniel@0: Daniel@0: #if you need to sort the features by value and not by feature ID then use this line intead: Daniel@0: #ws = sortbyvalue(w) Daniel@0: Daniel@0: ws = sortbykey(w) Daniel@0: if printOutput == True: Daniel@0: for (i,j) in ws: Daniel@0: print i,':',j Daniel@0: i+=1 Daniel@0: Daniel@0: