city2graph、OSMnx、PyTorch Geometricを用いた都市機能推論のための空間グラフニューラルネットワークの実装チュートリアル
このチュートリアルは、city2graph と OSMnx を活用して OpenStreetMap データから都市空間グラフを構築し、PyTorch Geometric で GraphSAGE モデルを訓練することで、都市機能の推論を行う実用的なエンドツーエンド・パイプラインを提示している。
キーポイント
統合されたデータ収集と前処理ワークフロー
OpenStreetMap から POI(施設)および道路ネットワークデータを取得し、信頼性を確保するための合成データフォールバック機能を実装した自動パイプラインを構築している。
多様なグラフ構造の構築と比較
異なる近接性戦略に基づいて複数のグラフファミリー(均質・非均質)を作成し、同じ都市環境をどのように表現するかの比較分析を行っている。
GraphSAGE による都市機能推論の実装
構築されたグラフ構造を PyTorch Geometric 形式に変換し、空間構造から POI カテゴリ(飲食店、教育施設など)を予測する GraphSAGE モデルの訓練と評価を実施している。
OSM データ取得とフォールバック戦略
Shibuya 周辺から OpenStreetMap を利用して POI と歩行ネットワークを取得し、通信エラー発生時にはランダム生成された合成データセットで処理を継続する堅牢な設計を採用している。
座標系の変換とサンプリング
空間分析に適した UTM 座標系へ変換し、POI の数が 700 を超える場合はランダムサンプリングで固定サイズに調整して計算負荷を管理している。
都市機能カテゴリの定義
food, retail, education, health の 4 つのカテゴリに対して、amenity や shop タグに基づいた具体的なクエリ設定を行い、都市機能を構造化して分類する。
POI の空間特徴量エンジニアリング
各 POI について投影座標の抽出、半径 150m 以内の近傍点による局所密度計算、および最寄りの道路セグメントまでの距離推定を行い、グラフ構築用の特徴量を準備します。
影響分析・編集コメントを表示
影響分析
本記事は、都市計画やスマートシティ分野における GNN の実装ハードルを下げ、研究者やエンジニアが即座に実験を開始できる具体的なコードベースを提供することで、空間グラフ学習の実用化を加速させる。特に、データの欠落に対するフォールバック戦略を含んでいる点は、現場での信頼性確保において重要な示唆を与える。
編集コメント
都市機能推論という具体的な課題に対し、ライブラリの連携とモデル設計の両面から詳細な実装例を示しており、空間 AI 分野の実践者にとって非常に価値の高いリソースです。
本チュートリアルでは、city2graph を用いたエンドツーエンドの空間グラフ学習パイプラインを構築します。まず、OpenStreetMap から実際の都市 POI(ポイント・オブ・インタレスト)データとストリートネットワーク情報を収集し、ワークフローの信頼性を確保するために合成データのフォールバックも用意します。次に、空間特徴量をエンジニアリングし、複数の近接性グラフファミリーを構築して、異なるグラフ構築戦略が同じ都市環境をどのように表現するかを比較します。その後、異種グラフ構造と同種グラフ構造の両方を作成し、PyTorch Geometric 形式に変換した上で、GraphSAGE モデルを訓練して空間構造から POI カテゴリを予測します。このプロセスを通じて、地理空間データ処理、グラフ構築、GNN(グラフニューラルネットワーク)に基づく都市機能推論を単一の実践的なワークフローに統合します。
city2graph のインストールと地理空間・グラフ学習ライブラリのインポート
コードをコピーしました
別のブラウザを使用してください
!pip -q install "city2graph[cpu]" osmnx contextily scikit-learn 2>/dev/null
import warnings, numpy as np, pandas as pd, geopandas as gpd
warnings.filterwarnings("ignore")
from shapely.geometry import Point
import matplotlib.pyplot as plt
import city2graph as c2g
print("city2graph version:", getattr(c2g, "__version__", "unknown"))
print("PyTorch / PyG available:", c2g.is_torch_available())
import torch
import torch.nn.functional as F
from torch_geometric.nn import SAGEConv, to_hetero
from torch_geometric.utils import to_undirected
from sklearn.preprocessing import StandardScaler
from sklearn.neighbors import NearestNeighbors
from sklearn.metrics import accuracy_score, f1_score
from sklearn.decomposition import PCA
SEED = 42
np.random.seed(SEED); torch.manual_seed(SEED)
まず、必要なライブラリのインストールと、本チュートリアル全体で使用される地理空間データ処理、グラフ学習、機械学習ツールのインポートを行います。city2graph と PyTorch Geometric が利用可能であることを確認し、以降のワークフローが正常に実行できるようにします。また、グラフ構築、トレーニングデータの分割、モデルの結果の再現性を高めるために、固定された乱数シードを設定しています。
合成フォールバックを用いた OpenStreetMap POI データの収集
CENTER = (35.6595, 139.7005)
DIST_M = 1100
TAG_QUERIES = {
"food": {"amenity": ["restaurant", "cafe", "fast_food", "bar", "pub"]},
"retail": {"shop": True},
"education": {"amenity": ["school", "university", "college", "kindergarten", "library"]},
"health": {"amenity": ["hospital", "clinic", "pharmacy", "doctors", "dentist"]},
}
def to_points(gdf):
g = gdf.copy()
g["geometry"] = g.geometry.representative_point()
return g
poi_gdf, segments_gdf = None, None
try:
import osmnx as ox
ox.settings.use_cache = True
ox.settings.log_console = False
frames = []
for label, tags in TAG_QUERIES.items():
try:
f = ox.features_from_point(CENTER, tags=tags, dist=DIST_M)
f = f[f.geometry.notna()]
if len(f):
f = to_points(f)[["geometry"]].copy()
f["category"] = label
frames.append(f)
except Exception as e:
print(f" (skip {label}: {e})")
if not frames:
raise RuntimeError("No POIs returned from Overpass.")
poi_gdf = gpd.GeoDataFrame(pd.concat(frames, ignore_index=True), crs="EPSG:4326")
G = ox.graph_from_point(CENTER, dist=DIST_M, network_type="walk")
segments_gdf = ox.graph_to_gdfs(G, nodes=False, edges=True).reset_index(drop=True)[["geometry"]]
print(f"OSM acquisition OK -> {len(poi_gdf)} POIs, {len(segments_gdf)} street segments")
except Exception as e:
print(f"OSM unavailable ({e}) -> generating synthetic clustered POIs.")
rng = np.random.default_rng(SEED)
cats = list(TAG_QUERIES.keys())
centers = rng.uniform(-0.01, 0.01, size=(8, 2)) + np.array(CENTER[::-1])
rows = []
for ci, c in enumerate(centers):
dom = cats[ci % len(cats)]
n = rng.integers(40, 90)
pts = c + rng.normal(0, 0.0016, size=(n, 2))
for (lon, lat) in pts:
cat = dom if rng.random() < 0.75 else rng.choice(cats)
rows.append({"geometry": Point(lon, lat), "category": cat})
poi_gdf = gpd.GeoDataFrame(rows, crs="EPSG:4326")
segments_gdf = None
print(f"Synthetic dataset -> {len(poi_gdf)} POIs")
if len(poi_gdf) > 700:
poi_gdf = poi_gdf.sample(700, random_state=SEED).reset_index(drop=True)
metric_crs = poi_gdf.estimate_utm_crs()
poi_gdf = poi_gdf.to_crs(metric_crs).reset_index(drop=True)
if segments_gdf is not None:
segments_gdf = segments_gdf.to_crs(metric_crs)
print("Class balance:\n", poi_gdf["category"].value_counts())
東京渋谷周辺から OpenStreetMap の実在 POI データを収集し、食品、小売、教育、健康といった広範な都市機能カテゴリに位置情報をグループ化します。また、後で POI を都市形態特徴と接続できるように、歩行者可能な道路ネットワークもダウンロードします。OSM へのリクエストが失敗した場合、チュートリアルがオンラインデータアクセス不可時でも実行可能となるよう、合成されたクラスタ型データを生成します。
空間特徴のエンジニアリングと近接グラフファミリーの構築
Copy CodeCopiedUse a different Browser
poi_gdf["cx"] = poi_gdf.geometry.x
poi_gdf["cy"] = poi_gdf.geometry.y
coords = poi_gdf[["cx", "cy"]].to_numpy()
nn = NearestNeighbors(radius=150.0).fit(coords)
poi_gdf["local_density"] = [len(idx) - 1 for idx in nn.radius_neighbors(coords, return_distance=False)]
if segments_gdf is not None and len(segments_gdf):
try:
joined = gpd.sjoin_nearest(poi_gdf[["geometry"]], segments_gdf[["geometry"]],
distance_col="dist_street")
poi_gdf["dist_street"] = joined.groupby(level=0)["dist_street"].min().reindex(poi_gdf.index).fillna(0.0)
except Exception:
poi_gdf["dist_street"] = 0.0
else:
poi_gdf["dist_street"] = 0.0
poi_gdf["category"] = poi_gdf["category"].astype("category")
poi_gdf["label"] = poi_gdf["category"].cat.codes.astype(int)
CLASS_NAMES = list(poi_gdf["category"].cat.categories)
print("Classes:", CLASS_NAMES)
def graph_stats(name, builder):
try:
nodes, edges = builder()
deg = pd.Series(np.r_[edges.index.get_level_values(0),
edges.index.get_level_values(1)]).value_counts()
return name, len(edges), round(deg.mean(), 2), (nodes, edges)
except Exception as e:
return name, f"ERR: {e}", None, None
builders = {
"KNN (k=8)": lambda: c2g.knn_graph(poi_gdf, distance_metric="euclidean", k=8, as_nx=False),
"Delaunay": lambda: c2g.delaunay_graph(poi_gdf, as_nx=False),
"Gabriel": lambda: c2g.gabriel_graph(poi_gdf, as_nx=False),
"RNG": lambda: c2g.relative_neighborhood_graph(poi_gdf, as_nx=False),
"EMST": lambda: c2g.euclidean_minimum_spanning_tree(poi_gdf, as_nx=False),
"Waxman": lambda: c2g.waxman_graph(poi_gdf, distance_metric="euclidean", r0=150, beta=0.6),
}
print("\n--- Proximity graph comparison ---")
print(f"{'graph':<14}{'#edges':>10}{'avg_degree':>12}")
built = {}
for nm, b in builders.items():
name, ne, avgdeg, payload = graph_stats(nm, b)
print(f"{name:<14}{str(ne):>10}{str(avgdeg):>12}")
if payload: built[nm] = payload
fig, axes = plt.subplots(1, 3, figsize=(16, 5))
for ax, key in zip(axes, ["KNN (k=8)", "Delaunay", "EMST"]):
if key in built:
n_, e_ = built[key]
e_.plot(ax=ax, linewidth=0.4, color="#3b7dd8", alpha=0.6)
poi_gdf.plot(ax=ax, markersize=4, color="#d83b5c")
ax.set_title(key); ax.set_axis_off()
plt.suptitle("Spatial graph topologies on the same POI set", y=1.02)
plt.tight_layout(); plt.show()
各 POI の空間特徴量は、投影座標の抽出、局所密度の計算、最寄りの道路セグメントまでの距離推定によって構築します。その後、カテゴリラベルを割り当て、KNN、Delaunay、Gabriel、RNG(Random Geometric Network)、EMST(Euclidean Minimum Spanning Tree)、Waxman など複数の近接性グラフファミリーを構築します。これらのグラフの辺数と平均次数を比較し、選択したグラフトポロジーを可視化して、同じ POI 集合をどのように異なる方法で接続しているかを確認します。
PyTorch Geometric における異種および同種グラフの構築
Copy CodeCopiedUse a different Browser
nodes_dict = {}
for cat in CLASS_NAMES:
sub = poi_gdf[poi_gdf["category"] == cat].copy().reset_index(drop=True)
nodes_dict[cat] = sub[["geometry", "cx", "cy", "local_density"]]
try:
_, bridge_edges = c2g.bridge_nodes(nodes_dict, proximity_method="knn", k=3,
distance_metric="euclidean")
hetero = c2g.gdf_to_pyg(
nodes_dict, bridge_edges,
node_feature_cols={cat: ["cx", "cy", "local_density"] for cat in CLASS_NAMES},
)
print("\nHeteroData node types:", hetero.node_types)
print("HeteroData edge types:")
for et in hetero.edge_types:
print(f" {et}: {hetero[et].edge_index.shape[1]} edges")
except Exception as e:
hetero = None
print("Heterogeneous build skipped:", e)
nodes, edges = c2g.knn_graph(poi_gdf, distance_metric="euclidean", k=8, as_nx=False)
deg = pd.Series(np.r_[edges.index.get_level_values(0),
edges.index.get_level_values(1)]).value_counts()
nodes["degree"] = deg.reindex(nodes.index).fillna(0).astype(float)
for col in ["cx", "cy", "local_density", "dist_street", "label"]:
if col not in nodes.columns:
nodes[col] = poi_gdf.loc[nodes.index, col].values
FEATS = ["cx", "cy", "local_density", "dist_street", "degree"]
nodes[FEATS] = StandardScaler().fit_transform(nodes[FEATS].astype(float))
data = c2g.gdf_to_pyg(nodes, edges, node_feature_cols=FEATS, node_label_cols=["label"])
data.edge_index = to_undirected(data.edge_index)
data.x = data.x.float()
y = data.y.long().view(-1)
N, num_classes = data.num_nodes, int(y.max()) + 1
print(f"\nHomogeneous Data: {N} nodes, {data.edge_index.shape[1]} directed-edges, "
f"{data.x.shape[1]} features, {num_classes} classes")
都市機能カテゴリに基づいて POI をノードタイプに分離し、異種多層グラフを構築します。その後、ブリッジエッジを用いて異なるレイヤー間の近接ノードを接続し、その結果を PyTorch Geometric の HeteroData 形式に変換します。次に、均質 KNN グラフを構築し、次数とエンジニアリングされた特徴量を付加して標準化し、GraphSAGE 学習用の最終的な PyG Data オブジェクトを用意します。
POI 分類のための GraphSAGE モデルの定義とトレーニング
Copy CodeCopiedUse a different Browser
perm = torch.randperm(N, generator=torch.Generator().manual_seed(SEED))
n_tr, n_va = int(0.6 * N), int(0.2 * N)
train_mask = torch.zeros(N, dtype=torch.bool); train_mask[perm[:n_tr]] = True
val_mask = torch.zeros(N, dtype=torch.bool); val_mask[perm[n_tr:n_tr + n_va]] = True
test_mask = torch.zeros(N, dtype=torch.bool); test_mask[perm[n_tr + n_va:]] = True
class GraphSAGE(torch.nn.Module):
def __init__(self, in_dim, hidden, out_dim, p=0.3):
super().__init__()
self.c1 = SAGEConv(in_dim, hidden)
self.c2 = SAGEConv(hidden, hidden)
self.lin = torch.nn.Linear(hidden, out_dim)
self.p = p
def forward(self, x, ei, return_emb=False):
h = F.relu(self.c1(x, ei))
h = F.dropout(h, p=self.p, training=self.training)
h = F.relu(self.c2(h, ei))
out = self.lin(h)
return (out, h) if return_emb else out
model = GraphSAGE(data.x.shape[1], 64, num_classes)
opt = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=5e-4)
def evaluate(mask):
model.eval()
with torch.no_grad():
pred = model(data.x, data.edge_index).argmax(1)
yt, yp = y[mask].numpy(), pred[mask].numpy()
return accuracy_score(yt, yp), f1_score(yt, yp, average="macro")
print("\n--- Training GraphSAGE ---")
best_val, best_state = 0.0, None
for epoch in range(1, 201):
model.train(); opt.zero_grad()
out = model(data.x, data.edge_index)
loss = F.cross_entropy(out[train_mask], y[train_mask])
loss.backward(); opt.step()
if epoch % 20 == 0:
va_acc, va_f1 = evaluate(val_mask)
if va_acc > best_val:
best_val, best_state = va_acc, {k: v.clone() for k, v in model.state_dict().items()}
print(f"epoch {epoch:3d} | loss {loss.item():.3f} | val_acc {va_acc:.3f} | val_f1 {va_f1:.3f}")
if best_state: model.load_state_dict(best_state)
te_acc, te_f1 = evaluate(test_mask)
print(f"\nTEST accuracy={te_acc:.3f} macro-F1={te_f1:.3f}")
グラフノードをトレーニング、バリデーション、テスト用のマスクに分割し、モデルが適切に学習・評価できるようにします。ノード特徴とグラフ構造の両方からノード表現を学習する 2 レイヤーの GraphSAGE モデルを定義します。モデルは 200 エポックにわたってトレーニングし、バリデーション精度とマクロ F1 スコアを監視しながら、最良のモデル状態を保存し、最終的にテストパフォーマンスを報告します。
埋め込みの可視化と異種グラフニューラルネットワーク(Heterogeneous GNN)の順伝播実行
Copy CodeCopiedUse a different Browser
model.eval()
with torch.no_grad():
logits, emb = model(data.x, data.edge_index, return_emb=True)
pred = logits.argmax(1).numpy()
emb2d = PCA(n_components=2).fit_transform(emb.numpy())
fig, axes = plt.subplots(1, 2, figsize=(15, 6))
for cls in range(num_classes):
m = y.numpy() == cls
axes[0].scatter(emb2d[m, 0], emb2d[m, 1], s=10, label=CLASS_NAMES[cls], alpha=0.7)
axes[0].set_title("GraphSAGE node embeddings (PCA), coloured by TRUE class")
axes[0].legend(fontsize=8); axes[0].set_xticks([]); axes[0].set_yticks([])
plot_gdf = nodes.copy(); plot_gdf["pred"] = pred
plot_gdf["pred_name"] = [CLASS_NAMES[p] for p in pred]
plot_gdf.plot(ax=axes[1], column="pred_name", legend=True, markersize=12, cmap="tab10")
axes[1].set_title("Predicted urban function (mapped back to geography)")
axes[1].set_axis_off()
try:
import contextily as ctx
ctx.add_basemap(axes[1], crs=plot_gdf.crs, source=ctx.providers.CartoDB.Positron)
except Exception:
pass
plt.tight_layout(); plt.show()
if hetero is not None:
try:
for nt in hetero.node_types:
hetero[nt].x = hetero[nt].x.float()
class HGNN(torch.nn.Module):
def __init__(self, hid, out):
super().__init__()
self.c1 = SAGEConv((-1, -1), hid)
self.c2 = SAGEConv((-1, -1), out)
def forward(self, x, ei):
x = {k: F.relu(v) for k, v in self.c1(x, ei).items()}
return self.c2(x, ei)
hmodel = to_hetero(HGNN(32, 16), hetero.metadata(), aggr="sum")
out_dict = hmodel(hetero.x_dict, hetero.edge_index_dict)
print("\nHeterogeneous GNN output embedding shapes:")
for nt, t in out_dict.items():
print(f" {nt}: {tuple(t.shape)}")
except Exception as e:
print("Hetero GNN forward skipped:", e)
print("\n
image Done — proximity comparison, hetero construction, and a trained spatial GNN.")
訓練済みの GraphSAGE モデルを使用して、均質なグラフからノード埋め込みと予測結果を抽出します。学習された埋め込みを PCA で次元削減し、地理的な予測マップと共に可視化することで、モデルがどのように都市機能を分離しているかを理解します。また、to_hetero を用いた異種 GNN の順伝播も実行し、本チュートリアルが均質トレーニングと異種グラフの実験の両方をサポートしていることを示します。
キーポイント
city2graph は、生の OpenStreetMap POI(Points of Interest)およびストリートデータを空間グラフに変換します。
6 つの近接性グラフファミリー(KNN, Delaunay, Gabriel, RNG, EMST, Waxman)は、同じ POI を異なる方法で接続します。
合成されたクラスタリングフォールバック機能により、OSM へのアクセスがなくてもワークフローを実行可能に保ちます。
2 層の GraphSAGE モデルが、空間構造から都市機能カテゴリを予測します。
to_hetero を通じて、均質トレーニングと異種グラフの実験の両方をサポートするパイプラインです。
結論
結論として、私たちは生きた都市データをグラフベースの学習と可視化へ変換する完全な空間 GNN パイプラインを完成させました。いくつかの近接性グラフ手法を比較し、異種多層グラフを構築し、均質な GraphSAGE クラスファを訓練し、学習された埋め込み表現と地理的予測を検証しました。これにより、POI 間の空間関係をグラフ構造として表現し、都市機能を予測するためにどのように活用できるかという実践的な理解が得られました。また、city2graph、GeoPandas、OSMnx、PyTorch Geometric が Colab に親和的な環境で高度な地理空間機械学習実験を支援するためにどのように連携するかを示すものでもあります。
ノートブック付きの完全なコードはこちらでご覧ください。Twitter で私たちをフォローすることもお気軽にどうぞ。また、150k 人以上の ML サブレッドに参加し、ニュースレターを購読することを忘れないでください。待ってください!Telegram をご利用ですか?今なら Telegram でも私たちに参加できます。
GitHub リポジトリや Hugging Face ページ、製品リリース、ウェビナーなどのプロモーションのためにパートナーシップをご検討の場合は、ぜひご連絡ください。
本記事「city2graph、OSMnx、PyTorch Geometric を用いた都市機能推論のための空間グラフニューラルネットワークのコーディング実装」は、MarkTechPost で最初に公開されました。
原文を表示
In this tutorial, we build an end-to-end spatial graph learning pipeline using city2graph. We start by collecting real urban POI data and street network information from OpenStreetMap, with a synthetic fallback to ensure the workflow remains reliable. We then engineer spatial features, construct multiple proximity graph families, and compare how different graph-building strategies represent the same urban environment. After that, we create both heterogeneous and homogeneous graph structures, convert them into PyTorch Geometric format, and train a GraphSAGE model to predict POI categories from spatial structure. Through this process, we integrate geospatial data processing, graph construction, and GNN-based urban function inference into a single practical workflow.
Installing city2graph and Importing Geospatial and Graph Learning Libraries
Copy CodeCopiedUse a different Browser
!pip -q install "city2graph[cpu]" osmnx contextily scikit-learn 2>/dev/null
import warnings, numpy as np, pandas as pd, geopandas as gpd
warnings.filterwarnings("ignore")
from shapely.geometry import Point
import matplotlib.pyplot as plt
import city2graph as c2g
print("city2graph version:", getattr(c2g, "__version__", "unknown"))
print("PyTorch / PyG available:", c2g.is_torch_available())
import torch
import torch.nn.functional as F
from torch_geometric.nn import SAGEConv, to_hetero
from torch_geometric.utils import to_undirected
from sklearn.preprocessing import StandardScaler
from sklearn.neighbors import NearestNeighbors
from sklearn.metrics import accuracy_score, f1_score
from sklearn.decomposition import PCA
SEED = 42
np.random.seed(SEED); torch.manual_seed(SEED)
We begin by installing the required libraries and importing the geospatial, graph learning, and machine learning tools used throughout the tutorial. We verify that city2graph and PyTorch Geometric are available so the rest of the workflow can run properly. We also set a fixed random seed to make the graph construction, training split, and model results more reproducible.
Collecting OpenStreetMap POI Data with a Synthetic Fallback
Copy CodeCopiedUse a different Browser
CENTER = (35.6595, 139.7005)
DIST_M = 1100
TAG_QUERIES = {
"food": {"amenity": ["restaurant", "cafe", "fast_food", "bar", "pub"]},
"retail": {"shop": True},
"education": {"amenity": ["school", "university", "college", "kindergarten", "library"]},
"health": {"amenity": ["hospital", "clinic", "pharmacy", "doctors", "dentist"]},
}
def to_points(gdf):
g = gdf.copy()
g["geometry"] = g.geometry.representative_point()
return g
poi_gdf, segments_gdf = None, None
try:
import osmnx as ox
ox.settings.use_cache = True
ox.settings.log_console = False
frames = []
for label, tags in TAG_QUERIES.items():
try:
f = ox.features_from_point(CENTER, tags=tags, dist=DIST_M)
f = f[f.geometry.notna()]
if len(f):
f = to_points(f)[["geometry"]].copy()
f["category"] = label
frames.append(f)
except Exception as e:
print(f" (skip {label}: {e})")
if not frames:
raise RuntimeError("No POIs returned from Overpass.")
poi_gdf = gpd.GeoDataFrame(pd.concat(frames, ignore_index=True), crs="EPSG:4326")
G = ox.graph_from_point(CENTER, dist=DIST_M, network_type="walk")
segments_gdf = ox.graph_to_gdfs(G, nodes=False, edges=True).reset_index(drop=True)[["geometry"]]
print(f"OSM acquisition OK -> {len(poi_gdf)} POIs, {len(segments_gdf)} street segments")
except Exception as e:
print(f"OSM unavailable ({e}) -> generating synthetic clustered POIs.")
rng = np.random.default_rng(SEED)
cats = list(TAG_QUERIES.keys())
centers = rng.uniform(-0.01, 0.01, size=(8, 2)) + np.array(CENTER[::-1])
rows = []
for ci, c in enumerate(centers):
dom = cats[ci % len(cats)]
n = rng.integers(40, 90)
pts = c + rng.normal(0, 0.0016, size=(n, 2))
for (lon, lat) in pts:
cat = dom if rng.random() < 0.75 else rng.choice(cats)
rows.append({"geometry": Point(lon, lat), "category": cat})
poi_gdf = gpd.GeoDataFrame(rows, crs="EPSG:4326")
segments_gdf = None
print(f"Synthetic dataset -> {len(poi_gdf)} POIs")
if len(poi_gdf) > 700:
poi_gdf = poi_gdf.sample(700, random_state=SEED).reset_index(drop=True)
metric_crs = poi_gdf.estimate_utm_crs()
poi_gdf = poi_gdf.to_crs(metric_crs).reset_index(drop=True)
if segments_gdf is not None:
segments_gdf = segments_gdf.to_crs(metric_crs)
print("Class balance:\n", poi_gdf["category"].value_counts())
We collect real POI data from OpenStreetMap around Shibuya, Tokyo, and group the locations into broad urban function categories such as food, retail, education, and health. We also download the walkable street network so that the POIs can later be connected with urban-form features. If the OSM request fails, we generate a synthetic clustered dataset, which keeps the tutorial runnable even when online data access is unavailable.
Engineering Spatial Features and Building Proximity Graph Families
Copy CodeCopiedUse a different Browser
poi_gdf["cx"] = poi_gdf.geometry.x
poi_gdf["cy"] = poi_gdf.geometry.y
coords = poi_gdf[["cx", "cy"]].to_numpy()
nn = NearestNeighbors(radius=150.0).fit(coords)
poi_gdf["local_density"] = [len(idx) - 1 for idx in nn.radius_neighbors(coords, return_distance=False)]
if segments_gdf is not None and len(segments_gdf):
try:
joined = gpd.sjoin_nearest(poi_gdf[["geometry"]], segments_gdf[["geometry"]],
distance_col="dist_street")
poi_gdf["dist_street"] = joined.groupby(level=0)["dist_street"].min().reindex(poi_gdf.index).fillna(0.0)
except Exception:
poi_gdf["dist_street"] = 0.0
else:
poi_gdf["dist_street"] = 0.0
poi_gdf["category"] = poi_gdf["category"].astype("category")
poi_gdf["label"] = poi_gdf["category"].cat.codes.astype(int)
CLASS_NAMES = list(poi_gdf["category"].cat.categories)
print("Classes:", CLASS_NAMES)
def graph_stats(name, builder):
try:
nodes, edges = builder()
deg = pd.Series(np.r_[edges.index.get_level_values(0),
edges.index.get_level_values(1)]).value_counts()
return name, len(edges), round(deg.mean(), 2), (nodes, edges)
except Exception as e:
return name, f"ERR: {e}", None, None
builders = {
"KNN (k=8)": lambda: c2g.knn_graph(poi_gdf, distance_metric="euclidean", k=8, as_nx=False),
"Delaunay": lambda: c2g.delaunay_graph(poi_gdf, as_nx=False),
"Gabriel": lambda: c2g.gabriel_graph(poi_gdf, as_nx=False),
"RNG": lambda: c2g.relative_neighborhood_graph(poi_gdf, as_nx=False),
"EMST": lambda: c2g.euclidean_minimum_spanning_tree(poi_gdf, as_nx=False),
"Waxman": lambda: c2g.waxman_graph(poi_gdf, distance_metric="euclidean", r0=150, beta=0.6),
}
print("\n--- Proximity graph comparison ---")
print(f"{'graph':<14}{'#edges':>10}{'avg_degree':>12}")
built = {}
for nm, b in builders.items():
name, ne, avgdeg, payload = graph_stats(nm, b)
print(f"{name:<14}{str(ne):>10}{str(avgdeg):>12}")
if payload: built[nm] = payload
fig, axes = plt.subplots(1, 3, figsize=(16, 5))
for ax, key in zip(axes, ["KNN (k=8)", "Delaunay", "EMST"]):
if key in built:
n_, e_ = built[key]
e_.plot(ax=ax, linewidth=0.4, color="#3b7dd8", alpha=0.6)
poi_gdf.plot(ax=ax, markersize=4, color="#d83b5c")
ax.set_title(key); ax.set_axis_off()
plt.suptitle("Spatial graph topologies on the same POI set", y=1.02)
plt.tight_layout(); plt.show()
We engineer spatial features for each POI by extracting its projected coordinates, calculating local density, and estimating distance to the nearest street segment. We then assign category labels and build several families of proximity graphs, including KNN, Delaunay, Gabriel, RNG, EMST, and Waxman. We compare their edge counts and average degrees, then visualize selected graph topologies to see how differently they connect the same set of POIs.
Constructing Heterogeneous and Homogeneous Graphs in PyTorch Geometric
Copy CodeCopiedUse a different Browser
nodes_dict = {}
for cat in CLASS_NAMES:
sub = poi_gdf[poi_gdf["category"] == cat].copy().reset_index(drop=True)
nodes_dict[cat] = sub[["geometry", "cx", "cy", "local_density"]]
try:
_, bridge_edges = c2g.bridge_nodes(nodes_dict, proximity_method="knn", k=3,
distance_metric="euclidean")
hetero = c2g.gdf_to_pyg(
nodes_dict, bridge_edges,
node_feature_cols={cat: ["cx", "cy", "local_density"] for cat in CLASS_NAMES},
)
print("\nHeteroData node types:", hetero.node_types)
print("HeteroData edge types:")
for et in hetero.edge_types:
print(f" {et}: {hetero[et].edge_index.shape[1]} edges")
except Exception as e:
hetero = None
print("Heterogeneous build skipped:", e)
nodes, edges = c2g.knn_graph(poi_gdf, distance_metric="euclidean", k=8, as_nx=False)
deg = pd.Series(np.r_[edges.index.get_level_values(0),
edges.index.get_level_values(1)]).value_counts()
nodes["degree"] = deg.reindex(nodes.index).fillna(0).astype(float)
for col in ["cx", "cy", "local_density", "dist_street", "label"]:
if col not in nodes.columns:
nodes[col] = poi_gdf.loc[nodes.index, col].values
FEATS = ["cx", "cy", "local_density", "dist_street", "degree"]
nodes[FEATS] = StandardScaler().fit_transform(nodes[FEATS].astype(float))
data = c2g.gdf_to_pyg(nodes, edges, node_feature_cols=FEATS, node_label_cols=["label"])
data.edge_index = to_undirected(data.edge_index)
data.x = data.x.float()
y = data.y.long().view(-1)
N, num_classes = data.num_nodes, int(y.max()) + 1
print(f"\nHomogeneous Data: {N} nodes, {data.edge_index.shape[1]} directed-edges, "
f"{data.x.shape[1]} features, {num_classes} classes")
We construct a heterogeneous multi-layer graph by separating POIs into node types based on their urban function categories. We then use bridge edges to connect nearby nodes across different layers and convert the result into PyTorch Geometric HeteroData format. After that, we build a homogeneous KNN graph, attach degree and engineered features, standardize them, and prepare the final PyG Data object for GraphSAGE training.
Defining and Training a GraphSAGE Model for POI Classification
Copy CodeCopiedUse a different Browser
perm = torch.randperm(N, generator=torch.Generator().manual_seed(SEED))
n_tr, n_va = int(0.6 * N), int(0.2 * N)
train_mask = torch.zeros(N, dtype=torch.bool); train_mask[perm[:n_tr]] = True
val_mask = torch.zeros(N, dtype=torch.bool); val_mask[perm[n_tr:n_tr + n_va]] = True
test_mask = torch.zeros(N, dtype=torch.bool); test_mask[perm[n_tr + n_va:]] = True
class GraphSAGE(torch.nn.Module):
def __init__(self, in_dim, hidden, out_dim, p=0.3):
super().__init__()
self.c1 = SAGEConv(in_dim, hidden)
self.c2 = SAGEConv(hidden, hidden)
self.lin = torch.nn.Linear(hidden, out_dim)
self.p = p
def forward(self, x, ei, return_emb=False):
h = F.relu(self.c1(x, ei))
h = F.dropout(h, p=self.p, training=self.training)
h = F.relu(self.c2(h, ei))
out = self.lin(h)
return (out, h) if return_emb else out
model = GraphSAGE(data.x.shape[1], 64, num_classes)
opt = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=5e-4)
def evaluate(mask):
model.eval()
with torch.no_grad():
pred = model(data.x, data.edge_index).argmax(1)
yt, yp = y[mask].numpy(), pred[mask].numpy()
return accuracy_score(yt, yp), f1_score(yt, yp, average="macro")
print("\n--- Training GraphSAGE ---")
best_val, best_state = 0.0, None
for epoch in range(1, 201):
model.train(); opt.zero_grad()
out = model(data.x, data.edge_index)
loss = F.cross_entropy(out[train_mask], y[train_mask])
loss.backward(); opt.step()
if epoch % 20 == 0:
va_acc, va_f1 = evaluate(val_mask)
if va_acc > best_val:
best_val, best_state = va_acc, {k: v.clone() for k, v in model.state_dict().items()}
print(f"epoch {epoch:3d} | loss {loss.item():.3f} | val_acc {va_acc:.3f} | val_f1 {va_f1:.3f}")
if best_state: model.load_state_dict(best_state)
te_acc, te_f1 = evaluate(test_mask)
print(f"\nTEST accuracy={te_acc:.3f} macro-F1={te_f1:.3f}")
We split the graph nodes into training, validation, and test masks so the model can learn and be evaluated properly. We define a two-layer GraphSAGE model that learns node representations from both node features and graph structure. We train the model for 200 epochs, monitor validation accuracy and macro-F1, save the best model state, and finally report test performance.
Visualizing Embeddings and Running a Heterogeneous GNN Forward Pass
Copy CodeCopiedUse a different Browser
model.eval()
with torch.no_grad():
logits, emb = model(data.x, data.edge_index, return_emb=True)
pred = logits.argmax(1).numpy()
emb2d = PCA(n_components=2).fit_transform(emb.numpy())
fig, axes = plt.subplots(1, 2, figsize=(15, 6))
for cls in range(num_classes):
m = y.numpy() == cls
axes[0].scatter(emb2d[m, 0], emb2d[m, 1], s=10, label=CLASS_NAMES[cls], alpha=0.7)
axes[0].set_title("GraphSAGE node embeddings (PCA), coloured by TRUE class")
axes[0].legend(fontsize=8); axes[0].set_xticks([]); axes[0].set_yticks([])
plot_gdf = nodes.copy(); plot_gdf["pred"] = pred
plot_gdf["pred_name"] = [CLASS_NAMES[p] for p in pred]
plot_gdf.plot(ax=axes[1], column="pred_name", legend=True, markersize=12, cmap="tab10")
axes[1].set_title("Predicted urban function (mapped back to geography)")
axes[1].set_axis_off()
try:
import contextily as ctx
ctx.add_basemap(axes[1], crs=plot_gdf.crs, source=ctx.providers.CartoDB.Positron)
except Exception:
pass
plt.tight_layout(); plt.show()
if hetero is not None:
try:
for nt in hetero.node_types:
hetero[nt].x = hetero[nt].x.float()
class HGNN(torch.nn.Module):
def __init__(self, hid, out):
super().__init__()
self.c1 = SAGEConv((-1, -1), hid)
self.c2 = SAGEConv((-1, -1), out)
def forward(self, x, ei):
x = {k: F.relu(v) for k, v in self.c1(x, ei).items()}
return self.c2(x, ei)
hmodel = to_hetero(HGNN(32, 16), hetero.metadata(), aggr="sum")
out_dict = hmodel(hetero.x_dict, hetero.edge_index_dict)
print("\nHeterogeneous GNN output embedding shapes:")
for nt, t in out_dict.items():
print(f" {nt}: {tuple(t.shape)}")
except Exception as e:
print("Hetero GNN forward skipped:", e)
print("\n
image Done — proximity comparison, hetero construction, and a trained spatial GNN.")
We use the trained GraphSAGE model to extract node embeddings and predictions from the homogeneous graph. We reduce the learned embeddings with PCA and visualize them alongside a geographic prediction map to understand how the model separates urban functions. We also run a heterogeneous GNN forward pass with to_hetero, showing that the tutorial supports both homogeneous training and heterogeneous graph experimentation.
Key Takeaways
city2graph turns raw OpenStreetMap POI and street data into spatial graphs.
Six proximity graph families (KNN, Delaunay, Gabriel, RNG, EMST, Waxman) connect the same POIs differently.
A synthetic clustered fallback keeps the workflow runnable without OSM access.
A two-layer GraphSAGE model predicts urban function categories from spatial structure.
The pipeline supports both homogeneous training and heterogeneous graph experimentation via to_hetero.
Conclusion
In conclusion, we completed a full spatial GNN pipeline that transforms raw city data into graph-based learning and visualization. We compared several proximity graph methods, built a heterogeneous multi-layer graph, trained a homogeneous GraphSAGE classifier, and inspected the learned embeddings and geographic predictions. It gives us a practical understanding of how spatial relationships among POIs can be represented as graph structures and used to predict urban functions. It also shows how city2graph, GeoPandas, OSMnx, and PyTorch Geometric work together to support advanced geospatial machine learning experiments in a Colab-friendly setup.
Check out the Full Codes with Notebook here. Also, feel free to follow us on Twitter and don’t forget to join our 150k+ML SubReddit and Subscribe to our Newsletter. Wait! are you on telegram? now you can join us on telegram as well.
Need to partner with us for promoting your GitHub Repo OR Hugging Face Page OR Product Release OR Webinar etc.? Connect with us
The post A Coding Implementation on Spatial Graph Neural Networks for Urban Function Inference Using city2graph, OSMnx, and PyTorch Geometric appeared first on MarkTechPost.
関連記事
スタートアップが小売業者の製品をリアルタイムで追跡する支援を開始
MIT ML News が紹介したスタートアップは、在庫管理に時間を要する課題に対し、製品の位置情報をリアルタイムで把握できる技術を提供し、従業員が在庫を追跡する負担を軽減する。
プレゼンテーション:グラフニューラルネットワークによるプラットフォームエンゲージメントの再構築
Mariia Bulycheva氏が、Zalandoのランディングページにおいて、従来の深層学習からグラフニューラルネットワーク(GNN)への移行について説明した。ユーザーログを異種グラフに変換する複雑さや「メッセージパッシング」の学習プロセス、グラフデータリークの技術的課題を解説し、ハイブリッドアーキテクチャによる推論遅延の解決と文脈埋め込みの実現を共有した。
ヤンデックスがプロトコルバッファ用ゼロコピーワイヤーフォーマット「YaFF」をオープンソース化
ヤンデックスは、プロトコルバッファの物理メモリレイアウトを変更するゼロコピー形式「YaFF」を Apache 2.0 ライセンスで公開した。同社ベンチマークではホットデータ読み込み速度が FlatBuffers の約 3.8 倍に達し、広告推薦システムでも 10〜20% の性能向上を確認している。
今日のまとめ
AI日報で今日の重要ニュースをまとめ読み