diff utils.py @ 19:2837cfeaf353

Fixed plot functions in utils.py Started working for test_approx.py
author Nic Cleju <nikcleju@gmail.com>
date Thu, 05 Apr 2012 13:59:22 +0300
parents 7fdf964f4edd
children d395461b92ae
line wrap: on
line diff
--- a/utils.py	Thu Apr 05 11:01:22 2012 +0300
+++ b/utils.py	Thu Apr 05 13:59:22 2012 +0300
@@ -14,25 +14,58 @@
 # Sample call
 #utils.loadshowmatrices_multipixels('H:\\CS\\Python\\Results\\pt_std1\\approx_pt_std1.mat', dosave=True, saveplotbase='approx_pt_std1_',saveplotexts=('png','eps','pdf'))
 
-def loadshowmatrices_multipixels(filename, algonames = None, doshow=True, dosave=False, saveplotbase=None, saveplotexts=None):
+def replot_exact(filename, algonames = None, doshow=True, dosave=False, saveplotbase=None, saveplotexts=None):
+    """
+    Replot exact recovery results from mat file, with better axis ticks and 
+    other custom tweaked options.
+    """
+
+    mdict = scipy.io.loadmat(filename)
+
+    if algonames == None:
+      if 'algonames' in mdict:
+        algonames = mdict['algonames']
+      else:
+        print "No algonames given, and couldn't find them in mat file."
+        print "Exiting."
+        return
+ 
+    loadshowmatrices_multipixels(filename, algonames, [], [], algonames, [], doshow, dosave, saveplotbase, saveplotexts)
+      
+def replot_approx(filename, algonames = None, doshow=True, dosave=False, saveplotbase=None, saveplotexts=None):
+    """
+    Replot exact recovery results from mat file, with better axis ticks and 
+    other custom tweaked options.
+    """
+
+    mdict = scipy.io.loadmat(filename)
+
+    if algonames == None:
+      if 'algosNnames' in mdict and 'algosLnames' in mdict:
+        algonames = numpy.vstack((mdict['algosNnames'], mdict['algosLnames']))
+      else:
+        print "No algonames given, and couldn't find them in mat file."
+        print "Exiting."
+        return
+        
+    if dosave:
+      lambdas = mdict['lambdas']
+    threshs = [(0.85,2,0),(0.8,2,0.4),(0.5,2,1)]
+    withticks = ['GAP']
+    withnoaxes = [algoname[0][0] for algoname in algonames if algoname not in withticks]
+    #withnoaxes.remove('GAP')
+    loadshowmatrices_multipixels(filename, algonames, lambdas, threshs, withticks, withnoaxes, doshow, dosave, saveplotbase, saveplotexts)
+
+def loadshowmatrices_multipixels(filename, algonames, lambdas, threshs = [], withticks = [], withnoaxes = [], doshow=True, dosave=False, saveplotbase=None, saveplotexts=None):
   
     if dosave and (saveplotbase is None or saveplotexts is None):
       print('Error: please specify name and extensions for saving')
       raise Exception('Name or extensions for saving not specified')
       
     mdict = scipy.io.loadmat(filename)
-    
-    if dosave:
-      lambdas = mdict['lambdas']
-      
-    if algonames == None:
-      algonames = mdict['algonames']
-    
-#    thresh = 0.90
-    N = 10
-#    border = 2
-#    bordercolor = 0 # black
-    
+
+    N = 10  # one data box = NxN
+
     for algonameobj in algonames:
         algoname = algonameobj[0][0]
         print algoname
@@ -44,10 +77,14 @@
             for i in numpy.arange(rows):
               for j in numpy.arange(cols):
                 bigmatrix[i*N:i*N+N,j*N:j*N+N] = mdict['meanmatrix'][algoname][0,0][i,j]
-            bigmatrix = int_drawseparation(mdict['meanmatrix'][algoname][0,0],bigmatrix,10,0.95,2,0)
-            #bigmatrix = int_drawseparation(mdict['meanmatrix'][algoname][0,0],bigmatrix,10,0.9, 2,0.2)
-            bigmatrix = int_drawseparation(mdict['meanmatrix'][algoname][0,0],bigmatrix,10,0.8, 2,0.4)
-            bigmatrix = int_drawseparation(mdict['meanmatrix'][algoname][0,0],bigmatrix,10,0.5, 2,1)
+            
+            for thrval,width,color in threshs:
+              bigmatrix = int_drawseparation(mdict['meanmatrix'][algoname][0,0],bigmatrix,N,thrval,width,color)
+
+            #bigmatrix = int_drawseparation(mdict['meanmatrix'][algoname][0,0],bigmatrix,10,0.95,2,0)
+            #bigmatrix = int_drawseparation(mdict['meanmatrix'][algoname][0,0],bigmatrix,10,0.8, 2,0.4)
+            #bigmatrix = int_drawseparation(mdict['meanmatrix'][algoname][0,0],bigmatrix,10,0.5, 2,1)
+            
 #                # Mark 95% border
 #                if mdict['meanmatrix'][algoname][0,0][i,j] > thresh:
 #                  # Top border
@@ -64,11 +101,11 @@
 #                    bigmatrix[i*N:i*N+N,j*N+N-border:j*N+N] = bordercolor
                     
             plt.figure()
-            #plt.imshow(mdict['meanmatrix'][algoname][0,0], cmap=cm.gray, interpolation='nearest',origin='lower')            
             plt.imshow(bigmatrix, cmap=cm.gray, norm=mcolors.Normalize(0,1), interpolation='nearest',origin='lower')
-            if algoname == 'GAP':
+            
+            if algoname in withticks:
               int_setticks()
-            else:
+            if algoname in withnoaxes:
               plt.gca().get_xaxis().set_visible(False)
               plt.gca().get_yaxis().set_visible(False)
               
@@ -87,10 +124,13 @@
                 for i in numpy.arange(rows):
                   for j in numpy.arange(cols):
                     bigmatrix[i*N:i*N+N,j*N:j*N+N] = matrix[i,j]
-                bigmatrix = int_drawseparation(matrix,bigmatrix,10,0.95,2,0)
-                #bigmatrix = int_drawseparation(matrix,bigmatrix,10,0.9, 2,0.2)
-                bigmatrix = int_drawseparation(matrix,bigmatrix,10,0.8, 2,0.4)
-                bigmatrix = int_drawseparation(matrix,bigmatrix,10,0.5, 2,1)
+
+                for thrval,width,color in threshs:
+                  bigmatrix = int_drawseparation(mdict['meanmatrix'][algoname][0,0][ilbd],bigmatrix,N,thrval,width,color)
+
+                #bigmatrix = int_drawseparation(matrix,bigmatrix,10,0.95,2,0)
+                #bigmatrix = int_drawseparation(matrix,bigmatrix,10,0.8, 2,0.4)
+                #bigmatrix = int_drawseparation(matrix,bigmatrix,10,0.5, 2,1)
 #                    # Mark 95% border
 #                    if matrix[i,j] > thresh:
 #                      # Top border
@@ -109,13 +149,21 @@
                 plt.figure()
                 #plt.imshow(matrix, cmap=cm.gray, interpolation='nearest',origin='lower')
                 plt.imshow(bigmatrix, cmap=cm.gray, norm=mcolors.Normalize(0,1), interpolation='nearest',origin='lower')
-                plt.gca().get_xaxis().set_visible(False)
-                plt.gca().get_yaxis().set_visible(False)
+
+                #plt.gca().get_xaxis().set_visible(False)
+                #plt.gca().get_yaxis().set_visible(False)
                 #int_setticks()
+                if algoname in withticks:
+                  int_setticks()
+                if algoname in withnoaxes:
+                  plt.gca().get_xaxis().set_visible(False)
+                  plt.gca().get_yaxis().set_visible(False)                
+                
                 if dosave:
                   for ext in saveplotexts:
                     plt.savefig(saveplotbase + algoname + ('_lbd%.0e' % lambdas[ilbd]) + '.' + ext, bbox_inches='tight')
-                  ilbd = ilbd + 1                
+                
+                ilbd = ilbd + 1                
     if doshow:
       plt.show()
     print "Finished."    
@@ -154,6 +202,7 @@
           bigmatrix[i*N:i*N+N,j*N+N-border:j*N+N] = bordercolor  
   
   return bigmatrix
+
   
 def int_setticks():