import os
import subprocess
import argparse
import re
import json
import sds_common
from sds_common import fail
import random
OLD_BASE_MODEL_CLASS_NAME = "TSYapDatabaseObject"
NEW_BASE_MODEL_CLASS_NAME = "BaseModel"
CODE_GEN_SNIPPET_MARKER_OBJC = "// --- CODE GENERATION MARKER"
USE_CODABLE_FOR_PRIMITIVES = False
USE_CODABLE_FOR_NONPRIMITIVES = False
def update_generated_snippet(file_path, marker, snippet):
if not os.path.exists(file_path):
fail("Missing file:", file_path)
with open(file_path, "rt") as f:
text = f.read()
start_index = text.find(marker)
end_index = text.rfind(marker)
if start_index < 0 or end_index < 0 or start_index >= end_index:
fail(f"Could not find markers ('{marker}'): {file_path}")
text = (
text[:start_index].strip()
+ "\n\n"
+ marker
+ "\n\n"
+ snippet
+ "\n\n"
+ marker
+ "\n\n"
+ text[end_index + len(marker) :].lstrip()
)
sds_common.write_text_file_if_changed(file_path, text)
def update_objc_snippet(file_path, snippet):
snippet = sds_common.clean_up_generated_objc(snippet).strip()
if len(snippet) < 1:
return
snippet = (
"// This snippet is generated by %s. Do not manually edit it, instead run `sds_codegen.sh`."
% (sds_common.pretty_module_path(__file__),)
+ "\n\n"
+ snippet
)
update_generated_snippet(file_path, CODE_GEN_SNIPPET_MARKER_OBJC, snippet)
global_class_map = {}
global_subclass_map = {}
global_args = None
def to_swift_identifier_name(identifier_name):
return identifier_name[0].lower() + identifier_name[1:]
class ParsedClass:
def __init__(self, json_dict):
self.name = json_dict.get("name")
self.super_class_name = json_dict.get("super_class_name")
self.filepath = sds_common.sds_from_relative_path(json_dict.get("filepath"))
self.finalize_method_name = json_dict.get("finalize_method_name")
self.property_map = {}
for property_dict in json_dict.get("properties"):
property = ParsedProperty(property_dict)
property.class_name = self.name
if property.should_ignore_property():
continue
self.property_map[property.name] = property
def properties(self):
result = []
for name in sorted(self.property_map.keys()):
result.append(self.property_map[name])
return result
def database_subclass_properties(self):
all_property_map = {}
subclass_property_map = {}
root_property_names = set()
for property in self.properties():
all_property_map[property.name] = property
root_property_names.add(property.name)
for subclass in all_descendents_of_class(self):
if should_ignore_class(subclass):
continue
for property in subclass.properties():
duplicate_property = all_property_map.get(property.name)
if duplicate_property is not None:
if (
property.swift_type_safe()
!= duplicate_property.swift_type_safe()
):
print(
"property:",
property.class_name,
property.name,
property.swift_type_safe(),
property.is_optional,
)
print(
"duplicate_property:",
duplicate_property.class_name,
duplicate_property.name,
duplicate_property.swift_type_safe(),
duplicate_property.is_optional,
)
fail("Duplicate property doesn't match:", property.name)
elif property.is_optional != duplicate_property.is_optional:
if property.name in root_property_names:
print(
"property:",
property.class_name,
property.name,
property.swift_type_safe(),
property.is_optional,
)
print(
"duplicate_property:",
duplicate_property.class_name,
duplicate_property.name,
duplicate_property.swift_type_safe(),
duplicate_property.is_optional,
)
fail("Duplicate property doesn't match:", property.name)
if not property.is_optional:
continue
else:
continue
all_property_map[property.name] = property
subclass_property_map[property.name] = property
result = []
for name in sorted(subclass_property_map.keys()):
result.append(subclass_property_map[name])
return result
def record_id_source(self):
for property in self.properties():
if property.name == "sortId":
return property.name
return None
def is_sds_model(self):
if self.super_class_name is None:
return False
if not self.super_class_name in global_class_map:
return False
if self.super_class_name in (
OLD_BASE_MODEL_CLASS_NAME,
NEW_BASE_MODEL_CLASS_NAME,
):
return True
super_class = global_class_map[self.super_class_name]
return super_class.is_sds_model()
def has_sds_superclass(self):
return (
self.super_class_name
and self.super_class_name in global_class_map
and self.super_class_name != OLD_BASE_MODEL_CLASS_NAME
and self.super_class_name != NEW_BASE_MODEL_CLASS_NAME
)
def table_superclass(self):
if self.super_class_name is None:
return self
if not self.super_class_name in global_class_map:
return self
if self.super_class_name == OLD_BASE_MODEL_CLASS_NAME:
return self
if self.super_class_name == NEW_BASE_MODEL_CLASS_NAME:
return self
super_class = global_class_map[self.super_class_name]
return super_class.table_superclass()
def all_superclass_names(self):
result = [self.name]
if self.super_class_name is not None:
if self.super_class_name in global_class_map:
super_class = global_class_map[self.super_class_name]
result += super_class.all_superclass_names()
return result
def has_any_superclass_with_name(self, name):
return name in self.all_superclass_names()
def should_generate_extensions(self):
if self.name in (
OLD_BASE_MODEL_CLASS_NAME,
NEW_BASE_MODEL_CLASS_NAME,
):
return False
if should_ignore_class(self):
return False
if not self.is_sds_model():
return False
if self.name in (
"OWSDatabaseMigration",
"YDBDatabaseMigration",
"OWSResaveCollectionDBMigration",
):
return False
if self.super_class_name in (
"OWSDatabaseMigration",
"YDBDatabaseMigration",
"OWSResaveCollectionDBMigration",
):
return False
return True
def record_name(self):
return remove_prefix_from_class_name(self.name) + "Record"
def sorted_record_properties(self):
record_name = self.record_name()
base_properties = [
property
for property in self.properties()
if not property.has_aliased_column_name()
]
subclass_properties = [
property
for property in self.database_subclass_properties()
if not property.has_aliased_column_name()
]
record_properties = []
for property in base_properties:
force_optional = property.type_info().is_enum
property.force_optional = force_optional
record_properties.append(property)
for property in subclass_properties:
property.force_optional = True
record_properties.append(property)
for property in record_properties:
property.property_order = property_order_for_property(property, record_name)
all_property_orders = [
property.property_order
for property in record_properties
if property.property_order
]
next_property_order = 1 + (
max(all_property_orders) if len(all_property_orders) > 0 else 0
)
record_properties.sort(key=lambda value: value.name)
for property in record_properties:
if property.property_order is None:
property.property_order = next_property_order
set_property_order_for_property(
property, record_name, next_property_order
)
next_property_order = next_property_order + 1
record_properties.sort(key=lambda value: value.property_order)
return record_properties
class TypeInfo:
def __init__(
self,
swift_type,
objc_type,
should_use_blob=False,
is_codable=False,
is_enum=False,
field_override_column_type=None,
field_override_record_swift_type=None,
):
self._swift_type = swift_type
self._objc_type = objc_type
self.should_use_blob = should_use_blob
self.is_codable = is_codable
self.is_enum = is_enum
self.field_override_column_type = field_override_column_type
self.field_override_record_swift_type = field_override_record_swift_type
def swift_type(self):
return str(self._swift_type)
def objc_type(self):
return str(self._objc_type)
def database_column_type(self, value_name):
if self.field_override_column_type is not None:
return self.field_override_column_type
elif self.should_use_blob or self.is_codable:
return ".blob"
elif self.is_enum:
return ".int"
elif self._swift_type == "String":
return ".unicodeString"
elif self._objc_type == "NSDate *":
return ".double"
elif self._swift_type == "Date":
fail(
'We should not use `Date` as a "swift type" since all NSDates are serialized as doubles.',
self._swift_type,
)
elif self._swift_type == "Data":
return ".blob"
elif self._swift_type in ("Boolouble", "Bool"):
return ".int"
elif self._swift_type in ("Double", "Float"):
return ".double"
elif self.is_numeric():
return ".int64"
else:
fail("Unknown type(1):", self._swift_type)
def is_numeric(self):
return self._swift_type in (
"Bool",
"UInt64",
"UInt",
"Int64",
"Int",
"Int32",
"UInt32",
"Double",
"Float",
)
def should_cast_to_swift(self):
if self._swift_type in (
"Bool",
"Int64",
"UInt64",
):
return False
return self.is_numeric()
def deserialize_record_invocation(
self, property, value_name, is_optional, did_force_optional
):
value_expr = "record.%s" % (property.column_name(),)
deserialization_optional = None
deserialization_not_optional = None
deserialization_conversion = ""
if self._swift_type == "String":
deserialization_not_optional = "required"
elif self._objc_type == "NSDate *":
pass
elif self._swift_type == "Date":
fail("Unknown type(0):", self._swift_type)
elif self.is_codable:
deserialization_not_optional = "required"
elif self._swift_type == "Data":
deserialization_optional = "optionalData"
deserialization_not_optional = "required"
elif self.is_numeric():
deserialization_optional = "optionalNumericAsNSNumber"
deserialization_not_optional = "required"
deserialization_conversion = ", conversion: { NSNumber(value: $0) }"
initializer_param_type = self.swift_type()
if is_optional:
initializer_param_type = initializer_param_type + "?"
if value_expr == "record.id":
value_expr = "%s(recordId)" % (initializer_param_type,)
elif is_optional:
if deserialization_optional is not None:
value_expr = 'SDSDeserialization.%s(%s, name: "%s"%s)' % (
deserialization_optional,
value_expr,
value_name,
deserialization_conversion,
)
elif did_force_optional:
if deserialization_not_optional is not None:
value_expr = 'try SDSDeserialization.%s(%s, name: "%s")' % (
deserialization_not_optional,
value_expr,
value_name,
)
else:
pass
if self.is_codable:
value_statement = "let %s: %s = %s" % (
value_name,
initializer_param_type,
value_expr,
)
elif self.should_use_blob:
blob_name = "%sSerialized" % (str(value_name),)
if is_optional:
serialized_statement = "let %s: Data? = %s" % (
blob_name,
value_expr,
)
elif did_force_optional:
serialized_statement = f'let {blob_name}: Data = try {value_expr} ?? {{ () -> Data in throw SDSError.missingRequiredField(fieldName: "{value_name}") }}()'
else:
serialized_statement = "let %s: Data = %s" % (
blob_name,
value_expr,
)
from_name = "$0" if is_optional else blob_name
swift_type = self._swift_type
if swift_type == "[InfoMessageUserInfoKey: AnyObject]":
decode_statement = (
'try SDSDeserialization.unarchivedInfoDictionary(from: %s)'
% (
from_name,
)
)
elif ": " in swift_type:
assert swift_type.startswith("[")
assert swift_type.endswith("]")
divider_index = swift_type.index(": ")
key_type = swift_type[1:divider_index]
value_type = swift_type[divider_index + 2:-1]
decode_statement = (
'try SDSDeserialization.unarchivedDictionary(ofKeyClass: %s.self, objectClass: %s.self, from: %s)'
% (
key_type,
value_type,
from_name,
)
)
elif swift_type.startswith("["):
assert swift_type.endswith("]")
array_type = self._swift_type[1:-1]
objc_types = {
"String": "NSString",
}
objc_type = objc_types.get(array_type, array_type)
decode_statement = (
'try SDSDeserialization.unarchivedArrayOfObjects(ofClass: %s.self, from: %s)'
% (
objc_type,
from_name,
)
)
if array_type in objc_types:
decode_statement += ' as ' + self._swift_type
else:
decode_statement = (
'try SDSDeserialization.unarchivedObject(ofClass: %s.self, from: %s)'
% (
self._swift_type,
from_name,
)
)
if is_optional:
value_statement = (
'let %s: %s? = try %s.map({ %s })'
% (
value_name,
self._swift_type,
blob_name,
decode_statement,
)
)
else:
value_statement = (
'let %s: %s = %s'
% (
value_name,
self._swift_type,
decode_statement,
)
)
return [
serialized_statement,
value_statement,
]
elif self.is_enum and did_force_optional and not is_optional:
return [
"guard let %s: %s = %s else {"
% (
value_name,
initializer_param_type,
value_expr,
),
" throw SDSError.missingRequiredField()",
"}",
]
elif is_optional and self._objc_type == "NSNumber *":
return [
"let %s: %s = %s"
% (
value_name,
"NSNumber?",
value_expr,
),
]
elif self._objc_type == "NSDate *":
interval_name = "%sInterval" % (str(value_name),)
if did_force_optional:
serialized_statements = [
"guard let %s: Double = %s else {"
% (
interval_name,
value_expr,
),
" throw SDSError.missingRequiredField()",
"}",
]
elif is_optional:
serialized_statements = [
"let %s: Double? = %s"
% (
interval_name,
value_expr,
),
]
else:
serialized_statements = [
"let %s: Double = %s"
% (
interval_name,
value_expr,
),
]
if is_optional:
value_statement = (
'let %s: Date? = SDSDeserialization.optionalDoubleAsDate(%s, name: "%s")'
% (
value_name,
interval_name,
value_name,
)
)
else:
value_statement = (
'let %s: Date = SDSDeserialization.requiredDoubleAsDate(%s, name: "%s")'
% (
value_name,
interval_name,
value_name,
)
)
return serialized_statements + [
value_statement,
]
else:
value_statement = "let %s: %s = %s" % (
value_name,
initializer_param_type,
value_expr,
)
return [
value_statement,
]
def serialize_record_invocation(
self, property, value_name, is_optional, did_force_optional
):
value_expr = value_name
if property.field_override_serialize_record_invocation() is not None:
return property.field_override_serialize_record_invocation() % (value_expr,)
elif self.is_codable:
pass
elif self.should_use_blob:
if is_optional or did_force_optional:
return "optionalArchive(%s)" % (value_expr,)
else:
return "requiredArchive(%s)" % (value_expr,)
elif self._objc_type == "NSDate *":
if is_optional or did_force_optional:
return "archiveOptionalDate(%s)" % (value_expr,)
else:
return "archiveDate(%s)" % (value_expr,)
elif self._objc_type == "NSNumber *":
conversion_map = {
"Int8": "int8Value",
"UInt8": "uint8Value",
"Int16": "int16Value",
"UInt16": "uint16Value",
"Int32": "int32Value",
"UInt32": "uint32Value",
"Int64": "int64Value",
"UInt64": "uint64Value",
"Float": "floatValue",
"Double": "doubleValue",
"Bool": "boolValue",
"Int": "intValue",
"UInt": "uintValue",
}
conversion_method = conversion_map[self.swift_type()]
if conversion_method is None:
fail("Could not convert:", self.swift_type())
serialization_conversion = "{ $0.%s }" % (conversion_method,)
if is_optional or did_force_optional:
return "archiveOptionalNSNumber(%s, conversion: %s)" % (
value_expr,
serialization_conversion,
)
else:
return "archiveNSNumber(%s, conversion: %s)" % (
value_expr,
serialization_conversion,
)
return value_expr
def record_field_type(self, value_name):
if self.field_override_record_swift_type is not None:
return self.field_override_record_swift_type
elif self.is_codable:
pass
elif self.should_use_blob:
return "Data"
return self.swift_type()
class ParsedProperty:
def __init__(self, json_dict):
self.name = json_dict.get("name")
self.is_optional = json_dict.get("is_optional")
self.objc_type = json_dict.get("objc_type")
self.class_name = json_dict.get("class_name")
self.swift_type = None
def try_to_convert_objc_primitive_to_swift(self, objc_type, unpack_nsnumber=True):
if objc_type is None:
fail("Missing type")
elif objc_type == "NSString *":
return "String"
elif objc_type == "NSDate *":
return "Double"
elif objc_type == "NSData *":
return "Data"
elif objc_type == "BOOL":
return "Bool"
elif objc_type == "NSInteger":
return "Int"
elif objc_type == "NSUInteger":
return "UInt"
elif objc_type == "int32_t":
return "Int32"
elif objc_type == "uint32_t":
return "UInt32"
elif objc_type == "int64_t":
return "Int64"
elif objc_type == "long long":
return "Int64"
elif objc_type == "unsigned long long":
return "UInt64"
elif objc_type == "uint64_t":
return "UInt64"
elif objc_type == "unsigned long":
return "UInt64"
elif objc_type == "unsigned int":
return "UInt32"
elif objc_type == "double":
return "Double"
elif objc_type == "float":
return "Float"
elif objc_type == "CGFloat":
return "Double"
elif objc_type == "NSNumber *":
if unpack_nsnumber:
return swift_type_for_nsnumber(self)
else:
return "NSNumber"
else:
return None
def convert_objc_class_to_swift(self, objc_type, unpack_nsnumber=True):
if objc_type == "id":
return "AnyObject"
elif not objc_type.endswith(" *"):
return None
swift_primitive = self.try_to_convert_objc_primitive_to_swift(
objc_type, unpack_nsnumber=unpack_nsnumber
)
if swift_primitive is not None:
return swift_primitive
array_match = re.search(r"^NS(Mutable)?Array<(.+)> \*$", objc_type)
if array_match is not None:
split = array_match.group(2)
return (
"["
+ self.convert_objc_class_to_swift(split, unpack_nsnumber=False)
+ "]"
)
dict_match = re.search(r"^NS(Mutable)?Dictionary<(.+),(.+)> \*$", objc_type)
if dict_match is not None:
split1 = dict_match.group(2).strip()
split2 = dict_match.group(3).strip()
return (
"["
+ self.convert_objc_class_to_swift(split1, unpack_nsnumber=False)
+ ": "
+ self.convert_objc_class_to_swift(split2, unpack_nsnumber=False)
+ "]"
)
ordered_set_match = re.search(r"^NSOrderedSet<(.+)> \*$", objc_type)
if ordered_set_match is not None:
return "NSOrderedSet"
swift_type = objc_type[: -len(" *")]
if "<" in swift_type or "{" in swift_type or "*" in swift_type:
fail("Unexpected type:", objc_type)
return swift_type
def try_to_convert_objc_type_to_type_info(self):
objc_type = self.objc_type
if objc_type is None:
fail("Missing type")
elif self.field_override_swift_type():
return TypeInfo(
self.field_override_swift_type(),
objc_type,
should_use_blob=self.field_override_should_use_blob(),
is_enum=self.field_override_is_enum(),
field_override_column_type=self.field_override_column_type(),
field_override_record_swift_type=self.field_override_record_swift_type(),
)
elif objc_type in enum_type_map:
enum_type = objc_type
return TypeInfo(enum_type, objc_type, is_enum=True)
elif objc_type.startswith("enum "):
enum_type = objc_type[len("enum ") :]
return TypeInfo(enum_type, objc_type, is_enum=True)
swift_primitive = self.try_to_convert_objc_primitive_to_swift(objc_type)
if swift_primitive is not None:
return TypeInfo(swift_primitive, objc_type)
if objc_type in (
"struct CGSize",
"struct CGRect",
"struct CGPoint",
):
objc_type = objc_type[len("struct ") :]
swift_type = objc_type
return TypeInfo(
swift_type,
objc_type,
should_use_blob=True,
is_codable=USE_CODABLE_FOR_PRIMITIVES,
)
swift_type = self.convert_objc_class_to_swift(self.objc_type)
if swift_type is not None:
if self.is_objc_type_codable(objc_type):
return TypeInfo(
swift_type, objc_type, should_use_blob=True, is_codable=False
)
return TypeInfo(
swift_type, objc_type, should_use_blob=True, is_codable=False
)
fail("Unknown type(3):", self.class_name, self.objc_type, self.name)
def is_objc_type_codable(self, objc_type):
if not USE_CODABLE_FOR_PRIMITIVES:
return False
if objc_type in ("NSString *",):
return True
elif objc_type in (
"struct CGSize",
"struct CGRect",
"struct CGPoint",
):
return True
elif self.field_override_is_objc_codable() is not None:
return self.field_override_is_objc_codable()
elif objc_type in enum_type_map:
return True
elif objc_type.startswith("enum "):
return True
if not USE_CODABLE_FOR_NONPRIMITIVES:
return False
array_match = re.search(r"^NS(Mutable)?Array<(.+)> \*$", objc_type)
if array_match is not None:
split = array_match.group(2)
return self.is_objc_type_codable(split)
dict_match = re.search(r"^NS(Mutable)?Dictionary<(.+),(.+)> \*$", objc_type)
if dict_match is not None:
split1 = dict_match.group(2).strip()
split2 = dict_match.group(3).strip()
return self.is_objc_type_codable(split1) and self.is_objc_type_codable(
split2
)
return False
def field_override_swift_type(self):
return self._field_override("swift_type")
def field_override_is_objc_codable(self):
return self._field_override("is_objc_codable")
def field_override_is_enum(self):
return self._field_override("is_enum")
def field_override_column_type(self):
return self._field_override("column_type")
def field_override_record_swift_type(self):
return self._field_override("record_swift_type")
def field_override_serialize_record_invocation(self):
return self._field_override("serialize_record_invocation")
def field_override_should_use_blob(self):
return self._field_override("should_use_blob")
def field_override_objc_initializer_type(self):
return self._field_override("objc_initializer_type")
def _field_override(self, override_field):
manually_typed_fields = configuration_json.get("manually_typed_fields")
if manually_typed_fields is None:
fail("Configuration JSON is missing manually_typed_fields")
key = self.class_name + "." + self.name
if key in manually_typed_fields:
return manually_typed_fields[key][override_field]
else:
return None
def type_info(self):
if self.swift_type is not None:
should_use_blob = (
self.swift_type.startswith("[")
or self.swift_type.startswith("{")
or is_swift_class_name(self.swift_type)
)
return TypeInfo(
self.swift_type,
objc_type,
should_use_blob=should_use_blob,
is_codable=USE_CODABLE_FOR_PRIMITIVES,
field_override_column_type=self.field_override_column_type,
)
return self.try_to_convert_objc_type_to_type_info()
def swift_type_safe(self):
return self.type_info().swift_type()
def objc_type_safe(self):
if self.field_override_objc_initializer_type() is not None:
return self.field_override_objc_initializer_type()
result = self.type_info().objc_type()
if result.startswith("enum "):
result = result[len("enum ") :]
return result
def database_column_type(self):
return self.type_info().database_column_type(self.name)
def should_ignore_property(self):
return should_ignore_property(self)
def has_aliased_column_name(self):
return aliased_column_name_for_property(self) is not None
def deserialize_record_invocation(self, value_name, did_force_optional):
return self.type_info().deserialize_record_invocation(
self, value_name, self.is_optional, did_force_optional
)
def deep_copy_record_invocation(self, value_name, did_force_optional):
swift_type = self.swift_type_safe()
objc_type = self.objc_type_safe()
is_optional = self.is_optional
model_accessor = accessor_name_for_property(self)
initializer_param_type = swift_type
if is_optional:
initializer_param_type = initializer_param_type + "?"
simple_type_map = {
"NSString *": "String",
"NSNumber *": "NSNumber",
"NSDate *": "Date",
"NSData *": "Data",
"CGSize": "CGSize",
"CGRect": "CGRect",
"CGPoint": "CGPoint",
}
if objc_type in simple_type_map:
initializer_param_type = simple_type_map[objc_type]
if is_optional:
initializer_param_type += "?"
return [
"let %s: %s = modelToCopy.%s"
% (
value_name,
initializer_param_type,
model_accessor,
),
]
can_shallow_copy = False
if self.type_info().is_numeric():
can_shallow_copy = True
elif self.is_enum():
can_shallow_copy = True
if can_shallow_copy:
return [
"let %s: %s = modelToCopy.%s"
% (
value_name,
initializer_param_type,
model_accessor,
),
]
initializer_param_type = initializer_param_type.replace("AnyObject", "Any")
if is_optional:
return [
"let %s: %s"
% (
value_name,
initializer_param_type,
),
"if let %sForCopy = modelToCopy.%s {"
% (
value_name,
model_accessor,
),
" %s = try DeepCopies.deepCopy(%sForCopy)"
% (
value_name,
value_name,
),
"} else {",
" %s = nil" % (value_name,),
"}",
]
else:
return [
"let %s: %s = try DeepCopies.deepCopy(modelToCopy.%s)"
% (
value_name,
initializer_param_type,
model_accessor,
),
]
fail(
"I don't know how to deep copy this type: %s / %s" % (objc_type, swift_type)
)
def possible_class_type_for_property(self):
swift_type = self.swift_type_safe()
if swift_type in global_class_map:
return global_class_map[swift_type]
objc_type = self.objc_type_safe()
if objc_type.endswith(" *"):
objc_type = objc_type[:-2]
if objc_type in global_class_map:
return global_class_map[objc_type]
return None
def serialize_record_invocation(self, value_name, did_force_optional):
return self.type_info().serialize_record_invocation(
self, value_name, self.is_optional, did_force_optional
)
def record_field_type(self):
return self.type_info().record_field_type(self.name)
def is_enum(self):
return self.type_info().is_enum
def swift_identifier(self):
return to_swift_identifier_name(self.name)
def column_name(self):
aliased_column_name = aliased_column_name_for_property(self)
if aliased_column_name is not None:
return aliased_column_name
custom_column_name = custom_column_name_for_property(self)
if custom_column_name is not None:
return custom_column_name
else:
return self.swift_identifier()
def ows_getoutput(cmd):
proc = subprocess.Popen(
cmd,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
)
stdout, stderr = proc.communicate()
return proc.returncode, stdout, stderr
def properties_and_inherited_properties(clazz):
result = []
if clazz.super_class_name in global_class_map:
super_class = global_class_map[clazz.super_class_name]
result.extend(properties_and_inherited_properties(super_class))
result.extend(clazz.properties())
return result
def generate_swift_extensions_for_model(clazz):
if not clazz.should_generate_extensions():
return
has_sds_superclass = clazz.has_sds_superclass()
has_remove_methods = clazz.name not in ("TSInteraction")
has_grdb_serializer = clazz.name in ("TSInteraction")
swift_filename = os.path.basename(clazz.filepath)
swift_filename = swift_filename[: swift_filename.find(".")] + "+SDS.swift"
swift_filepath = os.path.join(os.path.dirname(clazz.filepath), swift_filename)
record_type = get_record_type(clazz)
swift_body = """//
// Copyright 2022 Signal Messenger, LLC
// SPDX-License-Identifier: AGPL-3.0-only
//
import Foundation
%simport GRDB
// NOTE: This file is generated by %s.
// Do not manually edit it, instead run `sds_codegen.sh`.
""" % (
"" if has_sds_superclass else "public ",
sds_common.pretty_module_path(__file__),
)
if not has_sds_superclass:
base_properties = [
property
for property in clazz.properties()
if not property.has_aliased_column_name()
]
subclass_properties = [
property
for property in clazz.database_subclass_properties()
if not property.has_aliased_column_name()
]
swift_body += """
// MARK: - Record
"""
record_name = clazz.record_name()
swift_body += """
public struct %s: SDSRecord {
public weak var delegate: SDSRecordDelegate?
public var tableMetadata: SDSTableMetadata {
%sSerializer.table
}
public static var databaseTableName: String {
%sSerializer.table.tableName
}
public var id: Int64?
// This defines all of the columns used in the table
// where this model (and any subclasses) are persisted.
public let recordType: SDSRecordType?
public let uniqueId: String
""" % (
record_name,
str(clazz.name),
str(clazz.name),
)
def write_record_property(property, force_optional=False):
column_name = property.swift_identifier()
record_field_type = property.record_field_type()
is_optional = property.is_optional or force_optional
optional_split = "?" if is_optional else ""
custom_column_name = custom_column_name_for_property(property)
if custom_column_name is not None:
column_name = custom_column_name
return """ public let %s: %s%s
""" % (
str(column_name),
record_field_type,
optional_split,
)
record_properties = clazz.sorted_record_properties()
if len(record_properties) > 0:
swift_body += "\n // Properties \n"
for property in record_properties:
swift_body += write_record_property(
property, force_optional=property.force_optional
)
sds_properties = [
ParsedProperty(
{
"name": "id",
"is_optional": False,
"objc_type": "NSInteger",
"class_name": clazz.name,
}
),
ParsedProperty(
{
"name": "recordType",
"is_optional": False,
"objc_type": "NSUInteger",
"class_name": clazz.name,
}
),
ParsedProperty(
{
"name": "uniqueId",
"is_optional": False,
"objc_type": "NSString *",
"class_name": clazz.name,
}
),
]
persisted_properties = sds_properties + record_properties
swift_body += """
public enum CodingKeys: String, CodingKey, ColumnExpression, CaseIterable {
"""
for property in persisted_properties:
custom_column_name = custom_column_name_for_property(property)
was_property_renamed = was_property_renamed_for_property(property)
if custom_column_name is not None:
if was_property_renamed:
swift_body += """ case %s
""" % (
custom_column_name,
)
else:
swift_body += """ case %s = "%s"
""" % (
custom_column_name,
property.swift_identifier(),
)
else:
swift_body += """ case %s
""" % (
property.swift_identifier(),
)
swift_body += """ }
"""
swift_body += """
public static func columnName(_ column: %s.CodingKeys, fullyQualified: Bool = false) -> String {
fullyQualified ? "\\(databaseTableName).\\(column.rawValue)" : column.rawValue
}
public func didInsert(with rowID: Int64, for column: String?) {
guard let delegate = delegate else {
owsFailDebug("Missing delegate.")
return
}
delegate.updateRowId(rowID)
}
}
""" % (
record_name,
)
swift_body += """
// MARK: - Row Initializer
public extension %s {
static var databaseSelection: [SQLSelectable] {
CodingKeys.allCases
}
init(row: Row) {""" % (
record_name
)
for index, property in enumerate(persisted_properties):
type_info = property.type_info()
property_name = property.column_name()
swift_type = type_info.swift_type()
did_force_optional = type_info.is_enum
if property_name == "recordType":
swift_body += """
%s = row[%s].flatMap { SDSRecordType(rawValue: $0) }""" % (
property_name,
index,
)
elif did_force_optional:
swift_body += """
%s = row[%s].flatMap { %s(rawValue: $0) }""" % (
property_name,
index,
swift_type,
)
else:
swift_body += """
%s = row[%s]""" % (
property_name,
index,
)
swift_body += """
}
}
"""
swift_body += """
// MARK: - StringInterpolation
public extension String.StringInterpolation {
mutating func appendInterpolation(%(record_identifier)sColumn column: %(record_name)s.CodingKeys) {
appendLiteral(%(record_name)s.columnName(column))
}
mutating func appendInterpolation(%(record_identifier)sColumnFullyQualified column: %(record_name)s.CodingKeys) {
appendLiteral(%(record_name)s.columnName(column, fullyQualified: true))
}
}
""" % {
"record_identifier": record_identifier(clazz.name),
"record_name": record_name,
}
swift_body += """
// MARK: - Deserialization
extension %s {
// This method defines how to deserialize a model, given a
// database row. The recordType column is used to determine
// the corresponding model class.
class func fromRecord(_ record: %s) throws -> %s {
""" % (
str(clazz.name),
record_name,
str(clazz.name),
)
swift_body += """
guard let recordId = record.id else { throw SDSError.missingRequiredField(fieldName: "id") }
guard let recordType = record.recordType else { throw SDSError.missingRequiredField(fieldName: "recordType") }
switch recordType {
"""
deserialize_classes = all_descendents_of_class(clazz) + [clazz]
deserialize_classes.sort(key=lambda value: value.name)
for deserialize_class in deserialize_classes:
if should_ignore_class(deserialize_class):
continue
initializer_params = []
objc_initializer_params = []
objc_super_initializer_args = []
objc_initializer_assigns = []
deserialize_record_type = get_record_type_enum_name(deserialize_class.name)
swift_body += """ case .%s:
""" % (
str(deserialize_record_type),
)
swift_body += """
let uniqueId: String = record.uniqueId
"""
base_property_names = set()
for property in base_properties:
base_property_names.add(property.name)
deserialize_properties = properties_and_inherited_properties(
deserialize_class
)
has_local_properties = False
for property in deserialize_properties:
value_name = "%s" % property.name
if property.name not in ("uniqueId",):
did_force_optional = property.name not in base_property_names
did_force_optional = did_force_optional and not property.is_optional
did_force_optional = did_force_optional or property.type_info().is_enum
for statement in property.deserialize_record_invocation(
value_name, did_force_optional
):
swift_body += " %s\n" % (str(statement),)
initializer_params.append(
"%s: %s"
% (
str(property.name),
value_name,
)
)
objc_initializer_type = str(property.objc_type_safe())
if objc_initializer_type.startswith("NSMutable"):
objc_initializer_type = (
"NS" + objc_initializer_type[len("NSMutable") :]
)
if property.is_optional:
objc_initializer_type = "nullable " + objc_initializer_type
objc_initializer_params.append(
"%s:(%s)%s"
% (
str(property.name),
objc_initializer_type,
str(property.name),
)
)
is_superclass_property = property.class_name != deserialize_class.name
if is_superclass_property:
objc_super_initializer_args.append(
"%s:%s"
% (
str(property.name),
str(property.name),
)
)
else:
has_local_properties = True
if str(property.objc_type_safe()).startswith("NSMutableArray"):
objc_initializer_assigns.append(
"_%s = %s ? [%s mutableCopy] : [NSMutableArray new];"
% (
str(property.name),
str(property.name),
str(property.name),
)
)
elif str(property.objc_type_safe()).startswith(
"NSMutableDictionary"
):
objc_initializer_assigns.append(
"_%s = %s ? [%s mutableCopy] : [NSMutableDictionary new];"
% (
str(property.name),
str(property.name),
str(property.name),
)
)
elif (
deserialize_class.name == "TSIncomingMessage"
and property.name in ("authorUUID", "authorPhoneNumber")
):
pass
else:
objc_initializer_assigns.append(
"_%s = %s;"
% (
str(property.name),
str(property.name),
)
)
h_snippet = ""
h_snippet += """
// clang-format off
- (instancetype)initWithGrdbId:(int64_t)grdbId
uniqueId:(NSString *)uniqueId
"""
for objc_initializer_param in objc_initializer_params[1:]:
alignment = max(
0,
len("- (instancetype)initWithUniqueId")
- objc_initializer_param.index(":"),
)
h_snippet += (" " * alignment) + objc_initializer_param + "\n"
h_snippet += (
"NS_DESIGNATED_INITIALIZER NS_SWIFT_NAME(init(grdbId:%s:));\n"
% ":".join([str(property.name) for property in deserialize_properties])
)
h_snippet += """
// clang-format on
"""
m_snippet = ""
m_snippet += """
// clang-format off
- (instancetype)initWithGrdbId:(int64_t)grdbId
uniqueId:(NSString *)uniqueId
"""
for objc_initializer_param in objc_initializer_params[1:]:
alignment = max(
0,
len("- (instancetype)initWithUniqueId")
- objc_initializer_param.index(":"),
)
m_snippet += (" " * alignment) + objc_initializer_param + "\n"
if len(objc_super_initializer_args) == 1:
suffix = "];"
else:
suffix = ""
m_snippet += """{
self = [super initWithGrdbId:grdbId
uniqueId:uniqueId%s
""" % (
suffix
)
for index, objc_super_initializer_arg in enumerate(
objc_super_initializer_args[1:]
):
alignment = max(
0,
len(" self = [super initWithUniqueId")
- objc_super_initializer_arg.index(":"),
)
if index == len(objc_super_initializer_args) - 2:
suffix = "];"
else:
suffix = ""
m_snippet += (
(" " * alignment) + objc_super_initializer_arg + suffix + "\n"
)
m_snippet += """
if (!self) {
return self;
}
"""
if deserialize_class.name == "TSIncomingMessage":
m_snippet += """
if (authorUUID != nil) {
_authorUUID = authorUUID;
} else if (authorPhoneNumber != nil) {
_authorPhoneNumber = authorPhoneNumber;
}
"""
for objc_initializer_assign in objc_initializer_assigns:
m_snippet += (" " * 4) + objc_initializer_assign + "\n"
if deserialize_class.finalize_method_name is not None:
m_snippet += """
[self %s];
""" % (
str(deserialize_class.finalize_method_name),
)
m_snippet += """
return self;
}
// clang-format on
"""
if not has_local_properties:
h_snippet = ""
m_snippet = ""
if deserialize_class.filepath.endswith(".m"):
m_filepath = deserialize_class.filepath
h_filepath = m_filepath[:-2] + ".h"
update_objc_snippet(h_filepath, h_snippet)
update_objc_snippet(m_filepath, m_snippet)
swift_body += """
"""
initializer_invocation = " return %s(" % str(
deserialize_class.name
)
swift_body += initializer_invocation
initializer_params = [
"grdbId: recordId",
] + initializer_params
swift_body += (",\n" + " " * len(initializer_invocation)).join(
initializer_params
)
swift_body += ")"
swift_body += """
"""
swift_body += """ default:
owsFailDebug("Unexpected record type: \\(recordType)")
throw SDSError.invalidValue()
"""
swift_body += """ }
"""
swift_body += """ }
"""
swift_body += """}
"""
if not has_sds_superclass:
swift_body += """
// MARK: - SDSModel
extension %s: SDSModel {
public var serializer: SDSSerializer {
// Any subclass can be cast to it's superclass,
// so the order of this switch statement matters.
// We need to do a "depth first" search by type.
switch self {""" % str(
clazz.name
)
for subclass in reversed(all_descendents_of_class(clazz)):
if should_ignore_class(subclass):
continue
swift_body += """
case let model as %s:
assert(type(of: model) == %s.self)
return %sSerializer(model: model)""" % (
str(subclass.name),
str(subclass.name),
str(subclass.name),
)
swift_body += """
default:
return %sSerializer(model: self)
}
}
public func asRecord() -> SDSRecord {
serializer.asRecord()
}
public var sdsTableName: String {
%s.databaseTableName
}
public static var table: SDSTableMetadata {
%sSerializer.table
}
}
""" % (
str(clazz.name),
record_name,
str(clazz.name),
)
if not has_sds_superclass:
swift_body += """
// MARK: - DeepCopyable
extension %(class_name)s: DeepCopyable {
public func deepCopy() throws -> AnyObject {
guard let id = self.grdbId?.int64Value else {
throw OWSAssertionError("Model missing grdbId.")
}
// Any subclass can be cast to its superclass, so the order of these if
// statements matters. We need to do a "depth first" search by type.
""" % {
"class_name": str(clazz.name)
}
classes_to_copy = list(reversed(all_descendents_of_class(clazz))) + [
clazz,
]
for class_to_copy in classes_to_copy:
if should_ignore_class(class_to_copy):
continue
if class_to_copy == clazz:
swift_body += """
do {
let modelToCopy = self
assert(type(of: modelToCopy) == %(class_name)s.self)
""" % {
"class_name": str(class_to_copy.name)
}
else:
swift_body += """
if let modelToCopy = self as? %(class_name)s {
assert(type(of: modelToCopy) == %(class_name)s.self)
""" % {
"class_name": str(class_to_copy.name)
}
initializer_params = []
base_property_names = set()
for property in base_properties:
base_property_names.add(property.name)
deserialize_properties = properties_and_inherited_properties(class_to_copy)
for property in deserialize_properties:
value_name = "%s" % property.name
did_force_optional = property.name not in base_property_names
did_force_optional = did_force_optional and not property.is_optional
did_force_optional = did_force_optional or property.type_info().is_enum
for statement in property.deep_copy_record_invocation(
value_name, did_force_optional
):
swift_body += " %s\n" % (str(statement),)
initializer_params.append(
"%s: %s"
% (
str(property.name),
value_name,
)
)
swift_body += """
"""
initializer_invocation = " return %s(" % str(class_to_copy.name)
swift_body += initializer_invocation
initializer_params = [
"grdbId: id",
] + initializer_params
swift_body += (",\n" + " " * len(initializer_invocation)).join(
initializer_params
)
swift_body += ")"
swift_body += """
}
"""
swift_body += """
}
}
"""
if has_grdb_serializer:
swift_body += """
// MARK: - Table Metadata
extension %sRecord {
// This defines all of the columns used in the table
// where this model (and any subclasses) are persisted.
internal func asValues() -> [DatabaseValueConvertible?] {
return [
""" % str(
remove_prefix_from_class_name(clazz.name)
)
def write_grdb_column_metadata(metadata):
return """ %s,
""" % (
str(metadata)
)
for property in sds_properties:
column_name = property.column_name()
if column_name == "recordType" or property.type_info().is_enum:
swift_body += write_grdb_column_metadata("%s?.rawValue" % (column_name))
elif property.name != "id":
swift_body += write_grdb_column_metadata(column_name)
for property in record_properties:
column_name = property.column_name()
if property.type_info().is_enum:
swift_body += write_grdb_column_metadata("%s?.rawValue" % (column_name))
else:
swift_body += write_grdb_column_metadata(column_name)
swift_body += """
]
}
internal func asArguments() -> StatementArguments {
return StatementArguments(asValues())
}
}
"""
if not has_sds_superclass:
swift_body += """
// MARK: - Table Metadata
extension %sSerializer {
// This defines all of the columns used in the table
// where this model (and any subclasses) are persisted.
""" % str(
clazz.name
)
column_property_names = []
def write_column_metadata(property, force_optional=False):
column_name = property.swift_identifier()
column_property_names.append(column_name)
is_optional = property.is_optional or force_optional
optional_split = ", isOptional: true" if is_optional else ""
is_unique = column_name == str("uniqueId")
is_unique_split = ", isUnique: true" if is_unique else ""
database_column_type = property.database_column_type()
if property.name == "id":
database_column_type = ".primaryKey"
return """ static var %sColumn: SDSColumnMetadata { SDSColumnMetadata(columnName: "%s", columnType: %s%s%s) }
""" % (
str(column_name),
str(column_name),
database_column_type,
optional_split,
is_unique_split,
)
for property in sds_properties:
swift_body += write_column_metadata(property)
if len(record_properties) > 0:
swift_body += " // Properties \n"
for property in record_properties:
swift_body += write_column_metadata(
property, force_optional=property.force_optional
)
database_table_name = "model_%s" % str(clazz.name)
swift_body += """
public static var table: SDSTableMetadata {
SDSTableMetadata(
tableName: "%s",
columns: [
""" % (
database_table_name,
)
swift_body += "\n".join(
[
" %sColumn," % str(column_property_name)
for column_property_name in column_property_names
]
)
swift_body += """
]
)
}
}
"""
cached_method = "anyFetch"
uncached_method = "anyFetch"
if cache_get_code_for_class(clazz) is not None:
cached_method = "fetchViaCache"
swift_body += """
// MARK: - Save/Remove/Update
@objc
public extension %(class_name)s {
func anyInsert(transaction: DBWriteTransaction) {
sdsSave(saveMode: .insert, transaction: transaction)
}
// Avoid this method whenever feasible.
//
// If the record has previously been saved, this method does an overwriting
// update of the corresponding row, otherwise if it's a new record, this
// method inserts a new row.
//
// For performance, when possible, you should explicitly specify whether
// you are inserting or updating rather than calling this method.
func anyUpsert(transaction: DBWriteTransaction) {
let isInserting: Bool
if %(class_name)s.%(cached_method)s(uniqueId: uniqueId, transaction: transaction) != nil {
isInserting = false
} else {
isInserting = true
}
sdsSave(saveMode: isInserting ? .insert : .update, transaction: transaction)
}
// This method is used by "updateWith..." methods.
//
// This model may be updated from many threads. We don't want to save
// our local copy (this instance) since it may be out of date. We also
// want to avoid re-saving a model that has been deleted. Therefore, we
// use "updateWith..." methods to:
//
// a) Update a property of this instance.
// b) If a copy of this model exists in the database, load an up-to-date copy,
// and update and save that copy.
// b) If a copy of this model _DOES NOT_ exist in the database, do _NOT_ save
// this local instance.
//
// After "updateWith...":
//
// a) Any copy of this model in the database will have been updated.
// b) The local property on this instance will always have been updated.
// c) Other properties on this instance may be out of date.
//
// All mutable properties of this class have been made read-only to
// prevent accidentally modifying them directly.
//
// This isn't a perfect arrangement, but in practice this will prevent
// data loss and will resolve all known issues.
func anyUpdate(transaction: DBWriteTransaction, block: (%(class_name)s) -> Void) {
block(self)
// If it's not saved, we don't expect to find it in the database, and we
// won't save any changes we make back into the database.
guard shouldBeSaved else {
return
}
guard let dbCopy = type(of: self).%(uncached_method)s(uniqueId: uniqueId, transaction: transaction) else {
return
}
// Don't apply the block twice to the same instance.
// It's at least unnecessary and actually wrong for some blocks.
// e.g. `block: { $0 in $0.someField++ }`
if dbCopy !== self {
block(dbCopy)
}
dbCopy.sdsSave(saveMode: .update, transaction: transaction)
}
// This method is an alternative to `anyUpdate(transaction:block:)` methods.
//
// We should generally use `anyUpdate` to ensure we're not unintentionally
// clobbering other columns in the database when another concurrent update
// has occurred.
//
// There are cases when this doesn't make sense, e.g. when we know we've
// just loaded the model in the same transaction. In those cases it is
// safe and faster to do a "overwriting" update
func anyOverwritingUpdate(transaction: DBWriteTransaction) {
sdsSave(saveMode: .update, transaction: transaction)
}
""" % {
"class_name": str(clazz.name),
"cached_method": cached_method,
"uncached_method": uncached_method,
}
if has_remove_methods:
swift_body += """
func anyRemove(transaction: DBWriteTransaction) {
sdsRemove(transaction: transaction)
}
"""
swift_body += """}
"""
swift_body += """
// MARK: - %sCursor
@objc
public class %sCursor: NSObject, SDSCursor {
private let transaction: DBReadTransaction
private let cursor: RecordCursor<%s>
init(transaction: DBReadTransaction, cursor: RecordCursor<%s>) {
self.transaction = transaction
self.cursor = cursor
}
public func next() throws -> %s? {
guard let record = try cursor.next() else {
return nil
}""" % (
str(clazz.name),
str(clazz.name),
record_name,
record_name,
str(clazz.name),
)
cache_code = cache_set_code_for_class(clazz)
if cache_code is not None:
swift_body += """
let value = try %s.fromRecord(record)
%s(value, transaction: transaction)
return value""" % (
str(clazz.name),
cache_code,
)
else:
swift_body += """
return try %s.fromRecord(record)""" % (
str(clazz.name),
)
swift_body += """
}
public func all() throws -> [%s] {
var result = [%s]()
while true {
guard let model = try next() else {
break
}
result.append(model)
}
return result
}
}
""" % (
str(clazz.name),
str(clazz.name),
)
swift_body += """
// MARK: - Obj-C Fetch
@objc
public extension %(class_name)s {
@nonobjc
class func grdbFetchCursor(transaction: DBReadTransaction) -> %(class_name)sCursor {
let database = transaction.database
return failIfThrows {
let cursor = try %(record_name)s.fetchCursor(database)
return %(class_name)sCursor(transaction: transaction, cursor: cursor)
}
}
""" % {
"class_name": str(clazz.name),
"record_name": record_name,
}
cache_code = cache_get_code_for_class(clazz)
assert cache_code is not None
swift_body += """
// Fetches a single model by "unique id".
class func fetchViaCache(uniqueId: String, transaction: DBReadTransaction) -> %(class_name)s? {
assert(!uniqueId.isEmpty)
if let cachedCopy = %(cache_code)s {
return cachedCopy
}
return anyFetch(uniqueId: uniqueId, transaction: transaction)
}
""" % {
"class_name": str(clazz.name),
"cache_code": str(cache_code),
}
swift_body += """
// Fetches a single model by "unique id".
class func anyFetch(uniqueId: String, transaction: DBReadTransaction) -> %(class_name)s? {
assert(!uniqueId.isEmpty)
""" % {
"class_name": str(clazz.name),
}
swift_body += """
let sql = "SELECT * FROM \\(%(record_name)s.databaseTableName) WHERE \\(%(record_identifier)sColumn: .uniqueId) = ?"
return grdbFetchOne(sql: sql, arguments: [uniqueId], transaction: transaction)
}
""" % {
"record_name": record_name,
"record_identifier": record_identifier(clazz.name),
}
swift_body += """
// Traverses all records.
// Records are not visited in any particular order.
class func anyEnumerate(
transaction: DBReadTransaction,
block: (%s) -> Void,
) {
let cursor = %s.grdbFetchCursor(transaction: transaction)
do {
while let value = try cursor.next() {
block(value)
}
} catch let error {
owsFailDebug("Couldn't fetch model: \\(error)")
}
}
""" % (
(str(clazz.name),) * 2
)
swift_body += """
// Does not order the results.
class func anyFetchAll(transaction: DBReadTransaction) -> [%s] {
var result = [%s]()
anyEnumerate(transaction: transaction) { model in
result.append(model)
}
return result
}
""" % (
(str(clazz.name),) * 2
)
swift_body += """
class func anyCount(transaction: DBReadTransaction) -> UInt {
return %s.ows_fetchCount(transaction.database)
}
}
""" % (
record_name,
)
swift_body += """
// MARK: - Swift Fetch
public extension %(class_name)s {
class func grdbFetchCursor(sql: String,
arguments: StatementArguments = StatementArguments(),
transaction: DBReadTransaction) -> %(class_name)sCursor {
return failIfThrows {
let sqlRequest = SQLRequest<Void>(sql: sql, arguments: arguments, cached: true)
let cursor = try %(record_name)s.fetchCursor(transaction.database, sqlRequest)
return %(class_name)sCursor(transaction: transaction, cursor: cursor)
}
}
""" % {
"class_name": str(clazz.name),
"record_name": record_name,
}
string_interpolation_name = remove_prefix_from_class_name(clazz.name)
swift_body += """
class func grdbFetchOne(sql: String,
arguments: StatementArguments = StatementArguments(),
transaction: DBReadTransaction) -> %s? {
assert(!sql.isEmpty)
do {
let sqlRequest = SQLRequest<Void>(sql: sql, arguments: arguments, cached: true)
guard let record = try %s.fetchOne(transaction.database, sqlRequest) else {
return nil
}
""" % (
str(clazz.name),
record_name,
)
cache_code = cache_set_code_for_class(clazz)
if cache_code is not None:
swift_body += """
let value = try %s.fromRecord(record)
%s(value, transaction: transaction)
return value""" % (
str(clazz.name),
cache_code,
)
else:
swift_body += """
return try %s.fromRecord(record)""" % (
str(clazz.name),
)
swift_body += """
} catch {
owsFailDebug("error: \\(error)")
return nil
}
}
}
"""
if has_sds_superclass:
swift_body += """
// MARK: - Typed Convenience Methods
@objc
public extension %s {
// NOTE: This method will fail if the object has unexpected type.
class func fetch%sViaCache(
uniqueId: String,
transaction: DBReadTransaction
) -> %s? {
assert(!uniqueId.isEmpty)
guard let object = fetchViaCache(uniqueId: uniqueId, transaction: transaction) else {
return nil
}
guard let instance = object as? %s else {
owsFailDebug("Object has unexpected type: \\(type(of: object))")
return nil
}
return instance
}
// NOTE: This method will fail if the object has unexpected type.
func anyUpdate%s(transaction: DBWriteTransaction, block: (%s) -> Void) {
anyUpdate(transaction: transaction) { (object) in
guard let instance = object as? %s else {
owsFailDebug("Object has unexpected type: \\(type(of: object))")
return
}
block(instance)
}
}
}
""" % (
str(clazz.name),
str(remove_prefix_from_class_name(clazz.name)),
str(clazz.name),
str(clazz.name),
str(remove_prefix_from_class_name(clazz.name)),
str(clazz.name),
str(clazz.name),
)
table_superclass = clazz.table_superclass()
table_class_name = str(table_superclass.name)
has_serializable_superclass = table_superclass.name != clazz.name
override_keyword = ""
swift_body += """
// MARK: - SDSSerializer
// The SDSSerializer protocol specifies how to insert and update the
// row that corresponds to this model.
class %sSerializer: SDSSerializer {
private let model: %s
public init(model: %s) {
self.model = model
}
""" % (
str(clazz.name),
str(clazz.name),
str(clazz.name),
)
root_class = clazz.table_superclass()
root_record_name = remove_prefix_from_class_name(root_class.name) + "Record"
record_id_source = "model.grdbId?.int64Value"
if root_class.record_id_source() is not None:
record_id_source = (
"model.%(source)s > 0 ? Int64(model.%(source)s) : %(default_source)s"
% {
"source": root_class.record_id_source(),
"default_source": record_id_source,
}
)
swift_body += """
// MARK: - Record
func asRecord() -> SDSRecord {
let id: Int64? = %(record_id_source)s
let recordType: SDSRecordType = .%(record_type)s
let uniqueId: String = model.uniqueId
""" % {
"record_type": get_record_type_enum_name(clazz.name),
"record_id_source": record_id_source,
}
initializer_args = [
"id",
"recordType",
"uniqueId",
]
inherited_property_map = {}
for property in properties_and_inherited_properties(clazz):
inherited_property_map[property.column_name()] = property
def write_record_property(property, force_optional=False):
optional_value = ""
if property.column_name() in inherited_property_map:
inherited_property = inherited_property_map[property.column_name()]
did_force_optional = property.force_optional
model_accessor = accessor_name_for_property(inherited_property)
value_expr = inherited_property.serialize_record_invocation(
"model.%s" % (model_accessor,), did_force_optional
)
optional_value = " = %s" % (value_expr,)
else:
optional_value = " = nil"
record_field_type = property.record_field_type()
is_optional = property.is_optional or force_optional
optional_split = "?" if is_optional else ""
initializer_args.append(property.column_name())
return """ let %s: %s%s%s
""" % (
str(property.column_name()),
record_field_type,
optional_split,
optional_value,
)
root_record_properties = root_class.sorted_record_properties()
if len(root_record_properties) > 0:
swift_body += "\n // Properties \n"
for property in root_record_properties:
swift_body += write_record_property(
property, force_optional=property.force_optional
)
initializer_args = [
"%s: %s"
% (
arg,
arg,
)
for arg in initializer_args
]
swift_body += """
return %s(delegate: model, %s)
}
""" % (
root_record_name,
", ".join(initializer_args),
)
swift_body += """}
"""
print(f"Writing {swift_filename}")
swift_body = sds_common.clean_up_generated_swift(swift_body)
sds_common.write_text_file_if_changed(swift_filepath, swift_body)
def process_class_map(class_map):
for clazz in class_map.values():
generate_swift_extensions_for_model(clazz)
record_type_map = {}
def update_record_type_map(record_type_swift_path, record_type_json_path):
record_type_map_filepath = record_type_json_path
if os.path.exists(record_type_map_filepath):
with open(record_type_map_filepath, "rt") as f:
json_string = f.read()
json_data = json.loads(json_string)
record_type_map.update(json_data)
max_record_type = 0
for class_name in record_type_map:
if class_name.startswith("#"):
continue
record_type = record_type_map[class_name]
max_record_type = max(max_record_type, record_type)
for clazz in global_class_map.values():
if clazz.name not in record_type_map:
if not clazz.should_generate_extensions():
continue
max_record_type = int(max_record_type) + 1
record_type = max_record_type
record_type_map[clazz.name] = record_type
record_type_map["#comment"] = (
"NOTE: This file is generated by %s. Do not manually edit it, instead run `sds_codegen.sh`."
% (sds_common.pretty_module_path(__file__),)
)
json_string = json.dumps(record_type_map, sort_keys=True, indent=4)
sds_common.write_text_file_if_changed(record_type_map_filepath, json_string)
swift_body = """//
// Copyright 2022 Signal Messenger, LLC
// SPDX-License-Identifier: AGPL-3.0-only
//
import Foundation
import GRDB
// NOTE: This file is generated by %s.
// Do not manually edit it, instead run `sds_codegen.sh`.
@objc
public enum SDSRecordType: UInt, CaseIterable {
""" % (
sds_common.pretty_module_path(__file__),
)
record_type_pairs = []
for key in record_type_map.keys():
if key.startswith("#"):
continue
enum_name = get_record_type_enum_name(key)
record_type_pairs.append((str(enum_name), record_type_map[key]))
record_type_pairs.sort(key=lambda value: value[1])
for enum_name, record_type_id in record_type_pairs:
swift_body += """ case %s = %s
""" % (
enum_name,
str(record_type_id),
)
swift_body += """}
"""
swift_body = sds_common.clean_up_generated_swift(swift_body)
sds_common.write_text_file_if_changed(record_type_swift_path, swift_body)
def get_record_type(clazz):
return record_type_map[clazz.name]
def remove_prefix_from_class_name(class_name):
name = class_name
if name.startswith("TS"):
name = name[len("TS") :]
elif name.startswith("OWS"):
name = name[len("OWS") :]
elif name.startswith("SSK"):
name = name[len("SSK") :]
return name
def get_record_type_enum_name(class_name):
name = remove_prefix_from_class_name(class_name)
if name[0].isnumeric():
name = "_" + name
return to_swift_identifier_name(name)
def record_identifier(class_name):
name = remove_prefix_from_class_name(class_name)
return to_swift_identifier_name(name)
column_ordering_map = {}
has_loaded_column_ordering_map = False
enum_type_map = {}
def objc_type_for_enum(enum_name):
if enum_name not in enum_type_map:
print("enum_type_map", enum_type_map)
fail("Enum has unknown type:", enum_name)
enum_type = enum_type_map[enum_name]
return enum_type
def swift_type_for_enum(enum_name):
objc_type = objc_type_for_enum(enum_name)
if objc_type == "NSInteger":
return "Int"
elif objc_type == "NSUInteger":
return "UInt"
elif objc_type == "int32_t":
return "Int32"
elif objc_type == "unsigned long long":
return "uint64_t"
elif objc_type == "unsigned long long":
return "UInt64"
elif objc_type == "unsigned long":
return "UInt64"
elif objc_type == "unsigned int":
return "UInt"
else:
fail("Unknown objc type:", objc_type)
def parse_sds_json(file_path):
with open(file_path, "rt") as f:
json_str = f.read()
json_data = json.loads(json_str)
classes = json_data["classes"]
class_map = {}
for class_dict in classes:
clazz = ParsedClass(class_dict)
class_map[clazz.name] = clazz
enums = json_data["enums"]
enum_type_map.update(enums)
return class_map
def try_to_parse_file(file_path):
filename = os.path.basename(file_path)
_, file_extension = os.path.splitext(filename)
if filename.endswith(sds_common.SDS_JSON_FILE_EXTENSION):
return parse_sds_json(file_path)
else:
return {}
def find_sds_intermediary_files_in_path(path):
class_map = {}
if os.path.isfile(path):
class_map.update(try_to_parse_file(path))
else:
for rootdir, dirnames, filenames in os.walk(path):
for filename in filenames:
file_path = os.path.abspath(os.path.join(rootdir, filename))
class_map.update(try_to_parse_file(file_path))
return class_map
def update_subclass_map():
for clazz in global_class_map.values():
if clazz.super_class_name is not None:
subclasses = global_subclass_map.get(clazz.super_class_name, [])
subclasses.append(clazz)
global_subclass_map[clazz.super_class_name] = subclasses
def all_descendents_of_class(clazz):
result = []
subclasses = global_subclass_map.get(clazz.name, [])
subclasses.sort(key=lambda value: value.name)
for subclass in subclasses:
result.append(subclass)
result.extend(all_descendents_of_class(subclass))
return result
def is_swift_class_name(swift_type):
return global_class_map.get(swift_type) is not None
configuration_json = {}
def parse_config_json(config_json_path):
with open(config_json_path, "rt") as f:
json_str = f.read()
json_data = json.loads(json_str)
global configuration_json
configuration_json = json_data
def swift_type_for_nsnumber(property):
nsnumber_types = configuration_json.get("nsnumber_types")
if nsnumber_types is None:
print("Suggestion: update: %s" % (str(global_args.config_json_path),))
fail("Configuration JSON is missing mapping for properties of type NSNumber.")
key = property.class_name + "." + property.name
swift_type = nsnumber_types.get(key)
if swift_type is None:
print("Suggestion: update: %s" % (str(global_args.config_json_path),))
fail(
"Configuration JSON is missing mapping for properties of type NSNumber:",
key,
)
return swift_type
def should_ignore_property(property):
properties_to_ignore = configuration_json.get("properties_to_ignore")
if properties_to_ignore is None:
fail(
"Configuration JSON is missing list of properties to ignore during serialization."
)
key = property.class_name + "." + property.name
return key in properties_to_ignore
def cache_get_code_for_class(clazz):
code_map = configuration_json.get("class_cache_get_code")
if code_map is None:
fail("Configuration JSON is missing dict of class_cache_get_code.")
key = clazz.name
return code_map.get(key)
def cache_set_code_for_class(clazz):
code_map = configuration_json.get("class_cache_set_code")
if code_map is None:
fail("Configuration JSON is missing dict of class_cache_set_code.")
key = clazz.name
return code_map.get(key)
def should_ignore_class(clazz):
class_to_skip_serialization = configuration_json.get("class_to_skip_serialization")
if class_to_skip_serialization is None:
fail(
"Configuration JSON is missing list of classes to ignore during serialization."
)
if clazz.name in class_to_skip_serialization:
return True
if clazz.super_class_name is None:
return False
if not clazz.super_class_name in global_class_map:
return False
super_clazz = global_class_map[clazz.super_class_name]
return should_ignore_class(super_clazz)
def accessor_name_for_property(property):
custom_accessors = configuration_json.get("custom_accessors")
if custom_accessors is None:
fail("Configuration JSON is missing list of custom property accessors.")
key = property.class_name + "." + property.name
return custom_accessors.get(key, property.name)
def custom_column_name_for_property(property):
custom_column_names = configuration_json.get("custom_column_names")
if custom_column_names is None:
fail("Configuration JSON is missing list of custom column names.")
key = property.class_name + "." + property.name
return custom_column_names.get(key)
def aliased_column_name_for_property(property):
custom_column_names = configuration_json.get("aliased_column_names")
if custom_column_names is None:
fail("Configuration JSON is missing dict of aliased_column_names.")
key = property.class_name + "." + property.name
return custom_column_names.get(key)
def was_property_renamed_for_property(property):
renamed_column_names = configuration_json.get("renamed_column_names")
if renamed_column_names is None:
fail("Configuration JSON is missing list of renamed column names.")
key = property.class_name + "." + property.name
return renamed_column_names.get(key) is not None
property_order_json = {}
def parse_property_order_json(property_order_json_path):
with open(property_order_json_path, "rt") as f:
json_str = f.read()
json_data = json.loads(json_str)
global property_order_json
property_order_json = json_data
def update_property_order_json(property_order_json_path):
property_order_json["#comment"] = (
"NOTE: This file is generated by %s. Do not manually edit it, instead run `sds_codegen.sh`."
% (sds_common.pretty_module_path(__file__),)
)
json_string = json.dumps(property_order_json, sort_keys=True, indent=4)
sds_common.write_text_file_if_changed(property_order_json_path, json_string)
def property_order_key(property, record_name):
return record_name + "." + property.name
def property_order_for_property(property, record_name):
key = property_order_key(property, record_name)
result = property_order_json.get(key)
return result
def set_property_order_for_property(property, record_name, value):
key = property_order_key(property, record_name)
property_order_json[key] = value
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Generate Swift extensions.")
parser.add_argument(
"--src-path", required=True, help="used to specify a path to process."
)
parser.add_argument(
"--search-path", required=True, help="used to specify a path to process."
)
parser.add_argument(
"--record-type-swift-path",
required=True,
help="path of the record type enum swift file.",
)
parser.add_argument(
"--record-type-json-path",
required=True,
help="path of the record type map json file.",
)
parser.add_argument(
"--config-json-path",
required=True,
help="path of the json file with code generation config info.",
)
parser.add_argument(
"--property-order-json-path",
required=True,
help="path of the json file with property ordering cache.",
)
args = parser.parse_args()
global_args = args
src_path = os.path.abspath(args.src_path)
search_path = os.path.abspath(args.search_path)
record_type_swift_path = os.path.abspath(args.record_type_swift_path)
record_type_json_path = os.path.abspath(args.record_type_json_path)
config_json_path = os.path.abspath(args.config_json_path)
property_order_json_path = os.path.abspath(args.property_order_json_path)
parse_config_json(config_json_path)
parse_property_order_json(property_order_json_path)
global_class_map.update(find_sds_intermediary_files_in_path(search_path))
update_subclass_map()
update_record_type_map(record_type_swift_path, record_type_json_path)
process_class_map(find_sds_intermediary_files_in_path(src_path))
update_property_order_json(property_order_json_path)