Mercurial > hg > plosone_underreview
comparison scripts/utils_spatial.py @ 4:e50c63cf96be branch-tests
rearranging folders
author | Maria Panteli |
---|---|
date | Mon, 11 Sep 2017 11:51:50 +0100 |
parents | |
children | 0f3eba42b425 |
comparison
equal
deleted
inserted
replaced
3:230a0cf17de0 | 4:e50c63cf96be |
---|---|
1 # -*- coding: utf-8 -*- | |
2 """ | |
3 Created on Wed May 17 11:35:51 2017 | |
4 | |
5 @author: mariapanteli | |
6 """ | |
7 import numpy as np | |
8 import json | |
9 import pysal # before shapely in util_plots | |
10 import fiona | |
11 import sys | |
12 sys.path.append('../misc') | |
13 import matplotlib.pyplot as plt | |
14 | |
15 | |
16 def neighbors_from_json_file(data_countries, json_DB='../MergeBL-Smith/data/countries.json'): | |
17 neighbors = {} | |
18 with open(json_DB) as json_file: | |
19 countries_dict = json.load(json_file) | |
20 country_names = [] | |
21 country_iso = [] | |
22 country_borders_iso = [] | |
23 for country_info in countries_dict: | |
24 country_names.append(country_info['name']['common']) | |
25 country_iso.append(country_info['cca3']) | |
26 country_borders_iso.append(country_info['borders']) | |
27 # temporary fixes of country names to match json data | |
28 country_names[country_names.index('United States')] = 'United States of America' | |
29 country_names[country_names.index('Tanzania')] = 'United Republic of Tanzania' | |
30 country_names[country_names.index('DR Congo')] = 'Democratic Republic of the Congo' | |
31 country_names[country_names.index('Czechia')] = 'Czech Republic' | |
32 for i, country in enumerate(data_countries): | |
33 neighbors[i] = {} | |
34 if country in country_names: | |
35 if len(country_borders_iso[country_names.index(country)])>0: | |
36 # if country has neighbors according to json file | |
37 neighbors_iso = country_borders_iso[country_names.index(country)] | |
38 neighbors_names = [country_names[country_iso.index(nn)] for nn in neighbors_iso] | |
39 for neighbor in neighbors_names: | |
40 if neighbor in data_countries: | |
41 neighbor_idx = np.where(data_countries==neighbor)[0][0] | |
42 neighbors[i][neighbor_idx] = 1.0 | |
43 w = pysal.weights.W(neighbors, id_order=range(len(data_countries))) | |
44 return w | |
45 | |
46 | |
47 def get_countries_from_shapefile(shapefile): | |
48 shp = fiona.open(shapefile, 'r') | |
49 countries = [] | |
50 if shp[0]["properties"].has_key("ADMIN"): | |
51 country_keyword = "ADMIN" | |
52 elif shp[0]["properties"].has_key("NAME"): | |
53 country_keyword = "NAME" | |
54 else: | |
55 country_keyword = "admin" | |
56 for line in shp: | |
57 countries.append(line["properties"][country_keyword]) | |
58 shp.close() | |
59 return countries | |
60 | |
61 | |
62 def replace_empty_neighbours_with_KNN(data_countries, w): | |
63 shapefile = "../MergeBL-Smith/shapefiles/ne_10m_admin_0_countries.shp" | |
64 no_neighbors_idx = w.islands | |
65 knn = 10 | |
66 wknn = pysal.knnW_from_shapefile(shapefile, knn) | |
67 knn_countries = get_countries_from_shapefile(shapefile) | |
68 neighbors = w.neighbors | |
69 for nn_idx in no_neighbors_idx: | |
70 country = data_countries[nn_idx] | |
71 print country | |
72 if country not in knn_countries: | |
73 continue | |
74 knn_country_idx = knn_countries.index(country) | |
75 knn_country_neighbors = [knn_countries[nn] for nn in wknn.neighbors[knn_country_idx]] | |
76 for knn_nn in knn_country_neighbors: | |
77 if len(neighbors[nn_idx])>2: | |
78 continue | |
79 data_country_idx = np.where(data_countries==knn_nn)[0] | |
80 if len(data_country_idx)>0: | |
81 neighbors[nn_idx][data_country_idx[0]] = 1.0 | |
82 w = pysal.weights.W(neighbors, id_order=range(len(data_countries))) | |
83 return w | |
84 | |
85 | |
86 def get_neighbors_for_countries_in_dataset(Y): | |
87 # neighbors | |
88 data_countries = np.unique(Y) | |
89 w = neighbors_from_json_file(data_countries) | |
90 w = replace_empty_neighbours_with_KNN(data_countries, w) | |
91 return w, data_countries | |
92 | |
93 | |
94 def from_weights_to_dict(w, data_countries): | |
95 w_dict = {} | |
96 for i in w.neighbors: | |
97 w_dict[data_countries[i]] = [data_countries[nn] for nn in w.neighbors[i]] | |
98 return w_dict | |
99 | |
100 | |
101 def get_LH_HL_idx(lm, p_vals): | |
102 sig_idx = np.where(p_vals<0.05)[0] | |
103 LH_idx = sig_idx[np.where(lm.q[sig_idx]==2)[0]] | |
104 HL_idx = sig_idx[np.where(lm.q[sig_idx]==4)[0]] | |
105 return LH_idx, HL_idx | |
106 | |
107 | |
108 def print_Moran_outliers(y, w, data_countries): | |
109 lm = pysal.Moran_Local(y, w) | |
110 p_vals = lm.p_z_sim | |
111 LH_idx, HL_idx = get_LH_HL_idx(lm, p_vals) | |
112 print 'LH', zip(data_countries[LH_idx], p_vals[LH_idx]) # LH | |
113 print 'HL', zip(data_countries[HL_idx], p_vals[HL_idx]) # HL | |
114 | |
115 | |
116 def plot_Moran_scatterplot(y, w, data_countries, out_file=None): | |
117 lm = pysal.Moran_Local(y, w) | |
118 p_vals = lm.p_z_sim | |
119 LH_idx, HL_idx = get_LH_HL_idx(lm, p_vals) | |
120 | |
121 ylt = pysal.lag_spatial(lm.w, lm.y) | |
122 yt = lm.y | |
123 yt = (yt - np.mean(yt))/np.std(yt) | |
124 ylt = (ylt - np.mean(ylt))/np.std(ylt) | |
125 colors = plt.cm.spectral(np.linspace(0,1,5)) | |
126 quad = np.zeros(yt.shape, dtype=int) | |
127 quad[np.bitwise_and(ylt > 0, yt > 0)]=1 # HH | |
128 quad[np.bitwise_and(ylt > 0, yt < 0)]=2 # LH | |
129 quad[np.bitwise_and(ylt < 0, yt < 0)]=3 # LL | |
130 quad[np.bitwise_and(ylt < 0, yt > 0)]=4 # HL | |
131 marker_color = colors[quad] | |
132 marker_size = 40*np.ones(yt.shape, dtype=int) | |
133 marker_size[LH_idx] = 140 | |
134 marker_size[HL_idx] = 140 | |
135 | |
136 plt.figure() | |
137 plt.scatter(yt, ylt, c=marker_color, s=marker_size, alpha=0.7) | |
138 plt.xlabel('Value') | |
139 plt.ylabel('Spatially Lagged Value') | |
140 plt.axvline(c='black', ls='--') | |
141 plt.axhline(c='black', ls='--') | |
142 plt.ylim(min(ylt)-0.5, max(ylt)+0.5) | |
143 plt.xlim(min(yt)-0.5, max(yt)+1.5) | |
144 for i in np.concatenate((LH_idx, HL_idx)): | |
145 plt.annotate(data_countries[i], (yt[i], ylt[i]), xytext=(yt[i]*1.1, ylt[i]*1.1),textcoords='data',arrowprops=dict(arrowstyle="->",connectionstyle="arc3")) | |
146 extreme_points = np.concatenate(([np.argmin(ylt)], [np.argmax(ylt)], np.where(yt>np.mean(yt)+2.8*np.std(yt))[0], np.where(ylt>np.mean(yt)+2.8*np.std(ylt))[0])) | |
147 extreme_points = np.array(list(set(extreme_points) - set(np.concatenate((LH_idx, HL_idx))))) | |
148 for i in extreme_points: | |
149 plt.annotate(data_countries[i], (yt[i]+0.1, ylt[i])) | |
150 if out_file is not None: | |
151 plt.savefig(out_file) |