Maria@4
|
1 # -*- coding: utf-8 -*-
|
Maria@4
|
2 """
|
Maria@4
|
3 Created on Wed May 17 11:35:51 2017
|
Maria@4
|
4
|
Maria@4
|
5 @author: mariapanteli
|
Maria@4
|
6 """
|
Maria@4
|
7 import numpy as np
|
Maria@4
|
8 import json
|
Maria@4
|
9 import pysal # before shapely in util_plots
|
Maria@4
|
10 import fiona
|
m@8
|
11 import os
|
Maria@4
|
12 import matplotlib.pyplot as plt
|
Maria@4
|
13
|
Maria@4
|
14
|
m@8
|
15 DATA_DIR = os.path.join(os.path.dirname(__file__), 'util_data')
|
m@8
|
16 JSON_DB = os.path.join(DATA_DIR, 'countries.json')
|
m@8
|
17 SHAPEFILE = os.path.join(DATA_DIR, 'shapefiles', 'ne_10m_admin_0_countries.shp')
|
m@8
|
18
|
m@8
|
19
|
m@8
|
20 def neighbors_from_json_file(data_countries, json_DB=JSON_DB):
|
Maria@4
|
21 neighbors = {}
|
Maria@4
|
22 with open(json_DB) as json_file:
|
Maria@4
|
23 countries_dict = json.load(json_file)
|
Maria@4
|
24 country_names = []
|
Maria@4
|
25 country_iso = []
|
Maria@4
|
26 country_borders_iso = []
|
Maria@4
|
27 for country_info in countries_dict:
|
Maria@4
|
28 country_names.append(country_info['name']['common'])
|
Maria@4
|
29 country_iso.append(country_info['cca3'])
|
Maria@4
|
30 country_borders_iso.append(country_info['borders'])
|
Maria@4
|
31 # temporary fixes of country names to match json data
|
Maria@4
|
32 country_names[country_names.index('United States')] = 'United States of America'
|
Maria@4
|
33 country_names[country_names.index('Tanzania')] = 'United Republic of Tanzania'
|
Maria@4
|
34 country_names[country_names.index('DR Congo')] = 'Democratic Republic of the Congo'
|
Maria@4
|
35 country_names[country_names.index('Czechia')] = 'Czech Republic'
|
Maria@4
|
36 for i, country in enumerate(data_countries):
|
Maria@4
|
37 neighbors[i] = {}
|
Maria@4
|
38 if country in country_names:
|
Maria@4
|
39 if len(country_borders_iso[country_names.index(country)])>0:
|
Maria@4
|
40 # if country has neighbors according to json file
|
Maria@4
|
41 neighbors_iso = country_borders_iso[country_names.index(country)]
|
Maria@4
|
42 neighbors_names = [country_names[country_iso.index(nn)] for nn in neighbors_iso]
|
Maria@4
|
43 for neighbor in neighbors_names:
|
Maria@4
|
44 if neighbor in data_countries:
|
Maria@4
|
45 neighbor_idx = np.where(data_countries==neighbor)[0][0]
|
Maria@4
|
46 neighbors[i][neighbor_idx] = 1.0
|
Maria@4
|
47 w = pysal.weights.W(neighbors, id_order=range(len(data_countries)))
|
Maria@4
|
48 return w
|
Maria@4
|
49
|
Maria@4
|
50
|
Maria@4
|
51 def get_countries_from_shapefile(shapefile):
|
Maria@4
|
52 shp = fiona.open(shapefile, 'r')
|
Maria@4
|
53 countries = []
|
Maria@4
|
54 if shp[0]["properties"].has_key("ADMIN"):
|
Maria@4
|
55 country_keyword = "ADMIN"
|
Maria@4
|
56 elif shp[0]["properties"].has_key("NAME"):
|
Maria@4
|
57 country_keyword = "NAME"
|
Maria@4
|
58 else:
|
Maria@4
|
59 country_keyword = "admin"
|
Maria@4
|
60 for line in shp:
|
Maria@4
|
61 countries.append(line["properties"][country_keyword])
|
Maria@4
|
62 shp.close()
|
Maria@4
|
63 return countries
|
Maria@4
|
64
|
Maria@4
|
65
|
Maria@4
|
66 def replace_empty_neighbours_with_KNN(data_countries, w):
|
m@8
|
67 shapefile = SHAPEFILE
|
Maria@4
|
68 no_neighbors_idx = w.islands
|
Maria@4
|
69 knn = 10
|
Maria@4
|
70 wknn = pysal.knnW_from_shapefile(shapefile, knn)
|
Maria@4
|
71 knn_countries = get_countries_from_shapefile(shapefile)
|
Maria@4
|
72 neighbors = w.neighbors
|
Maria@4
|
73 for nn_idx in no_neighbors_idx:
|
Maria@4
|
74 country = data_countries[nn_idx]
|
Maria@4
|
75 print country
|
Maria@4
|
76 if country not in knn_countries:
|
Maria@4
|
77 continue
|
Maria@4
|
78 knn_country_idx = knn_countries.index(country)
|
Maria@4
|
79 knn_country_neighbors = [knn_countries[nn] for nn in wknn.neighbors[knn_country_idx]]
|
Maria@4
|
80 for knn_nn in knn_country_neighbors:
|
Maria@4
|
81 if len(neighbors[nn_idx])>2:
|
Maria@4
|
82 continue
|
Maria@4
|
83 data_country_idx = np.where(data_countries==knn_nn)[0]
|
Maria@4
|
84 if len(data_country_idx)>0:
|
Maria@4
|
85 neighbors[nn_idx][data_country_idx[0]] = 1.0
|
Maria@4
|
86 w = pysal.weights.W(neighbors, id_order=range(len(data_countries)))
|
Maria@4
|
87 return w
|
Maria@4
|
88
|
Maria@4
|
89
|
Maria@4
|
90 def get_neighbors_for_countries_in_dataset(Y):
|
Maria@4
|
91 # neighbors
|
Maria@4
|
92 data_countries = np.unique(Y)
|
Maria@4
|
93 w = neighbors_from_json_file(data_countries)
|
Maria@4
|
94 w = replace_empty_neighbours_with_KNN(data_countries, w)
|
Maria@4
|
95 return w, data_countries
|
Maria@4
|
96
|
Maria@4
|
97
|
Maria@4
|
98 def from_weights_to_dict(w, data_countries):
|
Maria@4
|
99 w_dict = {}
|
Maria@4
|
100 for i in w.neighbors:
|
Maria@4
|
101 w_dict[data_countries[i]] = [data_countries[nn] for nn in w.neighbors[i]]
|
Maria@4
|
102 return w_dict
|
Maria@4
|
103
|
Maria@4
|
104
|
Maria@4
|
105 def get_LH_HL_idx(lm, p_vals):
|
Maria@4
|
106 sig_idx = np.where(p_vals<0.05)[0]
|
Maria@4
|
107 LH_idx = sig_idx[np.where(lm.q[sig_idx]==2)[0]]
|
Maria@4
|
108 HL_idx = sig_idx[np.where(lm.q[sig_idx]==4)[0]]
|
Maria@4
|
109 return LH_idx, HL_idx
|
Maria@4
|
110
|
Maria@4
|
111
|
Maria@4
|
112 def print_Moran_outliers(y, w, data_countries):
|
Maria@4
|
113 lm = pysal.Moran_Local(y, w)
|
Maria@4
|
114 p_vals = lm.p_z_sim
|
Maria@4
|
115 LH_idx, HL_idx = get_LH_HL_idx(lm, p_vals)
|
Maria@4
|
116 print 'LH', zip(data_countries[LH_idx], p_vals[LH_idx]) # LH
|
Maria@4
|
117 print 'HL', zip(data_countries[HL_idx], p_vals[HL_idx]) # HL
|
Maria@4
|
118
|
Maria@4
|
119
|
Maria@4
|
120 def plot_Moran_scatterplot(y, w, data_countries, out_file=None):
|
Maria@4
|
121 lm = pysal.Moran_Local(y, w)
|
Maria@4
|
122 p_vals = lm.p_z_sim
|
Maria@4
|
123 LH_idx, HL_idx = get_LH_HL_idx(lm, p_vals)
|
Maria@4
|
124
|
Maria@4
|
125 ylt = pysal.lag_spatial(lm.w, lm.y)
|
Maria@4
|
126 yt = lm.y
|
Maria@4
|
127 yt = (yt - np.mean(yt))/np.std(yt)
|
Maria@4
|
128 ylt = (ylt - np.mean(ylt))/np.std(ylt)
|
Maria@4
|
129 colors = plt.cm.spectral(np.linspace(0,1,5))
|
Maria@4
|
130 quad = np.zeros(yt.shape, dtype=int)
|
Maria@4
|
131 quad[np.bitwise_and(ylt > 0, yt > 0)]=1 # HH
|
Maria@4
|
132 quad[np.bitwise_and(ylt > 0, yt < 0)]=2 # LH
|
Maria@4
|
133 quad[np.bitwise_and(ylt < 0, yt < 0)]=3 # LL
|
Maria@4
|
134 quad[np.bitwise_and(ylt < 0, yt > 0)]=4 # HL
|
Maria@4
|
135 marker_color = colors[quad]
|
Maria@4
|
136 marker_size = 40*np.ones(yt.shape, dtype=int)
|
Maria@4
|
137 marker_size[LH_idx] = 140
|
Maria@4
|
138 marker_size[HL_idx] = 140
|
Maria@4
|
139
|
Maria@4
|
140 plt.figure()
|
Maria@4
|
141 plt.scatter(yt, ylt, c=marker_color, s=marker_size, alpha=0.7)
|
Maria@4
|
142 plt.xlabel('Value')
|
Maria@4
|
143 plt.ylabel('Spatially Lagged Value')
|
Maria@4
|
144 plt.axvline(c='black', ls='--')
|
Maria@4
|
145 plt.axhline(c='black', ls='--')
|
Maria@4
|
146 plt.ylim(min(ylt)-0.5, max(ylt)+0.5)
|
Maria@4
|
147 plt.xlim(min(yt)-0.5, max(yt)+1.5)
|
Maria@4
|
148 for i in np.concatenate((LH_idx, HL_idx)):
|
Maria@4
|
149 plt.annotate(data_countries[i], (yt[i], ylt[i]), xytext=(yt[i]*1.1, ylt[i]*1.1),textcoords='data',arrowprops=dict(arrowstyle="->",connectionstyle="arc3"))
|
Maria@4
|
150 extreme_points = np.concatenate(([np.argmin(ylt)], [np.argmax(ylt)], np.where(yt>np.mean(yt)+2.8*np.std(yt))[0], np.where(ylt>np.mean(yt)+2.8*np.std(ylt))[0]))
|
Maria@4
|
151 extreme_points = np.array(list(set(extreme_points) - set(np.concatenate((LH_idx, HL_idx)))))
|
Maria@4
|
152 for i in extreme_points:
|
Maria@4
|
153 plt.annotate(data_countries[i], (yt[i]+0.1, ylt[i]))
|
Maria@4
|
154 if out_file is not None:
|
Maria@4
|
155 plt.savefig(out_file)
|