Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
labmlai
GitHub Repository: labmlai/annotated_deep_learning_paper_implementations
Path: blob/master/utils/diagrams.py
4918 views
1
import shutil
2
from pathlib import Path
3
from typing import List
4
from xml.dom import minidom
5
import os
6
7
from labml import monit
8
9
HOME = Path('.').absolute()
10
11
STYLES = """
12
.black-stroke {
13
stroke: #aaa;
14
}
15
16
rect.black-stroke {
17
stroke: #444;
18
}
19
20
.black-fill {
21
fill: #ddd;
22
}
23
24
.white-fill {
25
fill: #333;
26
}
27
28
.blue-stroke {
29
stroke: #5b8fab;
30
}
31
32
.blue-fill {
33
fill: #356782;
34
}
35
36
.yellow-stroke {
37
stroke: #bbab52;
38
}
39
40
.yellow-fill {
41
fill: #a7942b;
42
}
43
44
.grey-stroke {
45
stroke: #484d5a;
46
}
47
48
.grey-fill {
49
fill: #2e323c;
50
}
51
52
.red-stroke {
53
stroke: #bb3232;
54
}
55
56
.red-fill {
57
fill: #901c1c;
58
}
59
60
.orange-stroke {
61
stroke: #a5753f;
62
}
63
64
.orange-fill {
65
fill: #82531e;
66
}
67
68
.purple-stroke {
69
stroke: #a556a5;
70
}
71
72
.purple-fill {
73
fill: #8a308a;
74
}
75
76
.green-stroke {
77
stroke: #80cc92;
78
}
79
80
.green-fill {
81
fill: #499e5d;
82
}
83
84
switch foreignObject div div div {
85
color: #ddd !important;
86
}
87
88
switch foreignObject div div div span {
89
color: #ddd !important;
90
}
91
92
.has-background {
93
background-color: #1d2127 !important;
94
}
95
"""
96
97
STROKES = {
98
'#000000': 'black',
99
'#6c8ebf': 'blue',
100
'#d6b656': 'yellow',
101
'#666666': 'grey',
102
'#b85450': 'red',
103
'#d79b00': 'orange',
104
'#9673a6': 'purple',
105
'#82b366': 'green',
106
}
107
108
FILLS = {
109
'#000000': 'black',
110
'#ffffff': 'white',
111
'#dae8fc': 'blue',
112
'#fff2cc': 'yellow',
113
'#f5f5f5': 'grey',
114
'#f8cecc': 'red',
115
'#ffe6cc': 'orange',
116
'#e1d5e7': 'purple',
117
'#d5e8d4': 'green',
118
}
119
120
121
def clear_switches(doc: minidom.Document):
122
switches = doc.getElementsByTagName('switch')
123
for s in switches:
124
children = s.childNodes
125
assert len(children) == 2
126
if children[0].tagName == 'g' and 'requiredFeatures' in children[0].attributes:
127
s.parentNode.removeChild(s)
128
s.unlink()
129
continue
130
assert children[0].tagName == 'foreignObject'
131
assert children[1].tagName == 'text'
132
c = children[1]
133
s.removeChild(c)
134
s.parentNode.insertBefore(c, s)
135
s.parentNode.removeChild(s)
136
137
138
def add_class(node: minidom.Node, class_name: str):
139
if 'class' not in node.attributes:
140
node.attributes['class'] = class_name
141
return
142
143
node.attributes['class'] = node.attributes['class'].value + f' {class_name}'
144
145
146
def add_bg_classes(nodes: List[minidom.Node]):
147
for node in nodes:
148
if 'style' in node.attributes:
149
s = node.attributes['style'].value
150
if s.count('background-color'):
151
add_class(node, 'has-background')
152
153
154
def add_stroke_classes(nodes: List[minidom.Node]):
155
for node in nodes:
156
if 'stroke' in node.attributes:
157
stroke = node.attributes['stroke'].value
158
if stroke not in STROKES:
159
continue
160
161
node.removeAttribute('stroke')
162
add_class(node, f'{STROKES[stroke]}-stroke')
163
164
165
def add_fill_classes(nodes: List[minidom.Node]):
166
for node in nodes:
167
if 'fill' in node.attributes:
168
fill = node.attributes['fill'].value
169
if fill not in FILLS:
170
continue
171
172
node.removeAttribute('fill')
173
add_class(node, f'{FILLS[fill]}-fill')
174
175
176
def add_classes(doc: minidom.Document):
177
paths = doc.getElementsByTagName('path')
178
add_stroke_classes(paths)
179
add_fill_classes(paths)
180
181
rects = doc.getElementsByTagName('rect')
182
add_stroke_classes(rects)
183
add_fill_classes(rects)
184
185
ellipse = doc.getElementsByTagName('ellipse')
186
add_stroke_classes(ellipse)
187
add_fill_classes(ellipse)
188
189
text = doc.getElementsByTagName('text')
190
add_fill_classes(text)
191
192
div = doc.getElementsByTagName('div')
193
add_bg_classes(div)
194
195
span = doc.getElementsByTagName('span')
196
add_bg_classes(span)
197
198
199
def parse(source: Path, dest: Path):
200
doc: minidom.Document = minidom.parse(str(source))
201
202
svg = doc.getElementsByTagName('svg')
203
204
assert len(svg) == 1
205
svg = svg[0]
206
207
if 'content' in svg.attributes:
208
svg.removeAttribute('content')
209
# svg.attributes['height'] = str(int(svg.attributes['height'].value[:-2]) + 30) + 'px'
210
# svg.attributes['width'] = str(int(svg.attributes['width'].value[:-2]) + 30) + 'px'
211
212
view_box = svg.attributes['viewBox'].value.split(' ')
213
view_box = [float(v) for v in view_box]
214
view_box[0] -= 10
215
view_box[1] -= 10
216
view_box[2] += 20
217
view_box[3] += 20
218
svg.attributes['viewBox'] = ' '.join([str(v) for v in view_box])
219
220
svg.attributes['style'] = 'background: #1d2127;' # padding: 10px;'
221
222
# clear_switches(doc)
223
224
style = doc.createElement('style')
225
style.appendChild(doc.createTextNode(STYLES))
226
svg.insertBefore(style, svg.childNodes[0])
227
add_classes(doc)
228
229
with open(str(dest), 'w') as f:
230
doc.writexml(f)
231
232
233
def recurse(path: Path):
234
files = []
235
if path.is_file():
236
files.append(path)
237
return files
238
239
for f in path.iterdir():
240
files += recurse(f)
241
242
return files
243
244
245
def main():
246
diagrams_path = HOME / 'diagrams'
247
docs_path = HOME / 'docs'
248
249
# For first invocation
250
os.makedirs(diagrams_path, exist_ok=True)
251
252
for p in recurse(diagrams_path):
253
source_path = p
254
p = p.relative_to(diagrams_path)
255
dest_path = docs_path / p
256
if not dest_path.parent.exists():
257
dest_path.parent.mkdir(parents=True)
258
259
with monit.section(str(p)):
260
if source_path.suffix == '.svg':
261
parse(source_path, dest_path)
262
else:
263
shutil.copy(str(source_path), str(dest_path))
264
265
266
if __name__ == '__main__':
267
main()
268
269