view utils.py @ 19:2837cfeaf353

Fixed plot functions in utils.py Started working for test_approx.py
author Nic Cleju <nikcleju@gmail.com>
date Thu, 05 Apr 2012 13:59:22 +0300
parents 7fdf964f4edd
children d395461b92ae
line wrap: on
line source
# -*- coding: utf-8 -*-
"""
Some utility functions.

Author: Nicolae Cleju
"""

import numpy
import scipy.io
import matplotlib.pyplot as plt
import matplotlib.cm as cm
import matplotlib.colors as mcolors

# Sample call
#utils.loadshowmatrices_multipixels('H:\\CS\\Python\\Results\\pt_std1\\approx_pt_std1.mat', dosave=True, saveplotbase='approx_pt_std1_',saveplotexts=('png','eps','pdf'))

def replot_exact(filename, algonames = None, doshow=True, dosave=False, saveplotbase=None, saveplotexts=None):
    """
    Replot exact recovery results from mat file, with better axis ticks and 
    other custom tweaked options.
    """

    mdict = scipy.io.loadmat(filename)

    if algonames == None:
      if 'algonames' in mdict:
        algonames = mdict['algonames']
      else:
        print "No algonames given, and couldn't find them in mat file."
        print "Exiting."
        return
 
    loadshowmatrices_multipixels(filename, algonames, [], [], algonames, [], doshow, dosave, saveplotbase, saveplotexts)
      
def replot_approx(filename, algonames = None, doshow=True, dosave=False, saveplotbase=None, saveplotexts=None):
    """
    Replot exact recovery results from mat file, with better axis ticks and 
    other custom tweaked options.
    """

    mdict = scipy.io.loadmat(filename)

    if algonames == None:
      if 'algosNnames' in mdict and 'algosLnames' in mdict:
        algonames = numpy.vstack((mdict['algosNnames'], mdict['algosLnames']))
      else:
        print "No algonames given, and couldn't find them in mat file."
        print "Exiting."
        return
        
    if dosave:
      lambdas = mdict['lambdas']
    threshs = [(0.85,2,0),(0.8,2,0.4),(0.5,2,1)]
    withticks = ['GAP']
    withnoaxes = [algoname[0][0] for algoname in algonames if algoname not in withticks]
    #withnoaxes.remove('GAP')
    loadshowmatrices_multipixels(filename, algonames, lambdas, threshs, withticks, withnoaxes, doshow, dosave, saveplotbase, saveplotexts)

def loadshowmatrices_multipixels(filename, algonames, lambdas, threshs = [], withticks = [], withnoaxes = [], doshow=True, dosave=False, saveplotbase=None, saveplotexts=None):
  
    if dosave and (saveplotbase is None or saveplotexts is None):
      print('Error: please specify name and extensions for saving')
      raise Exception('Name or extensions for saving not specified')
      
    mdict = scipy.io.loadmat(filename)

    N = 10  # one data box = NxN

    for algonameobj in algonames:
        algoname = algonameobj[0][0]
        print algoname
        if mdict['meanmatrix'][algoname][0,0].ndim == 2:
            
            # Prepare bigger matrix
            rows,cols = mdict['meanmatrix'][algoname][0,0].shape
            bigmatrix = numpy.zeros((N*rows,N*cols))
            for i in numpy.arange(rows):
              for j in numpy.arange(cols):
                bigmatrix[i*N:i*N+N,j*N:j*N+N] = mdict['meanmatrix'][algoname][0,0][i,j]
            
            for thrval,width,color in threshs:
              bigmatrix = int_drawseparation(mdict['meanmatrix'][algoname][0,0],bigmatrix,N,thrval,width,color)

            #bigmatrix = int_drawseparation(mdict['meanmatrix'][algoname][0,0],bigmatrix,10,0.95,2,0)
            #bigmatrix = int_drawseparation(mdict['meanmatrix'][algoname][0,0],bigmatrix,10,0.8, 2,0.4)
            #bigmatrix = int_drawseparation(mdict['meanmatrix'][algoname][0,0],bigmatrix,10,0.5, 2,1)
            
#                # Mark 95% border
#                if mdict['meanmatrix'][algoname][0,0][i,j] > thresh:
#                  # Top border
#                  if mdict['meanmatrix'][algoname][0,0][i-1,j] < thresh and i>0:
#                    bigmatrix[i*N:i*N+border,j*N:j*N+N] = bordercolor
#                  # Bottom border
#                  if mdict['meanmatrix'][algoname][0,0][i+1,j] < thresh and i<rows-1:
#                    bigmatrix[i*N+N-border:i*N+N,j*N:j*N+N] = bordercolor                
#                  # Left border
#                  if mdict['meanmatrix'][algoname][0,0][i,j-1] < thresh and j>0:
#                    bigmatrix[i*N:i*N+N,j*N:j*N+border] = bordercolor
#                  # Right border (not very probable)
#                  if j<cols-1 and mdict['meanmatrix'][algoname][0,0][i,j+1] < thresh:
#                    bigmatrix[i*N:i*N+N,j*N+N-border:j*N+N] = bordercolor
                    
            plt.figure()
            plt.imshow(bigmatrix, cmap=cm.gray, norm=mcolors.Normalize(0,1), interpolation='nearest',origin='lower')
            
            if algoname in withticks:
              int_setticks()
            if algoname in withnoaxes:
              plt.gca().get_xaxis().set_visible(False)
              plt.gca().get_yaxis().set_visible(False)
              
            if dosave:
              for ext in saveplotexts:
                plt.savefig(saveplotbase + algoname + '.' + ext, bbox_inches='tight')            
        elif mdict['meanmatrix'][algoname][0,0].ndim == 3:
            if dosave:
              ilbd = 0          
              
            for matrix in mdict['meanmatrix'][algoname][0,0]:
              
                # Prepare bigger matrix
                rows,cols = matrix.shape
                bigmatrix = numpy.zeros((N*rows,N*cols))
                for i in numpy.arange(rows):
                  for j in numpy.arange(cols):
                    bigmatrix[i*N:i*N+N,j*N:j*N+N] = matrix[i,j]

                for thrval,width,color in threshs:
                  bigmatrix = int_drawseparation(mdict['meanmatrix'][algoname][0,0][ilbd],bigmatrix,N,thrval,width,color)

                #bigmatrix = int_drawseparation(matrix,bigmatrix,10,0.95,2,0)
                #bigmatrix = int_drawseparation(matrix,bigmatrix,10,0.8, 2,0.4)
                #bigmatrix = int_drawseparation(matrix,bigmatrix,10,0.5, 2,1)
#                    # Mark 95% border
#                    if matrix[i,j] > thresh:
#                      # Top border
#                      if matrix[i-1,j] < thresh and i>0:
#                        bigmatrix[i*N:i*N+border,j*N:j*N+N] = bordercolor
#                      # Bottom border
#                      if matrix[i+1,j] < thresh and i<rows-1:
#                        bigmatrix[i*N+N-border:i*N+N,j*N:j*N+N] = bordercolor                
#                      # Left border
#                      if matrix[i,j-1] < thresh and j>0:
#                        bigmatrix[i*N:i*N+N,j*N:j*N+border] = bordercolor
#                      # Right border (not very probable)
#                      if j<cols-1 and matrix[i,j+1] < thresh:
#                        bigmatrix[i*N:i*N+N,j*N+N-border:j*N+N] = bordercolor
                
                plt.figure()
                #plt.imshow(matrix, cmap=cm.gray, interpolation='nearest',origin='lower')
                plt.imshow(bigmatrix, cmap=cm.gray, norm=mcolors.Normalize(0,1), interpolation='nearest',origin='lower')

                #plt.gca().get_xaxis().set_visible(False)
                #plt.gca().get_yaxis().set_visible(False)
                #int_setticks()
                if algoname in withticks:
                  int_setticks()
                if algoname in withnoaxes:
                  plt.gca().get_xaxis().set_visible(False)
                  plt.gca().get_yaxis().set_visible(False)                
                
                if dosave:
                  for ext in saveplotexts:
                    plt.savefig(saveplotbase + algoname + ('_lbd%.0e' % lambdas[ilbd]) + '.' + ext, bbox_inches='tight')
                
                ilbd = ilbd + 1                
    if doshow:
      plt.show()
    print "Finished."    
    
def appendtomatfile(filename, toappend, toappendname):
  mdict = scipy.io.loadmat(filename)
  mdict[toappendname] = toappend
  try:
    scipy.io.savemat(filename, mdict)
  except:
    print "Save error"  
  
  # To save to a cell array, create an object array:
  #  >>> obj_arr = np.zeros((2,), dtype=np.object)
  #  >>> obj_arr[0] = 1
  #  >>> obj_arr[1] = 'a string'    
  
def int_drawseparation(matrix,bigmatrix,N,thresh,border,bordercolor):
  rows,cols = matrix.shape
  for i in numpy.arange(rows):
    for j in numpy.arange(cols):
      # Mark border
      # Use top-left corner of current square for reference
      if matrix[i,j] > thresh:
        # Top border
        if matrix[i-1,j] < thresh and i>0:
          bigmatrix[i*N:i*N+border,j*N:j*N+N] = bordercolor
        # Bottom border
        if i<rows-1 and matrix[i+1,j] < thresh:
          bigmatrix[i*N+N-border:i*N+N,j*N:j*N+N] = bordercolor                
        # Left border
        if matrix[i,j-1] < thresh and j>0:
          bigmatrix[i*N:i*N+N,j*N:j*N+border] = bordercolor
        # Right border (not very probable)
        if j<cols-1 and matrix[i,j+1] < thresh:
          bigmatrix[i*N:i*N+N,j*N+N-border:j*N+N] = bordercolor  
  
  return bigmatrix

  
def int_setticks():

  #ticks = [10, 94, 179]
  #ticklabels = ["0.05", "0.5", "0.95"]
  ticks = [10, 179]
  ticklabels = ["0.05", "0.95"]

  ax = plt.gca()
  ax.set_xticks(ticks)
  ax.set_xticklabels(ticklabels)
  ax.set_yticks(ticks)
  ax.set_yticklabels(ticklabels)
  
  for label in ax.get_xticklabels():
    label.set_fontsize(42) 
  for label in ax.get_yticklabels():
    label.set_fontsize(42)
    
  ax.set_xlabel(r'$\delta$', size=60)
  ax.set_ylabel(r'$\rho$', size=60)