comparison core/tools/machine_learning/svmlight2weight.py @ 0:e9a9cd732c1e tip

first hg version after svn
author wolffd
date Tue, 10 Feb 2015 15:05:51 +0000
parents
children
comparison
equal deleted inserted replaced
-1:000000000000 0:e9a9cd732c1e
1 # Compute the weight vector of linear SVM based on the model file
2 # Original Perl Author: Thorsten Joachims (thorsten@joachims.org)
3 # Python Version: Ori Cohen (orioric@gmail.com)
4 # Call: python svm2weights.py svm_model
5
6 import sys
7 from operator import itemgetter
8
9 try:
10 import psyco
11 psyco.full()
12 except ImportError:
13 print 'Psyco not installed, the program will just run slower'
14
15 def sortbyvalue(d,reverse=True):
16 ''' proposed in PEP 265, using the itemgetter this function sorts a dictionary'''
17 return sorted(d.iteritems(), key=itemgetter(1), reverse=True)
18
19 def sortbykey(d,reverse=True):
20 ''' proposed in PEP 265, using the itemgetter this function sorts a dictionary'''
21 return sorted(d.iteritems(), key=itemgetter(0), reverse=False)
22
23 def get_file():
24 """
25 Tries to extract a filename from the command line. If none is present, it
26 assumes file to be svm_model (default svmLight output). If the file
27 exists, it returns it, otherwise it prints an error message and ends
28 execution.
29 """
30 # Get the name of the data file and load it into
31 if len(sys.argv) < 2:
32 # assume file to be svm_model (default svmLight output)
33 print "Assuming file as svm_model"
34 filename = 'svm_model'
35 #filename = sys.stdin.readline().strip()
36 else:
37 filename = sys.argv[1]
38
39
40 try:
41 f = open(filename, "r")
42 except IOError:
43 print "Error: The file '%s' was not found on this system." % filename
44 sys.exit(0)
45
46 return f
47
48
49
50
51 if __name__ == "__main__":
52 f = get_file()
53 i=0
54 lines = f.readlines()
55 printOutput = True
56 w = {}
57 for line in lines:
58 if i>10:
59 features = line[:line.find('#')-1]
60 comments = line[line.find('#'):]
61 alpha = features[:features.find(' ')]
62 feat = features[features.find(' ')+1:]
63 for p in feat.split(' '): # Changed the code here.
64 a,v = p.split(':')
65 if not (int(a) in w):
66 w[int(a)] = 0
67 for p in feat.split(' '):
68 a,v = p.split(':')
69 w[int(a)] +=float(alpha)*float(v)
70 elif i==1:
71 if line.find('0')==-1:
72 print 'Not linear Kernel!\n'
73 printOutput = False
74 break
75 elif i==10:
76 if line.find('threshold b')==-1:
77 print "Parsing error!\n"
78 printOutput = False
79 break
80
81 i+=1
82 f.close()
83
84 #if you need to sort the features by value and not by feature ID then use this line intead:
85 #ws = sortbyvalue(w)
86
87 ws = sortbykey(w)
88 if printOutput == True:
89 for (i,j) in ws:
90 print i,':',j
91 i+=1
92
93