Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
signalapp
GitHub Repository: signalapp/Signal-iOS
Path: blob/main/Scripts/sds_codegen/sds_generate.py
1 views
1
#!/usr/bin/env python3
2
3
import os
4
import subprocess
5
import argparse
6
import re
7
import json
8
import sds_common
9
from sds_common import fail
10
import random
11
12
# TODO: We should probably generate a class that knows how to set up
13
# the database. It would:
14
#
15
# * Create all tables (or apply database schema).
16
# * Register renamed classes.
17
# [NSKeyedUnarchiver setClass:[OWSUserProfile class] forClassName:[OWSUserProfile collection]];
18
# [NSKeyedUnarchiver setClass:[OWSDatabaseMigration class] forClassName:[OWSDatabaseMigration collection]];
19
20
# We consider any subclass of TSYapDatabaseObject to be a "serializable model".
21
#
22
# We treat direct subclasses of TSYapDatabaseObject as "roots" of the model class hierarchy.
23
# Only root models do deserialization.
24
OLD_BASE_MODEL_CLASS_NAME = "TSYapDatabaseObject"
25
NEW_BASE_MODEL_CLASS_NAME = "BaseModel"
26
27
CODE_GEN_SNIPPET_MARKER_OBJC = "// --- CODE GENERATION MARKER"
28
29
# GRDB seems to encode non-primitive using JSON.
30
# GRDB chokes when decodes this JSON, due to it being a JSON "fragment".
31
# Either this is a bug in GRDB or we're using GRDB incorrectly.
32
# Until we resolve this issue, we need to encode/decode
33
# non-primitives ourselves.
34
USE_CODABLE_FOR_PRIMITIVES = False
35
USE_CODABLE_FOR_NONPRIMITIVES = False
36
37
38
def update_generated_snippet(file_path, marker, snippet):
39
# file_path = sds_common.sds_from_relative_path(relative_path)
40
if not os.path.exists(file_path):
41
fail("Missing file:", file_path)
42
43
with open(file_path, "rt") as f:
44
text = f.read()
45
46
start_index = text.find(marker)
47
end_index = text.rfind(marker)
48
if start_index < 0 or end_index < 0 or start_index >= end_index:
49
fail(f"Could not find markers ('{marker}'): {file_path}")
50
51
text = (
52
text[:start_index].strip()
53
+ "\n\n"
54
+ marker
55
+ "\n\n"
56
+ snippet
57
+ "\n\n"
58
+ marker
59
+ "\n\n"
60
+ text[end_index + len(marker) :].lstrip()
61
)
62
63
sds_common.write_text_file_if_changed(file_path, text)
64
65
66
def update_objc_snippet(file_path, snippet):
67
snippet = sds_common.clean_up_generated_objc(snippet).strip()
68
69
if len(snippet) < 1:
70
return
71
72
snippet = (
73
"// This snippet is generated by %s. Do not manually edit it, instead run `sds_codegen.sh`."
74
% (sds_common.pretty_module_path(__file__),)
75
+ "\n\n"
76
+ snippet
77
)
78
79
update_generated_snippet(file_path, CODE_GEN_SNIPPET_MARKER_OBJC, snippet)
80
81
82
# ----
83
84
global_class_map = {}
85
global_subclass_map = {}
86
global_args = None
87
88
# ----
89
90
91
def to_swift_identifier_name(identifier_name):
92
return identifier_name[0].lower() + identifier_name[1:]
93
94
95
class ParsedClass:
96
def __init__(self, json_dict):
97
self.name = json_dict.get("name")
98
self.super_class_name = json_dict.get("super_class_name")
99
self.filepath = sds_common.sds_from_relative_path(json_dict.get("filepath"))
100
self.finalize_method_name = json_dict.get("finalize_method_name")
101
self.property_map = {}
102
for property_dict in json_dict.get("properties"):
103
property = ParsedProperty(property_dict)
104
property.class_name = self.name
105
106
# TODO: We should handle all properties?
107
if property.should_ignore_property():
108
continue
109
110
self.property_map[property.name] = property
111
112
def properties(self):
113
result = []
114
for name in sorted(self.property_map.keys()):
115
result.append(self.property_map[name])
116
return result
117
118
def database_subclass_properties(self):
119
# More than one subclass of a SDS model may declare properties
120
# with the same name. This is fine, so long as they have
121
# the same type.
122
all_property_map = {}
123
subclass_property_map = {}
124
root_property_names = set()
125
126
for property in self.properties():
127
all_property_map[property.name] = property
128
root_property_names.add(property.name)
129
130
for subclass in all_descendents_of_class(self):
131
if should_ignore_class(subclass):
132
continue
133
134
for property in subclass.properties():
135
136
duplicate_property = all_property_map.get(property.name)
137
if duplicate_property is not None:
138
if (
139
property.swift_type_safe()
140
!= duplicate_property.swift_type_safe()
141
):
142
print(
143
"property:",
144
property.class_name,
145
property.name,
146
property.swift_type_safe(),
147
property.is_optional,
148
)
149
print(
150
"duplicate_property:",
151
duplicate_property.class_name,
152
duplicate_property.name,
153
duplicate_property.swift_type_safe(),
154
duplicate_property.is_optional,
155
)
156
fail("Duplicate property doesn't match:", property.name)
157
elif property.is_optional != duplicate_property.is_optional:
158
if property.name in root_property_names:
159
print(
160
"property:",
161
property.class_name,
162
property.name,
163
property.swift_type_safe(),
164
property.is_optional,
165
)
166
print(
167
"duplicate_property:",
168
duplicate_property.class_name,
169
duplicate_property.name,
170
duplicate_property.swift_type_safe(),
171
duplicate_property.is_optional,
172
)
173
fail("Duplicate property doesn't match:", property.name)
174
175
# If one subclass property is optional and the other isn't, we should
176
# treat both as optional for the purposes of the database schema.
177
if not property.is_optional:
178
continue
179
else:
180
continue
181
182
all_property_map[property.name] = property
183
subclass_property_map[property.name] = property
184
185
result = []
186
for name in sorted(subclass_property_map.keys()):
187
result.append(subclass_property_map[name])
188
return result
189
190
def record_id_source(self):
191
for property in self.properties():
192
if property.name == "sortId":
193
return property.name
194
return None
195
196
def is_sds_model(self):
197
if self.super_class_name is None:
198
return False
199
if not self.super_class_name in global_class_map:
200
return False
201
if self.super_class_name in (
202
OLD_BASE_MODEL_CLASS_NAME,
203
NEW_BASE_MODEL_CLASS_NAME,
204
):
205
return True
206
super_class = global_class_map[self.super_class_name]
207
return super_class.is_sds_model()
208
209
def has_sds_superclass(self):
210
return (
211
self.super_class_name
212
and self.super_class_name in global_class_map
213
and self.super_class_name != OLD_BASE_MODEL_CLASS_NAME
214
and self.super_class_name != NEW_BASE_MODEL_CLASS_NAME
215
)
216
217
def table_superclass(self):
218
if self.super_class_name is None:
219
return self
220
if not self.super_class_name in global_class_map:
221
return self
222
if self.super_class_name == OLD_BASE_MODEL_CLASS_NAME:
223
return self
224
if self.super_class_name == NEW_BASE_MODEL_CLASS_NAME:
225
return self
226
super_class = global_class_map[self.super_class_name]
227
return super_class.table_superclass()
228
229
def all_superclass_names(self):
230
result = [self.name]
231
if self.super_class_name is not None:
232
if self.super_class_name in global_class_map:
233
super_class = global_class_map[self.super_class_name]
234
result += super_class.all_superclass_names()
235
return result
236
237
def has_any_superclass_with_name(self, name):
238
return name in self.all_superclass_names()
239
240
def should_generate_extensions(self):
241
if self.name in (
242
OLD_BASE_MODEL_CLASS_NAME,
243
NEW_BASE_MODEL_CLASS_NAME,
244
):
245
return False
246
if should_ignore_class(self):
247
return False
248
249
if not self.is_sds_model():
250
# Only write serialization extensions for SDS models.
251
return False
252
253
# The migration should not be persisted in the data store.
254
if self.name in (
255
"OWSDatabaseMigration",
256
"YDBDatabaseMigration",
257
"OWSResaveCollectionDBMigration",
258
):
259
return False
260
if self.super_class_name in (
261
"OWSDatabaseMigration",
262
"YDBDatabaseMigration",
263
"OWSResaveCollectionDBMigration",
264
):
265
return False
266
267
return True
268
269
def record_name(self):
270
return remove_prefix_from_class_name(self.name) + "Record"
271
272
def sorted_record_properties(self):
273
274
record_name = self.record_name()
275
# If a property has a custom column source, we don't redundantly create a column for that column
276
base_properties = [
277
property
278
for property in self.properties()
279
if not property.has_aliased_column_name()
280
]
281
# If a property has a custom column source, we don't redundantly create a column for that column
282
subclass_properties = [
283
property
284
for property in self.database_subclass_properties()
285
if not property.has_aliased_column_name()
286
]
287
288
# We need to maintain a stable ordering of record properties
289
# across migrations, e.g. adding new columns to the tables.
290
#
291
# First, we build a list of "model" properties. This is the
292
# the superset of properties in the model base class and all
293
# of its subclasses.
294
#
295
# NOTE: We punch two values onto these properties:
296
# force_optional and property_order.
297
record_properties = []
298
for property in base_properties:
299
# Treat all enum properties as forced-optional, so that during
300
# deserialization we can survive unexpected raw values.
301
#
302
# Except the special-cased ones.
303
force_optional = property.type_info().is_enum
304
property.force_optional = force_optional
305
record_properties.append(property)
306
for property in subclass_properties:
307
# We must "force" subclass properties to be optional
308
# since they don't apply to the base model and other
309
# subclasses.
310
property.force_optional = True
311
record_properties.append(property)
312
for property in record_properties:
313
# Try to load the known "order" for each property.
314
#
315
# "Orders" are indices used to ensure a stable ordering.
316
# We find the "orders" of all properties that already have
317
# one.
318
#
319
# This will initially be nil for new properties
320
# which have not yet been assigned an order.
321
property.property_order = property_order_for_property(property, record_name)
322
all_property_orders = [
323
property.property_order
324
for property in record_properties
325
if property.property_order
326
]
327
# We determine the "next" order we would assign to any
328
# new property without an order.
329
next_property_order = 1 + (
330
max(all_property_orders) if len(all_property_orders) > 0 else 0
331
)
332
# Pre-sort model properties by name, so that if we add more
333
# than one at a time they are nicely (and stable-y) sorted
334
# in an attractive way.
335
record_properties.sort(key=lambda value: value.name)
336
# Now iterate over all model properties and assign an order
337
# to any new properties without one.
338
for property in record_properties:
339
if property.property_order is None:
340
property.property_order = next_property_order
341
# We "set" the order in the mapping which is persisted
342
# as JSON to ensure continuity.
343
set_property_order_for_property(
344
property, record_name, next_property_order
345
)
346
next_property_order = next_property_order + 1
347
# Now sort the model properties, applying the ordering.
348
record_properties.sort(key=lambda value: value.property_order)
349
return record_properties
350
351
352
class TypeInfo:
353
def __init__(
354
self,
355
swift_type,
356
objc_type,
357
should_use_blob=False,
358
is_codable=False,
359
is_enum=False,
360
field_override_column_type=None,
361
field_override_record_swift_type=None,
362
):
363
self._swift_type = swift_type
364
self._objc_type = objc_type
365
self.should_use_blob = should_use_blob
366
self.is_codable = is_codable
367
self.is_enum = is_enum
368
self.field_override_column_type = field_override_column_type
369
self.field_override_record_swift_type = field_override_record_swift_type
370
371
def swift_type(self):
372
return str(self._swift_type)
373
374
def objc_type(self):
375
return str(self._objc_type)
376
377
# This defines the mapping of Swift types to database column types.
378
# We'll be iterating on this mapping.
379
# Note that we currently store all sub-models and collections (e.g. [String]) as a blob.
380
#
381
# TODO:
382
def database_column_type(self, value_name):
383
if self.field_override_column_type is not None:
384
return self.field_override_column_type
385
elif self.should_use_blob or self.is_codable:
386
return ".blob"
387
elif self.is_enum:
388
return ".int"
389
elif self._swift_type == "String":
390
return ".unicodeString"
391
elif self._objc_type == "NSDate *":
392
# Persist dates as NSTimeInterval timeIntervalSince1970.
393
return ".double"
394
elif self._swift_type == "Date":
395
# Persist dates as NSTimeInterval timeIntervalSince1970.
396
fail(
397
'We should not use `Date` as a "swift type" since all NSDates are serialized as doubles.',
398
self._swift_type,
399
)
400
elif self._swift_type == "Data":
401
return ".blob"
402
elif self._swift_type in ("Boolouble", "Bool"):
403
return ".int"
404
elif self._swift_type in ("Double", "Float"):
405
return ".double"
406
elif self.is_numeric():
407
return ".int64"
408
else:
409
fail("Unknown type(1):", self._swift_type)
410
411
def is_numeric(self):
412
# TODO: We need to revisit how we serialize numeric types.
413
return self._swift_type in (
414
# 'signed char',
415
"Bool",
416
"UInt64",
417
"UInt",
418
"Int64",
419
"Int",
420
"Int32",
421
"UInt32",
422
"Double",
423
"Float",
424
)
425
426
def should_cast_to_swift(self):
427
if self._swift_type in (
428
"Bool",
429
"Int64",
430
"UInt64",
431
):
432
return False
433
return self.is_numeric()
434
435
def deserialize_record_invocation(
436
self, property, value_name, is_optional, did_force_optional
437
):
438
439
value_expr = "record.%s" % (property.column_name(),)
440
441
deserialization_optional = None
442
deserialization_not_optional = None
443
deserialization_conversion = ""
444
if self._swift_type == "String":
445
deserialization_not_optional = "required"
446
elif self._objc_type == "NSDate *":
447
pass
448
elif self._swift_type == "Date":
449
fail("Unknown type(0):", self._swift_type)
450
elif self.is_codable:
451
deserialization_not_optional = "required"
452
elif self._swift_type == "Data":
453
deserialization_optional = "optionalData"
454
deserialization_not_optional = "required"
455
elif self.is_numeric():
456
deserialization_optional = "optionalNumericAsNSNumber"
457
deserialization_not_optional = "required"
458
deserialization_conversion = ", conversion: { NSNumber(value: $0) }"
459
460
initializer_param_type = self.swift_type()
461
if is_optional:
462
initializer_param_type = initializer_param_type + "?"
463
464
# Special-case the unpacking of the auto-incremented
465
# primary key.
466
if value_expr == "record.id":
467
value_expr = "%s(recordId)" % (initializer_param_type,)
468
elif is_optional:
469
if deserialization_optional is not None:
470
value_expr = 'SDSDeserialization.%s(%s, name: "%s"%s)' % (
471
deserialization_optional,
472
value_expr,
473
value_name,
474
deserialization_conversion,
475
)
476
elif did_force_optional:
477
if deserialization_not_optional is not None:
478
value_expr = 'try SDSDeserialization.%s(%s, name: "%s")' % (
479
deserialization_not_optional,
480
value_expr,
481
value_name,
482
)
483
else:
484
# Do nothing; we don't need to unpack this non-optional.
485
pass
486
487
if self.is_codable:
488
value_statement = "let %s: %s = %s" % (
489
value_name,
490
initializer_param_type,
491
value_expr,
492
)
493
elif self.should_use_blob:
494
blob_name = "%sSerialized" % (str(value_name),)
495
if is_optional:
496
serialized_statement = "let %s: Data? = %s" % (
497
blob_name,
498
value_expr,
499
)
500
elif did_force_optional:
501
serialized_statement = f'let {blob_name}: Data = try {value_expr} ?? {{ () -> Data in throw SDSError.missingRequiredField(fieldName: "{value_name}") }}()'
502
else:
503
serialized_statement = "let %s: Data = %s" % (
504
blob_name,
505
value_expr,
506
)
507
from_name = "$0" if is_optional else blob_name
508
swift_type = self._swift_type
509
if swift_type == "[InfoMessageUserInfoKey: AnyObject]":
510
decode_statement = (
511
'try SDSDeserialization.unarchivedInfoDictionary(from: %s)'
512
% (
513
from_name,
514
)
515
)
516
elif ": " in swift_type:
517
assert swift_type.startswith("[")
518
assert swift_type.endswith("]")
519
divider_index = swift_type.index(": ")
520
key_type = swift_type[1:divider_index]
521
value_type = swift_type[divider_index + 2:-1]
522
decode_statement = (
523
'try SDSDeserialization.unarchivedDictionary(ofKeyClass: %s.self, objectClass: %s.self, from: %s)'
524
% (
525
key_type,
526
value_type,
527
from_name,
528
)
529
)
530
elif swift_type.startswith("["):
531
assert swift_type.endswith("]")
532
array_type = self._swift_type[1:-1]
533
objc_types = {
534
"String": "NSString",
535
}
536
objc_type = objc_types.get(array_type, array_type)
537
decode_statement = (
538
'try SDSDeserialization.unarchivedArrayOfObjects(ofClass: %s.self, from: %s)'
539
% (
540
objc_type,
541
from_name,
542
)
543
)
544
if array_type in objc_types:
545
decode_statement += ' as ' + self._swift_type
546
else:
547
decode_statement = (
548
'try SDSDeserialization.unarchivedObject(ofClass: %s.self, from: %s)'
549
% (
550
self._swift_type,
551
from_name,
552
)
553
)
554
if is_optional:
555
value_statement = (
556
'let %s: %s? = try %s.map({ %s })'
557
% (
558
value_name,
559
self._swift_type,
560
blob_name,
561
decode_statement,
562
)
563
)
564
else:
565
value_statement = (
566
'let %s: %s = %s'
567
% (
568
value_name,
569
self._swift_type,
570
decode_statement,
571
)
572
)
573
return [
574
serialized_statement,
575
value_statement,
576
]
577
elif self.is_enum and did_force_optional and not is_optional:
578
return [
579
"guard let %s: %s = %s else {"
580
% (
581
value_name,
582
initializer_param_type,
583
value_expr,
584
),
585
" throw SDSError.missingRequiredField()",
586
"}",
587
]
588
elif is_optional and self._objc_type == "NSNumber *":
589
return [
590
"let %s: %s = %s"
591
% (
592
value_name,
593
"NSNumber?",
594
value_expr,
595
),
596
# 'let %sRaw = %s' % ( value_name, value_expr, ),
597
# 'var %s : NSNumber?' % ( value_name, ),
598
# 'if let value = %sRaw {' % ( value_name, ),
599
# ' %s = NSNumber(value: value)' % ( value_name, ),
600
# '}',
601
]
602
elif self._objc_type == "NSDate *":
603
# Persist dates as NSTimeInterval timeIntervalSince1970.
604
605
interval_name = "%sInterval" % (str(value_name),)
606
if did_force_optional:
607
serialized_statements = [
608
"guard let %s: Double = %s else {"
609
% (
610
interval_name,
611
value_expr,
612
),
613
" throw SDSError.missingRequiredField()",
614
"}",
615
]
616
elif is_optional:
617
serialized_statements = [
618
"let %s: Double? = %s"
619
% (
620
interval_name,
621
value_expr,
622
),
623
]
624
else:
625
serialized_statements = [
626
"let %s: Double = %s"
627
% (
628
interval_name,
629
value_expr,
630
),
631
]
632
if is_optional:
633
value_statement = (
634
'let %s: Date? = SDSDeserialization.optionalDoubleAsDate(%s, name: "%s")'
635
% (
636
value_name,
637
interval_name,
638
value_name,
639
)
640
)
641
else:
642
value_statement = (
643
'let %s: Date = SDSDeserialization.requiredDoubleAsDate(%s, name: "%s")'
644
% (
645
value_name,
646
interval_name,
647
value_name,
648
)
649
)
650
return serialized_statements + [
651
value_statement,
652
]
653
else:
654
value_statement = "let %s: %s = %s" % (
655
value_name,
656
initializer_param_type,
657
value_expr,
658
)
659
return [
660
value_statement,
661
]
662
663
def serialize_record_invocation(
664
self, property, value_name, is_optional, did_force_optional
665
):
666
667
value_expr = value_name
668
669
if property.field_override_serialize_record_invocation() is not None:
670
return property.field_override_serialize_record_invocation() % (value_expr,)
671
elif self.is_codable:
672
pass
673
elif self.should_use_blob:
674
# blob_name = '%sSerialized' % ( str(value_name), )
675
if is_optional or did_force_optional:
676
return "optionalArchive(%s)" % (value_expr,)
677
else:
678
return "requiredArchive(%s)" % (value_expr,)
679
elif self._objc_type == "NSDate *":
680
if is_optional or did_force_optional:
681
return "archiveOptionalDate(%s)" % (value_expr,)
682
else:
683
return "archiveDate(%s)" % (value_expr,)
684
elif self._objc_type == "NSNumber *":
685
# elif self.is_numeric():
686
conversion_map = {
687
"Int8": "int8Value",
688
"UInt8": "uint8Value",
689
"Int16": "int16Value",
690
"UInt16": "uint16Value",
691
"Int32": "int32Value",
692
"UInt32": "uint32Value",
693
"Int64": "int64Value",
694
"UInt64": "uint64Value",
695
"Float": "floatValue",
696
"Double": "doubleValue",
697
"Bool": "boolValue",
698
"Int": "intValue",
699
"UInt": "uintValue",
700
}
701
conversion_method = conversion_map[self.swift_type()]
702
if conversion_method is None:
703
fail("Could not convert:", self.swift_type())
704
serialization_conversion = "{ $0.%s }" % (conversion_method,)
705
if is_optional or did_force_optional:
706
return "archiveOptionalNSNumber(%s, conversion: %s)" % (
707
value_expr,
708
serialization_conversion,
709
)
710
else:
711
return "archiveNSNumber(%s, conversion: %s)" % (
712
value_expr,
713
serialization_conversion,
714
)
715
716
return value_expr
717
718
def record_field_type(self, value_name):
719
# Special case this oddball type.
720
if self.field_override_record_swift_type is not None:
721
return self.field_override_record_swift_type
722
elif self.is_codable:
723
pass
724
elif self.should_use_blob:
725
return "Data"
726
return self.swift_type()
727
728
729
class ParsedProperty:
730
def __init__(self, json_dict):
731
self.name = json_dict.get("name")
732
self.is_optional = json_dict.get("is_optional")
733
self.objc_type = json_dict.get("objc_type")
734
self.class_name = json_dict.get("class_name")
735
self.swift_type = None
736
737
def try_to_convert_objc_primitive_to_swift(self, objc_type, unpack_nsnumber=True):
738
if objc_type is None:
739
fail("Missing type")
740
elif objc_type == "NSString *":
741
return "String"
742
elif objc_type == "NSDate *":
743
# Persist dates as NSTimeInterval timeIntervalSince1970.
744
return "Double"
745
elif objc_type == "NSData *":
746
return "Data"
747
elif objc_type == "BOOL":
748
return "Bool"
749
elif objc_type == "NSInteger":
750
return "Int"
751
elif objc_type == "NSUInteger":
752
return "UInt"
753
elif objc_type == "int32_t":
754
return "Int32"
755
elif objc_type == "uint32_t":
756
return "UInt32"
757
elif objc_type == "int64_t":
758
return "Int64"
759
elif objc_type == "long long":
760
return "Int64"
761
elif objc_type == "unsigned long long":
762
return "UInt64"
763
elif objc_type == "uint64_t":
764
return "UInt64"
765
elif objc_type == "unsigned long":
766
return "UInt64"
767
elif objc_type == "unsigned int":
768
return "UInt32"
769
elif objc_type == "double":
770
return "Double"
771
elif objc_type == "float":
772
return "Float"
773
elif objc_type == "CGFloat":
774
return "Double"
775
elif objc_type == "NSNumber *":
776
if unpack_nsnumber:
777
return swift_type_for_nsnumber(self)
778
else:
779
return "NSNumber"
780
else:
781
return None
782
783
# NOTE: This method recurses to unpack types like: NSArray<NSArray<SomeClassName *> *> *
784
def convert_objc_class_to_swift(self, objc_type, unpack_nsnumber=True):
785
if objc_type == "id":
786
return "AnyObject"
787
elif not objc_type.endswith(" *"):
788
return None
789
790
swift_primitive = self.try_to_convert_objc_primitive_to_swift(
791
objc_type, unpack_nsnumber=unpack_nsnumber
792
)
793
if swift_primitive is not None:
794
return swift_primitive
795
796
array_match = re.search(r"^NS(Mutable)?Array<(.+)> \*$", objc_type)
797
if array_match is not None:
798
split = array_match.group(2)
799
return (
800
"["
801
+ self.convert_objc_class_to_swift(split, unpack_nsnumber=False)
802
+ "]"
803
)
804
805
dict_match = re.search(r"^NS(Mutable)?Dictionary<(.+),(.+)> \*$", objc_type)
806
if dict_match is not None:
807
split1 = dict_match.group(2).strip()
808
split2 = dict_match.group(3).strip()
809
return (
810
"["
811
+ self.convert_objc_class_to_swift(split1, unpack_nsnumber=False)
812
+ ": "
813
+ self.convert_objc_class_to_swift(split2, unpack_nsnumber=False)
814
+ "]"
815
)
816
817
ordered_set_match = re.search(r"^NSOrderedSet<(.+)> \*$", objc_type)
818
if ordered_set_match is not None:
819
# swift has no primitive for ordered set, so we lose the element type
820
return "NSOrderedSet"
821
822
swift_type = objc_type[: -len(" *")]
823
824
if "<" in swift_type or "{" in swift_type or "*" in swift_type:
825
fail("Unexpected type:", objc_type)
826
return swift_type
827
828
def try_to_convert_objc_type_to_type_info(self):
829
objc_type = self.objc_type
830
831
if objc_type is None:
832
fail("Missing type")
833
834
elif self.field_override_swift_type():
835
return TypeInfo(
836
self.field_override_swift_type(),
837
objc_type,
838
should_use_blob=self.field_override_should_use_blob(),
839
is_enum=self.field_override_is_enum(),
840
field_override_column_type=self.field_override_column_type(),
841
field_override_record_swift_type=self.field_override_record_swift_type(),
842
)
843
elif objc_type in enum_type_map:
844
enum_type = objc_type
845
return TypeInfo(enum_type, objc_type, is_enum=True)
846
elif objc_type.startswith("enum "):
847
enum_type = objc_type[len("enum ") :]
848
return TypeInfo(enum_type, objc_type, is_enum=True)
849
850
swift_primitive = self.try_to_convert_objc_primitive_to_swift(objc_type)
851
if swift_primitive is not None:
852
return TypeInfo(swift_primitive, objc_type)
853
854
if objc_type in (
855
"struct CGSize",
856
"struct CGRect",
857
"struct CGPoint",
858
):
859
objc_type = objc_type[len("struct ") :]
860
swift_type = objc_type
861
return TypeInfo(
862
swift_type,
863
objc_type,
864
should_use_blob=True,
865
is_codable=USE_CODABLE_FOR_PRIMITIVES,
866
)
867
868
swift_type = self.convert_objc_class_to_swift(self.objc_type)
869
if swift_type is not None:
870
if self.is_objc_type_codable(objc_type):
871
return TypeInfo(
872
swift_type, objc_type, should_use_blob=True, is_codable=False
873
)
874
return TypeInfo(
875
swift_type, objc_type, should_use_blob=True, is_codable=False
876
)
877
878
fail("Unknown type(3):", self.class_name, self.objc_type, self.name)
879
880
# NOTE: This method recurses to unpack types like: NSArray<NSArray<SomeClassName *> *> *
881
def is_objc_type_codable(self, objc_type):
882
if not USE_CODABLE_FOR_PRIMITIVES:
883
return False
884
885
if objc_type in ("NSString *",):
886
return True
887
elif objc_type in (
888
"struct CGSize",
889
"struct CGRect",
890
"struct CGPoint",
891
):
892
return True
893
elif self.field_override_is_objc_codable() is not None:
894
return self.field_override_is_objc_codable()
895
elif objc_type in enum_type_map:
896
return True
897
elif objc_type.startswith("enum "):
898
return True
899
900
if not USE_CODABLE_FOR_NONPRIMITIVES:
901
return False
902
903
array_match = re.search(r"^NS(Mutable)?Array<(.+)> \*$", objc_type)
904
if array_match is not None:
905
split = array_match.group(2)
906
return self.is_objc_type_codable(split)
907
908
dict_match = re.search(r"^NS(Mutable)?Dictionary<(.+),(.+)> \*$", objc_type)
909
if dict_match is not None:
910
split1 = dict_match.group(2).strip()
911
split2 = dict_match.group(3).strip()
912
return self.is_objc_type_codable(split1) and self.is_objc_type_codable(
913
split2
914
)
915
916
return False
917
918
def field_override_swift_type(self):
919
return self._field_override("swift_type")
920
921
def field_override_is_objc_codable(self):
922
return self._field_override("is_objc_codable")
923
924
def field_override_is_enum(self):
925
return self._field_override("is_enum")
926
927
def field_override_column_type(self):
928
return self._field_override("column_type")
929
930
def field_override_record_swift_type(self):
931
return self._field_override("record_swift_type")
932
933
def field_override_serialize_record_invocation(self):
934
return self._field_override("serialize_record_invocation")
935
936
def field_override_should_use_blob(self):
937
return self._field_override("should_use_blob")
938
939
def field_override_objc_initializer_type(self):
940
return self._field_override("objc_initializer_type")
941
942
def _field_override(self, override_field):
943
manually_typed_fields = configuration_json.get("manually_typed_fields")
944
if manually_typed_fields is None:
945
fail("Configuration JSON is missing manually_typed_fields")
946
key = self.class_name + "." + self.name
947
948
if key in manually_typed_fields:
949
return manually_typed_fields[key][override_field]
950
else:
951
return None
952
953
def type_info(self):
954
if self.swift_type is not None:
955
should_use_blob = (
956
self.swift_type.startswith("[")
957
or self.swift_type.startswith("{")
958
or is_swift_class_name(self.swift_type)
959
)
960
return TypeInfo(
961
self.swift_type,
962
objc_type,
963
should_use_blob=should_use_blob,
964
is_codable=USE_CODABLE_FOR_PRIMITIVES,
965
field_override_column_type=self.field_override_column_type,
966
)
967
968
return self.try_to_convert_objc_type_to_type_info()
969
970
def swift_type_safe(self):
971
return self.type_info().swift_type()
972
973
def objc_type_safe(self):
974
if self.field_override_objc_initializer_type() is not None:
975
return self.field_override_objc_initializer_type()
976
977
result = self.type_info().objc_type()
978
979
if result.startswith("enum "):
980
result = result[len("enum ") :]
981
982
return result
983
# if self.objc_type is None:
984
# fail("Don't know Obj-C type for:", self.name)
985
# return self.objc_type
986
987
def database_column_type(self):
988
return self.type_info().database_column_type(self.name)
989
990
def should_ignore_property(self):
991
return should_ignore_property(self)
992
993
def has_aliased_column_name(self):
994
return aliased_column_name_for_property(self) is not None
995
996
def deserialize_record_invocation(self, value_name, did_force_optional):
997
return self.type_info().deserialize_record_invocation(
998
self, value_name, self.is_optional, did_force_optional
999
)
1000
1001
def deep_copy_record_invocation(self, value_name, did_force_optional):
1002
1003
swift_type = self.swift_type_safe()
1004
objc_type = self.objc_type_safe()
1005
is_optional = self.is_optional
1006
model_accessor = accessor_name_for_property(self)
1007
1008
initializer_param_type = swift_type
1009
if is_optional:
1010
initializer_param_type = initializer_param_type + "?"
1011
1012
simple_type_map = {
1013
"NSString *": "String",
1014
"NSNumber *": "NSNumber",
1015
"NSDate *": "Date",
1016
"NSData *": "Data",
1017
"CGSize": "CGSize",
1018
"CGRect": "CGRect",
1019
"CGPoint": "CGPoint",
1020
}
1021
if objc_type in simple_type_map:
1022
initializer_param_type = simple_type_map[objc_type]
1023
if is_optional:
1024
initializer_param_type += "?"
1025
return [
1026
"let %s: %s = modelToCopy.%s"
1027
% (
1028
value_name,
1029
initializer_param_type,
1030
model_accessor,
1031
),
1032
]
1033
1034
can_shallow_copy = False
1035
if self.type_info().is_numeric():
1036
can_shallow_copy = True
1037
elif self.is_enum():
1038
can_shallow_copy = True
1039
1040
if can_shallow_copy:
1041
return [
1042
"let %s: %s = modelToCopy.%s"
1043
% (
1044
value_name,
1045
initializer_param_type,
1046
model_accessor,
1047
),
1048
]
1049
1050
initializer_param_type = initializer_param_type.replace("AnyObject", "Any")
1051
1052
if is_optional:
1053
return [
1054
"let %s: %s"
1055
% (
1056
value_name,
1057
initializer_param_type,
1058
),
1059
"if let %sForCopy = modelToCopy.%s {"
1060
% (
1061
value_name,
1062
model_accessor,
1063
),
1064
" %s = try DeepCopies.deepCopy(%sForCopy)"
1065
% (
1066
value_name,
1067
value_name,
1068
),
1069
"} else {",
1070
" %s = nil" % (value_name,),
1071
"}",
1072
]
1073
else:
1074
return [
1075
"let %s: %s = try DeepCopies.deepCopy(modelToCopy.%s)"
1076
% (
1077
value_name,
1078
initializer_param_type,
1079
model_accessor,
1080
),
1081
]
1082
1083
fail(
1084
"I don't know how to deep copy this type: %s / %s" % (objc_type, swift_type)
1085
)
1086
1087
def possible_class_type_for_property(self):
1088
swift_type = self.swift_type_safe()
1089
if swift_type in global_class_map:
1090
return global_class_map[swift_type]
1091
objc_type = self.objc_type_safe()
1092
if objc_type.endswith(" *"):
1093
objc_type = objc_type[:-2]
1094
if objc_type in global_class_map:
1095
return global_class_map[objc_type]
1096
return None
1097
1098
def serialize_record_invocation(self, value_name, did_force_optional):
1099
return self.type_info().serialize_record_invocation(
1100
self, value_name, self.is_optional, did_force_optional
1101
)
1102
1103
def record_field_type(self):
1104
return self.type_info().record_field_type(self.name)
1105
1106
def is_enum(self):
1107
return self.type_info().is_enum
1108
1109
def swift_identifier(self):
1110
return to_swift_identifier_name(self.name)
1111
1112
def column_name(self):
1113
aliased_column_name = aliased_column_name_for_property(self)
1114
if aliased_column_name is not None:
1115
return aliased_column_name
1116
custom_column_name = custom_column_name_for_property(self)
1117
if custom_column_name is not None:
1118
return custom_column_name
1119
else:
1120
return self.swift_identifier()
1121
1122
1123
def ows_getoutput(cmd):
1124
proc = subprocess.Popen(
1125
cmd,
1126
stdout=subprocess.PIPE,
1127
stderr=subprocess.PIPE,
1128
)
1129
stdout, stderr = proc.communicate()
1130
1131
return proc.returncode, stdout, stderr
1132
1133
1134
# ---- Parsing
1135
1136
1137
def properties_and_inherited_properties(clazz):
1138
result = []
1139
if clazz.super_class_name in global_class_map:
1140
super_class = global_class_map[clazz.super_class_name]
1141
result.extend(properties_and_inherited_properties(super_class))
1142
result.extend(clazz.properties())
1143
return result
1144
1145
1146
def generate_swift_extensions_for_model(clazz):
1147
if not clazz.should_generate_extensions():
1148
return
1149
1150
has_sds_superclass = clazz.has_sds_superclass()
1151
has_remove_methods = clazz.name not in ("TSInteraction")
1152
has_grdb_serializer = clazz.name in ("TSInteraction")
1153
1154
swift_filename = os.path.basename(clazz.filepath)
1155
swift_filename = swift_filename[: swift_filename.find(".")] + "+SDS.swift"
1156
swift_filepath = os.path.join(os.path.dirname(clazz.filepath), swift_filename)
1157
1158
record_type = get_record_type(clazz)
1159
1160
# TODO: We'll need to import SignalServiceKit for non-SSK models.
1161
1162
swift_body = """//
1163
// Copyright 2022 Signal Messenger, LLC
1164
// SPDX-License-Identifier: AGPL-3.0-only
1165
//
1166
1167
import Foundation
1168
%simport GRDB
1169
1170
// NOTE: This file is generated by %s.
1171
// Do not manually edit it, instead run `sds_codegen.sh`.
1172
""" % (
1173
"" if has_sds_superclass else "public ",
1174
sds_common.pretty_module_path(__file__),
1175
)
1176
1177
if not has_sds_superclass:
1178
1179
# If a property has a custom column source, we don't redundantly create a column for that column
1180
base_properties = [
1181
property
1182
for property in clazz.properties()
1183
if not property.has_aliased_column_name()
1184
]
1185
# If a property has a custom column source, we don't redundantly create a column for that column
1186
subclass_properties = [
1187
property
1188
for property in clazz.database_subclass_properties()
1189
if not property.has_aliased_column_name()
1190
]
1191
1192
swift_body += """
1193
// MARK: - Record
1194
"""
1195
1196
record_name = clazz.record_name()
1197
swift_body += """
1198
public struct %s: SDSRecord {
1199
public weak var delegate: SDSRecordDelegate?
1200
1201
public var tableMetadata: SDSTableMetadata {
1202
%sSerializer.table
1203
}
1204
1205
public static var databaseTableName: String {
1206
%sSerializer.table.tableName
1207
}
1208
1209
public var id: Int64?
1210
1211
// This defines all of the columns used in the table
1212
// where this model (and any subclasses) are persisted.
1213
public let recordType: SDSRecordType?
1214
public let uniqueId: String
1215
1216
""" % (
1217
record_name,
1218
str(clazz.name),
1219
str(clazz.name),
1220
)
1221
1222
def write_record_property(property, force_optional=False):
1223
column_name = property.swift_identifier()
1224
1225
record_field_type = property.record_field_type()
1226
1227
is_optional = property.is_optional or force_optional
1228
optional_split = "?" if is_optional else ""
1229
1230
custom_column_name = custom_column_name_for_property(property)
1231
if custom_column_name is not None:
1232
column_name = custom_column_name
1233
1234
return """ public let %s: %s%s
1235
""" % (
1236
str(column_name),
1237
record_field_type,
1238
optional_split,
1239
)
1240
1241
record_properties = clazz.sorted_record_properties()
1242
1243
# Declare the model properties in the record.
1244
if len(record_properties) > 0:
1245
swift_body += "\n // Properties \n"
1246
for property in record_properties:
1247
swift_body += write_record_property(
1248
property, force_optional=property.force_optional
1249
)
1250
1251
sds_properties = [
1252
ParsedProperty(
1253
{
1254
"name": "id",
1255
"is_optional": False,
1256
"objc_type": "NSInteger",
1257
"class_name": clazz.name,
1258
}
1259
),
1260
ParsedProperty(
1261
{
1262
"name": "recordType",
1263
"is_optional": False,
1264
"objc_type": "NSUInteger",
1265
"class_name": clazz.name,
1266
}
1267
),
1268
ParsedProperty(
1269
{
1270
"name": "uniqueId",
1271
"is_optional": False,
1272
"objc_type": "NSString *",
1273
"class_name": clazz.name,
1274
}
1275
),
1276
]
1277
# We use the pre-sorted collection record_properties so that
1278
# we use the correct property order when generating:
1279
#
1280
# * CodingKeys
1281
# * init(row: Row)
1282
# * The table/column metadata.
1283
persisted_properties = sds_properties + record_properties
1284
1285
swift_body += """
1286
public enum CodingKeys: String, CodingKey, ColumnExpression, CaseIterable {
1287
"""
1288
1289
for property in persisted_properties:
1290
custom_column_name = custom_column_name_for_property(property)
1291
was_property_renamed = was_property_renamed_for_property(property)
1292
if custom_column_name is not None:
1293
if was_property_renamed:
1294
swift_body += """ case %s
1295
""" % (
1296
custom_column_name,
1297
)
1298
else:
1299
swift_body += """ case %s = "%s"
1300
""" % (
1301
custom_column_name,
1302
property.swift_identifier(),
1303
)
1304
else:
1305
swift_body += """ case %s
1306
""" % (
1307
property.swift_identifier(),
1308
)
1309
1310
swift_body += """ }
1311
"""
1312
swift_body += """
1313
public static func columnName(_ column: %s.CodingKeys, fullyQualified: Bool = false) -> String {
1314
fullyQualified ? "\\(databaseTableName).\\(column.rawValue)" : column.rawValue
1315
}
1316
1317
public func didInsert(with rowID: Int64, for column: String?) {
1318
guard let delegate = delegate else {
1319
owsFailDebug("Missing delegate.")
1320
return
1321
}
1322
delegate.updateRowId(rowID)
1323
}
1324
}
1325
""" % (
1326
record_name,
1327
)
1328
1329
swift_body += """
1330
// MARK: - Row Initializer
1331
1332
public extension %s {
1333
static var databaseSelection: [SQLSelectable] {
1334
CodingKeys.allCases
1335
}
1336
1337
init(row: Row) {""" % (
1338
record_name
1339
)
1340
1341
for index, property in enumerate(persisted_properties):
1342
type_info = property.type_info()
1343
property_name = property.column_name()
1344
swift_type = type_info.swift_type()
1345
1346
did_force_optional = type_info.is_enum
1347
1348
if property_name == "recordType":
1349
# recordType is an enum, but its property info here doesn't
1350
# reflect that, so special-case it.
1351
swift_body += """
1352
%s = row[%s].flatMap { SDSRecordType(rawValue: $0) }""" % (
1353
property_name,
1354
index,
1355
)
1356
elif did_force_optional:
1357
swift_body += """
1358
%s = row[%s].flatMap { %s(rawValue: $0) }""" % (
1359
property_name,
1360
index,
1361
swift_type,
1362
)
1363
else:
1364
swift_body += """
1365
%s = row[%s]""" % (
1366
property_name,
1367
index,
1368
)
1369
1370
swift_body += """
1371
}
1372
}
1373
"""
1374
1375
swift_body += """
1376
// MARK: - StringInterpolation
1377
1378
public extension String.StringInterpolation {
1379
mutating func appendInterpolation(%(record_identifier)sColumn column: %(record_name)s.CodingKeys) {
1380
appendLiteral(%(record_name)s.columnName(column))
1381
}
1382
mutating func appendInterpolation(%(record_identifier)sColumnFullyQualified column: %(record_name)s.CodingKeys) {
1383
appendLiteral(%(record_name)s.columnName(column, fullyQualified: true))
1384
}
1385
}
1386
""" % {
1387
"record_identifier": record_identifier(clazz.name),
1388
"record_name": record_name,
1389
}
1390
1391
# TODO: Rework metadata to not include, for example, columns, column indices.
1392
swift_body += """
1393
// MARK: - Deserialization
1394
1395
extension %s {
1396
// This method defines how to deserialize a model, given a
1397
// database row. The recordType column is used to determine
1398
// the corresponding model class.
1399
class func fromRecord(_ record: %s) throws -> %s {
1400
""" % (
1401
str(clazz.name),
1402
record_name,
1403
str(clazz.name),
1404
)
1405
swift_body += """
1406
1407
guard let recordId = record.id else { throw SDSError.missingRequiredField(fieldName: "id") }
1408
guard let recordType = record.recordType else { throw SDSError.missingRequiredField(fieldName: "recordType") }
1409
1410
switch recordType {
1411
"""
1412
1413
deserialize_classes = all_descendents_of_class(clazz) + [clazz]
1414
deserialize_classes.sort(key=lambda value: value.name)
1415
1416
for deserialize_class in deserialize_classes:
1417
if should_ignore_class(deserialize_class):
1418
continue
1419
1420
initializer_params = []
1421
objc_initializer_params = []
1422
objc_super_initializer_args = []
1423
objc_initializer_assigns = []
1424
deserialize_record_type = get_record_type_enum_name(deserialize_class.name)
1425
swift_body += """ case .%s:
1426
""" % (
1427
str(deserialize_record_type),
1428
)
1429
1430
swift_body += """
1431
let uniqueId: String = record.uniqueId
1432
"""
1433
1434
base_property_names = set()
1435
for property in base_properties:
1436
base_property_names.add(property.name)
1437
1438
deserialize_properties = properties_and_inherited_properties(
1439
deserialize_class
1440
)
1441
has_local_properties = False
1442
for property in deserialize_properties:
1443
value_name = "%s" % property.name
1444
1445
if property.name not in ("uniqueId",):
1446
did_force_optional = property.name not in base_property_names
1447
did_force_optional = did_force_optional and not property.is_optional
1448
did_force_optional = did_force_optional or property.type_info().is_enum
1449
for statement in property.deserialize_record_invocation(
1450
value_name, did_force_optional
1451
):
1452
swift_body += " %s\n" % (str(statement),)
1453
1454
initializer_params.append(
1455
"%s: %s"
1456
% (
1457
str(property.name),
1458
value_name,
1459
)
1460
)
1461
objc_initializer_type = str(property.objc_type_safe())
1462
if objc_initializer_type.startswith("NSMutable"):
1463
objc_initializer_type = (
1464
"NS" + objc_initializer_type[len("NSMutable") :]
1465
)
1466
if property.is_optional:
1467
objc_initializer_type = "nullable " + objc_initializer_type
1468
objc_initializer_params.append(
1469
"%s:(%s)%s"
1470
% (
1471
str(property.name),
1472
objc_initializer_type,
1473
str(property.name),
1474
)
1475
)
1476
1477
is_superclass_property = property.class_name != deserialize_class.name
1478
if is_superclass_property:
1479
objc_super_initializer_args.append(
1480
"%s:%s"
1481
% (
1482
str(property.name),
1483
str(property.name),
1484
)
1485
)
1486
else:
1487
has_local_properties = True
1488
if str(property.objc_type_safe()).startswith("NSMutableArray"):
1489
objc_initializer_assigns.append(
1490
"_%s = %s ? [%s mutableCopy] : [NSMutableArray new];"
1491
% (
1492
str(property.name),
1493
str(property.name),
1494
str(property.name),
1495
)
1496
)
1497
elif str(property.objc_type_safe()).startswith(
1498
"NSMutableDictionary"
1499
):
1500
objc_initializer_assigns.append(
1501
"_%s = %s ? [%s mutableCopy] : [NSMutableDictionary new];"
1502
% (
1503
str(property.name),
1504
str(property.name),
1505
str(property.name),
1506
)
1507
)
1508
elif (
1509
deserialize_class.name == "TSIncomingMessage"
1510
and property.name in ("authorUUID", "authorPhoneNumber")
1511
):
1512
pass
1513
else:
1514
objc_initializer_assigns.append(
1515
"_%s = %s;"
1516
% (
1517
str(property.name),
1518
str(property.name),
1519
)
1520
)
1521
1522
# --- Initializer Snippets
1523
1524
h_snippet = ""
1525
h_snippet += """
1526
// clang-format off
1527
1528
- (instancetype)initWithGrdbId:(int64_t)grdbId
1529
uniqueId:(NSString *)uniqueId
1530
"""
1531
for objc_initializer_param in objc_initializer_params[1:]:
1532
alignment = max(
1533
0,
1534
len("- (instancetype)initWithUniqueId")
1535
- objc_initializer_param.index(":"),
1536
)
1537
h_snippet += (" " * alignment) + objc_initializer_param + "\n"
1538
1539
h_snippet += (
1540
"NS_DESIGNATED_INITIALIZER NS_SWIFT_NAME(init(grdbId:%s:));\n"
1541
% ":".join([str(property.name) for property in deserialize_properties])
1542
)
1543
h_snippet += """
1544
// clang-format on
1545
"""
1546
1547
m_snippet = ""
1548
m_snippet += """
1549
// clang-format off
1550
1551
- (instancetype)initWithGrdbId:(int64_t)grdbId
1552
uniqueId:(NSString *)uniqueId
1553
"""
1554
for objc_initializer_param in objc_initializer_params[1:]:
1555
alignment = max(
1556
0,
1557
len("- (instancetype)initWithUniqueId")
1558
- objc_initializer_param.index(":"),
1559
)
1560
m_snippet += (" " * alignment) + objc_initializer_param + "\n"
1561
1562
if len(objc_super_initializer_args) == 1:
1563
suffix = "];"
1564
else:
1565
suffix = ""
1566
m_snippet += """{
1567
self = [super initWithGrdbId:grdbId
1568
uniqueId:uniqueId%s
1569
""" % (
1570
suffix
1571
)
1572
for index, objc_super_initializer_arg in enumerate(
1573
objc_super_initializer_args[1:]
1574
):
1575
alignment = max(
1576
0,
1577
len(" self = [super initWithUniqueId")
1578
- objc_super_initializer_arg.index(":"),
1579
)
1580
if index == len(objc_super_initializer_args) - 2:
1581
suffix = "];"
1582
else:
1583
suffix = ""
1584
m_snippet += (
1585
(" " * alignment) + objc_super_initializer_arg + suffix + "\n"
1586
)
1587
m_snippet += """
1588
if (!self) {
1589
return self;
1590
}
1591
1592
"""
1593
1594
if deserialize_class.name == "TSIncomingMessage":
1595
m_snippet += """
1596
if (authorUUID != nil) {
1597
_authorUUID = authorUUID;
1598
} else if (authorPhoneNumber != nil) {
1599
_authorPhoneNumber = authorPhoneNumber;
1600
}
1601
"""
1602
1603
for objc_initializer_assign in objc_initializer_assigns:
1604
m_snippet += (" " * 4) + objc_initializer_assign + "\n"
1605
1606
if deserialize_class.finalize_method_name is not None:
1607
m_snippet += """
1608
[self %s];
1609
""" % (
1610
str(deserialize_class.finalize_method_name),
1611
)
1612
1613
m_snippet += """
1614
return self;
1615
}
1616
1617
// clang-format on
1618
"""
1619
1620
# Skip initializer generation for classes without any properties.
1621
if not has_local_properties:
1622
h_snippet = ""
1623
m_snippet = ""
1624
1625
if deserialize_class.filepath.endswith(".m"):
1626
m_filepath = deserialize_class.filepath
1627
h_filepath = m_filepath[:-2] + ".h"
1628
update_objc_snippet(h_filepath, h_snippet)
1629
update_objc_snippet(m_filepath, m_snippet)
1630
1631
swift_body += """
1632
"""
1633
1634
# --- Invoke Initializer
1635
1636
initializer_invocation = " return %s(" % str(
1637
deserialize_class.name
1638
)
1639
swift_body += initializer_invocation
1640
initializer_params = [
1641
"grdbId: recordId",
1642
] + initializer_params
1643
swift_body += (",\n" + " " * len(initializer_invocation)).join(
1644
initializer_params
1645
)
1646
swift_body += ")"
1647
swift_body += """
1648
1649
"""
1650
1651
# TODO: We could generate a comment with the Obj-C (or Swift) model initializer
1652
# that this deserialization code expects.
1653
1654
swift_body += """ default:
1655
owsFailDebug("Unexpected record type: \\(recordType)")
1656
throw SDSError.invalidValue()
1657
"""
1658
swift_body += """ }
1659
"""
1660
swift_body += """ }
1661
"""
1662
swift_body += """}
1663
"""
1664
1665
# TODO: Remove the serialization glue below.
1666
1667
if not has_sds_superclass:
1668
swift_body += """
1669
// MARK: - SDSModel
1670
1671
extension %s: SDSModel {
1672
public var serializer: SDSSerializer {
1673
// Any subclass can be cast to it's superclass,
1674
// so the order of this switch statement matters.
1675
// We need to do a "depth first" search by type.
1676
switch self {""" % str(
1677
clazz.name
1678
)
1679
1680
for subclass in reversed(all_descendents_of_class(clazz)):
1681
if should_ignore_class(subclass):
1682
continue
1683
1684
swift_body += """
1685
case let model as %s:
1686
assert(type(of: model) == %s.self)
1687
return %sSerializer(model: model)""" % (
1688
str(subclass.name),
1689
str(subclass.name),
1690
str(subclass.name),
1691
)
1692
1693
swift_body += """
1694
default:
1695
return %sSerializer(model: self)
1696
}
1697
}
1698
1699
public func asRecord() -> SDSRecord {
1700
serializer.asRecord()
1701
}
1702
1703
public var sdsTableName: String {
1704
%s.databaseTableName
1705
}
1706
1707
public static var table: SDSTableMetadata {
1708
%sSerializer.table
1709
}
1710
}
1711
""" % (
1712
str(clazz.name),
1713
record_name,
1714
str(clazz.name),
1715
)
1716
1717
if not has_sds_superclass:
1718
swift_body += """
1719
// MARK: - DeepCopyable
1720
1721
extension %(class_name)s: DeepCopyable {
1722
1723
public func deepCopy() throws -> AnyObject {
1724
guard let id = self.grdbId?.int64Value else {
1725
throw OWSAssertionError("Model missing grdbId.")
1726
}
1727
1728
// Any subclass can be cast to its superclass, so the order of these if
1729
// statements matters. We need to do a "depth first" search by type.
1730
""" % {
1731
"class_name": str(clazz.name)
1732
}
1733
1734
classes_to_copy = list(reversed(all_descendents_of_class(clazz))) + [
1735
clazz,
1736
]
1737
for class_to_copy in classes_to_copy:
1738
if should_ignore_class(class_to_copy):
1739
continue
1740
1741
if class_to_copy == clazz:
1742
swift_body += """
1743
do {
1744
let modelToCopy = self
1745
assert(type(of: modelToCopy) == %(class_name)s.self)
1746
""" % {
1747
"class_name": str(class_to_copy.name)
1748
}
1749
else:
1750
swift_body += """
1751
if let modelToCopy = self as? %(class_name)s {
1752
assert(type(of: modelToCopy) == %(class_name)s.self)
1753
""" % {
1754
"class_name": str(class_to_copy.name)
1755
}
1756
1757
initializer_params = []
1758
base_property_names = set()
1759
for property in base_properties:
1760
base_property_names.add(property.name)
1761
1762
deserialize_properties = properties_and_inherited_properties(class_to_copy)
1763
for property in deserialize_properties:
1764
value_name = "%s" % property.name
1765
1766
did_force_optional = property.name not in base_property_names
1767
did_force_optional = did_force_optional and not property.is_optional
1768
did_force_optional = did_force_optional or property.type_info().is_enum
1769
for statement in property.deep_copy_record_invocation(
1770
value_name, did_force_optional
1771
):
1772
swift_body += " %s\n" % (str(statement),)
1773
1774
initializer_params.append(
1775
"%s: %s"
1776
% (
1777
str(property.name),
1778
value_name,
1779
)
1780
)
1781
1782
swift_body += """
1783
"""
1784
1785
# --- Invoke Initializer
1786
1787
initializer_invocation = " return %s(" % str(class_to_copy.name)
1788
swift_body += initializer_invocation
1789
initializer_params = [
1790
"grdbId: id",
1791
] + initializer_params
1792
swift_body += (",\n" + " " * len(initializer_invocation)).join(
1793
initializer_params
1794
)
1795
swift_body += ")"
1796
swift_body += """
1797
}
1798
"""
1799
1800
swift_body += """
1801
}
1802
}
1803
"""
1804
1805
if has_grdb_serializer:
1806
swift_body += """
1807
// MARK: - Table Metadata
1808
1809
extension %sRecord {
1810
1811
// This defines all of the columns used in the table
1812
// where this model (and any subclasses) are persisted.
1813
internal func asValues() -> [DatabaseValueConvertible?] {
1814
return [
1815
""" % str(
1816
remove_prefix_from_class_name(clazz.name)
1817
)
1818
1819
def write_grdb_column_metadata(metadata):
1820
return """ %s,
1821
""" % (
1822
str(metadata)
1823
)
1824
1825
for property in sds_properties:
1826
column_name = property.column_name()
1827
1828
if column_name == "recordType" or property.type_info().is_enum:
1829
swift_body += write_grdb_column_metadata("%s?.rawValue" % (column_name))
1830
elif property.name != "id":
1831
swift_body += write_grdb_column_metadata(column_name)
1832
1833
for property in record_properties:
1834
column_name = property.column_name()
1835
1836
if property.type_info().is_enum:
1837
swift_body += write_grdb_column_metadata("%s?.rawValue" % (column_name))
1838
else:
1839
swift_body += write_grdb_column_metadata(column_name)
1840
1841
swift_body += """
1842
]
1843
}
1844
1845
internal func asArguments() -> StatementArguments {
1846
return StatementArguments(asValues())
1847
}
1848
}
1849
"""
1850
1851
if not has_sds_superclass:
1852
swift_body += """
1853
// MARK: - Table Metadata
1854
1855
extension %sSerializer {
1856
1857
// This defines all of the columns used in the table
1858
// where this model (and any subclasses) are persisted.
1859
""" % str(
1860
clazz.name
1861
)
1862
1863
# Eventually we need a (persistent?) mechanism for guaranteeing
1864
# consistency of column ordering, that is robust to schema
1865
# changes, class hierarchy changes, etc.
1866
column_property_names = []
1867
1868
def write_column_metadata(property, force_optional=False):
1869
column_name = property.swift_identifier()
1870
column_property_names.append(column_name)
1871
1872
is_optional = property.is_optional or force_optional
1873
optional_split = ", isOptional: true" if is_optional else ""
1874
1875
is_unique = column_name == str("uniqueId")
1876
is_unique_split = ", isUnique: true" if is_unique else ""
1877
1878
database_column_type = property.database_column_type()
1879
if property.name == "id":
1880
database_column_type = ".primaryKey"
1881
1882
# TODO: Use skipSelect.
1883
return """ static var %sColumn: SDSColumnMetadata { SDSColumnMetadata(columnName: "%s", columnType: %s%s%s) }
1884
""" % (
1885
str(column_name),
1886
str(column_name),
1887
database_column_type,
1888
optional_split,
1889
is_unique_split,
1890
)
1891
1892
for property in sds_properties:
1893
swift_body += write_column_metadata(property)
1894
1895
if len(record_properties) > 0:
1896
swift_body += " // Properties \n"
1897
for property in record_properties:
1898
swift_body += write_column_metadata(
1899
property, force_optional=property.force_optional
1900
)
1901
1902
database_table_name = "model_%s" % str(clazz.name)
1903
swift_body += """
1904
public static var table: SDSTableMetadata {
1905
SDSTableMetadata(
1906
tableName: "%s",
1907
columns: [
1908
""" % (
1909
database_table_name,
1910
)
1911
swift_body += "\n".join(
1912
[
1913
" %sColumn," % str(column_property_name)
1914
for column_property_name in column_property_names
1915
]
1916
)
1917
swift_body += """
1918
]
1919
)
1920
}
1921
}
1922
"""
1923
1924
# ---- Fetch ----
1925
1926
cached_method = "anyFetch"
1927
uncached_method = "anyFetch"
1928
if cache_get_code_for_class(clazz) is not None:
1929
cached_method = "fetchViaCache"
1930
1931
swift_body += """
1932
// MARK: - Save/Remove/Update
1933
1934
@objc
1935
public extension %(class_name)s {
1936
func anyInsert(transaction: DBWriteTransaction) {
1937
sdsSave(saveMode: .insert, transaction: transaction)
1938
}
1939
1940
// Avoid this method whenever feasible.
1941
//
1942
// If the record has previously been saved, this method does an overwriting
1943
// update of the corresponding row, otherwise if it's a new record, this
1944
// method inserts a new row.
1945
//
1946
// For performance, when possible, you should explicitly specify whether
1947
// you are inserting or updating rather than calling this method.
1948
func anyUpsert(transaction: DBWriteTransaction) {
1949
let isInserting: Bool
1950
if %(class_name)s.%(cached_method)s(uniqueId: uniqueId, transaction: transaction) != nil {
1951
isInserting = false
1952
} else {
1953
isInserting = true
1954
}
1955
sdsSave(saveMode: isInserting ? .insert : .update, transaction: transaction)
1956
}
1957
1958
// This method is used by "updateWith..." methods.
1959
//
1960
// This model may be updated from many threads. We don't want to save
1961
// our local copy (this instance) since it may be out of date. We also
1962
// want to avoid re-saving a model that has been deleted. Therefore, we
1963
// use "updateWith..." methods to:
1964
//
1965
// a) Update a property of this instance.
1966
// b) If a copy of this model exists in the database, load an up-to-date copy,
1967
// and update and save that copy.
1968
// b) If a copy of this model _DOES NOT_ exist in the database, do _NOT_ save
1969
// this local instance.
1970
//
1971
// After "updateWith...":
1972
//
1973
// a) Any copy of this model in the database will have been updated.
1974
// b) The local property on this instance will always have been updated.
1975
// c) Other properties on this instance may be out of date.
1976
//
1977
// All mutable properties of this class have been made read-only to
1978
// prevent accidentally modifying them directly.
1979
//
1980
// This isn't a perfect arrangement, but in practice this will prevent
1981
// data loss and will resolve all known issues.
1982
func anyUpdate(transaction: DBWriteTransaction, block: (%(class_name)s) -> Void) {
1983
1984
block(self)
1985
1986
// If it's not saved, we don't expect to find it in the database, and we
1987
// won't save any changes we make back into the database.
1988
guard shouldBeSaved else {
1989
return
1990
}
1991
1992
guard let dbCopy = type(of: self).%(uncached_method)s(uniqueId: uniqueId, transaction: transaction) else {
1993
return
1994
}
1995
1996
// Don't apply the block twice to the same instance.
1997
// It's at least unnecessary and actually wrong for some blocks.
1998
// e.g. `block: { $0 in $0.someField++ }`
1999
if dbCopy !== self {
2000
block(dbCopy)
2001
}
2002
2003
dbCopy.sdsSave(saveMode: .update, transaction: transaction)
2004
}
2005
2006
// This method is an alternative to `anyUpdate(transaction:block:)` methods.
2007
//
2008
// We should generally use `anyUpdate` to ensure we're not unintentionally
2009
// clobbering other columns in the database when another concurrent update
2010
// has occurred.
2011
//
2012
// There are cases when this doesn't make sense, e.g. when we know we've
2013
// just loaded the model in the same transaction. In those cases it is
2014
// safe and faster to do a "overwriting" update
2015
func anyOverwritingUpdate(transaction: DBWriteTransaction) {
2016
sdsSave(saveMode: .update, transaction: transaction)
2017
}
2018
""" % {
2019
"class_name": str(clazz.name),
2020
"cached_method": cached_method,
2021
"uncached_method": uncached_method,
2022
}
2023
2024
if has_remove_methods:
2025
swift_body += """
2026
func anyRemove(transaction: DBWriteTransaction) {
2027
sdsRemove(transaction: transaction)
2028
}
2029
"""
2030
2031
swift_body += """}
2032
"""
2033
2034
# ---- Cursor ----
2035
2036
swift_body += """
2037
// MARK: - %sCursor
2038
2039
@objc
2040
public class %sCursor: NSObject, SDSCursor {
2041
private let transaction: DBReadTransaction
2042
private let cursor: RecordCursor<%s>
2043
2044
init(transaction: DBReadTransaction, cursor: RecordCursor<%s>) {
2045
self.transaction = transaction
2046
self.cursor = cursor
2047
}
2048
2049
public func next() throws -> %s? {
2050
guard let record = try cursor.next() else {
2051
return nil
2052
}""" % (
2053
str(clazz.name),
2054
str(clazz.name),
2055
record_name,
2056
record_name,
2057
str(clazz.name),
2058
)
2059
2060
cache_code = cache_set_code_for_class(clazz)
2061
if cache_code is not None:
2062
swift_body += """
2063
let value = try %s.fromRecord(record)
2064
%s(value, transaction: transaction)
2065
return value""" % (
2066
str(clazz.name),
2067
cache_code,
2068
)
2069
else:
2070
swift_body += """
2071
return try %s.fromRecord(record)""" % (
2072
str(clazz.name),
2073
)
2074
2075
swift_body += """
2076
}
2077
2078
public func all() throws -> [%s] {
2079
var result = [%s]()
2080
while true {
2081
guard let model = try next() else {
2082
break
2083
}
2084
result.append(model)
2085
}
2086
return result
2087
}
2088
}
2089
""" % (
2090
str(clazz.name),
2091
str(clazz.name),
2092
)
2093
2094
# ---- Fetch ----
2095
2096
swift_body += """
2097
// MARK: - Obj-C Fetch
2098
2099
@objc
2100
public extension %(class_name)s {
2101
@nonobjc
2102
class func grdbFetchCursor(transaction: DBReadTransaction) -> %(class_name)sCursor {
2103
let database = transaction.database
2104
return failIfThrows {
2105
let cursor = try %(record_name)s.fetchCursor(database)
2106
return %(class_name)sCursor(transaction: transaction, cursor: cursor)
2107
}
2108
}
2109
""" % {
2110
"class_name": str(clazz.name),
2111
"record_name": record_name,
2112
}
2113
2114
cache_code = cache_get_code_for_class(clazz)
2115
assert cache_code is not None
2116
swift_body += """
2117
// Fetches a single model by "unique id".
2118
class func fetchViaCache(uniqueId: String, transaction: DBReadTransaction) -> %(class_name)s? {
2119
assert(!uniqueId.isEmpty)
2120
2121
if let cachedCopy = %(cache_code)s {
2122
return cachedCopy
2123
}
2124
2125
return anyFetch(uniqueId: uniqueId, transaction: transaction)
2126
}
2127
""" % {
2128
"class_name": str(clazz.name),
2129
"cache_code": str(cache_code),
2130
}
2131
2132
swift_body += """
2133
2134
// Fetches a single model by "unique id".
2135
class func anyFetch(uniqueId: String, transaction: DBReadTransaction) -> %(class_name)s? {
2136
assert(!uniqueId.isEmpty)
2137
2138
""" % {
2139
"class_name": str(clazz.name),
2140
}
2141
2142
swift_body += """
2143
let sql = "SELECT * FROM \\(%(record_name)s.databaseTableName) WHERE \\(%(record_identifier)sColumn: .uniqueId) = ?"
2144
return grdbFetchOne(sql: sql, arguments: [uniqueId], transaction: transaction)
2145
}
2146
""" % {
2147
"record_name": record_name,
2148
"record_identifier": record_identifier(clazz.name),
2149
}
2150
2151
swift_body += """
2152
// Traverses all records.
2153
// Records are not visited in any particular order.
2154
class func anyEnumerate(
2155
transaction: DBReadTransaction,
2156
block: (%s) -> Void,
2157
) {
2158
let cursor = %s.grdbFetchCursor(transaction: transaction)
2159
do {
2160
while let value = try cursor.next() {
2161
block(value)
2162
}
2163
} catch let error {
2164
owsFailDebug("Couldn't fetch model: \\(error)")
2165
}
2166
}
2167
""" % (
2168
(str(clazz.name),) * 2
2169
)
2170
2171
swift_body += """
2172
// Does not order the results.
2173
class func anyFetchAll(transaction: DBReadTransaction) -> [%s] {
2174
var result = [%s]()
2175
anyEnumerate(transaction: transaction) { model in
2176
result.append(model)
2177
}
2178
return result
2179
}
2180
""" % (
2181
(str(clazz.name),) * 2
2182
)
2183
2184
# ---- Count ----
2185
2186
swift_body += """
2187
class func anyCount(transaction: DBReadTransaction) -> UInt {
2188
return %s.ows_fetchCount(transaction.database)
2189
}
2190
}
2191
""" % (
2192
record_name,
2193
)
2194
2195
# ---- Fetch ----
2196
2197
swift_body += """
2198
// MARK: - Swift Fetch
2199
2200
public extension %(class_name)s {
2201
class func grdbFetchCursor(sql: String,
2202
arguments: StatementArguments = StatementArguments(),
2203
transaction: DBReadTransaction) -> %(class_name)sCursor {
2204
return failIfThrows {
2205
let sqlRequest = SQLRequest<Void>(sql: sql, arguments: arguments, cached: true)
2206
let cursor = try %(record_name)s.fetchCursor(transaction.database, sqlRequest)
2207
return %(class_name)sCursor(transaction: transaction, cursor: cursor)
2208
}
2209
}
2210
""" % {
2211
"class_name": str(clazz.name),
2212
"record_name": record_name,
2213
}
2214
2215
string_interpolation_name = remove_prefix_from_class_name(clazz.name)
2216
swift_body += """
2217
class func grdbFetchOne(sql: String,
2218
arguments: StatementArguments = StatementArguments(),
2219
transaction: DBReadTransaction) -> %s? {
2220
assert(!sql.isEmpty)
2221
2222
do {
2223
let sqlRequest = SQLRequest<Void>(sql: sql, arguments: arguments, cached: true)
2224
guard let record = try %s.fetchOne(transaction.database, sqlRequest) else {
2225
return nil
2226
}
2227
""" % (
2228
str(clazz.name),
2229
record_name,
2230
)
2231
2232
cache_code = cache_set_code_for_class(clazz)
2233
if cache_code is not None:
2234
swift_body += """
2235
let value = try %s.fromRecord(record)
2236
%s(value, transaction: transaction)
2237
return value""" % (
2238
str(clazz.name),
2239
cache_code,
2240
)
2241
else:
2242
swift_body += """
2243
return try %s.fromRecord(record)""" % (
2244
str(clazz.name),
2245
)
2246
2247
swift_body += """
2248
} catch {
2249
owsFailDebug("error: \\(error)")
2250
return nil
2251
}
2252
}
2253
}
2254
"""
2255
2256
# ---- Typed Convenience Methods ----
2257
2258
if has_sds_superclass:
2259
swift_body += """
2260
// MARK: - Typed Convenience Methods
2261
2262
@objc
2263
public extension %s {
2264
// NOTE: This method will fail if the object has unexpected type.
2265
class func fetch%sViaCache(
2266
uniqueId: String,
2267
transaction: DBReadTransaction
2268
) -> %s? {
2269
assert(!uniqueId.isEmpty)
2270
2271
guard let object = fetchViaCache(uniqueId: uniqueId, transaction: transaction) else {
2272
return nil
2273
}
2274
guard let instance = object as? %s else {
2275
owsFailDebug("Object has unexpected type: \\(type(of: object))")
2276
return nil
2277
}
2278
return instance
2279
}
2280
2281
// NOTE: This method will fail if the object has unexpected type.
2282
func anyUpdate%s(transaction: DBWriteTransaction, block: (%s) -> Void) {
2283
anyUpdate(transaction: transaction) { (object) in
2284
guard let instance = object as? %s else {
2285
owsFailDebug("Object has unexpected type: \\(type(of: object))")
2286
return
2287
}
2288
block(instance)
2289
}
2290
}
2291
}
2292
""" % (
2293
str(clazz.name),
2294
str(remove_prefix_from_class_name(clazz.name)),
2295
str(clazz.name),
2296
str(clazz.name),
2297
str(remove_prefix_from_class_name(clazz.name)),
2298
str(clazz.name),
2299
str(clazz.name),
2300
)
2301
2302
# ---- SDSModel ----
2303
2304
table_superclass = clazz.table_superclass()
2305
table_class_name = str(table_superclass.name)
2306
has_serializable_superclass = table_superclass.name != clazz.name
2307
2308
override_keyword = ""
2309
2310
swift_body += """
2311
// MARK: - SDSSerializer
2312
2313
// The SDSSerializer protocol specifies how to insert and update the
2314
// row that corresponds to this model.
2315
class %sSerializer: SDSSerializer {
2316
2317
private let model: %s
2318
public init(model: %s) {
2319
self.model = model
2320
}
2321
""" % (
2322
str(clazz.name),
2323
str(clazz.name),
2324
str(clazz.name),
2325
)
2326
2327
# --- To Record
2328
2329
root_class = clazz.table_superclass()
2330
root_record_name = remove_prefix_from_class_name(root_class.name) + "Record"
2331
2332
record_id_source = "model.grdbId?.int64Value"
2333
if root_class.record_id_source() is not None:
2334
record_id_source = (
2335
"model.%(source)s > 0 ? Int64(model.%(source)s) : %(default_source)s"
2336
% {
2337
"source": root_class.record_id_source(),
2338
"default_source": record_id_source,
2339
}
2340
)
2341
2342
swift_body += """
2343
// MARK: - Record
2344
2345
func asRecord() -> SDSRecord {
2346
let id: Int64? = %(record_id_source)s
2347
2348
let recordType: SDSRecordType = .%(record_type)s
2349
let uniqueId: String = model.uniqueId
2350
""" % {
2351
"record_type": get_record_type_enum_name(clazz.name),
2352
"record_id_source": record_id_source,
2353
}
2354
2355
initializer_args = [
2356
"id",
2357
"recordType",
2358
"uniqueId",
2359
]
2360
2361
inherited_property_map = {}
2362
for property in properties_and_inherited_properties(clazz):
2363
inherited_property_map[property.column_name()] = property
2364
2365
def write_record_property(property, force_optional=False):
2366
optional_value = ""
2367
2368
if property.column_name() in inherited_property_map:
2369
inherited_property = inherited_property_map[property.column_name()]
2370
did_force_optional = property.force_optional
2371
model_accessor = accessor_name_for_property(inherited_property)
2372
value_expr = inherited_property.serialize_record_invocation(
2373
"model.%s" % (model_accessor,), did_force_optional
2374
)
2375
2376
optional_value = " = %s" % (value_expr,)
2377
else:
2378
optional_value = " = nil"
2379
2380
record_field_type = property.record_field_type()
2381
2382
is_optional = property.is_optional or force_optional
2383
optional_split = "?" if is_optional else ""
2384
2385
initializer_args.append(property.column_name())
2386
2387
return """ let %s: %s%s%s
2388
""" % (
2389
str(property.column_name()),
2390
record_field_type,
2391
optional_split,
2392
optional_value,
2393
)
2394
2395
root_record_properties = root_class.sorted_record_properties()
2396
2397
if len(root_record_properties) > 0:
2398
swift_body += "\n // Properties \n"
2399
for property in root_record_properties:
2400
swift_body += write_record_property(
2401
property, force_optional=property.force_optional
2402
)
2403
2404
initializer_args = [
2405
"%s: %s"
2406
% (
2407
arg,
2408
arg,
2409
)
2410
for arg in initializer_args
2411
]
2412
swift_body += """
2413
return %s(delegate: model, %s)
2414
}
2415
""" % (
2416
root_record_name,
2417
", ".join(initializer_args),
2418
)
2419
2420
swift_body += """}
2421
"""
2422
2423
print(f"Writing {swift_filename}")
2424
2425
swift_body = sds_common.clean_up_generated_swift(swift_body)
2426
2427
sds_common.write_text_file_if_changed(swift_filepath, swift_body)
2428
2429
2430
def process_class_map(class_map):
2431
for clazz in class_map.values():
2432
generate_swift_extensions_for_model(clazz)
2433
2434
2435
# ---- Record Type Map
2436
2437
record_type_map = {}
2438
2439
2440
# It's critical that our "record type" values are consistent, even if we add/remove/rename model classes.
2441
# Therefore we persist the mapping of known classes in a JSON file that is under source control.
2442
def update_record_type_map(record_type_swift_path, record_type_json_path):
2443
record_type_map_filepath = record_type_json_path
2444
2445
if os.path.exists(record_type_map_filepath):
2446
with open(record_type_map_filepath, "rt") as f:
2447
json_string = f.read()
2448
json_data = json.loads(json_string)
2449
record_type_map.update(json_data)
2450
2451
max_record_type = 0
2452
for class_name in record_type_map:
2453
if class_name.startswith("#"):
2454
continue
2455
record_type = record_type_map[class_name]
2456
max_record_type = max(max_record_type, record_type)
2457
2458
for clazz in global_class_map.values():
2459
if clazz.name not in record_type_map:
2460
2461
if not clazz.should_generate_extensions():
2462
continue
2463
2464
max_record_type = int(max_record_type) + 1
2465
record_type = max_record_type
2466
record_type_map[clazz.name] = record_type
2467
2468
record_type_map["#comment"] = (
2469
"NOTE: This file is generated by %s. Do not manually edit it, instead run `sds_codegen.sh`."
2470
% (sds_common.pretty_module_path(__file__),)
2471
)
2472
2473
json_string = json.dumps(record_type_map, sort_keys=True, indent=4)
2474
2475
sds_common.write_text_file_if_changed(record_type_map_filepath, json_string)
2476
2477
# TODO: We'll need to import SignalServiceKit for non-SSK classes.
2478
2479
swift_body = """//
2480
// Copyright 2022 Signal Messenger, LLC
2481
// SPDX-License-Identifier: AGPL-3.0-only
2482
//
2483
2484
import Foundation
2485
import GRDB
2486
2487
// NOTE: This file is generated by %s.
2488
// Do not manually edit it, instead run `sds_codegen.sh`.
2489
2490
@objc
2491
public enum SDSRecordType: UInt, CaseIterable {
2492
""" % (
2493
sds_common.pretty_module_path(__file__),
2494
)
2495
2496
record_type_pairs = []
2497
for key in record_type_map.keys():
2498
if key.startswith("#"):
2499
# Ignore comments
2500
continue
2501
enum_name = get_record_type_enum_name(key)
2502
record_type_pairs.append((str(enum_name), record_type_map[key]))
2503
2504
record_type_pairs.sort(key=lambda value: value[1])
2505
for enum_name, record_type_id in record_type_pairs:
2506
swift_body += """ case %s = %s
2507
""" % (
2508
enum_name,
2509
str(record_type_id),
2510
)
2511
2512
swift_body += """}
2513
"""
2514
2515
swift_body = sds_common.clean_up_generated_swift(swift_body)
2516
2517
sds_common.write_text_file_if_changed(record_type_swift_path, swift_body)
2518
2519
2520
def get_record_type(clazz):
2521
return record_type_map[clazz.name]
2522
2523
2524
def remove_prefix_from_class_name(class_name):
2525
name = class_name
2526
if name.startswith("TS"):
2527
name = name[len("TS") :]
2528
elif name.startswith("OWS"):
2529
name = name[len("OWS") :]
2530
elif name.startswith("SSK"):
2531
name = name[len("SSK") :]
2532
return name
2533
2534
2535
def get_record_type_enum_name(class_name):
2536
name = remove_prefix_from_class_name(class_name)
2537
if name[0].isnumeric():
2538
name = "_" + name
2539
return to_swift_identifier_name(name)
2540
2541
2542
def record_identifier(class_name):
2543
name = remove_prefix_from_class_name(class_name)
2544
return to_swift_identifier_name(name)
2545
2546
2547
# ---- Column Ordering
2548
2549
2550
column_ordering_map = {}
2551
has_loaded_column_ordering_map = False
2552
2553
2554
# ---- Parsing
2555
2556
enum_type_map = {}
2557
2558
2559
def objc_type_for_enum(enum_name):
2560
if enum_name not in enum_type_map:
2561
print("enum_type_map", enum_type_map)
2562
fail("Enum has unknown type:", enum_name)
2563
enum_type = enum_type_map[enum_name]
2564
return enum_type
2565
2566
2567
def swift_type_for_enum(enum_name):
2568
objc_type = objc_type_for_enum(enum_name)
2569
2570
if objc_type == "NSInteger":
2571
return "Int"
2572
elif objc_type == "NSUInteger":
2573
return "UInt"
2574
elif objc_type == "int32_t":
2575
return "Int32"
2576
elif objc_type == "unsigned long long":
2577
return "uint64_t"
2578
elif objc_type == "unsigned long long":
2579
return "UInt64"
2580
elif objc_type == "unsigned long":
2581
return "UInt64"
2582
elif objc_type == "unsigned int":
2583
return "UInt"
2584
else:
2585
fail("Unknown objc type:", objc_type)
2586
2587
2588
def parse_sds_json(file_path):
2589
with open(file_path, "rt") as f:
2590
json_str = f.read()
2591
json_data = json.loads(json_str)
2592
2593
classes = json_data["classes"]
2594
class_map = {}
2595
for class_dict in classes:
2596
clazz = ParsedClass(class_dict)
2597
class_map[clazz.name] = clazz
2598
2599
enums = json_data["enums"]
2600
enum_type_map.update(enums)
2601
2602
return class_map
2603
2604
2605
def try_to_parse_file(file_path):
2606
filename = os.path.basename(file_path)
2607
_, file_extension = os.path.splitext(filename)
2608
if filename.endswith(sds_common.SDS_JSON_FILE_EXTENSION):
2609
return parse_sds_json(file_path)
2610
else:
2611
return {}
2612
2613
2614
def find_sds_intermediary_files_in_path(path):
2615
class_map = {}
2616
if os.path.isfile(path):
2617
class_map.update(try_to_parse_file(path))
2618
else:
2619
for rootdir, dirnames, filenames in os.walk(path):
2620
for filename in filenames:
2621
file_path = os.path.abspath(os.path.join(rootdir, filename))
2622
class_map.update(try_to_parse_file(file_path))
2623
return class_map
2624
2625
2626
def update_subclass_map():
2627
for clazz in global_class_map.values():
2628
if clazz.super_class_name is not None:
2629
subclasses = global_subclass_map.get(clazz.super_class_name, [])
2630
subclasses.append(clazz)
2631
global_subclass_map[clazz.super_class_name] = subclasses
2632
2633
2634
def all_descendents_of_class(clazz):
2635
result = []
2636
2637
subclasses = global_subclass_map.get(clazz.name, [])
2638
subclasses.sort(key=lambda value: value.name)
2639
for subclass in subclasses:
2640
result.append(subclass)
2641
result.extend(all_descendents_of_class(subclass))
2642
2643
return result
2644
2645
2646
def is_swift_class_name(swift_type):
2647
return global_class_map.get(swift_type) is not None
2648
2649
2650
# ---- Config JSON
2651
2652
configuration_json = {}
2653
2654
2655
def parse_config_json(config_json_path):
2656
with open(config_json_path, "rt") as f:
2657
json_str = f.read()
2658
2659
json_data = json.loads(json_str)
2660
global configuration_json
2661
configuration_json = json_data
2662
2663
2664
# We often use nullable NSNumber * for optional numerics (bool, int, int64, double, etc.).
2665
# There's now way to infer which type we're boxing in NSNumber.
2666
# Therefore, we need to specify that in the configuration JSON.
2667
def swift_type_for_nsnumber(property):
2668
nsnumber_types = configuration_json.get("nsnumber_types")
2669
if nsnumber_types is None:
2670
print("Suggestion: update: %s" % (str(global_args.config_json_path),))
2671
fail("Configuration JSON is missing mapping for properties of type NSNumber.")
2672
key = property.class_name + "." + property.name
2673
swift_type = nsnumber_types.get(key)
2674
if swift_type is None:
2675
print("Suggestion: update: %s" % (str(global_args.config_json_path),))
2676
fail(
2677
"Configuration JSON is missing mapping for properties of type NSNumber:",
2678
key,
2679
)
2680
return swift_type
2681
2682
2683
# Some properties shouldn't get serialized.
2684
# For now, there's just one: TSGroupModel.groupImage which is a UIImage.
2685
# We might end up extending the serialization to handle images.
2686
# Or we might store these as Data/NSData/blob.
2687
# TODO:
2688
def should_ignore_property(property):
2689
properties_to_ignore = configuration_json.get("properties_to_ignore")
2690
if properties_to_ignore is None:
2691
fail(
2692
"Configuration JSON is missing list of properties to ignore during serialization."
2693
)
2694
key = property.class_name + "." + property.name
2695
return key in properties_to_ignore
2696
2697
2698
def cache_get_code_for_class(clazz):
2699
code_map = configuration_json.get("class_cache_get_code")
2700
if code_map is None:
2701
fail("Configuration JSON is missing dict of class_cache_get_code.")
2702
key = clazz.name
2703
return code_map.get(key)
2704
2705
2706
def cache_set_code_for_class(clazz):
2707
code_map = configuration_json.get("class_cache_set_code")
2708
if code_map is None:
2709
fail("Configuration JSON is missing dict of class_cache_set_code.")
2710
key = clazz.name
2711
return code_map.get(key)
2712
2713
2714
def should_ignore_class(clazz):
2715
class_to_skip_serialization = configuration_json.get("class_to_skip_serialization")
2716
if class_to_skip_serialization is None:
2717
fail(
2718
"Configuration JSON is missing list of classes to ignore during serialization."
2719
)
2720
if clazz.name in class_to_skip_serialization:
2721
return True
2722
2723
if clazz.super_class_name is None:
2724
return False
2725
if not clazz.super_class_name in global_class_map:
2726
return False
2727
super_clazz = global_class_map[clazz.super_class_name]
2728
return should_ignore_class(super_clazz)
2729
2730
2731
def accessor_name_for_property(property):
2732
custom_accessors = configuration_json.get("custom_accessors")
2733
if custom_accessors is None:
2734
fail("Configuration JSON is missing list of custom property accessors.")
2735
key = property.class_name + "." + property.name
2736
return custom_accessors.get(key, property.name)
2737
2738
2739
# include_renamed_columns
2740
def custom_column_name_for_property(property):
2741
custom_column_names = configuration_json.get("custom_column_names")
2742
if custom_column_names is None:
2743
fail("Configuration JSON is missing list of custom column names.")
2744
key = property.class_name + "." + property.name
2745
return custom_column_names.get(key)
2746
2747
2748
def aliased_column_name_for_property(property):
2749
custom_column_names = configuration_json.get("aliased_column_names")
2750
if custom_column_names is None:
2751
fail("Configuration JSON is missing dict of aliased_column_names.")
2752
key = property.class_name + "." + property.name
2753
2754
return custom_column_names.get(key)
2755
2756
2757
def was_property_renamed_for_property(property):
2758
renamed_column_names = configuration_json.get("renamed_column_names")
2759
if renamed_column_names is None:
2760
fail("Configuration JSON is missing list of renamed column names.")
2761
key = property.class_name + "." + property.name
2762
return renamed_column_names.get(key) is not None
2763
2764
2765
# ---- Config JSON
2766
2767
property_order_json = {}
2768
2769
2770
def parse_property_order_json(property_order_json_path):
2771
with open(property_order_json_path, "rt") as f:
2772
json_str = f.read()
2773
2774
json_data = json.loads(json_str)
2775
global property_order_json
2776
property_order_json = json_data
2777
2778
2779
# It's critical that our "property order" is consistent, even if we add columns.
2780
# Therefore we persist the "property order" for all known properties in a JSON file that is under source control.
2781
def update_property_order_json(property_order_json_path):
2782
property_order_json["#comment"] = (
2783
"NOTE: This file is generated by %s. Do not manually edit it, instead run `sds_codegen.sh`."
2784
% (sds_common.pretty_module_path(__file__),)
2785
)
2786
2787
json_string = json.dumps(property_order_json, sort_keys=True, indent=4)
2788
2789
sds_common.write_text_file_if_changed(property_order_json_path, json_string)
2790
2791
2792
def property_order_key(property, record_name):
2793
return record_name + "." + property.name
2794
2795
2796
def property_order_for_property(property, record_name):
2797
key = property_order_key(property, record_name)
2798
result = property_order_json.get(key)
2799
return result
2800
2801
2802
def set_property_order_for_property(property, record_name, value):
2803
key = property_order_key(property, record_name)
2804
property_order_json[key] = value
2805
2806
2807
if __name__ == "__main__":
2808
2809
parser = argparse.ArgumentParser(description="Generate Swift extensions.")
2810
parser.add_argument(
2811
"--src-path", required=True, help="used to specify a path to process."
2812
)
2813
parser.add_argument(
2814
"--search-path", required=True, help="used to specify a path to process."
2815
)
2816
parser.add_argument(
2817
"--record-type-swift-path",
2818
required=True,
2819
help="path of the record type enum swift file.",
2820
)
2821
parser.add_argument(
2822
"--record-type-json-path",
2823
required=True,
2824
help="path of the record type map json file.",
2825
)
2826
parser.add_argument(
2827
"--config-json-path",
2828
required=True,
2829
help="path of the json file with code generation config info.",
2830
)
2831
parser.add_argument(
2832
"--property-order-json-path",
2833
required=True,
2834
help="path of the json file with property ordering cache.",
2835
)
2836
args = parser.parse_args()
2837
2838
global_args = args
2839
2840
src_path = os.path.abspath(args.src_path)
2841
search_path = os.path.abspath(args.search_path)
2842
record_type_swift_path = os.path.abspath(args.record_type_swift_path)
2843
record_type_json_path = os.path.abspath(args.record_type_json_path)
2844
config_json_path = os.path.abspath(args.config_json_path)
2845
property_order_json_path = os.path.abspath(args.property_order_json_path)
2846
2847
# We control the code generation process using a JSON config file.
2848
parse_config_json(config_json_path)
2849
parse_property_order_json(property_order_json_path)
2850
2851
# The code generation needs to understand the class hierarchy so that
2852
# it can:
2853
#
2854
# * Define table schemas that include the superset of properties in
2855
# the model class hierarchies.
2856
# * Generate deserialization methods that handle all subclasses.
2857
# * etc.
2858
global_class_map.update(find_sds_intermediary_files_in_path(search_path))
2859
update_subclass_map()
2860
update_record_type_map(record_type_swift_path, record_type_json_path)
2861
process_class_map(find_sds_intermediary_files_in_path(src_path))
2862
2863
# Persist updated property order
2864
update_property_order_json(property_order_json_path)
2865
2866