Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
eclipse
GitHub Repository: eclipse/sumo
Path: blob/main/tools/net/patchRailConflicts.py
193758 views
1
#!/usr/bin/env python
2
# Eclipse SUMO, Simulation of Urban MObility; see https://eclipse.dev/sumo
3
# Copyright (C) 2007-2026 German Aerospace Center (DLR) and others.
4
# This program and the accompanying materials are made available under the
5
# terms of the Eclipse Public License 2.0 which is available at
6
# https://www.eclipse.org/legal/epl-2.0/
7
# This Source Code may also be made available under the following Secondary
8
# Licenses when the conditions for such availability set forth in the Eclipse
9
# Public License 2.0 are satisfied: GNU General Public License, version 2
10
# or later which is available at
11
# https://www.gnu.org/licenses/old-licenses/gpl-2.0-standalone.html
12
# SPDX-License-Identifier: EPL-2.0 OR GPL-2.0-or-later
13
14
# @file patchRailConflicts.py
15
# @author Jakob Erdmann
16
# @date 2026-01-17
17
18
"""
19
Identifies all rail crossings (two one-directional tracks crossing each other)
20
and converts the junction type to the given value
21
"""
22
import os
23
import sys
24
import subprocess
25
from collections import defaultdict
26
SUMO_HOME = os.environ.get('SUMO_HOME',
27
os.path.join(os.path.dirname(os.path.abspath(__file__)), '..', '..'))
28
sys.path.append(os.path.join(SUMO_HOME, 'tools'))
29
import sumolib # noqa
30
from sumolib.options import ArgumentParser # noqa
31
32
try:
33
sys.stdout.reconfigure(encoding='utf-8')
34
except: # noqa
35
pass
36
37
NETCONVERT = sumolib.checkBinary('netconvert')
38
39
40
def get_options():
41
ap = ArgumentParser()
42
ap.add_argument("-n", "--net-file", category="input", dest="netfile", required=True, type=ap.net_file,
43
help="the network to read lane and edge permissions")
44
ap.add_argument("-o", "--output-file", category="output", dest="output", required=True, type=ap.file,
45
help="output network file")
46
ap.add_argument("--vclass", default="tram",
47
help="the vehicle class that restricts which junctions are considered")
48
ap.add_argument("-t", "--junction-type", dest="junctionType", default="rail_signal",
49
help="the new junction type for rail/rail crossings")
50
ap.add_argument("-k", "--keep-junction-type", dest="keepJunctionType",
51
default="traffic_light,traffic_light_unregulated,rail_signal",
52
help="the new junction type for rail/rail crossings")
53
ap.add_argument("-e", "--end-offset", dest="endOffset", type=float, default=0,
54
help="move back the stop line from the crossing")
55
ap.add_argument("-j", "--join-distance", dest="joinDist", type=float, default=200,
56
help="The distance for joining clusters which are guarded only from the outside")
57
ap.add_argument("-v", "--verbose", action="store_true", default=False,
58
help="tell me what you are doing")
59
options = ap.parse_args()
60
61
outputBase = options.output
62
if outputBase[-8:] == ".net.xml":
63
outputBase = outputBase[:-8]
64
elif outputBase[-11:] == ".net.xml.gz":
65
outputBase = outputBase[:-11]
66
67
options.vclass = [options.vclass]
68
options.keepJunctionType = set(options.keepJunctionType.split(','))
69
options.output_nodes = outputBase + ".nod.xml"
70
options.output_edges = outputBase + ".edg.xml"
71
options.output_conns = outputBase + ".con.xml"
72
return options
73
74
75
def getDownstream(node, joinDist, nodes):
76
result = set()
77
seen = set()
78
check = [(node, 0)]
79
while check:
80
n, dist = check.pop()
81
seen.add(n)
82
if n != node and n in nodes:
83
result.add(n)
84
for e in n.getOutgoing():
85
dist2 = dist + e.getLength()
86
if dist2 < joinDist and e.getToNode() not in seen:
87
check.append((e.getToNode(), dist2))
88
return result
89
90
91
def findClusters(joinDist, nodes):
92
downstreamNodes = {n: getDownstream(n, joinDist, nodes) for n in nodes}
93
check = list(downstreamNodes.keys())
94
while check:
95
n = check.pop(0)
96
if n not in downstreamNodes:
97
# already merged
98
continue
99
down = downstreamNodes[n]
100
down2 = set(down)
101
for d in down:
102
if d == n:
103
continue
104
if d in downstreamNodes:
105
down2.update(downstreamNodes[d])
106
del downstreamNodes[d]
107
if down != down2:
108
downstreamNodes[n] = down2
109
check.append(n)
110
111
result = []
112
unconnectedNodes = []
113
for node, cluster in downstreamNodes.items():
114
cluster.add(node)
115
if len(cluster) > 1:
116
incoming = sum([n.getIncoming() for n in cluster], start=[])
117
result.append((list(cluster), [e for e in incoming if e.getFromNode() not in cluster]))
118
else:
119
unconnectedNodes.append(node)
120
121
result.append((unconnectedNodes, sum([n.getIncoming() for n in unconnectedNodes], start=[])))
122
return result
123
124
125
def main(options):
126
net = sumolib.net.readNet(options.netfile)
127
crossingNodes = []
128
skipped = defaultdict(lambda: 0)
129
130
for node in net.getNodes():
131
if any([e.getPermissions() != options.vclass for e in node.getIncoming() + node.getOutgoing()]):
132
continue
133
if not node.hasFoes():
134
continue
135
nIn = len(node.getIncoming())
136
nOut = len(node.getOutgoing())
137
nBidi = len([e for e in node.getIncoming() + node.getOutgoing() if e.getBidi() is not None]) / 2
138
if nIn >= 2 and nOut >= 1 and ((nIn + nOut - nBidi) >= 3):
139
if node.getType() in options.keepJunctionType:
140
skipped[node.getType()] += 1
141
continue
142
crossingNodes.append(node)
143
144
clusters = findClusters(options.joinDist, crossingNodes)
145
146
outf_nod = open(options.output_nodes, 'w')
147
outf_edg = open(options.output_edges, 'w')
148
outf_con = open(options.output_conns, 'w')
149
sumolib.writeXMLHeader(outf_nod, "$Id$", "nodes", options=options)
150
sumolib.writeXMLHeader(outf_edg, "$Id$", "edges", schemaPath="edgediff_file.xsd", options=options)
151
sumolib.writeXMLHeader(outf_con, "$Id$", "connections", options=options)
152
153
for outerNodes, incomingEdges in clusters:
154
allIncoming = set()
155
for node in outerNodes:
156
allIncoming.update(node.getIncoming())
157
uncontrolled = allIncoming.difference(incomingEdges)
158
159
controlledNodes = []
160
for n in outerNodes:
161
if set(n.getIncoming()).difference(uncontrolled):
162
controlledNodes.append(n.getID())
163
164
for nodeID in sorted(controlledNodes):
165
outf_nod.write(' <node id="%s" type="%s"/>\n' % (nodeID, options.junctionType))
166
outf_nod.write('\n')
167
168
if options.endOffset > 0:
169
for edgeID in sorted([e.getID() for e in incomingEdges]):
170
outf_edg.write(' <edge id="%s" endOffset="%s"/>\n' % (edgeID, options.endOffset))
171
outf_edg.write('\n')
172
173
if uncontrolled:
174
for edge in uncontrolled:
175
for con in sum(edge.getOutgoing().values(), start=[]):
176
outf_con.write(' <connection from="%s" fromLane="%s" to="%s" toLane="%s" uncontrolled="1"/>\n' % ( # noqa
177
con.getFrom().getID(), con.getFromLane().getIndex(),
178
con.getTo().getID(), con.getToLane().getIndex()))
179
outf_con.write('\n')
180
181
outf_nod.write("</nodes>\n")
182
outf_edg.write("</edges>\n")
183
outf_con.write("</connections>\n")
184
outf_nod.close()
185
outf_edg.close()
186
outf_con.close()
187
188
if options.verbose:
189
if skipped:
190
for t, n in skipped.items():
191
print("Skipped %s crossings of type %s" % (n, t))
192
193
print("Building new net")
194
sys.stderr.flush()
195
subprocess.call([NETCONVERT,
196
'-s', options.netfile,
197
'-n', options.output_nodes,
198
'-e', options.output_edges,
199
'-x', options.output_conns,
200
'-o', options.output,
201
], stdout=subprocess.DEVNULL)
202
203
204
if __name__ == "__main__":
205
main(get_options())
206
207