You are looking for Dijkstra's shortest path algorithm. Coincidentally, just this week I implemented this in Python. The implementation is a little heavy (since my goal was to show how to employ unittest
to perform unit testing), but it works nevertheless. Dijkstra's algorithm works on directed graphs, but you can convert undirected graphs to directed graphs by creating A -> B
and B -> A
directed edges for each A -- B
undirected edge, just as you have.
from collections import defaultdict
from itertools import product, chain
class DijkstraNegativeWeightException(Exception):
pass
class DijkstraDisconnectedGraphException(Exception):
pass
class Dijkstra:
def __init__(self, graph_data, source):
self._graph = defaultdict(dict, graph_data)
self._check_edge_weights()
self.reset_source(source)
self._solved = False
@property
def edges(self):
return [(i, j) for i in self._graph for j in self._graph[i]]
@property
def nodes(self):
return list(set(chain(*self.edges)))
def _check_source_in_nodes(self):
msg = 'Source node \'{}\' not in graph.'
if self._source not in self.nodes:
raise ValueError(msg.format(self._source))
def _check_edge_weights(self):
msg = 'Graph has negative weights, but weights must be non-negative.'
if any(self._graph[i][j] < 0 for (i, j) in self.edges):
raise DijkstraNegativeWeightException(msg)
def reset_source(self, source):
self._source = source
self._check_source_in_nodes()
self._solution_x = []
self._solution_z = {source: 0}
self._visited = set([source])
self._unvisited = set()
for key, val in self._graph.items():
self._unvisited.add(key)
self._unvisited.update(val.keys())
self._unvisited.difference_update(self._visited)
self._solved = False
def run(self):
weight_candidates = self._graph[self._source].copy()
node_candidates = dict(product(weight_candidates.keys(),
(self._source,)))
while node_candidates:
j = min(weight_candidates, key=weight_candidates.get)
weight_best, i = weight_candidates.pop(j), node_candidates.pop(j)
for k in self._graph[j].keys() & self._unvisited:
weight_next = self._graph[j][k]
if (k not in node_candidates
or weight_candidates[k] > weight_best + weight_next):
weight_candidates[k] = weight_best + weight_next
node_candidates[k] = j
self._solution_x.append((i, j))
self._solution_z[j] = weight_best
self._visited |= {j}
self._unvisited -= {j}
self._solved = True
def path_to(self, target):
if self._source in self._visited and target in self._unvisited:
msg = 'No path from {} to {}; graph is disconnected.'
msg = msg.format(self._visited, self._unvisited)
raise DijkstraDisconnectedGraphException(msg)
solution = self._solution_x.copy()
path = []
while solution:
i, j = solution.pop()
if j == target:
path.append((i, j))
break
while solution:
i_prev, _, i, j = *path[-1], *solution.pop()
if j == i_prev:
path.append((i, j))
if i == self._source:
break
return list(reversed(path)), self._solution_z[target]
def visualize(self, source=None, target=None):
import networkx as nx
import matplotlib.pyplot as plt
if (source is not None and source != self._source):
self.reset_source(source)
if not self._solved:
self.run()
if target is not None:
path, _ = self.path_to(target=target)
else:
path = self._solution_x
edgelist = self.edges
nodelist = self.nodes
nxgraph = nx.DiGraph()
nxgraph.add_edges_from(edgelist)
weights = {(i, j): self._graph[i][j] for (i, j) in edgelist}
found = list(chain(*path))
ncolors = ['springgreen' if node in found else 'lightcoral'
for node in nodelist]
ecolors = ['dodgerblue' if edge in path else 'black'
for edge in edgelist]
sizes = [3 if edge in path else 1 for edge in edgelist]
pos = nx.kamada_kawai_layout(nxgraph)
nx.draw_networkx(nxgraph, pos=pos,
nodelist=nodelist, node_color=ncolors,
edgelist=edgelist, edge_color=ecolors, width=sizes)
nx.draw_networkx_edge_labels(nxgraph, pos=pos, edge_labels=weights)
plt.axis('equal')
plt.show()
For your graph (with weights added),
graph_data = {
"A": {"B": 1, "C": 1, "E": 1},
"B": {"A": 1, "D": 1, "E": 1},
"C": {"A": 1, "F": 1, "G": 1},
"D": {"B": 1},
"E": {"A": 1, "B": 1,"D": 1},
"F": {"C": 1},
"G": {"C": 1}
}
algo = Dijkstra(graph_data, source='A')
algo.run() # creates a shortest-path tree (SPT) to all reachable nodes
x, z = algo.path_to('G')
print(x)
print(z)
algo.visualize(source='A', target='G')
Output (the traversed edges and the combined weight of those edges).
[('A', 'C'), ('C', 'G')]
2
The visualization is rough-looking for undirected graphs, but it still gives you the gist of the solution.
