diff core/tools/machine_learning/svmlight2weight.py @ 0:e9a9cd732c1e tip

first hg version after svn
author wolffd
date Tue, 10 Feb 2015 15:05:51 +0000
parents
children
line wrap: on
line diff
--- /dev/null	Thu Jan 01 00:00:00 1970 +0000
+++ b/core/tools/machine_learning/svmlight2weight.py	Tue Feb 10 15:05:51 2015 +0000
@@ -0,0 +1,93 @@
+# Compute the weight vector of linear SVM based on the model file
+# Original Perl Author: Thorsten Joachims (thorsten@joachims.org)
+# Python Version: Ori Cohen (orioric@gmail.com)
+# Call: python svm2weights.py svm_model
+
+import sys
+from operator import itemgetter
+
+try:
+    import psyco
+    psyco.full()
+except ImportError:
+    print 'Psyco not installed, the program will just run slower'
+
+def sortbyvalue(d,reverse=True):
+    ''' proposed in PEP 265, using  the itemgetter this function sorts a dictionary'''
+    return sorted(d.iteritems(), key=itemgetter(1), reverse=True)
+
+def sortbykey(d,reverse=True):
+    ''' proposed in PEP 265, using  the itemgetter this function sorts a dictionary'''
+    return sorted(d.iteritems(), key=itemgetter(0), reverse=False)
+
+def get_file():
+    """
+    Tries to extract a filename from the command line.  If none is present, it
+    assumes file to be svm_model (default svmLight output).  If the file 
+    exists, it returns it, otherwise it prints an error message and ends
+    execution. 
+    """
+    # Get the name of the data file and load it into 
+    if len(sys.argv) < 2:
+        # assume file to be svm_model (default svmLight output)
+        print "Assuming file as svm_model"
+        filename = 'svm_model' 
+        #filename = sys.stdin.readline().strip()
+    else:
+        filename = sys.argv[1]
+
+    
+    try:
+        f = open(filename, "r")
+    except IOError:
+        print "Error: The file '%s' was not found on this system." % filename
+        sys.exit(0)
+
+    return f
+
+
+
+
+if __name__ == "__main__":
+    f = get_file()
+    i=0
+    lines = f.readlines()
+    printOutput = True
+    w = {}
+    for line in lines:
+        if i>10:
+            features = line[:line.find('#')-1]
+            comments = line[line.find('#'):]
+            alpha = features[:features.find(' ')]
+            feat = features[features.find(' ')+1:]
+            for p in feat.split(' '): # Changed the code here. 
+                a,v = p.split(':')
+                if not (int(a) in w):
+                    w[int(a)] = 0
+            for p in feat.split(' '): 
+                a,v = p.split(':')
+                w[int(a)] +=float(alpha)*float(v)
+        elif i==1:
+            if line.find('0')==-1:
+                print 'Not linear Kernel!\n'
+                printOutput = False
+                break
+        elif i==10:
+            if line.find('threshold b')==-1:
+                print "Parsing error!\n"
+                printOutput = False
+                break
+        
+        i+=1    
+    f.close()
+
+    #if you need to sort the features by value and not by feature ID then use this line intead:
+    #ws = sortbyvalue(w) 
+    
+    ws = sortbykey(w)
+    if printOutput == True:
+        for (i,j) in ws:
+            print i,':',j
+            i+=1
+
+