annotate utils.py @ 19:2837cfeaf353

Fixed plot functions in utils.py Started working for test_approx.py
author Nic Cleju <nikcleju@gmail.com>
date Thu, 05 Apr 2012 13:59:22 +0300
parents 7fdf964f4edd
children d395461b92ae
rev   line source
nikcleju@0 1 # -*- coding: utf-8 -*-
nikcleju@0 2 """
nikcleju@17 3 Some utility functions.
nikcleju@0 4
nikcleju@17 5 Author: Nicolae Cleju
nikcleju@0 6 """
nikcleju@0 7
nikcleju@0 8 import numpy
nikcleju@0 9 import scipy.io
nikcleju@0 10 import matplotlib.pyplot as plt
nikcleju@0 11 import matplotlib.cm as cm
nikcleju@0 12 import matplotlib.colors as mcolors
nikcleju@0 13
nikcleju@0 14 # Sample call
nikcleju@0 15 #utils.loadshowmatrices_multipixels('H:\\CS\\Python\\Results\\pt_std1\\approx_pt_std1.mat', dosave=True, saveplotbase='approx_pt_std1_',saveplotexts=('png','eps','pdf'))
nikcleju@0 16
nikcleju@19 17 def replot_exact(filename, algonames = None, doshow=True, dosave=False, saveplotbase=None, saveplotexts=None):
nikcleju@19 18 """
nikcleju@19 19 Replot exact recovery results from mat file, with better axis ticks and
nikcleju@19 20 other custom tweaked options.
nikcleju@19 21 """
nikcleju@19 22
nikcleju@19 23 mdict = scipy.io.loadmat(filename)
nikcleju@19 24
nikcleju@19 25 if algonames == None:
nikcleju@19 26 if 'algonames' in mdict:
nikcleju@19 27 algonames = mdict['algonames']
nikcleju@19 28 else:
nikcleju@19 29 print "No algonames given, and couldn't find them in mat file."
nikcleju@19 30 print "Exiting."
nikcleju@19 31 return
nikcleju@19 32
nikcleju@19 33 loadshowmatrices_multipixels(filename, algonames, [], [], algonames, [], doshow, dosave, saveplotbase, saveplotexts)
nikcleju@19 34
nikcleju@19 35 def replot_approx(filename, algonames = None, doshow=True, dosave=False, saveplotbase=None, saveplotexts=None):
nikcleju@19 36 """
nikcleju@19 37 Replot exact recovery results from mat file, with better axis ticks and
nikcleju@19 38 other custom tweaked options.
nikcleju@19 39 """
nikcleju@19 40
nikcleju@19 41 mdict = scipy.io.loadmat(filename)
nikcleju@19 42
nikcleju@19 43 if algonames == None:
nikcleju@19 44 if 'algosNnames' in mdict and 'algosLnames' in mdict:
nikcleju@19 45 algonames = numpy.vstack((mdict['algosNnames'], mdict['algosLnames']))
nikcleju@19 46 else:
nikcleju@19 47 print "No algonames given, and couldn't find them in mat file."
nikcleju@19 48 print "Exiting."
nikcleju@19 49 return
nikcleju@19 50
nikcleju@19 51 if dosave:
nikcleju@19 52 lambdas = mdict['lambdas']
nikcleju@19 53 threshs = [(0.85,2,0),(0.8,2,0.4),(0.5,2,1)]
nikcleju@19 54 withticks = ['GAP']
nikcleju@19 55 withnoaxes = [algoname[0][0] for algoname in algonames if algoname not in withticks]
nikcleju@19 56 #withnoaxes.remove('GAP')
nikcleju@19 57 loadshowmatrices_multipixels(filename, algonames, lambdas, threshs, withticks, withnoaxes, doshow, dosave, saveplotbase, saveplotexts)
nikcleju@19 58
nikcleju@19 59 def loadshowmatrices_multipixels(filename, algonames, lambdas, threshs = [], withticks = [], withnoaxes = [], doshow=True, dosave=False, saveplotbase=None, saveplotexts=None):
nikcleju@0 60
nikcleju@0 61 if dosave and (saveplotbase is None or saveplotexts is None):
nikcleju@0 62 print('Error: please specify name and extensions for saving')
nikcleju@0 63 raise Exception('Name or extensions for saving not specified')
nikcleju@0 64
nikcleju@0 65 mdict = scipy.io.loadmat(filename)
nikcleju@19 66
nikcleju@19 67 N = 10 # one data box = NxN
nikcleju@19 68
nikcleju@0 69 for algonameobj in algonames:
nikcleju@0 70 algoname = algonameobj[0][0]
nikcleju@0 71 print algoname
nikcleju@0 72 if mdict['meanmatrix'][algoname][0,0].ndim == 2:
nikcleju@0 73
nikcleju@0 74 # Prepare bigger matrix
nikcleju@0 75 rows,cols = mdict['meanmatrix'][algoname][0,0].shape
nikcleju@0 76 bigmatrix = numpy.zeros((N*rows,N*cols))
nikcleju@0 77 for i in numpy.arange(rows):
nikcleju@0 78 for j in numpy.arange(cols):
nikcleju@0 79 bigmatrix[i*N:i*N+N,j*N:j*N+N] = mdict['meanmatrix'][algoname][0,0][i,j]
nikcleju@19 80
nikcleju@19 81 for thrval,width,color in threshs:
nikcleju@19 82 bigmatrix = int_drawseparation(mdict['meanmatrix'][algoname][0,0],bigmatrix,N,thrval,width,color)
nikcleju@19 83
nikcleju@19 84 #bigmatrix = int_drawseparation(mdict['meanmatrix'][algoname][0,0],bigmatrix,10,0.95,2,0)
nikcleju@19 85 #bigmatrix = int_drawseparation(mdict['meanmatrix'][algoname][0,0],bigmatrix,10,0.8, 2,0.4)
nikcleju@19 86 #bigmatrix = int_drawseparation(mdict['meanmatrix'][algoname][0,0],bigmatrix,10,0.5, 2,1)
nikcleju@19 87
nikcleju@0 88 # # Mark 95% border
nikcleju@0 89 # if mdict['meanmatrix'][algoname][0,0][i,j] > thresh:
nikcleju@0 90 # # Top border
nikcleju@0 91 # if mdict['meanmatrix'][algoname][0,0][i-1,j] < thresh and i>0:
nikcleju@0 92 # bigmatrix[i*N:i*N+border,j*N:j*N+N] = bordercolor
nikcleju@0 93 # # Bottom border
nikcleju@0 94 # if mdict['meanmatrix'][algoname][0,0][i+1,j] < thresh and i<rows-1:
nikcleju@0 95 # bigmatrix[i*N+N-border:i*N+N,j*N:j*N+N] = bordercolor
nikcleju@0 96 # # Left border
nikcleju@0 97 # if mdict['meanmatrix'][algoname][0,0][i,j-1] < thresh and j>0:
nikcleju@0 98 # bigmatrix[i*N:i*N+N,j*N:j*N+border] = bordercolor
nikcleju@0 99 # # Right border (not very probable)
nikcleju@0 100 # if j<cols-1 and mdict['meanmatrix'][algoname][0,0][i,j+1] < thresh:
nikcleju@0 101 # bigmatrix[i*N:i*N+N,j*N+N-border:j*N+N] = bordercolor
nikcleju@0 102
nikcleju@0 103 plt.figure()
nikcleju@7 104 plt.imshow(bigmatrix, cmap=cm.gray, norm=mcolors.Normalize(0,1), interpolation='nearest',origin='lower')
nikcleju@19 105
nikcleju@19 106 if algoname in withticks:
nikcleju@8 107 int_setticks()
nikcleju@19 108 if algoname in withnoaxes:
nikcleju@8 109 plt.gca().get_xaxis().set_visible(False)
nikcleju@8 110 plt.gca().get_yaxis().set_visible(False)
nikcleju@7 111
nikcleju@0 112 if dosave:
nikcleju@0 113 for ext in saveplotexts:
nikcleju@0 114 plt.savefig(saveplotbase + algoname + '.' + ext, bbox_inches='tight')
nikcleju@0 115 elif mdict['meanmatrix'][algoname][0,0].ndim == 3:
nikcleju@0 116 if dosave:
nikcleju@0 117 ilbd = 0
nikcleju@0 118
nikcleju@0 119 for matrix in mdict['meanmatrix'][algoname][0,0]:
nikcleju@0 120
nikcleju@0 121 # Prepare bigger matrix
nikcleju@0 122 rows,cols = matrix.shape
nikcleju@0 123 bigmatrix = numpy.zeros((N*rows,N*cols))
nikcleju@0 124 for i in numpy.arange(rows):
nikcleju@0 125 for j in numpy.arange(cols):
nikcleju@0 126 bigmatrix[i*N:i*N+N,j*N:j*N+N] = matrix[i,j]
nikcleju@19 127
nikcleju@19 128 for thrval,width,color in threshs:
nikcleju@19 129 bigmatrix = int_drawseparation(mdict['meanmatrix'][algoname][0,0][ilbd],bigmatrix,N,thrval,width,color)
nikcleju@19 130
nikcleju@19 131 #bigmatrix = int_drawseparation(matrix,bigmatrix,10,0.95,2,0)
nikcleju@19 132 #bigmatrix = int_drawseparation(matrix,bigmatrix,10,0.8, 2,0.4)
nikcleju@19 133 #bigmatrix = int_drawseparation(matrix,bigmatrix,10,0.5, 2,1)
nikcleju@0 134 # # Mark 95% border
nikcleju@0 135 # if matrix[i,j] > thresh:
nikcleju@0 136 # # Top border
nikcleju@0 137 # if matrix[i-1,j] < thresh and i>0:
nikcleju@0 138 # bigmatrix[i*N:i*N+border,j*N:j*N+N] = bordercolor
nikcleju@0 139 # # Bottom border
nikcleju@0 140 # if matrix[i+1,j] < thresh and i<rows-1:
nikcleju@0 141 # bigmatrix[i*N+N-border:i*N+N,j*N:j*N+N] = bordercolor
nikcleju@0 142 # # Left border
nikcleju@0 143 # if matrix[i,j-1] < thresh and j>0:
nikcleju@0 144 # bigmatrix[i*N:i*N+N,j*N:j*N+border] = bordercolor
nikcleju@0 145 # # Right border (not very probable)
nikcleju@0 146 # if j<cols-1 and matrix[i,j+1] < thresh:
nikcleju@0 147 # bigmatrix[i*N:i*N+N,j*N+N-border:j*N+N] = bordercolor
nikcleju@0 148
nikcleju@0 149 plt.figure()
nikcleju@0 150 #plt.imshow(matrix, cmap=cm.gray, interpolation='nearest',origin='lower')
nikcleju@5 151 plt.imshow(bigmatrix, cmap=cm.gray, norm=mcolors.Normalize(0,1), interpolation='nearest',origin='lower')
nikcleju@19 152
nikcleju@19 153 #plt.gca().get_xaxis().set_visible(False)
nikcleju@19 154 #plt.gca().get_yaxis().set_visible(False)
nikcleju@7 155 #int_setticks()
nikcleju@19 156 if algoname in withticks:
nikcleju@19 157 int_setticks()
nikcleju@19 158 if algoname in withnoaxes:
nikcleju@19 159 plt.gca().get_xaxis().set_visible(False)
nikcleju@19 160 plt.gca().get_yaxis().set_visible(False)
nikcleju@19 161
nikcleju@0 162 if dosave:
nikcleju@0 163 for ext in saveplotexts:
nikcleju@0 164 plt.savefig(saveplotbase + algoname + ('_lbd%.0e' % lambdas[ilbd]) + '.' + ext, bbox_inches='tight')
nikcleju@19 165
nikcleju@19 166 ilbd = ilbd + 1
nikcleju@0 167 if doshow:
nikcleju@0 168 plt.show()
nikcleju@0 169 print "Finished."
nikcleju@0 170
nikcleju@0 171 def appendtomatfile(filename, toappend, toappendname):
nikcleju@0 172 mdict = scipy.io.loadmat(filename)
nikcleju@0 173 mdict[toappendname] = toappend
nikcleju@0 174 try:
nikcleju@0 175 scipy.io.savemat(filename, mdict)
nikcleju@0 176 except:
nikcleju@0 177 print "Save error"
nikcleju@0 178
nikcleju@0 179 # To save to a cell array, create an object array:
nikcleju@0 180 # >>> obj_arr = np.zeros((2,), dtype=np.object)
nikcleju@0 181 # >>> obj_arr[0] = 1
nikcleju@0 182 # >>> obj_arr[1] = 'a string'
nikcleju@0 183
nikcleju@0 184 def int_drawseparation(matrix,bigmatrix,N,thresh,border,bordercolor):
nikcleju@0 185 rows,cols = matrix.shape
nikcleju@0 186 for i in numpy.arange(rows):
nikcleju@0 187 for j in numpy.arange(cols):
nikcleju@0 188 # Mark border
nikcleju@0 189 # Use top-left corner of current square for reference
nikcleju@0 190 if matrix[i,j] > thresh:
nikcleju@0 191 # Top border
nikcleju@0 192 if matrix[i-1,j] < thresh and i>0:
nikcleju@0 193 bigmatrix[i*N:i*N+border,j*N:j*N+N] = bordercolor
nikcleju@0 194 # Bottom border
nikcleju@0 195 if i<rows-1 and matrix[i+1,j] < thresh:
nikcleju@0 196 bigmatrix[i*N+N-border:i*N+N,j*N:j*N+N] = bordercolor
nikcleju@0 197 # Left border
nikcleju@0 198 if matrix[i,j-1] < thresh and j>0:
nikcleju@0 199 bigmatrix[i*N:i*N+N,j*N:j*N+border] = bordercolor
nikcleju@0 200 # Right border (not very probable)
nikcleju@0 201 if j<cols-1 and matrix[i,j+1] < thresh:
nikcleju@0 202 bigmatrix[i*N:i*N+N,j*N+N-border:j*N+N] = bordercolor
nikcleju@0 203
nikcleju@7 204 return bigmatrix
nikcleju@19 205
nikcleju@7 206
nikcleju@7 207 def int_setticks():
nikcleju@7 208
nikcleju@8 209 #ticks = [10, 94, 179]
nikcleju@8 210 #ticklabels = ["0.05", "0.5", "0.95"]
nikcleju@8 211 ticks = [10, 179]
nikcleju@8 212 ticklabels = ["0.05", "0.95"]
nikcleju@7 213
nikcleju@7 214 ax = plt.gca()
nikcleju@7 215 ax.set_xticks(ticks)
nikcleju@7 216 ax.set_xticklabels(ticklabels)
nikcleju@7 217 ax.set_yticks(ticks)
nikcleju@7 218 ax.set_yticklabels(ticklabels)
nikcleju@7 219
nikcleju@7 220 for label in ax.get_xticklabels():
nikcleju@7 221 label.set_fontsize(42)
nikcleju@7 222 for label in ax.get_yticklabels():
nikcleju@7 223 label.set_fontsize(42)
nikcleju@7 224
nikcleju@8 225 ax.set_xlabel(r'$\delta$', size=60)
nikcleju@8 226 ax.set_ylabel(r'$\rho$', size=60)
nikcleju@8 227