Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
wiseplat
GitHub Repository: wiseplat/python-code
Path: blob/master/ invest-robot-contest_TinkoffBotTwitch-main/venv/lib/python3.8/site-packages/pandas/io/stata.py
7813 views
1
"""
2
Module contains tools for processing Stata files into DataFrames
3
4
The StataReader below was originally written by Joe Presbrey as part of PyDTA.
5
It has been extended and improved by Skipper Seabold from the Statsmodels
6
project who also developed the StataWriter and was finally added to pandas in
7
a once again improved version.
8
9
You can find more information on http://presbrey.mit.edu/PyDTA and
10
https://www.statsmodels.org/devel/
11
"""
12
from __future__ import annotations
13
14
from collections import abc
15
import datetime
16
from io import BytesIO
17
import os
18
import struct
19
import sys
20
from typing import (
21
IO,
22
TYPE_CHECKING,
23
Any,
24
AnyStr,
25
Hashable,
26
Sequence,
27
cast,
28
)
29
import warnings
30
31
from dateutil.relativedelta import relativedelta
32
import numpy as np
33
34
from pandas._libs.lib import infer_dtype
35
from pandas._libs.writers import max_len_string_array
36
from pandas._typing import (
37
CompressionOptions,
38
FilePath,
39
ReadBuffer,
40
StorageOptions,
41
WriteBuffer,
42
)
43
from pandas.util._decorators import (
44
Appender,
45
doc,
46
)
47
48
from pandas.core.dtypes.common import (
49
ensure_object,
50
is_categorical_dtype,
51
is_datetime64_dtype,
52
is_numeric_dtype,
53
)
54
55
from pandas import (
56
Categorical,
57
DatetimeIndex,
58
NaT,
59
Timestamp,
60
isna,
61
to_datetime,
62
to_timedelta,
63
)
64
from pandas.core.arrays.boolean import BooleanDtype
65
from pandas.core.arrays.integer import _IntegerDtype
66
from pandas.core.frame import DataFrame
67
from pandas.core.indexes.base import Index
68
from pandas.core.series import Series
69
from pandas.core.shared_docs import _shared_docs
70
71
from pandas.io.common import get_handle
72
73
if TYPE_CHECKING:
74
from typing import Literal
75
76
_version_error = (
77
"Version of given Stata file is {version}. pandas supports importing "
78
"versions 105, 108, 111 (Stata 7SE), 113 (Stata 8/9), "
79
"114 (Stata 10/11), 115 (Stata 12), 117 (Stata 13), 118 (Stata 14/15/16),"
80
"and 119 (Stata 15/16, over 32,767 variables)."
81
)
82
83
_statafile_processing_params1 = """\
84
convert_dates : bool, default True
85
Convert date variables to DataFrame time values.
86
convert_categoricals : bool, default True
87
Read value labels and convert columns to Categorical/Factor variables."""
88
89
_statafile_processing_params2 = """\
90
index_col : str, optional
91
Column to set as index.
92
convert_missing : bool, default False
93
Flag indicating whether to convert missing values to their Stata
94
representations. If False, missing values are replaced with nan.
95
If True, columns containing missing values are returned with
96
object data types and missing values are represented by
97
StataMissingValue objects.
98
preserve_dtypes : bool, default True
99
Preserve Stata datatypes. If False, numeric data are upcast to pandas
100
default types for foreign data (float64 or int64).
101
columns : list or None
102
Columns to retain. Columns will be returned in the given order. None
103
returns all columns.
104
order_categoricals : bool, default True
105
Flag indicating whether converted categorical data are ordered."""
106
107
_chunksize_params = """\
108
chunksize : int, default None
109
Return StataReader object for iterations, returns chunks with
110
given number of lines."""
111
112
_iterator_params = """\
113
iterator : bool, default False
114
Return StataReader object."""
115
116
_reader_notes = """\
117
Notes
118
-----
119
Categorical variables read through an iterator may not have the same
120
categories and dtype. This occurs when a variable stored in a DTA
121
file is associated to an incomplete set of value labels that only
122
label a strict subset of the values."""
123
124
_read_stata_doc = f"""
125
Read Stata file into DataFrame.
126
127
Parameters
128
----------
129
filepath_or_buffer : str, path object or file-like object
130
Any valid string path is acceptable. The string could be a URL. Valid
131
URL schemes include http, ftp, s3, and file. For file URLs, a host is
132
expected. A local file could be: ``file://localhost/path/to/table.dta``.
133
134
If you want to pass in a path object, pandas accepts any ``os.PathLike``.
135
136
By file-like object, we refer to objects with a ``read()`` method,
137
such as a file handle (e.g. via builtin ``open`` function)
138
or ``StringIO``.
139
{_statafile_processing_params1}
140
{_statafile_processing_params2}
141
{_chunksize_params}
142
{_iterator_params}
143
{_shared_docs["decompression_options"]}
144
{_shared_docs["storage_options"]}
145
146
Returns
147
-------
148
DataFrame or StataReader
149
150
See Also
151
--------
152
io.stata.StataReader : Low-level reader for Stata data files.
153
DataFrame.to_stata: Export Stata data files.
154
155
{_reader_notes}
156
157
Examples
158
--------
159
160
Creating a dummy stata for this example
161
>>> df = pd.DataFrame({{'animal': ['falcon', 'parrot', 'falcon',
162
... 'parrot'],
163
... 'speed': [350, 18, 361, 15]}}) # doctest: +SKIP
164
>>> df.to_stata('animals.dta') # doctest: +SKIP
165
166
Read a Stata dta file:
167
168
>>> df = pd.read_stata('animals.dta') # doctest: +SKIP
169
170
Read a Stata dta file in 10,000 line chunks:
171
>>> values = np.random.randint(0, 10, size=(20_000, 1), dtype="uint8") # doctest: +SKIP
172
>>> df = pd.DataFrame(values, columns=["i"]) # doctest: +SKIP
173
>>> df.to_stata('filename.dta') # doctest: +SKIP
174
175
>>> itr = pd.read_stata('filename.dta', chunksize=10000) # doctest: +SKIP
176
>>> for chunk in itr:
177
... # Operate on a single chunk, e.g., chunk.mean()
178
... pass # doctest: +SKIP
179
"""
180
181
_read_method_doc = f"""\
182
Reads observations from Stata file, converting them into a dataframe
183
184
Parameters
185
----------
186
nrows : int
187
Number of lines to read from data file, if None read whole file.
188
{_statafile_processing_params1}
189
{_statafile_processing_params2}
190
191
Returns
192
-------
193
DataFrame
194
"""
195
196
_stata_reader_doc = f"""\
197
Class for reading Stata dta files.
198
199
Parameters
200
----------
201
path_or_buf : path (string), buffer or path object
202
string, path object (pathlib.Path or py._path.local.LocalPath) or object
203
implementing a binary read() functions.
204
{_statafile_processing_params1}
205
{_statafile_processing_params2}
206
{_chunksize_params}
207
{_shared_docs["decompression_options"]}
208
{_shared_docs["storage_options"]}
209
210
{_reader_notes}
211
"""
212
213
214
_date_formats = ["%tc", "%tC", "%td", "%d", "%tw", "%tm", "%tq", "%th", "%ty"]
215
216
217
stata_epoch = datetime.datetime(1960, 1, 1)
218
219
220
# TODO: Add typing. As of January 2020 it is not possible to type this function since
221
# mypy doesn't understand that a Series and an int can be combined using mathematical
222
# operations. (+, -).
223
def _stata_elapsed_date_to_datetime_vec(dates, fmt) -> Series:
224
"""
225
Convert from SIF to datetime. https://www.stata.com/help.cgi?datetime
226
227
Parameters
228
----------
229
dates : Series
230
The Stata Internal Format date to convert to datetime according to fmt
231
fmt : str
232
The format to convert to. Can be, tc, td, tw, tm, tq, th, ty
233
Returns
234
235
Returns
236
-------
237
converted : Series
238
The converted dates
239
240
Examples
241
--------
242
>>> dates = pd.Series([52])
243
>>> _stata_elapsed_date_to_datetime_vec(dates , "%tw")
244
0 1961-01-01
245
dtype: datetime64[ns]
246
247
Notes
248
-----
249
datetime/c - tc
250
milliseconds since 01jan1960 00:00:00.000, assuming 86,400 s/day
251
datetime/C - tC - NOT IMPLEMENTED
252
milliseconds since 01jan1960 00:00:00.000, adjusted for leap seconds
253
date - td
254
days since 01jan1960 (01jan1960 = 0)
255
weekly date - tw
256
weeks since 1960w1
257
This assumes 52 weeks in a year, then adds 7 * remainder of the weeks.
258
The datetime value is the start of the week in terms of days in the
259
year, not ISO calendar weeks.
260
monthly date - tm
261
months since 1960m1
262
quarterly date - tq
263
quarters since 1960q1
264
half-yearly date - th
265
half-years since 1960h1 yearly
266
date - ty
267
years since 0000
268
"""
269
MIN_YEAR, MAX_YEAR = Timestamp.min.year, Timestamp.max.year
270
MAX_DAY_DELTA = (Timestamp.max - datetime.datetime(1960, 1, 1)).days
271
MIN_DAY_DELTA = (Timestamp.min - datetime.datetime(1960, 1, 1)).days
272
MIN_MS_DELTA = MIN_DAY_DELTA * 24 * 3600 * 1000
273
MAX_MS_DELTA = MAX_DAY_DELTA * 24 * 3600 * 1000
274
275
def convert_year_month_safe(year, month) -> Series:
276
"""
277
Convert year and month to datetimes, using pandas vectorized versions
278
when the date range falls within the range supported by pandas.
279
Otherwise it falls back to a slower but more robust method
280
using datetime.
281
"""
282
if year.max() < MAX_YEAR and year.min() > MIN_YEAR:
283
return to_datetime(100 * year + month, format="%Y%m")
284
else:
285
index = getattr(year, "index", None)
286
return Series(
287
[datetime.datetime(y, m, 1) for y, m in zip(year, month)], index=index
288
)
289
290
def convert_year_days_safe(year, days) -> Series:
291
"""
292
Converts year (e.g. 1999) and days since the start of the year to a
293
datetime or datetime64 Series
294
"""
295
if year.max() < (MAX_YEAR - 1) and year.min() > MIN_YEAR:
296
return to_datetime(year, format="%Y") + to_timedelta(days, unit="d")
297
else:
298
index = getattr(year, "index", None)
299
value = [
300
datetime.datetime(y, 1, 1) + relativedelta(days=int(d))
301
for y, d in zip(year, days)
302
]
303
return Series(value, index=index)
304
305
def convert_delta_safe(base, deltas, unit) -> Series:
306
"""
307
Convert base dates and deltas to datetimes, using pandas vectorized
308
versions if the deltas satisfy restrictions required to be expressed
309
as dates in pandas.
310
"""
311
index = getattr(deltas, "index", None)
312
if unit == "d":
313
if deltas.max() > MAX_DAY_DELTA or deltas.min() < MIN_DAY_DELTA:
314
values = [base + relativedelta(days=int(d)) for d in deltas]
315
return Series(values, index=index)
316
elif unit == "ms":
317
if deltas.max() > MAX_MS_DELTA or deltas.min() < MIN_MS_DELTA:
318
values = [
319
base + relativedelta(microseconds=(int(d) * 1000)) for d in deltas
320
]
321
return Series(values, index=index)
322
else:
323
raise ValueError("format not understood")
324
base = to_datetime(base)
325
deltas = to_timedelta(deltas, unit=unit)
326
return base + deltas
327
328
# TODO(non-nano): If/when pandas supports more than datetime64[ns], this
329
# should be improved to use correct range, e.g. datetime[Y] for yearly
330
bad_locs = np.isnan(dates)
331
has_bad_values = False
332
if bad_locs.any():
333
has_bad_values = True
334
data_col = Series(dates)
335
data_col[bad_locs] = 1.0 # Replace with NaT
336
dates = dates.astype(np.int64)
337
338
if fmt.startswith(("%tc", "tc")): # Delta ms relative to base
339
base = stata_epoch
340
ms = dates
341
conv_dates = convert_delta_safe(base, ms, "ms")
342
elif fmt.startswith(("%tC", "tC")):
343
344
warnings.warn("Encountered %tC format. Leaving in Stata Internal Format.")
345
conv_dates = Series(dates, dtype=object)
346
if has_bad_values:
347
conv_dates[bad_locs] = NaT
348
return conv_dates
349
# Delta days relative to base
350
elif fmt.startswith(("%td", "td", "%d", "d")):
351
base = stata_epoch
352
days = dates
353
conv_dates = convert_delta_safe(base, days, "d")
354
# does not count leap days - 7 days is a week.
355
# 52nd week may have more than 7 days
356
elif fmt.startswith(("%tw", "tw")):
357
year = stata_epoch.year + dates // 52
358
days = (dates % 52) * 7
359
conv_dates = convert_year_days_safe(year, days)
360
elif fmt.startswith(("%tm", "tm")): # Delta months relative to base
361
year = stata_epoch.year + dates // 12
362
month = (dates % 12) + 1
363
conv_dates = convert_year_month_safe(year, month)
364
elif fmt.startswith(("%tq", "tq")): # Delta quarters relative to base
365
year = stata_epoch.year + dates // 4
366
quarter_month = (dates % 4) * 3 + 1
367
conv_dates = convert_year_month_safe(year, quarter_month)
368
elif fmt.startswith(("%th", "th")): # Delta half-years relative to base
369
year = stata_epoch.year + dates // 2
370
month = (dates % 2) * 6 + 1
371
conv_dates = convert_year_month_safe(year, month)
372
elif fmt.startswith(("%ty", "ty")): # Years -- not delta
373
year = dates
374
first_month = np.ones_like(dates)
375
conv_dates = convert_year_month_safe(year, first_month)
376
else:
377
raise ValueError(f"Date fmt {fmt} not understood")
378
379
if has_bad_values: # Restore NaT for bad values
380
conv_dates[bad_locs] = NaT
381
382
return conv_dates
383
384
385
def _datetime_to_stata_elapsed_vec(dates: Series, fmt: str) -> Series:
386
"""
387
Convert from datetime to SIF. https://www.stata.com/help.cgi?datetime
388
389
Parameters
390
----------
391
dates : Series
392
Series or array containing datetime.datetime or datetime64[ns] to
393
convert to the Stata Internal Format given by fmt
394
fmt : str
395
The format to convert to. Can be, tc, td, tw, tm, tq, th, ty
396
"""
397
index = dates.index
398
NS_PER_DAY = 24 * 3600 * 1000 * 1000 * 1000
399
US_PER_DAY = NS_PER_DAY / 1000
400
401
def parse_dates_safe(dates, delta=False, year=False, days=False):
402
d = {}
403
if is_datetime64_dtype(dates.dtype):
404
if delta:
405
time_delta = dates - stata_epoch
406
d["delta"] = time_delta._values.view(np.int64) // 1000 # microseconds
407
if days or year:
408
date_index = DatetimeIndex(dates)
409
d["year"] = date_index._data.year
410
d["month"] = date_index._data.month
411
if days:
412
days_in_ns = dates.view(np.int64) - to_datetime(
413
d["year"], format="%Y"
414
).view(np.int64)
415
d["days"] = days_in_ns // NS_PER_DAY
416
417
elif infer_dtype(dates, skipna=False) == "datetime":
418
if delta:
419
delta = dates._values - stata_epoch
420
421
def f(x: datetime.timedelta) -> float:
422
return US_PER_DAY * x.days + 1000000 * x.seconds + x.microseconds
423
424
v = np.vectorize(f)
425
d["delta"] = v(delta)
426
if year:
427
year_month = dates.apply(lambda x: 100 * x.year + x.month)
428
d["year"] = year_month._values // 100
429
d["month"] = year_month._values - d["year"] * 100
430
if days:
431
432
def g(x: datetime.datetime) -> int:
433
return (x - datetime.datetime(x.year, 1, 1)).days
434
435
v = np.vectorize(g)
436
d["days"] = v(dates)
437
else:
438
raise ValueError(
439
"Columns containing dates must contain either "
440
"datetime64, datetime.datetime or null values."
441
)
442
443
return DataFrame(d, index=index)
444
445
bad_loc = isna(dates)
446
index = dates.index
447
if bad_loc.any():
448
dates = Series(dates)
449
if is_datetime64_dtype(dates):
450
dates[bad_loc] = to_datetime(stata_epoch)
451
else:
452
dates[bad_loc] = stata_epoch
453
454
if fmt in ["%tc", "tc"]:
455
d = parse_dates_safe(dates, delta=True)
456
conv_dates = d.delta / 1000
457
elif fmt in ["%tC", "tC"]:
458
warnings.warn("Stata Internal Format tC not supported.")
459
conv_dates = dates
460
elif fmt in ["%td", "td"]:
461
d = parse_dates_safe(dates, delta=True)
462
conv_dates = d.delta // US_PER_DAY
463
elif fmt in ["%tw", "tw"]:
464
d = parse_dates_safe(dates, year=True, days=True)
465
conv_dates = 52 * (d.year - stata_epoch.year) + d.days // 7
466
elif fmt in ["%tm", "tm"]:
467
d = parse_dates_safe(dates, year=True)
468
conv_dates = 12 * (d.year - stata_epoch.year) + d.month - 1
469
elif fmt in ["%tq", "tq"]:
470
d = parse_dates_safe(dates, year=True)
471
conv_dates = 4 * (d.year - stata_epoch.year) + (d.month - 1) // 3
472
elif fmt in ["%th", "th"]:
473
d = parse_dates_safe(dates, year=True)
474
conv_dates = 2 * (d.year - stata_epoch.year) + (d.month > 6).astype(int)
475
elif fmt in ["%ty", "ty"]:
476
d = parse_dates_safe(dates, year=True)
477
conv_dates = d.year
478
else:
479
raise ValueError(f"Format {fmt} is not a known Stata date format")
480
481
conv_dates = Series(conv_dates, dtype=np.float64)
482
missing_value = struct.unpack("<d", b"\x00\x00\x00\x00\x00\x00\xe0\x7f")[0]
483
conv_dates[bad_loc] = missing_value
484
485
return Series(conv_dates, index=index)
486
487
488
excessive_string_length_error = """
489
Fixed width strings in Stata .dta files are limited to 244 (or fewer)
490
characters. Column '{0}' does not satisfy this restriction. Use the
491
'version=117' parameter to write the newer (Stata 13 and later) format.
492
"""
493
494
495
class PossiblePrecisionLoss(Warning):
496
pass
497
498
499
precision_loss_doc = """
500
Column converted from {0} to {1}, and some data are outside of the lossless
501
conversion range. This may result in a loss of precision in the saved data.
502
"""
503
504
505
class ValueLabelTypeMismatch(Warning):
506
pass
507
508
509
value_label_mismatch_doc = """
510
Stata value labels (pandas categories) must be strings. Column {0} contains
511
non-string labels which will be converted to strings. Please check that the
512
Stata data file created has not lost information due to duplicate labels.
513
"""
514
515
516
class InvalidColumnName(Warning):
517
pass
518
519
520
invalid_name_doc = """
521
Not all pandas column names were valid Stata variable names.
522
The following replacements have been made:
523
524
{0}
525
526
If this is not what you expect, please make sure you have Stata-compliant
527
column names in your DataFrame (strings only, max 32 characters, only
528
alphanumerics and underscores, no Stata reserved words)
529
"""
530
531
532
class CategoricalConversionWarning(Warning):
533
pass
534
535
536
categorical_conversion_warning = """
537
One or more series with value labels are not fully labeled. Reading this
538
dataset with an iterator results in categorical variable with different
539
categories. This occurs since it is not possible to know all possible values
540
until the entire dataset has been read. To avoid this warning, you can either
541
read dataset without an iterator, or manually convert categorical data by
542
``convert_categoricals`` to False and then accessing the variable labels
543
through the value_labels method of the reader.
544
"""
545
546
547
def _cast_to_stata_types(data: DataFrame) -> DataFrame:
548
"""
549
Checks the dtypes of the columns of a pandas DataFrame for
550
compatibility with the data types and ranges supported by Stata, and
551
converts if necessary.
552
553
Parameters
554
----------
555
data : DataFrame
556
The DataFrame to check and convert
557
558
Notes
559
-----
560
Numeric columns in Stata must be one of int8, int16, int32, float32 or
561
float64, with some additional value restrictions. int8 and int16 columns
562
are checked for violations of the value restrictions and upcast if needed.
563
int64 data is not usable in Stata, and so it is downcast to int32 whenever
564
the value are in the int32 range, and sidecast to float64 when larger than
565
this range. If the int64 values are outside of the range of those
566
perfectly representable as float64 values, a warning is raised.
567
568
bool columns are cast to int8. uint columns are converted to int of the
569
same size if there is no loss in precision, otherwise are upcast to a
570
larger type. uint64 is currently not supported since it is concerted to
571
object in a DataFrame.
572
"""
573
ws = ""
574
# original, if small, if large
575
conversion_data = (
576
(np.bool_, np.int8, np.int8),
577
(np.uint8, np.int8, np.int16),
578
(np.uint16, np.int16, np.int32),
579
(np.uint32, np.int32, np.int64),
580
(np.uint64, np.int64, np.float64),
581
)
582
583
float32_max = struct.unpack("<f", b"\xff\xff\xff\x7e")[0]
584
float64_max = struct.unpack("<d", b"\xff\xff\xff\xff\xff\xff\xdf\x7f")[0]
585
586
for col in data:
587
# Cast from unsupported types to supported types
588
is_nullable_int = isinstance(data[col].dtype, (_IntegerDtype, BooleanDtype))
589
orig = data[col]
590
# We need to find orig_missing before altering data below
591
orig_missing = orig.isna()
592
if is_nullable_int:
593
missing_loc = data[col].isna()
594
if missing_loc.any():
595
# Replace with always safe value
596
data.loc[missing_loc, col] = 0
597
# Replace with NumPy-compatible column
598
data[col] = data[col].astype(data[col].dtype.numpy_dtype)
599
dtype = data[col].dtype
600
for c_data in conversion_data:
601
if dtype == c_data[0]:
602
# Value of type variable "_IntType" of "iinfo" cannot be "object"
603
if data[col].max() <= np.iinfo(c_data[1]).max: # type: ignore[type-var]
604
dtype = c_data[1]
605
else:
606
dtype = c_data[2]
607
if c_data[2] == np.int64: # Warn if necessary
608
if data[col].max() >= 2**53:
609
ws = precision_loss_doc.format("uint64", "float64")
610
611
data[col] = data[col].astype(dtype)
612
613
# Check values and upcast if necessary
614
if dtype == np.int8:
615
if data[col].max() > 100 or data[col].min() < -127:
616
data[col] = data[col].astype(np.int16)
617
elif dtype == np.int16:
618
if data[col].max() > 32740 or data[col].min() < -32767:
619
data[col] = data[col].astype(np.int32)
620
elif dtype == np.int64:
621
if data[col].max() <= 2147483620 and data[col].min() >= -2147483647:
622
data[col] = data[col].astype(np.int32)
623
else:
624
data[col] = data[col].astype(np.float64)
625
if data[col].max() >= 2**53 or data[col].min() <= -(2**53):
626
ws = precision_loss_doc.format("int64", "float64")
627
elif dtype in (np.float32, np.float64):
628
value = data[col].max()
629
if np.isinf(value):
630
raise ValueError(
631
f"Column {col} has a maximum value of infinity which is outside "
632
"the range supported by Stata."
633
)
634
if dtype == np.float32 and value > float32_max:
635
data[col] = data[col].astype(np.float64)
636
elif dtype == np.float64:
637
if value > float64_max:
638
raise ValueError(
639
f"Column {col} has a maximum value ({value}) outside the range "
640
f"supported by Stata ({float64_max})"
641
)
642
if is_nullable_int:
643
if orig_missing.any():
644
# Replace missing by Stata sentinel value
645
sentinel = StataMissingValue.BASE_MISSING_VALUES[data[col].dtype.name]
646
data.loc[orig_missing, col] = sentinel
647
if ws:
648
warnings.warn(ws, PossiblePrecisionLoss)
649
650
return data
651
652
653
class StataValueLabel:
654
"""
655
Parse a categorical column and prepare formatted output
656
657
Parameters
658
----------
659
catarray : Series
660
Categorical Series to encode
661
encoding : {"latin-1", "utf-8"}
662
Encoding to use for value labels.
663
"""
664
665
def __init__(self, catarray: Series, encoding: str = "latin-1"):
666
667
if encoding not in ("latin-1", "utf-8"):
668
raise ValueError("Only latin-1 and utf-8 are supported.")
669
self.labname = catarray.name
670
self._encoding = encoding
671
categories = catarray.cat.categories
672
self.value_labels: list[tuple[int | float, str]] = list(
673
zip(np.arange(len(categories)), categories)
674
)
675
self.value_labels.sort(key=lambda x: x[0])
676
677
self._prepare_value_labels()
678
679
def _prepare_value_labels(self):
680
"""Encode value labels."""
681
682
self.text_len = 0
683
self.txt: list[bytes] = []
684
self.n = 0
685
# Offsets (length of categories), converted to int32
686
self.off = np.array([], dtype=np.int32)
687
# Values, converted to int32
688
self.val = np.array([], dtype=np.int32)
689
self.len = 0
690
691
# Compute lengths and setup lists of offsets and labels
692
offsets: list[int] = []
693
values: list[int | float] = []
694
for vl in self.value_labels:
695
category: str | bytes = vl[1]
696
if not isinstance(category, str):
697
category = str(category)
698
warnings.warn(
699
value_label_mismatch_doc.format(self.labname),
700
ValueLabelTypeMismatch,
701
)
702
category = category.encode(self._encoding)
703
offsets.append(self.text_len)
704
self.text_len += len(category) + 1 # +1 for the padding
705
values.append(vl[0])
706
self.txt.append(category)
707
self.n += 1
708
709
if self.text_len > 32000:
710
raise ValueError(
711
"Stata value labels for a single variable must "
712
"have a combined length less than 32,000 characters."
713
)
714
715
# Ensure int32
716
self.off = np.array(offsets, dtype=np.int32)
717
self.val = np.array(values, dtype=np.int32)
718
719
# Total length
720
self.len = 4 + 4 + 4 * self.n + 4 * self.n + self.text_len
721
722
def generate_value_label(self, byteorder: str) -> bytes:
723
"""
724
Generate the binary representation of the value labels.
725
726
Parameters
727
----------
728
byteorder : str
729
Byte order of the output
730
731
Returns
732
-------
733
value_label : bytes
734
Bytes containing the formatted value label
735
"""
736
encoding = self._encoding
737
bio = BytesIO()
738
null_byte = b"\x00"
739
740
# len
741
bio.write(struct.pack(byteorder + "i", self.len))
742
743
# labname
744
labname = str(self.labname)[:32].encode(encoding)
745
lab_len = 32 if encoding not in ("utf-8", "utf8") else 128
746
labname = _pad_bytes(labname, lab_len + 1)
747
bio.write(labname)
748
749
# padding - 3 bytes
750
for i in range(3):
751
bio.write(struct.pack("c", null_byte))
752
753
# value_label_table
754
# n - int32
755
bio.write(struct.pack(byteorder + "i", self.n))
756
757
# textlen - int32
758
bio.write(struct.pack(byteorder + "i", self.text_len))
759
760
# off - int32 array (n elements)
761
for offset in self.off:
762
bio.write(struct.pack(byteorder + "i", offset))
763
764
# val - int32 array (n elements)
765
for value in self.val:
766
bio.write(struct.pack(byteorder + "i", value))
767
768
# txt - Text labels, null terminated
769
for text in self.txt:
770
bio.write(text + null_byte)
771
772
return bio.getvalue()
773
774
775
class StataNonCatValueLabel(StataValueLabel):
776
"""
777
Prepare formatted version of value labels
778
779
Parameters
780
----------
781
labname : str
782
Value label name
783
value_labels: Dictionary
784
Mapping of values to labels
785
encoding : {"latin-1", "utf-8"}
786
Encoding to use for value labels.
787
"""
788
789
def __init__(
790
self,
791
labname: str,
792
value_labels: dict[float | int, str],
793
encoding: Literal["latin-1", "utf-8"] = "latin-1",
794
):
795
796
if encoding not in ("latin-1", "utf-8"):
797
raise ValueError("Only latin-1 and utf-8 are supported.")
798
799
self.labname = labname
800
self._encoding = encoding
801
self.value_labels: list[tuple[int | float, str]] = sorted(
802
value_labels.items(), key=lambda x: x[0]
803
)
804
self._prepare_value_labels()
805
806
807
class StataMissingValue:
808
"""
809
An observation's missing value.
810
811
Parameters
812
----------
813
value : {int, float}
814
The Stata missing value code
815
816
Notes
817
-----
818
More information: <https://www.stata.com/help.cgi?missing>
819
820
Integer missing values make the code '.', '.a', ..., '.z' to the ranges
821
101 ... 127 (for int8), 32741 ... 32767 (for int16) and 2147483621 ...
822
2147483647 (for int32). Missing values for floating point data types are
823
more complex but the pattern is simple to discern from the following table.
824
825
np.float32 missing values (float in Stata)
826
0000007f .
827
0008007f .a
828
0010007f .b
829
...
830
00c0007f .x
831
00c8007f .y
832
00d0007f .z
833
834
np.float64 missing values (double in Stata)
835
000000000000e07f .
836
000000000001e07f .a
837
000000000002e07f .b
838
...
839
000000000018e07f .x
840
000000000019e07f .y
841
00000000001ae07f .z
842
"""
843
844
# Construct a dictionary of missing values
845
MISSING_VALUES: dict[float, str] = {}
846
bases = (101, 32741, 2147483621)
847
for b in bases:
848
# Conversion to long to avoid hash issues on 32 bit platforms #8968
849
MISSING_VALUES[b] = "."
850
for i in range(1, 27):
851
MISSING_VALUES[i + b] = "." + chr(96 + i)
852
853
float32_base = b"\x00\x00\x00\x7f"
854
increment = struct.unpack("<i", b"\x00\x08\x00\x00")[0]
855
for i in range(27):
856
key = struct.unpack("<f", float32_base)[0]
857
MISSING_VALUES[key] = "."
858
if i > 0:
859
MISSING_VALUES[key] += chr(96 + i)
860
int_value = struct.unpack("<i", struct.pack("<f", key))[0] + increment
861
float32_base = struct.pack("<i", int_value)
862
863
float64_base = b"\x00\x00\x00\x00\x00\x00\xe0\x7f"
864
increment = struct.unpack("q", b"\x00\x00\x00\x00\x00\x01\x00\x00")[0]
865
for i in range(27):
866
key = struct.unpack("<d", float64_base)[0]
867
MISSING_VALUES[key] = "."
868
if i > 0:
869
MISSING_VALUES[key] += chr(96 + i)
870
int_value = struct.unpack("q", struct.pack("<d", key))[0] + increment
871
float64_base = struct.pack("q", int_value)
872
873
BASE_MISSING_VALUES = {
874
"int8": 101,
875
"int16": 32741,
876
"int32": 2147483621,
877
"float32": struct.unpack("<f", float32_base)[0],
878
"float64": struct.unpack("<d", float64_base)[0],
879
}
880
881
def __init__(self, value: int | float):
882
self._value = value
883
# Conversion to int to avoid hash issues on 32 bit platforms #8968
884
value = int(value) if value < 2147483648 else float(value)
885
self._str = self.MISSING_VALUES[value]
886
887
@property
888
def string(self) -> str:
889
"""
890
The Stata representation of the missing value: '.', '.a'..'.z'
891
892
Returns
893
-------
894
str
895
The representation of the missing value.
896
"""
897
return self._str
898
899
@property
900
def value(self) -> int | float:
901
"""
902
The binary representation of the missing value.
903
904
Returns
905
-------
906
{int, float}
907
The binary representation of the missing value.
908
"""
909
return self._value
910
911
def __str__(self) -> str:
912
return self.string
913
914
def __repr__(self) -> str:
915
return f"{type(self)}({self})"
916
917
def __eq__(self, other: Any) -> bool:
918
return (
919
isinstance(other, type(self))
920
and self.string == other.string
921
and self.value == other.value
922
)
923
924
@classmethod
925
def get_base_missing_value(cls, dtype: np.dtype) -> int | float:
926
if dtype.type is np.int8:
927
value = cls.BASE_MISSING_VALUES["int8"]
928
elif dtype.type is np.int16:
929
value = cls.BASE_MISSING_VALUES["int16"]
930
elif dtype.type is np.int32:
931
value = cls.BASE_MISSING_VALUES["int32"]
932
elif dtype.type is np.float32:
933
value = cls.BASE_MISSING_VALUES["float32"]
934
elif dtype.type is np.float64:
935
value = cls.BASE_MISSING_VALUES["float64"]
936
else:
937
raise ValueError("Unsupported dtype")
938
return value
939
940
941
class StataParser:
942
def __init__(self):
943
944
# type code.
945
# --------------------
946
# str1 1 = 0x01
947
# str2 2 = 0x02
948
# ...
949
# str244 244 = 0xf4
950
# byte 251 = 0xfb (sic)
951
# int 252 = 0xfc
952
# long 253 = 0xfd
953
# float 254 = 0xfe
954
# double 255 = 0xff
955
# --------------------
956
# NOTE: the byte type seems to be reserved for categorical variables
957
# with a label, but the underlying variable is -127 to 100
958
# we're going to drop the label and cast to int
959
self.DTYPE_MAP = dict(
960
list(zip(range(1, 245), [np.dtype("a" + str(i)) for i in range(1, 245)]))
961
+ [
962
(251, np.dtype(np.int8)),
963
(252, np.dtype(np.int16)),
964
(253, np.dtype(np.int32)),
965
(254, np.dtype(np.float32)),
966
(255, np.dtype(np.float64)),
967
]
968
)
969
self.DTYPE_MAP_XML = {
970
32768: np.dtype(np.uint8), # Keys to GSO
971
65526: np.dtype(np.float64),
972
65527: np.dtype(np.float32),
973
65528: np.dtype(np.int32),
974
65529: np.dtype(np.int16),
975
65530: np.dtype(np.int8),
976
}
977
# error: Argument 1 to "list" has incompatible type "str";
978
# expected "Iterable[int]" [arg-type]
979
self.TYPE_MAP = list(range(251)) + list("bhlfd") # type: ignore[arg-type]
980
self.TYPE_MAP_XML = {
981
# Not really a Q, unclear how to handle byteswap
982
32768: "Q",
983
65526: "d",
984
65527: "f",
985
65528: "l",
986
65529: "h",
987
65530: "b",
988
}
989
# NOTE: technically, some of these are wrong. there are more numbers
990
# that can be represented. it's the 27 ABOVE and BELOW the max listed
991
# numeric data type in [U] 12.2.2 of the 11.2 manual
992
float32_min = b"\xff\xff\xff\xfe"
993
float32_max = b"\xff\xff\xff\x7e"
994
float64_min = b"\xff\xff\xff\xff\xff\xff\xef\xff"
995
float64_max = b"\xff\xff\xff\xff\xff\xff\xdf\x7f"
996
self.VALID_RANGE = {
997
"b": (-127, 100),
998
"h": (-32767, 32740),
999
"l": (-2147483647, 2147483620),
1000
"f": (
1001
np.float32(struct.unpack("<f", float32_min)[0]),
1002
np.float32(struct.unpack("<f", float32_max)[0]),
1003
),
1004
"d": (
1005
np.float64(struct.unpack("<d", float64_min)[0]),
1006
np.float64(struct.unpack("<d", float64_max)[0]),
1007
),
1008
}
1009
1010
self.OLD_TYPE_MAPPING = {
1011
98: 251, # byte
1012
105: 252, # int
1013
108: 253, # long
1014
102: 254, # float
1015
100: 255, # double
1016
}
1017
1018
# These missing values are the generic '.' in Stata, and are used
1019
# to replace nans
1020
self.MISSING_VALUES = {
1021
"b": 101,
1022
"h": 32741,
1023
"l": 2147483621,
1024
"f": np.float32(struct.unpack("<f", b"\x00\x00\x00\x7f")[0]),
1025
"d": np.float64(
1026
struct.unpack("<d", b"\x00\x00\x00\x00\x00\x00\xe0\x7f")[0]
1027
),
1028
}
1029
self.NUMPY_TYPE_MAP = {
1030
"b": "i1",
1031
"h": "i2",
1032
"l": "i4",
1033
"f": "f4",
1034
"d": "f8",
1035
"Q": "u8",
1036
}
1037
1038
# Reserved words cannot be used as variable names
1039
self.RESERVED_WORDS = (
1040
"aggregate",
1041
"array",
1042
"boolean",
1043
"break",
1044
"byte",
1045
"case",
1046
"catch",
1047
"class",
1048
"colvector",
1049
"complex",
1050
"const",
1051
"continue",
1052
"default",
1053
"delegate",
1054
"delete",
1055
"do",
1056
"double",
1057
"else",
1058
"eltypedef",
1059
"end",
1060
"enum",
1061
"explicit",
1062
"export",
1063
"external",
1064
"float",
1065
"for",
1066
"friend",
1067
"function",
1068
"global",
1069
"goto",
1070
"if",
1071
"inline",
1072
"int",
1073
"local",
1074
"long",
1075
"NULL",
1076
"pragma",
1077
"protected",
1078
"quad",
1079
"rowvector",
1080
"short",
1081
"typedef",
1082
"typename",
1083
"virtual",
1084
"_all",
1085
"_N",
1086
"_skip",
1087
"_b",
1088
"_pi",
1089
"str#",
1090
"in",
1091
"_pred",
1092
"strL",
1093
"_coef",
1094
"_rc",
1095
"using",
1096
"_cons",
1097
"_se",
1098
"with",
1099
"_n",
1100
)
1101
1102
1103
class StataReader(StataParser, abc.Iterator):
1104
__doc__ = _stata_reader_doc
1105
1106
def __init__(
1107
self,
1108
path_or_buf: FilePath | ReadBuffer[bytes],
1109
convert_dates: bool = True,
1110
convert_categoricals: bool = True,
1111
index_col: str | None = None,
1112
convert_missing: bool = False,
1113
preserve_dtypes: bool = True,
1114
columns: Sequence[str] | None = None,
1115
order_categoricals: bool = True,
1116
chunksize: int | None = None,
1117
compression: CompressionOptions = "infer",
1118
storage_options: StorageOptions = None,
1119
):
1120
super().__init__()
1121
self.col_sizes: list[int] = []
1122
1123
# Arguments to the reader (can be temporarily overridden in
1124
# calls to read).
1125
self._convert_dates = convert_dates
1126
self._convert_categoricals = convert_categoricals
1127
self._index_col = index_col
1128
self._convert_missing = convert_missing
1129
self._preserve_dtypes = preserve_dtypes
1130
self._columns = columns
1131
self._order_categoricals = order_categoricals
1132
self._encoding = ""
1133
self._chunksize = chunksize
1134
self._using_iterator = False
1135
if self._chunksize is None:
1136
self._chunksize = 1
1137
elif not isinstance(chunksize, int) or chunksize <= 0:
1138
raise ValueError("chunksize must be a positive integer when set.")
1139
1140
# State variables for the file
1141
self._has_string_data = False
1142
self._missing_values = False
1143
self._can_read_value_labels = False
1144
self._column_selector_set = False
1145
self._value_labels_read = False
1146
self._data_read = False
1147
self._dtype: np.dtype | None = None
1148
self._lines_read = 0
1149
1150
self._native_byteorder = _set_endianness(sys.byteorder)
1151
with get_handle(
1152
path_or_buf,
1153
"rb",
1154
storage_options=storage_options,
1155
is_text=False,
1156
compression=compression,
1157
) as handles:
1158
# Copy to BytesIO, and ensure no encoding
1159
self.path_or_buf = BytesIO(handles.handle.read())
1160
1161
self._read_header()
1162
self._setup_dtype()
1163
1164
def __enter__(self) -> StataReader:
1165
"""enter context manager"""
1166
return self
1167
1168
def __exit__(self, exc_type, exc_value, traceback) -> None:
1169
"""exit context manager"""
1170
self.close()
1171
1172
def close(self) -> None:
1173
"""close the handle if its open"""
1174
self.path_or_buf.close()
1175
1176
def _set_encoding(self) -> None:
1177
"""
1178
Set string encoding which depends on file version
1179
"""
1180
if self.format_version < 118:
1181
self._encoding = "latin-1"
1182
else:
1183
self._encoding = "utf-8"
1184
1185
def _read_header(self) -> None:
1186
first_char = self.path_or_buf.read(1)
1187
if struct.unpack("c", first_char)[0] == b"<":
1188
self._read_new_header()
1189
else:
1190
self._read_old_header(first_char)
1191
1192
self.has_string_data = len([x for x in self.typlist if type(x) is int]) > 0
1193
1194
# calculate size of a data record
1195
self.col_sizes = [self._calcsize(typ) for typ in self.typlist]
1196
1197
def _read_new_header(self) -> None:
1198
# The first part of the header is common to 117 - 119.
1199
self.path_or_buf.read(27) # stata_dta><header><release>
1200
self.format_version = int(self.path_or_buf.read(3))
1201
if self.format_version not in [117, 118, 119]:
1202
raise ValueError(_version_error.format(version=self.format_version))
1203
self._set_encoding()
1204
self.path_or_buf.read(21) # </release><byteorder>
1205
self.byteorder = self.path_or_buf.read(3) == b"MSF" and ">" or "<"
1206
self.path_or_buf.read(15) # </byteorder><K>
1207
nvar_type = "H" if self.format_version <= 118 else "I"
1208
nvar_size = 2 if self.format_version <= 118 else 4
1209
self.nvar = struct.unpack(
1210
self.byteorder + nvar_type, self.path_or_buf.read(nvar_size)
1211
)[0]
1212
self.path_or_buf.read(7) # </K><N>
1213
1214
self.nobs = self._get_nobs()
1215
self.path_or_buf.read(11) # </N><label>
1216
self._data_label = self._get_data_label()
1217
self.path_or_buf.read(19) # </label><timestamp>
1218
self.time_stamp = self._get_time_stamp()
1219
self.path_or_buf.read(26) # </timestamp></header><map>
1220
self.path_or_buf.read(8) # 0x0000000000000000
1221
self.path_or_buf.read(8) # position of <map>
1222
1223
self._seek_vartypes = (
1224
struct.unpack(self.byteorder + "q", self.path_or_buf.read(8))[0] + 16
1225
)
1226
self._seek_varnames = (
1227
struct.unpack(self.byteorder + "q", self.path_or_buf.read(8))[0] + 10
1228
)
1229
self._seek_sortlist = (
1230
struct.unpack(self.byteorder + "q", self.path_or_buf.read(8))[0] + 10
1231
)
1232
self._seek_formats = (
1233
struct.unpack(self.byteorder + "q", self.path_or_buf.read(8))[0] + 9
1234
)
1235
self._seek_value_label_names = (
1236
struct.unpack(self.byteorder + "q", self.path_or_buf.read(8))[0] + 19
1237
)
1238
1239
# Requires version-specific treatment
1240
self._seek_variable_labels = self._get_seek_variable_labels()
1241
1242
self.path_or_buf.read(8) # <characteristics>
1243
self.data_location = (
1244
struct.unpack(self.byteorder + "q", self.path_or_buf.read(8))[0] + 6
1245
)
1246
self.seek_strls = (
1247
struct.unpack(self.byteorder + "q", self.path_or_buf.read(8))[0] + 7
1248
)
1249
self.seek_value_labels = (
1250
struct.unpack(self.byteorder + "q", self.path_or_buf.read(8))[0] + 14
1251
)
1252
1253
self.typlist, self.dtyplist = self._get_dtypes(self._seek_vartypes)
1254
1255
self.path_or_buf.seek(self._seek_varnames)
1256
self.varlist = self._get_varlist()
1257
1258
self.path_or_buf.seek(self._seek_sortlist)
1259
self.srtlist = struct.unpack(
1260
self.byteorder + ("h" * (self.nvar + 1)),
1261
self.path_or_buf.read(2 * (self.nvar + 1)),
1262
)[:-1]
1263
1264
self.path_or_buf.seek(self._seek_formats)
1265
self.fmtlist = self._get_fmtlist()
1266
1267
self.path_or_buf.seek(self._seek_value_label_names)
1268
self.lbllist = self._get_lbllist()
1269
1270
self.path_or_buf.seek(self._seek_variable_labels)
1271
self._variable_labels = self._get_variable_labels()
1272
1273
# Get data type information, works for versions 117-119.
1274
def _get_dtypes(
1275
self, seek_vartypes: int
1276
) -> tuple[list[int | str], list[str | np.dtype]]:
1277
1278
self.path_or_buf.seek(seek_vartypes)
1279
raw_typlist = [
1280
struct.unpack(self.byteorder + "H", self.path_or_buf.read(2))[0]
1281
for _ in range(self.nvar)
1282
]
1283
1284
def f(typ: int) -> int | str:
1285
if typ <= 2045:
1286
return typ
1287
try:
1288
return self.TYPE_MAP_XML[typ]
1289
except KeyError as err:
1290
raise ValueError(f"cannot convert stata types [{typ}]") from err
1291
1292
typlist = [f(x) for x in raw_typlist]
1293
1294
def g(typ: int) -> str | np.dtype:
1295
if typ <= 2045:
1296
return str(typ)
1297
try:
1298
# error: Incompatible return value type (got "Type[number]", expected
1299
# "Union[str, dtype]")
1300
return self.DTYPE_MAP_XML[typ] # type: ignore[return-value]
1301
except KeyError as err:
1302
raise ValueError(f"cannot convert stata dtype [{typ}]") from err
1303
1304
dtyplist = [g(x) for x in raw_typlist]
1305
1306
return typlist, dtyplist
1307
1308
def _get_varlist(self) -> list[str]:
1309
# 33 in order formats, 129 in formats 118 and 119
1310
b = 33 if self.format_version < 118 else 129
1311
return [self._decode(self.path_or_buf.read(b)) for _ in range(self.nvar)]
1312
1313
# Returns the format list
1314
def _get_fmtlist(self) -> list[str]:
1315
if self.format_version >= 118:
1316
b = 57
1317
elif self.format_version > 113:
1318
b = 49
1319
elif self.format_version > 104:
1320
b = 12
1321
else:
1322
b = 7
1323
1324
return [self._decode(self.path_or_buf.read(b)) for _ in range(self.nvar)]
1325
1326
# Returns the label list
1327
def _get_lbllist(self) -> list[str]:
1328
if self.format_version >= 118:
1329
b = 129
1330
elif self.format_version > 108:
1331
b = 33
1332
else:
1333
b = 9
1334
return [self._decode(self.path_or_buf.read(b)) for _ in range(self.nvar)]
1335
1336
def _get_variable_labels(self) -> list[str]:
1337
if self.format_version >= 118:
1338
vlblist = [
1339
self._decode(self.path_or_buf.read(321)) for _ in range(self.nvar)
1340
]
1341
elif self.format_version > 105:
1342
vlblist = [
1343
self._decode(self.path_or_buf.read(81)) for _ in range(self.nvar)
1344
]
1345
else:
1346
vlblist = [
1347
self._decode(self.path_or_buf.read(32)) for _ in range(self.nvar)
1348
]
1349
return vlblist
1350
1351
def _get_nobs(self) -> int:
1352
if self.format_version >= 118:
1353
return struct.unpack(self.byteorder + "Q", self.path_or_buf.read(8))[0]
1354
else:
1355
return struct.unpack(self.byteorder + "I", self.path_or_buf.read(4))[0]
1356
1357
def _get_data_label(self) -> str:
1358
if self.format_version >= 118:
1359
strlen = struct.unpack(self.byteorder + "H", self.path_or_buf.read(2))[0]
1360
return self._decode(self.path_or_buf.read(strlen))
1361
elif self.format_version == 117:
1362
strlen = struct.unpack("b", self.path_or_buf.read(1))[0]
1363
return self._decode(self.path_or_buf.read(strlen))
1364
elif self.format_version > 105:
1365
return self._decode(self.path_or_buf.read(81))
1366
else:
1367
return self._decode(self.path_or_buf.read(32))
1368
1369
def _get_time_stamp(self) -> str:
1370
if self.format_version >= 118:
1371
strlen = struct.unpack("b", self.path_or_buf.read(1))[0]
1372
return self.path_or_buf.read(strlen).decode("utf-8")
1373
elif self.format_version == 117:
1374
strlen = struct.unpack("b", self.path_or_buf.read(1))[0]
1375
return self._decode(self.path_or_buf.read(strlen))
1376
elif self.format_version > 104:
1377
return self._decode(self.path_or_buf.read(18))
1378
else:
1379
raise ValueError()
1380
1381
def _get_seek_variable_labels(self) -> int:
1382
if self.format_version == 117:
1383
self.path_or_buf.read(8) # <variable_labels>, throw away
1384
# Stata 117 data files do not follow the described format. This is
1385
# a work around that uses the previous label, 33 bytes for each
1386
# variable, 20 for the closing tag and 17 for the opening tag
1387
return self._seek_value_label_names + (33 * self.nvar) + 20 + 17
1388
elif self.format_version >= 118:
1389
return struct.unpack(self.byteorder + "q", self.path_or_buf.read(8))[0] + 17
1390
else:
1391
raise ValueError()
1392
1393
def _read_old_header(self, first_char: bytes) -> None:
1394
self.format_version = struct.unpack("b", first_char)[0]
1395
if self.format_version not in [104, 105, 108, 111, 113, 114, 115]:
1396
raise ValueError(_version_error.format(version=self.format_version))
1397
self._set_encoding()
1398
self.byteorder = (
1399
struct.unpack("b", self.path_or_buf.read(1))[0] == 0x1 and ">" or "<"
1400
)
1401
self.filetype = struct.unpack("b", self.path_or_buf.read(1))[0]
1402
self.path_or_buf.read(1) # unused
1403
1404
self.nvar = struct.unpack(self.byteorder + "H", self.path_or_buf.read(2))[0]
1405
self.nobs = self._get_nobs()
1406
1407
self._data_label = self._get_data_label()
1408
1409
self.time_stamp = self._get_time_stamp()
1410
1411
# descriptors
1412
if self.format_version > 108:
1413
typlist = [ord(self.path_or_buf.read(1)) for _ in range(self.nvar)]
1414
else:
1415
buf = self.path_or_buf.read(self.nvar)
1416
typlistb = np.frombuffer(buf, dtype=np.uint8)
1417
typlist = []
1418
for tp in typlistb:
1419
if tp in self.OLD_TYPE_MAPPING:
1420
typlist.append(self.OLD_TYPE_MAPPING[tp])
1421
else:
1422
typlist.append(tp - 127) # bytes
1423
1424
try:
1425
self.typlist = [self.TYPE_MAP[typ] for typ in typlist]
1426
except ValueError as err:
1427
invalid_types = ",".join([str(x) for x in typlist])
1428
raise ValueError(f"cannot convert stata types [{invalid_types}]") from err
1429
try:
1430
self.dtyplist = [self.DTYPE_MAP[typ] for typ in typlist]
1431
except ValueError as err:
1432
invalid_dtypes = ",".join([str(x) for x in typlist])
1433
raise ValueError(f"cannot convert stata dtypes [{invalid_dtypes}]") from err
1434
1435
if self.format_version > 108:
1436
self.varlist = [
1437
self._decode(self.path_or_buf.read(33)) for _ in range(self.nvar)
1438
]
1439
else:
1440
self.varlist = [
1441
self._decode(self.path_or_buf.read(9)) for _ in range(self.nvar)
1442
]
1443
self.srtlist = struct.unpack(
1444
self.byteorder + ("h" * (self.nvar + 1)),
1445
self.path_or_buf.read(2 * (self.nvar + 1)),
1446
)[:-1]
1447
1448
self.fmtlist = self._get_fmtlist()
1449
1450
self.lbllist = self._get_lbllist()
1451
1452
self._variable_labels = self._get_variable_labels()
1453
1454
# ignore expansion fields (Format 105 and later)
1455
# When reading, read five bytes; the last four bytes now tell you
1456
# the size of the next read, which you discard. You then continue
1457
# like this until you read 5 bytes of zeros.
1458
1459
if self.format_version > 104:
1460
while True:
1461
data_type = struct.unpack(
1462
self.byteorder + "b", self.path_or_buf.read(1)
1463
)[0]
1464
if self.format_version > 108:
1465
data_len = struct.unpack(
1466
self.byteorder + "i", self.path_or_buf.read(4)
1467
)[0]
1468
else:
1469
data_len = struct.unpack(
1470
self.byteorder + "h", self.path_or_buf.read(2)
1471
)[0]
1472
if data_type == 0:
1473
break
1474
self.path_or_buf.read(data_len)
1475
1476
# necessary data to continue parsing
1477
self.data_location = self.path_or_buf.tell()
1478
1479
def _setup_dtype(self) -> np.dtype:
1480
"""Map between numpy and state dtypes"""
1481
if self._dtype is not None:
1482
return self._dtype
1483
1484
dtypes = [] # Convert struct data types to numpy data type
1485
for i, typ in enumerate(self.typlist):
1486
if typ in self.NUMPY_TYPE_MAP:
1487
typ = cast(str, typ) # only strs in NUMPY_TYPE_MAP
1488
dtypes.append(("s" + str(i), self.byteorder + self.NUMPY_TYPE_MAP[typ]))
1489
else:
1490
dtypes.append(("s" + str(i), "S" + str(typ)))
1491
self._dtype = np.dtype(dtypes)
1492
1493
return self._dtype
1494
1495
def _calcsize(self, fmt: int | str) -> int:
1496
if isinstance(fmt, int):
1497
return fmt
1498
return struct.calcsize(self.byteorder + fmt)
1499
1500
def _decode(self, s: bytes) -> str:
1501
# have bytes not strings, so must decode
1502
s = s.partition(b"\0")[0]
1503
try:
1504
return s.decode(self._encoding)
1505
except UnicodeDecodeError:
1506
# GH 25960, fallback to handle incorrect format produced when 117
1507
# files are converted to 118 files in Stata
1508
encoding = self._encoding
1509
msg = f"""
1510
One or more strings in the dta file could not be decoded using {encoding}, and
1511
so the fallback encoding of latin-1 is being used. This can happen when a file
1512
has been incorrectly encoded by Stata or some other software. You should verify
1513
the string values returned are correct."""
1514
warnings.warn(msg, UnicodeWarning)
1515
return s.decode("latin-1")
1516
1517
def _read_value_labels(self) -> None:
1518
if self._value_labels_read:
1519
# Don't read twice
1520
return
1521
if self.format_version <= 108:
1522
# Value labels are not supported in version 108 and earlier.
1523
self._value_labels_read = True
1524
self.value_label_dict: dict[str, dict[float | int, str]] = {}
1525
return
1526
1527
if self.format_version >= 117:
1528
self.path_or_buf.seek(self.seek_value_labels)
1529
else:
1530
assert self._dtype is not None
1531
offset = self.nobs * self._dtype.itemsize
1532
self.path_or_buf.seek(self.data_location + offset)
1533
1534
self._value_labels_read = True
1535
self.value_label_dict = {}
1536
1537
while True:
1538
if self.format_version >= 117:
1539
if self.path_or_buf.read(5) == b"</val": # <lbl>
1540
break # end of value label table
1541
1542
slength = self.path_or_buf.read(4)
1543
if not slength:
1544
break # end of value label table (format < 117)
1545
if self.format_version <= 117:
1546
labname = self._decode(self.path_or_buf.read(33))
1547
else:
1548
labname = self._decode(self.path_or_buf.read(129))
1549
self.path_or_buf.read(3) # padding
1550
1551
n = struct.unpack(self.byteorder + "I", self.path_or_buf.read(4))[0]
1552
txtlen = struct.unpack(self.byteorder + "I", self.path_or_buf.read(4))[0]
1553
off = np.frombuffer(
1554
self.path_or_buf.read(4 * n), dtype=self.byteorder + "i4", count=n
1555
)
1556
val = np.frombuffer(
1557
self.path_or_buf.read(4 * n), dtype=self.byteorder + "i4", count=n
1558
)
1559
ii = np.argsort(off)
1560
off = off[ii]
1561
val = val[ii]
1562
txt = self.path_or_buf.read(txtlen)
1563
self.value_label_dict[labname] = {}
1564
for i in range(n):
1565
end = off[i + 1] if i < n - 1 else txtlen
1566
self.value_label_dict[labname][val[i]] = self._decode(txt[off[i] : end])
1567
if self.format_version >= 117:
1568
self.path_or_buf.read(6) # </lbl>
1569
self._value_labels_read = True
1570
1571
def _read_strls(self) -> None:
1572
self.path_or_buf.seek(self.seek_strls)
1573
# Wrap v_o in a string to allow uint64 values as keys on 32bit OS
1574
self.GSO = {"0": ""}
1575
while True:
1576
if self.path_or_buf.read(3) != b"GSO":
1577
break
1578
1579
if self.format_version == 117:
1580
v_o = struct.unpack(self.byteorder + "Q", self.path_or_buf.read(8))[0]
1581
else:
1582
buf = self.path_or_buf.read(12)
1583
# Only tested on little endian file on little endian machine.
1584
v_size = 2 if self.format_version == 118 else 3
1585
if self.byteorder == "<":
1586
buf = buf[0:v_size] + buf[4 : (12 - v_size)]
1587
else:
1588
# This path may not be correct, impossible to test
1589
buf = buf[0:v_size] + buf[(4 + v_size) :]
1590
v_o = struct.unpack("Q", buf)[0]
1591
typ = struct.unpack("B", self.path_or_buf.read(1))[0]
1592
length = struct.unpack(self.byteorder + "I", self.path_or_buf.read(4))[0]
1593
va = self.path_or_buf.read(length)
1594
if typ == 130:
1595
decoded_va = va[0:-1].decode(self._encoding)
1596
else:
1597
# Stata says typ 129 can be binary, so use str
1598
decoded_va = str(va)
1599
# Wrap v_o in a string to allow uint64 values as keys on 32bit OS
1600
self.GSO[str(v_o)] = decoded_va
1601
1602
def __next__(self) -> DataFrame:
1603
self._using_iterator = True
1604
return self.read(nrows=self._chunksize)
1605
1606
def get_chunk(self, size: int | None = None) -> DataFrame:
1607
"""
1608
Reads lines from Stata file and returns as dataframe
1609
1610
Parameters
1611
----------
1612
size : int, defaults to None
1613
Number of lines to read. If None, reads whole file.
1614
1615
Returns
1616
-------
1617
DataFrame
1618
"""
1619
if size is None:
1620
size = self._chunksize
1621
return self.read(nrows=size)
1622
1623
@Appender(_read_method_doc)
1624
def read(
1625
self,
1626
nrows: int | None = None,
1627
convert_dates: bool | None = None,
1628
convert_categoricals: bool | None = None,
1629
index_col: str | None = None,
1630
convert_missing: bool | None = None,
1631
preserve_dtypes: bool | None = None,
1632
columns: Sequence[str] | None = None,
1633
order_categoricals: bool | None = None,
1634
) -> DataFrame:
1635
# Handle empty file or chunk. If reading incrementally raise
1636
# StopIteration. If reading the whole thing return an empty
1637
# data frame.
1638
if (self.nobs == 0) and (nrows is None):
1639
self._can_read_value_labels = True
1640
self._data_read = True
1641
self.close()
1642
return DataFrame(columns=self.varlist)
1643
1644
# Handle options
1645
if convert_dates is None:
1646
convert_dates = self._convert_dates
1647
if convert_categoricals is None:
1648
convert_categoricals = self._convert_categoricals
1649
if convert_missing is None:
1650
convert_missing = self._convert_missing
1651
if preserve_dtypes is None:
1652
preserve_dtypes = self._preserve_dtypes
1653
if columns is None:
1654
columns = self._columns
1655
if order_categoricals is None:
1656
order_categoricals = self._order_categoricals
1657
if index_col is None:
1658
index_col = self._index_col
1659
1660
if nrows is None:
1661
nrows = self.nobs
1662
1663
if (self.format_version >= 117) and (not self._value_labels_read):
1664
self._can_read_value_labels = True
1665
self._read_strls()
1666
1667
# Read data
1668
assert self._dtype is not None
1669
dtype = self._dtype
1670
max_read_len = (self.nobs - self._lines_read) * dtype.itemsize
1671
read_len = nrows * dtype.itemsize
1672
read_len = min(read_len, max_read_len)
1673
if read_len <= 0:
1674
# Iterator has finished, should never be here unless
1675
# we are reading the file incrementally
1676
if convert_categoricals:
1677
self._read_value_labels()
1678
self.close()
1679
raise StopIteration
1680
offset = self._lines_read * dtype.itemsize
1681
self.path_or_buf.seek(self.data_location + offset)
1682
read_lines = min(nrows, self.nobs - self._lines_read)
1683
raw_data = np.frombuffer(
1684
self.path_or_buf.read(read_len), dtype=dtype, count=read_lines
1685
)
1686
1687
self._lines_read += read_lines
1688
if self._lines_read == self.nobs:
1689
self._can_read_value_labels = True
1690
self._data_read = True
1691
# if necessary, swap the byte order to native here
1692
if self.byteorder != self._native_byteorder:
1693
raw_data = raw_data.byteswap().newbyteorder()
1694
1695
if convert_categoricals:
1696
self._read_value_labels()
1697
1698
if len(raw_data) == 0:
1699
data = DataFrame(columns=self.varlist)
1700
else:
1701
data = DataFrame.from_records(raw_data)
1702
data.columns = Index(self.varlist)
1703
1704
# If index is not specified, use actual row number rather than
1705
# restarting at 0 for each chunk.
1706
if index_col is None:
1707
rng = np.arange(self._lines_read - read_lines, self._lines_read)
1708
data.index = Index(rng) # set attr instead of set_index to avoid copy
1709
1710
if columns is not None:
1711
try:
1712
data = self._do_select_columns(data, columns)
1713
except ValueError:
1714
self.close()
1715
raise
1716
1717
# Decode strings
1718
for col, typ in zip(data, self.typlist):
1719
if type(typ) is int:
1720
data[col] = data[col].apply(self._decode, convert_dtype=True)
1721
1722
data = self._insert_strls(data)
1723
1724
cols_ = np.where([dtyp is not None for dtyp in self.dtyplist])[0]
1725
# Convert columns (if needed) to match input type
1726
ix = data.index
1727
requires_type_conversion = False
1728
data_formatted = []
1729
for i in cols_:
1730
if self.dtyplist[i] is not None:
1731
col = data.columns[i]
1732
dtype = data[col].dtype
1733
if dtype != np.dtype(object) and dtype != self.dtyplist[i]:
1734
requires_type_conversion = True
1735
data_formatted.append(
1736
(col, Series(data[col], ix, self.dtyplist[i]))
1737
)
1738
else:
1739
data_formatted.append((col, data[col]))
1740
if requires_type_conversion:
1741
data = DataFrame.from_dict(dict(data_formatted))
1742
del data_formatted
1743
1744
data = self._do_convert_missing(data, convert_missing)
1745
1746
if convert_dates:
1747
1748
def any_startswith(x: str) -> bool:
1749
return any(x.startswith(fmt) for fmt in _date_formats)
1750
1751
cols = np.where([any_startswith(x) for x in self.fmtlist])[0]
1752
for i in cols:
1753
col = data.columns[i]
1754
try:
1755
data[col] = _stata_elapsed_date_to_datetime_vec(
1756
data[col], self.fmtlist[i]
1757
)
1758
except ValueError:
1759
self.close()
1760
raise
1761
1762
if convert_categoricals and self.format_version > 108:
1763
data = self._do_convert_categoricals(
1764
data, self.value_label_dict, self.lbllist, order_categoricals
1765
)
1766
1767
if not preserve_dtypes:
1768
retyped_data = []
1769
convert = False
1770
for col in data:
1771
dtype = data[col].dtype
1772
if dtype in (np.dtype(np.float16), np.dtype(np.float32)):
1773
dtype = np.dtype(np.float64)
1774
convert = True
1775
elif dtype in (
1776
np.dtype(np.int8),
1777
np.dtype(np.int16),
1778
np.dtype(np.int32),
1779
):
1780
dtype = np.dtype(np.int64)
1781
convert = True
1782
retyped_data.append((col, data[col].astype(dtype)))
1783
if convert:
1784
data = DataFrame.from_dict(dict(retyped_data))
1785
1786
if index_col is not None:
1787
data = data.set_index(data.pop(index_col))
1788
1789
return data
1790
1791
def _do_convert_missing(self, data: DataFrame, convert_missing: bool) -> DataFrame:
1792
# Check for missing values, and replace if found
1793
replacements = {}
1794
for i, colname in enumerate(data):
1795
fmt = self.typlist[i]
1796
if fmt not in self.VALID_RANGE:
1797
continue
1798
1799
fmt = cast(str, fmt) # only strs in VALID_RANGE
1800
nmin, nmax = self.VALID_RANGE[fmt]
1801
series = data[colname]
1802
1803
# appreciably faster to do this with ndarray instead of Series
1804
svals = series._values
1805
missing = (svals < nmin) | (svals > nmax)
1806
1807
if not missing.any():
1808
continue
1809
1810
if convert_missing: # Replacement follows Stata notation
1811
missing_loc = np.nonzero(np.asarray(missing))[0]
1812
umissing, umissing_loc = np.unique(series[missing], return_inverse=True)
1813
replacement = Series(series, dtype=object)
1814
for j, um in enumerate(umissing):
1815
missing_value = StataMissingValue(um)
1816
1817
loc = missing_loc[umissing_loc == j]
1818
replacement.iloc[loc] = missing_value
1819
else: # All replacements are identical
1820
dtype = series.dtype
1821
if dtype not in (np.float32, np.float64):
1822
dtype = np.float64
1823
replacement = Series(series, dtype=dtype)
1824
if not replacement._values.flags["WRITEABLE"]:
1825
# only relevant for ArrayManager; construction
1826
# path for BlockManager ensures writeability
1827
replacement = replacement.copy()
1828
# Note: operating on ._values is much faster than directly
1829
# TODO: can we fix that?
1830
replacement._values[missing] = np.nan
1831
replacements[colname] = replacement
1832
1833
if replacements:
1834
for col in replacements:
1835
data[col] = replacements[col]
1836
return data
1837
1838
def _insert_strls(self, data: DataFrame) -> DataFrame:
1839
if not hasattr(self, "GSO") or len(self.GSO) == 0:
1840
return data
1841
for i, typ in enumerate(self.typlist):
1842
if typ != "Q":
1843
continue
1844
# Wrap v_o in a string to allow uint64 values as keys on 32bit OS
1845
data.iloc[:, i] = [self.GSO[str(k)] for k in data.iloc[:, i]]
1846
return data
1847
1848
def _do_select_columns(self, data: DataFrame, columns: Sequence[str]) -> DataFrame:
1849
1850
if not self._column_selector_set:
1851
column_set = set(columns)
1852
if len(column_set) != len(columns):
1853
raise ValueError("columns contains duplicate entries")
1854
unmatched = column_set.difference(data.columns)
1855
if unmatched:
1856
joined = ", ".join(list(unmatched))
1857
raise ValueError(
1858
"The following columns were not "
1859
f"found in the Stata data set: {joined}"
1860
)
1861
# Copy information for retained columns for later processing
1862
dtyplist = []
1863
typlist = []
1864
fmtlist = []
1865
lbllist = []
1866
for col in columns:
1867
i = data.columns.get_loc(col)
1868
dtyplist.append(self.dtyplist[i])
1869
typlist.append(self.typlist[i])
1870
fmtlist.append(self.fmtlist[i])
1871
lbllist.append(self.lbllist[i])
1872
1873
self.dtyplist = dtyplist
1874
self.typlist = typlist
1875
self.fmtlist = fmtlist
1876
self.lbllist = lbllist
1877
self._column_selector_set = True
1878
1879
return data[columns]
1880
1881
def _do_convert_categoricals(
1882
self,
1883
data: DataFrame,
1884
value_label_dict: dict[str, dict[float | int, str]],
1885
lbllist: Sequence[str],
1886
order_categoricals: bool,
1887
) -> DataFrame:
1888
"""
1889
Converts categorical columns to Categorical type.
1890
"""
1891
value_labels = list(value_label_dict.keys())
1892
cat_converted_data = []
1893
for col, label in zip(data, lbllist):
1894
if label in value_labels:
1895
# Explicit call with ordered=True
1896
vl = value_label_dict[label]
1897
keys = np.array(list(vl.keys()))
1898
column = data[col]
1899
key_matches = column.isin(keys)
1900
if self._using_iterator and key_matches.all():
1901
initial_categories: np.ndarray | None = keys
1902
# If all categories are in the keys and we are iterating,
1903
# use the same keys for all chunks. If some are missing
1904
# value labels, then we will fall back to the categories
1905
# varying across chunks.
1906
else:
1907
if self._using_iterator:
1908
# warn is using an iterator
1909
warnings.warn(
1910
categorical_conversion_warning, CategoricalConversionWarning
1911
)
1912
initial_categories = None
1913
cat_data = Categorical(
1914
column, categories=initial_categories, ordered=order_categoricals
1915
)
1916
if initial_categories is None:
1917
# If None here, then we need to match the cats in the Categorical
1918
categories = []
1919
for category in cat_data.categories:
1920
if category in vl:
1921
categories.append(vl[category])
1922
else:
1923
categories.append(category)
1924
else:
1925
# If all cats are matched, we can use the values
1926
categories = list(vl.values())
1927
try:
1928
# Try to catch duplicate categories
1929
cat_data.categories = categories
1930
except ValueError as err:
1931
vc = Series(categories).value_counts()
1932
repeated_cats = list(vc.index[vc > 1])
1933
repeats = "-" * 80 + "\n" + "\n".join(repeated_cats)
1934
# GH 25772
1935
msg = f"""
1936
Value labels for column {col} are not unique. These cannot be converted to
1937
pandas categoricals.
1938
1939
Either read the file with `convert_categoricals` set to False or use the
1940
low level interface in `StataReader` to separately read the values and the
1941
value_labels.
1942
1943
The repeated labels are:
1944
{repeats}
1945
"""
1946
raise ValueError(msg) from err
1947
# TODO: is the next line needed above in the data(...) method?
1948
cat_series = Series(cat_data, index=data.index)
1949
cat_converted_data.append((col, cat_series))
1950
else:
1951
cat_converted_data.append((col, data[col]))
1952
data = DataFrame(dict(cat_converted_data), copy=False)
1953
return data
1954
1955
@property
1956
def data_label(self) -> str:
1957
"""
1958
Return data label of Stata file.
1959
"""
1960
return self._data_label
1961
1962
def variable_labels(self) -> dict[str, str]:
1963
"""
1964
Return variable labels as a dict, associating each variable name
1965
with corresponding label.
1966
1967
Returns
1968
-------
1969
dict
1970
"""
1971
return dict(zip(self.varlist, self._variable_labels))
1972
1973
def value_labels(self) -> dict[str, dict[float | int, str]]:
1974
"""
1975
Return a dict, associating each variable name a dict, associating
1976
each value its corresponding label.
1977
1978
Returns
1979
-------
1980
dict
1981
"""
1982
if not self._value_labels_read:
1983
self._read_value_labels()
1984
1985
return self.value_label_dict
1986
1987
1988
@Appender(_read_stata_doc)
1989
def read_stata(
1990
filepath_or_buffer: FilePath | ReadBuffer[bytes],
1991
convert_dates: bool = True,
1992
convert_categoricals: bool = True,
1993
index_col: str | None = None,
1994
convert_missing: bool = False,
1995
preserve_dtypes: bool = True,
1996
columns: Sequence[str] | None = None,
1997
order_categoricals: bool = True,
1998
chunksize: int | None = None,
1999
iterator: bool = False,
2000
compression: CompressionOptions = "infer",
2001
storage_options: StorageOptions = None,
2002
) -> DataFrame | StataReader:
2003
2004
reader = StataReader(
2005
filepath_or_buffer,
2006
convert_dates=convert_dates,
2007
convert_categoricals=convert_categoricals,
2008
index_col=index_col,
2009
convert_missing=convert_missing,
2010
preserve_dtypes=preserve_dtypes,
2011
columns=columns,
2012
order_categoricals=order_categoricals,
2013
chunksize=chunksize,
2014
storage_options=storage_options,
2015
compression=compression,
2016
)
2017
2018
if iterator or chunksize:
2019
return reader
2020
2021
with reader:
2022
return reader.read()
2023
2024
2025
def _set_endianness(endianness: str) -> str:
2026
if endianness.lower() in ["<", "little"]:
2027
return "<"
2028
elif endianness.lower() in [">", "big"]:
2029
return ">"
2030
else: # pragma : no cover
2031
raise ValueError(f"Endianness {endianness} not understood")
2032
2033
2034
def _pad_bytes(name: AnyStr, length: int) -> AnyStr:
2035
"""
2036
Take a char string and pads it with null bytes until it's length chars.
2037
"""
2038
if isinstance(name, bytes):
2039
return name + b"\x00" * (length - len(name))
2040
return name + "\x00" * (length - len(name))
2041
2042
2043
def _convert_datetime_to_stata_type(fmt: str) -> np.dtype:
2044
"""
2045
Convert from one of the stata date formats to a type in TYPE_MAP.
2046
"""
2047
if fmt in [
2048
"tc",
2049
"%tc",
2050
"td",
2051
"%td",
2052
"tw",
2053
"%tw",
2054
"tm",
2055
"%tm",
2056
"tq",
2057
"%tq",
2058
"th",
2059
"%th",
2060
"ty",
2061
"%ty",
2062
]:
2063
return np.dtype(np.float64) # Stata expects doubles for SIFs
2064
else:
2065
raise NotImplementedError(f"Format {fmt} not implemented")
2066
2067
2068
def _maybe_convert_to_int_keys(convert_dates: dict, varlist: list[Hashable]) -> dict:
2069
new_dict = {}
2070
for key in convert_dates:
2071
if not convert_dates[key].startswith("%"): # make sure proper fmts
2072
convert_dates[key] = "%" + convert_dates[key]
2073
if key in varlist:
2074
new_dict.update({varlist.index(key): convert_dates[key]})
2075
else:
2076
if not isinstance(key, int):
2077
raise ValueError("convert_dates key must be a column or an integer")
2078
new_dict.update({key: convert_dates[key]})
2079
return new_dict
2080
2081
2082
def _dtype_to_stata_type(dtype: np.dtype, column: Series) -> int:
2083
"""
2084
Convert dtype types to stata types. Returns the byte of the given ordinal.
2085
See TYPE_MAP and comments for an explanation. This is also explained in
2086
the dta spec.
2087
1 - 244 are strings of this length
2088
Pandas Stata
2089
251 - for int8 byte
2090
252 - for int16 int
2091
253 - for int32 long
2092
254 - for float32 float
2093
255 - for double double
2094
2095
If there are dates to convert, then dtype will already have the correct
2096
type inserted.
2097
"""
2098
# TODO: expand to handle datetime to integer conversion
2099
if dtype.type is np.object_: # try to coerce it to the biggest string
2100
# not memory efficient, what else could we
2101
# do?
2102
itemsize = max_len_string_array(ensure_object(column._values))
2103
return max(itemsize, 1)
2104
elif dtype.type is np.float64:
2105
return 255
2106
elif dtype.type is np.float32:
2107
return 254
2108
elif dtype.type is np.int32:
2109
return 253
2110
elif dtype.type is np.int16:
2111
return 252
2112
elif dtype.type is np.int8:
2113
return 251
2114
else: # pragma : no cover
2115
raise NotImplementedError(f"Data type {dtype} not supported.")
2116
2117
2118
def _dtype_to_default_stata_fmt(
2119
dtype, column: Series, dta_version: int = 114, force_strl: bool = False
2120
) -> str:
2121
"""
2122
Map numpy dtype to stata's default format for this type. Not terribly
2123
important since users can change this in Stata. Semantics are
2124
2125
object -> "%DDs" where DD is the length of the string. If not a string,
2126
raise ValueError
2127
float64 -> "%10.0g"
2128
float32 -> "%9.0g"
2129
int64 -> "%9.0g"
2130
int32 -> "%12.0g"
2131
int16 -> "%8.0g"
2132
int8 -> "%8.0g"
2133
strl -> "%9s"
2134
"""
2135
# TODO: Refactor to combine type with format
2136
# TODO: expand this to handle a default datetime format?
2137
if dta_version < 117:
2138
max_str_len = 244
2139
else:
2140
max_str_len = 2045
2141
if force_strl:
2142
return "%9s"
2143
if dtype.type is np.object_:
2144
itemsize = max_len_string_array(ensure_object(column._values))
2145
if itemsize > max_str_len:
2146
if dta_version >= 117:
2147
return "%9s"
2148
else:
2149
raise ValueError(excessive_string_length_error.format(column.name))
2150
return "%" + str(max(itemsize, 1)) + "s"
2151
elif dtype == np.float64:
2152
return "%10.0g"
2153
elif dtype == np.float32:
2154
return "%9.0g"
2155
elif dtype == np.int32:
2156
return "%12.0g"
2157
elif dtype == np.int8 or dtype == np.int16:
2158
return "%8.0g"
2159
else: # pragma : no cover
2160
raise NotImplementedError(f"Data type {dtype} not supported.")
2161
2162
2163
@doc(
2164
storage_options=_shared_docs["storage_options"],
2165
compression_options=_shared_docs["compression_options"] % "fname",
2166
)
2167
class StataWriter(StataParser):
2168
"""
2169
A class for writing Stata binary dta files
2170
2171
Parameters
2172
----------
2173
fname : path (string), buffer or path object
2174
string, path object (pathlib.Path or py._path.local.LocalPath) or
2175
object implementing a binary write() functions. If using a buffer
2176
then the buffer will not be automatically closed after the file
2177
is written.
2178
data : DataFrame
2179
Input to save
2180
convert_dates : dict
2181
Dictionary mapping columns containing datetime types to stata internal
2182
format to use when writing the dates. Options are 'tc', 'td', 'tm',
2183
'tw', 'th', 'tq', 'ty'. Column can be either an integer or a name.
2184
Datetime columns that do not have a conversion type specified will be
2185
converted to 'tc'. Raises NotImplementedError if a datetime column has
2186
timezone information
2187
write_index : bool
2188
Write the index to Stata dataset.
2189
byteorder : str
2190
Can be ">", "<", "little", or "big". default is `sys.byteorder`
2191
time_stamp : datetime
2192
A datetime to use as file creation date. Default is the current time
2193
data_label : str
2194
A label for the data set. Must be 80 characters or smaller.
2195
variable_labels : dict
2196
Dictionary containing columns as keys and variable labels as values.
2197
Each label must be 80 characters or smaller.
2198
{compression_options}
2199
2200
.. versionadded:: 1.1.0
2201
2202
.. versionchanged:: 1.4.0 Zstandard support.
2203
2204
{storage_options}
2205
2206
.. versionadded:: 1.2.0
2207
2208
value_labels : dict of dicts
2209
Dictionary containing columns as keys and dictionaries of column value
2210
to labels as values. The combined length of all labels for a single
2211
variable must be 32,000 characters or smaller.
2212
2213
.. versionadded:: 1.4.0
2214
2215
Returns
2216
-------
2217
writer : StataWriter instance
2218
The StataWriter instance has a write_file method, which will
2219
write the file to the given `fname`.
2220
2221
Raises
2222
------
2223
NotImplementedError
2224
* If datetimes contain timezone information
2225
ValueError
2226
* Columns listed in convert_dates are neither datetime64[ns]
2227
or datetime.datetime
2228
* Column dtype is not representable in Stata
2229
* Column listed in convert_dates is not in DataFrame
2230
* Categorical label contains more than 32,000 characters
2231
2232
Examples
2233
--------
2234
>>> data = pd.DataFrame([[1.0, 1]], columns=['a', 'b'])
2235
>>> writer = StataWriter('./data_file.dta', data)
2236
>>> writer.write_file()
2237
2238
Directly write a zip file
2239
>>> compression = {{"method": "zip", "archive_name": "data_file.dta"}}
2240
>>> writer = StataWriter('./data_file.zip', data, compression=compression)
2241
>>> writer.write_file()
2242
2243
Save a DataFrame with dates
2244
>>> from datetime import datetime
2245
>>> data = pd.DataFrame([[datetime(2000,1,1)]], columns=['date'])
2246
>>> writer = StataWriter('./date_data_file.dta', data, {{'date' : 'tw'}})
2247
>>> writer.write_file()
2248
"""
2249
2250
_max_string_length = 244
2251
_encoding = "latin-1"
2252
2253
def __init__(
2254
self,
2255
fname: FilePath | WriteBuffer[bytes],
2256
data: DataFrame,
2257
convert_dates: dict[Hashable, str] | None = None,
2258
write_index: bool = True,
2259
byteorder: str | None = None,
2260
time_stamp: datetime.datetime | None = None,
2261
data_label: str | None = None,
2262
variable_labels: dict[Hashable, str] | None = None,
2263
compression: CompressionOptions = "infer",
2264
storage_options: StorageOptions = None,
2265
*,
2266
value_labels: dict[Hashable, dict[float | int, str]] | None = None,
2267
):
2268
super().__init__()
2269
self.data = data
2270
self._convert_dates = {} if convert_dates is None else convert_dates
2271
self._write_index = write_index
2272
self._time_stamp = time_stamp
2273
self._data_label = data_label
2274
self._variable_labels = variable_labels
2275
self._non_cat_value_labels = value_labels
2276
self._value_labels: list[StataValueLabel] = []
2277
self._has_value_labels = np.array([], dtype=bool)
2278
self._compression = compression
2279
self._output_file: IO[bytes] | None = None
2280
self._converted_names: dict[Hashable, str] = {}
2281
# attach nobs, nvars, data, varlist, typlist
2282
self._prepare_pandas(data)
2283
self.storage_options = storage_options
2284
2285
if byteorder is None:
2286
byteorder = sys.byteorder
2287
self._byteorder = _set_endianness(byteorder)
2288
self._fname = fname
2289
self.type_converters = {253: np.int32, 252: np.int16, 251: np.int8}
2290
2291
def _write(self, to_write: str) -> None:
2292
"""
2293
Helper to call encode before writing to file for Python 3 compat.
2294
"""
2295
self.handles.handle.write(to_write.encode(self._encoding))
2296
2297
def _write_bytes(self, value: bytes) -> None:
2298
"""
2299
Helper to assert file is open before writing.
2300
"""
2301
self.handles.handle.write(value)
2302
2303
def _prepare_non_cat_value_labels(
2304
self, data: DataFrame
2305
) -> list[StataNonCatValueLabel]:
2306
"""
2307
Check for value labels provided for non-categorical columns. Value
2308
labels
2309
"""
2310
non_cat_value_labels: list[StataNonCatValueLabel] = []
2311
if self._non_cat_value_labels is None:
2312
return non_cat_value_labels
2313
2314
for labname, labels in self._non_cat_value_labels.items():
2315
if labname in self._converted_names:
2316
colname = self._converted_names[labname]
2317
elif labname in data.columns:
2318
colname = str(labname)
2319
else:
2320
raise KeyError(
2321
f"Can't create value labels for {labname}, it wasn't "
2322
"found in the dataset."
2323
)
2324
2325
if not is_numeric_dtype(data[colname].dtype):
2326
# Labels should not be passed explicitly for categorical
2327
# columns that will be converted to int
2328
raise ValueError(
2329
f"Can't create value labels for {labname}, value labels "
2330
"can only be applied to numeric columns."
2331
)
2332
svl = StataNonCatValueLabel(colname, labels)
2333
non_cat_value_labels.append(svl)
2334
return non_cat_value_labels
2335
2336
def _prepare_categoricals(self, data: DataFrame) -> DataFrame:
2337
"""
2338
Check for categorical columns, retain categorical information for
2339
Stata file and convert categorical data to int
2340
"""
2341
is_cat = [is_categorical_dtype(data[col].dtype) for col in data]
2342
if not any(is_cat):
2343
return data
2344
2345
self._has_value_labels |= np.array(is_cat)
2346
2347
get_base_missing_value = StataMissingValue.get_base_missing_value
2348
data_formatted = []
2349
for col, col_is_cat in zip(data, is_cat):
2350
if col_is_cat:
2351
svl = StataValueLabel(data[col], encoding=self._encoding)
2352
self._value_labels.append(svl)
2353
dtype = data[col].cat.codes.dtype
2354
if dtype == np.int64:
2355
raise ValueError(
2356
"It is not possible to export "
2357
"int64-based categorical data to Stata."
2358
)
2359
values = data[col].cat.codes._values.copy()
2360
2361
# Upcast if needed so that correct missing values can be set
2362
if values.max() >= get_base_missing_value(dtype):
2363
if dtype == np.int8:
2364
dtype = np.dtype(np.int16)
2365
elif dtype == np.int16:
2366
dtype = np.dtype(np.int32)
2367
else:
2368
dtype = np.dtype(np.float64)
2369
values = np.array(values, dtype=dtype)
2370
2371
# Replace missing values with Stata missing value for type
2372
values[values == -1] = get_base_missing_value(dtype)
2373
data_formatted.append((col, values))
2374
else:
2375
data_formatted.append((col, data[col]))
2376
return DataFrame.from_dict(dict(data_formatted))
2377
2378
def _replace_nans(self, data: DataFrame) -> DataFrame:
2379
# return data
2380
"""
2381
Checks floating point data columns for nans, and replaces these with
2382
the generic Stata for missing value (.)
2383
"""
2384
for c in data:
2385
dtype = data[c].dtype
2386
if dtype in (np.float32, np.float64):
2387
if dtype == np.float32:
2388
replacement = self.MISSING_VALUES["f"]
2389
else:
2390
replacement = self.MISSING_VALUES["d"]
2391
data[c] = data[c].fillna(replacement)
2392
2393
return data
2394
2395
def _update_strl_names(self) -> None:
2396
"""No-op, forward compatibility"""
2397
pass
2398
2399
def _validate_variable_name(self, name: str) -> str:
2400
"""
2401
Validate variable names for Stata export.
2402
2403
Parameters
2404
----------
2405
name : str
2406
Variable name
2407
2408
Returns
2409
-------
2410
str
2411
The validated name with invalid characters replaced with
2412
underscores.
2413
2414
Notes
2415
-----
2416
Stata 114 and 117 support ascii characters in a-z, A-Z, 0-9
2417
and _.
2418
"""
2419
for c in name:
2420
if (
2421
(c < "A" or c > "Z")
2422
and (c < "a" or c > "z")
2423
and (c < "0" or c > "9")
2424
and c != "_"
2425
):
2426
name = name.replace(c, "_")
2427
return name
2428
2429
def _check_column_names(self, data: DataFrame) -> DataFrame:
2430
"""
2431
Checks column names to ensure that they are valid Stata column names.
2432
This includes checks for:
2433
* Non-string names
2434
* Stata keywords
2435
* Variables that start with numbers
2436
* Variables with names that are too long
2437
2438
When an illegal variable name is detected, it is converted, and if
2439
dates are exported, the variable name is propagated to the date
2440
conversion dictionary
2441
"""
2442
converted_names: dict[Hashable, str] = {}
2443
columns = list(data.columns)
2444
original_columns = columns[:]
2445
2446
duplicate_var_id = 0
2447
for j, name in enumerate(columns):
2448
orig_name = name
2449
if not isinstance(name, str):
2450
name = str(name)
2451
2452
name = self._validate_variable_name(name)
2453
2454
# Variable name must not be a reserved word
2455
if name in self.RESERVED_WORDS:
2456
name = "_" + name
2457
2458
# Variable name may not start with a number
2459
if "0" <= name[0] <= "9":
2460
name = "_" + name
2461
2462
name = name[: min(len(name), 32)]
2463
2464
if not name == orig_name:
2465
# check for duplicates
2466
while columns.count(name) > 0:
2467
# prepend ascending number to avoid duplicates
2468
name = "_" + str(duplicate_var_id) + name
2469
name = name[: min(len(name), 32)]
2470
duplicate_var_id += 1
2471
converted_names[orig_name] = name
2472
2473
columns[j] = name
2474
2475
data.columns = Index(columns)
2476
2477
# Check date conversion, and fix key if needed
2478
if self._convert_dates:
2479
for c, o in zip(columns, original_columns):
2480
if c != o:
2481
self._convert_dates[c] = self._convert_dates[o]
2482
del self._convert_dates[o]
2483
2484
if converted_names:
2485
conversion_warning = []
2486
for orig_name, name in converted_names.items():
2487
msg = f"{orig_name} -> {name}"
2488
conversion_warning.append(msg)
2489
2490
ws = invalid_name_doc.format("\n ".join(conversion_warning))
2491
warnings.warn(ws, InvalidColumnName)
2492
2493
self._converted_names = converted_names
2494
self._update_strl_names()
2495
2496
return data
2497
2498
def _set_formats_and_types(self, dtypes: Series) -> None:
2499
self.fmtlist: list[str] = []
2500
self.typlist: list[int] = []
2501
for col, dtype in dtypes.items():
2502
self.fmtlist.append(_dtype_to_default_stata_fmt(dtype, self.data[col]))
2503
self.typlist.append(_dtype_to_stata_type(dtype, self.data[col]))
2504
2505
def _prepare_pandas(self, data: DataFrame) -> None:
2506
# NOTE: we might need a different API / class for pandas objects so
2507
# we can set different semantics - handle this with a PR to pandas.io
2508
2509
data = data.copy()
2510
2511
if self._write_index:
2512
temp = data.reset_index()
2513
if isinstance(temp, DataFrame):
2514
data = temp
2515
2516
# Ensure column names are strings
2517
data = self._check_column_names(data)
2518
2519
# Check columns for compatibility with stata, upcast if necessary
2520
# Raise if outside the supported range
2521
data = _cast_to_stata_types(data)
2522
2523
# Replace NaNs with Stata missing values
2524
data = self._replace_nans(data)
2525
2526
# Set all columns to initially unlabelled
2527
self._has_value_labels = np.repeat(False, data.shape[1])
2528
2529
# Create value labels for non-categorical data
2530
non_cat_value_labels = self._prepare_non_cat_value_labels(data)
2531
2532
non_cat_columns = [svl.labname for svl in non_cat_value_labels]
2533
has_non_cat_val_labels = data.columns.isin(non_cat_columns)
2534
self._has_value_labels |= has_non_cat_val_labels
2535
self._value_labels.extend(non_cat_value_labels)
2536
2537
# Convert categoricals to int data, and strip labels
2538
data = self._prepare_categoricals(data)
2539
2540
self.nobs, self.nvar = data.shape
2541
self.data = data
2542
self.varlist = data.columns.tolist()
2543
2544
dtypes = data.dtypes
2545
2546
# Ensure all date columns are converted
2547
for col in data:
2548
if col in self._convert_dates:
2549
continue
2550
if is_datetime64_dtype(data[col]):
2551
self._convert_dates[col] = "tc"
2552
2553
self._convert_dates = _maybe_convert_to_int_keys(
2554
self._convert_dates, self.varlist
2555
)
2556
for key in self._convert_dates:
2557
new_type = _convert_datetime_to_stata_type(self._convert_dates[key])
2558
dtypes[key] = np.dtype(new_type)
2559
2560
# Verify object arrays are strings and encode to bytes
2561
self._encode_strings()
2562
2563
self._set_formats_and_types(dtypes)
2564
2565
# set the given format for the datetime cols
2566
if self._convert_dates is not None:
2567
for key in self._convert_dates:
2568
if isinstance(key, int):
2569
self.fmtlist[key] = self._convert_dates[key]
2570
2571
def _encode_strings(self) -> None:
2572
"""
2573
Encode strings in dta-specific encoding
2574
2575
Do not encode columns marked for date conversion or for strL
2576
conversion. The strL converter independently handles conversion and
2577
also accepts empty string arrays.
2578
"""
2579
convert_dates = self._convert_dates
2580
# _convert_strl is not available in dta 114
2581
convert_strl = getattr(self, "_convert_strl", [])
2582
for i, col in enumerate(self.data):
2583
# Skip columns marked for date conversion or strl conversion
2584
if i in convert_dates or col in convert_strl:
2585
continue
2586
column = self.data[col]
2587
dtype = column.dtype
2588
if dtype.type is np.object_:
2589
inferred_dtype = infer_dtype(column, skipna=True)
2590
if not ((inferred_dtype == "string") or len(column) == 0):
2591
col = column.name
2592
raise ValueError(
2593
f"""\
2594
Column `{col}` cannot be exported.\n\nOnly string-like object arrays
2595
containing all strings or a mix of strings and None can be exported.
2596
Object arrays containing only null values are prohibited. Other object
2597
types cannot be exported and must first be converted to one of the
2598
supported types."""
2599
)
2600
encoded = self.data[col].str.encode(self._encoding)
2601
# If larger than _max_string_length do nothing
2602
if (
2603
max_len_string_array(ensure_object(encoded._values))
2604
<= self._max_string_length
2605
):
2606
self.data[col] = encoded
2607
2608
def write_file(self) -> None:
2609
"""
2610
Export DataFrame object to Stata dta format.
2611
"""
2612
with get_handle(
2613
self._fname,
2614
"wb",
2615
compression=self._compression,
2616
is_text=False,
2617
storage_options=self.storage_options,
2618
) as self.handles:
2619
2620
if self.handles.compression["method"] is not None:
2621
# ZipFile creates a file (with the same name) for each write call.
2622
# Write it first into a buffer and then write the buffer to the ZipFile.
2623
self._output_file, self.handles.handle = self.handles.handle, BytesIO()
2624
self.handles.created_handles.append(self.handles.handle)
2625
2626
try:
2627
self._write_header(
2628
data_label=self._data_label, time_stamp=self._time_stamp
2629
)
2630
self._write_map()
2631
self._write_variable_types()
2632
self._write_varnames()
2633
self._write_sortlist()
2634
self._write_formats()
2635
self._write_value_label_names()
2636
self._write_variable_labels()
2637
self._write_expansion_fields()
2638
self._write_characteristics()
2639
records = self._prepare_data()
2640
self._write_data(records)
2641
self._write_strls()
2642
self._write_value_labels()
2643
self._write_file_close_tag()
2644
self._write_map()
2645
self._close()
2646
except Exception as exc:
2647
self.handles.close()
2648
if isinstance(self._fname, (str, os.PathLike)) and os.path.isfile(
2649
self._fname
2650
):
2651
try:
2652
os.unlink(self._fname)
2653
except OSError:
2654
warnings.warn(
2655
f"This save was not successful but {self._fname} could not "
2656
"be deleted. This file is not valid.",
2657
ResourceWarning,
2658
)
2659
raise exc
2660
2661
def _close(self) -> None:
2662
"""
2663
Close the file if it was created by the writer.
2664
2665
If a buffer or file-like object was passed in, for example a GzipFile,
2666
then leave this file open for the caller to close.
2667
"""
2668
# write compression
2669
if self._output_file is not None:
2670
assert isinstance(self.handles.handle, BytesIO)
2671
bio, self.handles.handle = self.handles.handle, self._output_file
2672
self.handles.handle.write(bio.getvalue())
2673
2674
def _write_map(self) -> None:
2675
"""No-op, future compatibility"""
2676
pass
2677
2678
def _write_file_close_tag(self) -> None:
2679
"""No-op, future compatibility"""
2680
pass
2681
2682
def _write_characteristics(self) -> None:
2683
"""No-op, future compatibility"""
2684
pass
2685
2686
def _write_strls(self) -> None:
2687
"""No-op, future compatibility"""
2688
pass
2689
2690
def _write_expansion_fields(self) -> None:
2691
"""Write 5 zeros for expansion fields"""
2692
self._write(_pad_bytes("", 5))
2693
2694
def _write_value_labels(self) -> None:
2695
for vl in self._value_labels:
2696
self._write_bytes(vl.generate_value_label(self._byteorder))
2697
2698
def _write_header(
2699
self,
2700
data_label: str | None = None,
2701
time_stamp: datetime.datetime | None = None,
2702
) -> None:
2703
byteorder = self._byteorder
2704
# ds_format - just use 114
2705
self._write_bytes(struct.pack("b", 114))
2706
# byteorder
2707
self._write(byteorder == ">" and "\x01" or "\x02")
2708
# filetype
2709
self._write("\x01")
2710
# unused
2711
self._write("\x00")
2712
# number of vars, 2 bytes
2713
self._write_bytes(struct.pack(byteorder + "h", self.nvar)[:2])
2714
# number of obs, 4 bytes
2715
self._write_bytes(struct.pack(byteorder + "i", self.nobs)[:4])
2716
# data label 81 bytes, char, null terminated
2717
if data_label is None:
2718
self._write_bytes(self._null_terminate_bytes(_pad_bytes("", 80)))
2719
else:
2720
self._write_bytes(
2721
self._null_terminate_bytes(_pad_bytes(data_label[:80], 80))
2722
)
2723
# time stamp, 18 bytes, char, null terminated
2724
# format dd Mon yyyy hh:mm
2725
if time_stamp is None:
2726
time_stamp = datetime.datetime.now()
2727
elif not isinstance(time_stamp, datetime.datetime):
2728
raise ValueError("time_stamp should be datetime type")
2729
# GH #13856
2730
# Avoid locale-specific month conversion
2731
months = [
2732
"Jan",
2733
"Feb",
2734
"Mar",
2735
"Apr",
2736
"May",
2737
"Jun",
2738
"Jul",
2739
"Aug",
2740
"Sep",
2741
"Oct",
2742
"Nov",
2743
"Dec",
2744
]
2745
month_lookup = {i + 1: month for i, month in enumerate(months)}
2746
ts = (
2747
time_stamp.strftime("%d ")
2748
+ month_lookup[time_stamp.month]
2749
+ time_stamp.strftime(" %Y %H:%M")
2750
)
2751
self._write_bytes(self._null_terminate_bytes(ts))
2752
2753
def _write_variable_types(self) -> None:
2754
for typ in self.typlist:
2755
self._write_bytes(struct.pack("B", typ))
2756
2757
def _write_varnames(self) -> None:
2758
# varlist names are checked by _check_column_names
2759
# varlist, requires null terminated
2760
for name in self.varlist:
2761
name = self._null_terminate_str(name)
2762
name = _pad_bytes(name[:32], 33)
2763
self._write(name)
2764
2765
def _write_sortlist(self) -> None:
2766
# srtlist, 2*(nvar+1), int array, encoded by byteorder
2767
srtlist = _pad_bytes("", 2 * (self.nvar + 1))
2768
self._write(srtlist)
2769
2770
def _write_formats(self) -> None:
2771
# fmtlist, 49*nvar, char array
2772
for fmt in self.fmtlist:
2773
self._write(_pad_bytes(fmt, 49))
2774
2775
def _write_value_label_names(self) -> None:
2776
# lbllist, 33*nvar, char array
2777
for i in range(self.nvar):
2778
# Use variable name when categorical
2779
if self._has_value_labels[i]:
2780
name = self.varlist[i]
2781
name = self._null_terminate_str(name)
2782
name = _pad_bytes(name[:32], 33)
2783
self._write(name)
2784
else: # Default is empty label
2785
self._write(_pad_bytes("", 33))
2786
2787
def _write_variable_labels(self) -> None:
2788
# Missing labels are 80 blank characters plus null termination
2789
blank = _pad_bytes("", 81)
2790
2791
if self._variable_labels is None:
2792
for i in range(self.nvar):
2793
self._write(blank)
2794
return
2795
2796
for col in self.data:
2797
if col in self._variable_labels:
2798
label = self._variable_labels[col]
2799
if len(label) > 80:
2800
raise ValueError("Variable labels must be 80 characters or fewer")
2801
is_latin1 = all(ord(c) < 256 for c in label)
2802
if not is_latin1:
2803
raise ValueError(
2804
"Variable labels must contain only characters that "
2805
"can be encoded in Latin-1"
2806
)
2807
self._write(_pad_bytes(label, 81))
2808
else:
2809
self._write(blank)
2810
2811
def _convert_strls(self, data: DataFrame) -> DataFrame:
2812
"""No-op, future compatibility"""
2813
return data
2814
2815
def _prepare_data(self) -> np.recarray:
2816
data = self.data
2817
typlist = self.typlist
2818
convert_dates = self._convert_dates
2819
# 1. Convert dates
2820
if self._convert_dates is not None:
2821
for i, col in enumerate(data):
2822
if i in convert_dates:
2823
data[col] = _datetime_to_stata_elapsed_vec(
2824
data[col], self.fmtlist[i]
2825
)
2826
# 2. Convert strls
2827
data = self._convert_strls(data)
2828
2829
# 3. Convert bad string data to '' and pad to correct length
2830
dtypes = {}
2831
native_byteorder = self._byteorder == _set_endianness(sys.byteorder)
2832
for i, col in enumerate(data):
2833
typ = typlist[i]
2834
if typ <= self._max_string_length:
2835
data[col] = data[col].fillna("").apply(_pad_bytes, args=(typ,))
2836
stype = f"S{typ}"
2837
dtypes[col] = stype
2838
data[col] = data[col].astype(stype)
2839
else:
2840
dtype = data[col].dtype
2841
if not native_byteorder:
2842
dtype = dtype.newbyteorder(self._byteorder)
2843
dtypes[col] = dtype
2844
2845
return data.to_records(index=False, column_dtypes=dtypes)
2846
2847
def _write_data(self, records: np.recarray) -> None:
2848
self._write_bytes(records.tobytes())
2849
2850
@staticmethod
2851
def _null_terminate_str(s: str) -> str:
2852
s += "\x00"
2853
return s
2854
2855
def _null_terminate_bytes(self, s: str) -> bytes:
2856
return self._null_terminate_str(s).encode(self._encoding)
2857
2858
2859
def _dtype_to_stata_type_117(dtype: np.dtype, column: Series, force_strl: bool) -> int:
2860
"""
2861
Converts dtype types to stata types. Returns the byte of the given ordinal.
2862
See TYPE_MAP and comments for an explanation. This is also explained in
2863
the dta spec.
2864
1 - 2045 are strings of this length
2865
Pandas Stata
2866
32768 - for object strL
2867
65526 - for int8 byte
2868
65527 - for int16 int
2869
65528 - for int32 long
2870
65529 - for float32 float
2871
65530 - for double double
2872
2873
If there are dates to convert, then dtype will already have the correct
2874
type inserted.
2875
"""
2876
# TODO: expand to handle datetime to integer conversion
2877
if force_strl:
2878
return 32768
2879
if dtype.type is np.object_: # try to coerce it to the biggest string
2880
# not memory efficient, what else could we
2881
# do?
2882
itemsize = max_len_string_array(ensure_object(column._values))
2883
itemsize = max(itemsize, 1)
2884
if itemsize <= 2045:
2885
return itemsize
2886
return 32768
2887
elif dtype.type is np.float64:
2888
return 65526
2889
elif dtype.type is np.float32:
2890
return 65527
2891
elif dtype.type is np.int32:
2892
return 65528
2893
elif dtype.type is np.int16:
2894
return 65529
2895
elif dtype.type is np.int8:
2896
return 65530
2897
else: # pragma : no cover
2898
raise NotImplementedError(f"Data type {dtype} not supported.")
2899
2900
2901
def _pad_bytes_new(name: str | bytes, length: int) -> bytes:
2902
"""
2903
Takes a bytes instance and pads it with null bytes until it's length chars.
2904
"""
2905
if isinstance(name, str):
2906
name = bytes(name, "utf-8")
2907
return name + b"\x00" * (length - len(name))
2908
2909
2910
class StataStrLWriter:
2911
"""
2912
Converter for Stata StrLs
2913
2914
Stata StrLs map 8 byte values to strings which are stored using a
2915
dictionary-like format where strings are keyed to two values.
2916
2917
Parameters
2918
----------
2919
df : DataFrame
2920
DataFrame to convert
2921
columns : Sequence[str]
2922
List of columns names to convert to StrL
2923
version : int, optional
2924
dta version. Currently supports 117, 118 and 119
2925
byteorder : str, optional
2926
Can be ">", "<", "little", or "big". default is `sys.byteorder`
2927
2928
Notes
2929
-----
2930
Supports creation of the StrL block of a dta file for dta versions
2931
117, 118 and 119. These differ in how the GSO is stored. 118 and
2932
119 store the GSO lookup value as a uint32 and a uint64, while 117
2933
uses two uint32s. 118 and 119 also encode all strings as unicode
2934
which is required by the format. 117 uses 'latin-1' a fixed width
2935
encoding that extends the 7-bit ascii table with an additional 128
2936
characters.
2937
"""
2938
2939
def __init__(
2940
self,
2941
df: DataFrame,
2942
columns: Sequence[str],
2943
version: int = 117,
2944
byteorder: str | None = None,
2945
):
2946
if version not in (117, 118, 119):
2947
raise ValueError("Only dta versions 117, 118 and 119 supported")
2948
self._dta_ver = version
2949
2950
self.df = df
2951
self.columns = columns
2952
self._gso_table = {"": (0, 0)}
2953
if byteorder is None:
2954
byteorder = sys.byteorder
2955
self._byteorder = _set_endianness(byteorder)
2956
2957
gso_v_type = "I" # uint32
2958
gso_o_type = "Q" # uint64
2959
self._encoding = "utf-8"
2960
if version == 117:
2961
o_size = 4
2962
gso_o_type = "I" # 117 used uint32
2963
self._encoding = "latin-1"
2964
elif version == 118:
2965
o_size = 6
2966
else: # version == 119
2967
o_size = 5
2968
self._o_offet = 2 ** (8 * (8 - o_size))
2969
self._gso_o_type = gso_o_type
2970
self._gso_v_type = gso_v_type
2971
2972
def _convert_key(self, key: tuple[int, int]) -> int:
2973
v, o = key
2974
return v + self._o_offet * o
2975
2976
def generate_table(self) -> tuple[dict[str, tuple[int, int]], DataFrame]:
2977
"""
2978
Generates the GSO lookup table for the DataFrame
2979
2980
Returns
2981
-------
2982
gso_table : dict
2983
Ordered dictionary using the string found as keys
2984
and their lookup position (v,o) as values
2985
gso_df : DataFrame
2986
DataFrame where strl columns have been converted to
2987
(v,o) values
2988
2989
Notes
2990
-----
2991
Modifies the DataFrame in-place.
2992
2993
The DataFrame returned encodes the (v,o) values as uint64s. The
2994
encoding depends on the dta version, and can be expressed as
2995
2996
enc = v + o * 2 ** (o_size * 8)
2997
2998
so that v is stored in the lower bits and o is in the upper
2999
bits. o_size is
3000
3001
* 117: 4
3002
* 118: 6
3003
* 119: 5
3004
"""
3005
gso_table = self._gso_table
3006
gso_df = self.df
3007
columns = list(gso_df.columns)
3008
selected = gso_df[self.columns]
3009
col_index = [(col, columns.index(col)) for col in self.columns]
3010
keys = np.empty(selected.shape, dtype=np.uint64)
3011
for o, (idx, row) in enumerate(selected.iterrows()):
3012
for j, (col, v) in enumerate(col_index):
3013
val = row[col]
3014
# Allow columns with mixed str and None (GH 23633)
3015
val = "" if val is None else val
3016
key = gso_table.get(val, None)
3017
if key is None:
3018
# Stata prefers human numbers
3019
key = (v + 1, o + 1)
3020
gso_table[val] = key
3021
keys[o, j] = self._convert_key(key)
3022
for i, col in enumerate(self.columns):
3023
gso_df[col] = keys[:, i]
3024
3025
return gso_table, gso_df
3026
3027
def generate_blob(self, gso_table: dict[str, tuple[int, int]]) -> bytes:
3028
"""
3029
Generates the binary blob of GSOs that is written to the dta file.
3030
3031
Parameters
3032
----------
3033
gso_table : dict
3034
Ordered dictionary (str, vo)
3035
3036
Returns
3037
-------
3038
gso : bytes
3039
Binary content of dta file to be placed between strl tags
3040
3041
Notes
3042
-----
3043
Output format depends on dta version. 117 uses two uint32s to
3044
express v and o while 118+ uses a uint32 for v and a uint64 for o.
3045
"""
3046
# Format information
3047
# Length includes null term
3048
# 117
3049
# GSOvvvvooootllllxxxxxxxxxxxxxxx...x
3050
# 3 u4 u4 u1 u4 string + null term
3051
#
3052
# 118, 119
3053
# GSOvvvvooooooootllllxxxxxxxxxxxxxxx...x
3054
# 3 u4 u8 u1 u4 string + null term
3055
3056
bio = BytesIO()
3057
gso = bytes("GSO", "ascii")
3058
gso_type = struct.pack(self._byteorder + "B", 130)
3059
null = struct.pack(self._byteorder + "B", 0)
3060
v_type = self._byteorder + self._gso_v_type
3061
o_type = self._byteorder + self._gso_o_type
3062
len_type = self._byteorder + "I"
3063
for strl, vo in gso_table.items():
3064
if vo == (0, 0):
3065
continue
3066
v, o = vo
3067
3068
# GSO
3069
bio.write(gso)
3070
3071
# vvvv
3072
bio.write(struct.pack(v_type, v))
3073
3074
# oooo / oooooooo
3075
bio.write(struct.pack(o_type, o))
3076
3077
# t
3078
bio.write(gso_type)
3079
3080
# llll
3081
utf8_string = bytes(strl, "utf-8")
3082
bio.write(struct.pack(len_type, len(utf8_string) + 1))
3083
3084
# xxx...xxx
3085
bio.write(utf8_string)
3086
bio.write(null)
3087
3088
return bio.getvalue()
3089
3090
3091
class StataWriter117(StataWriter):
3092
"""
3093
A class for writing Stata binary dta files in Stata 13 format (117)
3094
3095
Parameters
3096
----------
3097
fname : path (string), buffer or path object
3098
string, path object (pathlib.Path or py._path.local.LocalPath) or
3099
object implementing a binary write() functions. If using a buffer
3100
then the buffer will not be automatically closed after the file
3101
is written.
3102
data : DataFrame
3103
Input to save
3104
convert_dates : dict
3105
Dictionary mapping columns containing datetime types to stata internal
3106
format to use when writing the dates. Options are 'tc', 'td', 'tm',
3107
'tw', 'th', 'tq', 'ty'. Column can be either an integer or a name.
3108
Datetime columns that do not have a conversion type specified will be
3109
converted to 'tc'. Raises NotImplementedError if a datetime column has
3110
timezone information
3111
write_index : bool
3112
Write the index to Stata dataset.
3113
byteorder : str
3114
Can be ">", "<", "little", or "big". default is `sys.byteorder`
3115
time_stamp : datetime
3116
A datetime to use as file creation date. Default is the current time
3117
data_label : str
3118
A label for the data set. Must be 80 characters or smaller.
3119
variable_labels : dict
3120
Dictionary containing columns as keys and variable labels as values.
3121
Each label must be 80 characters or smaller.
3122
convert_strl : list
3123
List of columns names to convert to Stata StrL format. Columns with
3124
more than 2045 characters are automatically written as StrL.
3125
Smaller columns can be converted by including the column name. Using
3126
StrLs can reduce output file size when strings are longer than 8
3127
characters, and either frequently repeated or sparse.
3128
{compression_options}
3129
3130
.. versionadded:: 1.1.0
3131
3132
.. versionchanged:: 1.4.0 Zstandard support.
3133
3134
value_labels : dict of dicts
3135
Dictionary containing columns as keys and dictionaries of column value
3136
to labels as values. The combined length of all labels for a single
3137
variable must be 32,000 characters or smaller.
3138
3139
.. versionadded:: 1.4.0
3140
3141
Returns
3142
-------
3143
writer : StataWriter117 instance
3144
The StataWriter117 instance has a write_file method, which will
3145
write the file to the given `fname`.
3146
3147
Raises
3148
------
3149
NotImplementedError
3150
* If datetimes contain timezone information
3151
ValueError
3152
* Columns listed in convert_dates are neither datetime64[ns]
3153
or datetime.datetime
3154
* Column dtype is not representable in Stata
3155
* Column listed in convert_dates is not in DataFrame
3156
* Categorical label contains more than 32,000 characters
3157
3158
Examples
3159
--------
3160
>>> from pandas.io.stata import StataWriter117
3161
>>> data = pd.DataFrame([[1.0, 1, 'a']], columns=['a', 'b', 'c'])
3162
>>> writer = StataWriter117('./data_file.dta', data)
3163
>>> writer.write_file()
3164
3165
Directly write a zip file
3166
>>> compression = {"method": "zip", "archive_name": "data_file.dta"}
3167
>>> writer = StataWriter117('./data_file.zip', data, compression=compression)
3168
>>> writer.write_file()
3169
3170
Or with long strings stored in strl format
3171
>>> data = pd.DataFrame([['A relatively long string'], [''], ['']],
3172
... columns=['strls'])
3173
>>> writer = StataWriter117('./data_file_with_long_strings.dta', data,
3174
... convert_strl=['strls'])
3175
>>> writer.write_file()
3176
"""
3177
3178
_max_string_length = 2045
3179
_dta_version = 117
3180
3181
def __init__(
3182
self,
3183
fname: FilePath | WriteBuffer[bytes],
3184
data: DataFrame,
3185
convert_dates: dict[Hashable, str] | None = None,
3186
write_index: bool = True,
3187
byteorder: str | None = None,
3188
time_stamp: datetime.datetime | None = None,
3189
data_label: str | None = None,
3190
variable_labels: dict[Hashable, str] | None = None,
3191
convert_strl: Sequence[Hashable] | None = None,
3192
compression: CompressionOptions = "infer",
3193
storage_options: StorageOptions = None,
3194
*,
3195
value_labels: dict[Hashable, dict[float | int, str]] | None = None,
3196
):
3197
# Copy to new list since convert_strl might be modified later
3198
self._convert_strl: list[Hashable] = []
3199
if convert_strl is not None:
3200
self._convert_strl.extend(convert_strl)
3201
3202
super().__init__(
3203
fname,
3204
data,
3205
convert_dates,
3206
write_index,
3207
byteorder=byteorder,
3208
time_stamp=time_stamp,
3209
data_label=data_label,
3210
variable_labels=variable_labels,
3211
value_labels=value_labels,
3212
compression=compression,
3213
storage_options=storage_options,
3214
)
3215
self._map: dict[str, int] = {}
3216
self._strl_blob = b""
3217
3218
@staticmethod
3219
def _tag(val: str | bytes, tag: str) -> bytes:
3220
"""Surround val with <tag></tag>"""
3221
if isinstance(val, str):
3222
val = bytes(val, "utf-8")
3223
return bytes("<" + tag + ">", "utf-8") + val + bytes("</" + tag + ">", "utf-8")
3224
3225
def _update_map(self, tag: str) -> None:
3226
"""Update map location for tag with file position"""
3227
assert self.handles.handle is not None
3228
self._map[tag] = self.handles.handle.tell()
3229
3230
def _write_header(
3231
self,
3232
data_label: str | None = None,
3233
time_stamp: datetime.datetime | None = None,
3234
) -> None:
3235
"""Write the file header"""
3236
byteorder = self._byteorder
3237
self._write_bytes(bytes("<stata_dta>", "utf-8"))
3238
bio = BytesIO()
3239
# ds_format - 117
3240
bio.write(self._tag(bytes(str(self._dta_version), "utf-8"), "release"))
3241
# byteorder
3242
bio.write(self._tag(byteorder == ">" and "MSF" or "LSF", "byteorder"))
3243
# number of vars, 2 bytes in 117 and 118, 4 byte in 119
3244
nvar_type = "H" if self._dta_version <= 118 else "I"
3245
bio.write(self._tag(struct.pack(byteorder + nvar_type, self.nvar), "K"))
3246
# 117 uses 4 bytes, 118 uses 8
3247
nobs_size = "I" if self._dta_version == 117 else "Q"
3248
bio.write(self._tag(struct.pack(byteorder + nobs_size, self.nobs), "N"))
3249
# data label 81 bytes, char, null terminated
3250
label = data_label[:80] if data_label is not None else ""
3251
encoded_label = label.encode(self._encoding)
3252
label_size = "B" if self._dta_version == 117 else "H"
3253
label_len = struct.pack(byteorder + label_size, len(encoded_label))
3254
encoded_label = label_len + encoded_label
3255
bio.write(self._tag(encoded_label, "label"))
3256
# time stamp, 18 bytes, char, null terminated
3257
# format dd Mon yyyy hh:mm
3258
if time_stamp is None:
3259
time_stamp = datetime.datetime.now()
3260
elif not isinstance(time_stamp, datetime.datetime):
3261
raise ValueError("time_stamp should be datetime type")
3262
# Avoid locale-specific month conversion
3263
months = [
3264
"Jan",
3265
"Feb",
3266
"Mar",
3267
"Apr",
3268
"May",
3269
"Jun",
3270
"Jul",
3271
"Aug",
3272
"Sep",
3273
"Oct",
3274
"Nov",
3275
"Dec",
3276
]
3277
month_lookup = {i + 1: month for i, month in enumerate(months)}
3278
ts = (
3279
time_stamp.strftime("%d ")
3280
+ month_lookup[time_stamp.month]
3281
+ time_stamp.strftime(" %Y %H:%M")
3282
)
3283
# '\x11' added due to inspection of Stata file
3284
stata_ts = b"\x11" + bytes(ts, "utf-8")
3285
bio.write(self._tag(stata_ts, "timestamp"))
3286
self._write_bytes(self._tag(bio.getvalue(), "header"))
3287
3288
def _write_map(self) -> None:
3289
"""
3290
Called twice during file write. The first populates the values in
3291
the map with 0s. The second call writes the final map locations when
3292
all blocks have been written.
3293
"""
3294
if not self._map:
3295
self._map = {
3296
"stata_data": 0,
3297
"map": self.handles.handle.tell(),
3298
"variable_types": 0,
3299
"varnames": 0,
3300
"sortlist": 0,
3301
"formats": 0,
3302
"value_label_names": 0,
3303
"variable_labels": 0,
3304
"characteristics": 0,
3305
"data": 0,
3306
"strls": 0,
3307
"value_labels": 0,
3308
"stata_data_close": 0,
3309
"end-of-file": 0,
3310
}
3311
# Move to start of map
3312
self.handles.handle.seek(self._map["map"])
3313
bio = BytesIO()
3314
for val in self._map.values():
3315
bio.write(struct.pack(self._byteorder + "Q", val))
3316
self._write_bytes(self._tag(bio.getvalue(), "map"))
3317
3318
def _write_variable_types(self) -> None:
3319
self._update_map("variable_types")
3320
bio = BytesIO()
3321
for typ in self.typlist:
3322
bio.write(struct.pack(self._byteorder + "H", typ))
3323
self._write_bytes(self._tag(bio.getvalue(), "variable_types"))
3324
3325
def _write_varnames(self) -> None:
3326
self._update_map("varnames")
3327
bio = BytesIO()
3328
# 118 scales by 4 to accommodate utf-8 data worst case encoding
3329
vn_len = 32 if self._dta_version == 117 else 128
3330
for name in self.varlist:
3331
name = self._null_terminate_str(name)
3332
name = _pad_bytes_new(name[:32].encode(self._encoding), vn_len + 1)
3333
bio.write(name)
3334
self._write_bytes(self._tag(bio.getvalue(), "varnames"))
3335
3336
def _write_sortlist(self) -> None:
3337
self._update_map("sortlist")
3338
sort_size = 2 if self._dta_version < 119 else 4
3339
self._write_bytes(self._tag(b"\x00" * sort_size * (self.nvar + 1), "sortlist"))
3340
3341
def _write_formats(self) -> None:
3342
self._update_map("formats")
3343
bio = BytesIO()
3344
fmt_len = 49 if self._dta_version == 117 else 57
3345
for fmt in self.fmtlist:
3346
bio.write(_pad_bytes_new(fmt.encode(self._encoding), fmt_len))
3347
self._write_bytes(self._tag(bio.getvalue(), "formats"))
3348
3349
def _write_value_label_names(self) -> None:
3350
self._update_map("value_label_names")
3351
bio = BytesIO()
3352
# 118 scales by 4 to accommodate utf-8 data worst case encoding
3353
vl_len = 32 if self._dta_version == 117 else 128
3354
for i in range(self.nvar):
3355
# Use variable name when categorical
3356
name = "" # default name
3357
if self._has_value_labels[i]:
3358
name = self.varlist[i]
3359
name = self._null_terminate_str(name)
3360
encoded_name = _pad_bytes_new(name[:32].encode(self._encoding), vl_len + 1)
3361
bio.write(encoded_name)
3362
self._write_bytes(self._tag(bio.getvalue(), "value_label_names"))
3363
3364
def _write_variable_labels(self) -> None:
3365
# Missing labels are 80 blank characters plus null termination
3366
self._update_map("variable_labels")
3367
bio = BytesIO()
3368
# 118 scales by 4 to accommodate utf-8 data worst case encoding
3369
vl_len = 80 if self._dta_version == 117 else 320
3370
blank = _pad_bytes_new("", vl_len + 1)
3371
3372
if self._variable_labels is None:
3373
for _ in range(self.nvar):
3374
bio.write(blank)
3375
self._write_bytes(self._tag(bio.getvalue(), "variable_labels"))
3376
return
3377
3378
for col in self.data:
3379
if col in self._variable_labels:
3380
label = self._variable_labels[col]
3381
if len(label) > 80:
3382
raise ValueError("Variable labels must be 80 characters or fewer")
3383
try:
3384
encoded = label.encode(self._encoding)
3385
except UnicodeEncodeError as err:
3386
raise ValueError(
3387
"Variable labels must contain only characters that "
3388
f"can be encoded in {self._encoding}"
3389
) from err
3390
3391
bio.write(_pad_bytes_new(encoded, vl_len + 1))
3392
else:
3393
bio.write(blank)
3394
self._write_bytes(self._tag(bio.getvalue(), "variable_labels"))
3395
3396
def _write_characteristics(self) -> None:
3397
self._update_map("characteristics")
3398
self._write_bytes(self._tag(b"", "characteristics"))
3399
3400
def _write_data(self, records) -> None:
3401
self._update_map("data")
3402
self._write_bytes(b"<data>")
3403
self._write_bytes(records.tobytes())
3404
self._write_bytes(b"</data>")
3405
3406
def _write_strls(self) -> None:
3407
self._update_map("strls")
3408
self._write_bytes(self._tag(self._strl_blob, "strls"))
3409
3410
def _write_expansion_fields(self) -> None:
3411
"""No-op in dta 117+"""
3412
pass
3413
3414
def _write_value_labels(self) -> None:
3415
self._update_map("value_labels")
3416
bio = BytesIO()
3417
for vl in self._value_labels:
3418
lab = vl.generate_value_label(self._byteorder)
3419
lab = self._tag(lab, "lbl")
3420
bio.write(lab)
3421
self._write_bytes(self._tag(bio.getvalue(), "value_labels"))
3422
3423
def _write_file_close_tag(self) -> None:
3424
self._update_map("stata_data_close")
3425
self._write_bytes(bytes("</stata_dta>", "utf-8"))
3426
self._update_map("end-of-file")
3427
3428
def _update_strl_names(self) -> None:
3429
"""
3430
Update column names for conversion to strl if they might have been
3431
changed to comply with Stata naming rules
3432
"""
3433
# Update convert_strl if names changed
3434
for orig, new in self._converted_names.items():
3435
if orig in self._convert_strl:
3436
idx = self._convert_strl.index(orig)
3437
self._convert_strl[idx] = new
3438
3439
def _convert_strls(self, data: DataFrame) -> DataFrame:
3440
"""
3441
Convert columns to StrLs if either very large or in the
3442
convert_strl variable
3443
"""
3444
convert_cols = [
3445
col
3446
for i, col in enumerate(data)
3447
if self.typlist[i] == 32768 or col in self._convert_strl
3448
]
3449
3450
if convert_cols:
3451
ssw = StataStrLWriter(data, convert_cols, version=self._dta_version)
3452
tab, new_data = ssw.generate_table()
3453
data = new_data
3454
self._strl_blob = ssw.generate_blob(tab)
3455
return data
3456
3457
def _set_formats_and_types(self, dtypes: Series) -> None:
3458
self.typlist = []
3459
self.fmtlist = []
3460
for col, dtype in dtypes.items():
3461
force_strl = col in self._convert_strl
3462
fmt = _dtype_to_default_stata_fmt(
3463
dtype,
3464
self.data[col],
3465
dta_version=self._dta_version,
3466
force_strl=force_strl,
3467
)
3468
self.fmtlist.append(fmt)
3469
self.typlist.append(
3470
_dtype_to_stata_type_117(dtype, self.data[col], force_strl)
3471
)
3472
3473
3474
class StataWriterUTF8(StataWriter117):
3475
"""
3476
Stata binary dta file writing in Stata 15 (118) and 16 (119) formats
3477
3478
DTA 118 and 119 format files support unicode string data (both fixed
3479
and strL) format. Unicode is also supported in value labels, variable
3480
labels and the dataset label. Format 119 is automatically used if the
3481
file contains more than 32,767 variables.
3482
3483
.. versionadded:: 1.0.0
3484
3485
Parameters
3486
----------
3487
fname : path (string), buffer or path object
3488
string, path object (pathlib.Path or py._path.local.LocalPath) or
3489
object implementing a binary write() functions. If using a buffer
3490
then the buffer will not be automatically closed after the file
3491
is written.
3492
data : DataFrame
3493
Input to save
3494
convert_dates : dict, default None
3495
Dictionary mapping columns containing datetime types to stata internal
3496
format to use when writing the dates. Options are 'tc', 'td', 'tm',
3497
'tw', 'th', 'tq', 'ty'. Column can be either an integer or a name.
3498
Datetime columns that do not have a conversion type specified will be
3499
converted to 'tc'. Raises NotImplementedError if a datetime column has
3500
timezone information
3501
write_index : bool, default True
3502
Write the index to Stata dataset.
3503
byteorder : str, default None
3504
Can be ">", "<", "little", or "big". default is `sys.byteorder`
3505
time_stamp : datetime, default None
3506
A datetime to use as file creation date. Default is the current time
3507
data_label : str, default None
3508
A label for the data set. Must be 80 characters or smaller.
3509
variable_labels : dict, default None
3510
Dictionary containing columns as keys and variable labels as values.
3511
Each label must be 80 characters or smaller.
3512
convert_strl : list, default None
3513
List of columns names to convert to Stata StrL format. Columns with
3514
more than 2045 characters are automatically written as StrL.
3515
Smaller columns can be converted by including the column name. Using
3516
StrLs can reduce output file size when strings are longer than 8
3517
characters, and either frequently repeated or sparse.
3518
version : int, default None
3519
The dta version to use. By default, uses the size of data to determine
3520
the version. 118 is used if data.shape[1] <= 32767, and 119 is used
3521
for storing larger DataFrames.
3522
{compression_options}
3523
3524
.. versionadded:: 1.1.0
3525
3526
.. versionchanged:: 1.4.0 Zstandard support.
3527
3528
value_labels : dict of dicts
3529
Dictionary containing columns as keys and dictionaries of column value
3530
to labels as values. The combined length of all labels for a single
3531
variable must be 32,000 characters or smaller.
3532
3533
.. versionadded:: 1.4.0
3534
3535
Returns
3536
-------
3537
StataWriterUTF8
3538
The instance has a write_file method, which will write the file to the
3539
given `fname`.
3540
3541
Raises
3542
------
3543
NotImplementedError
3544
* If datetimes contain timezone information
3545
ValueError
3546
* Columns listed in convert_dates are neither datetime64[ns]
3547
or datetime.datetime
3548
* Column dtype is not representable in Stata
3549
* Column listed in convert_dates is not in DataFrame
3550
* Categorical label contains more than 32,000 characters
3551
3552
Examples
3553
--------
3554
Using Unicode data and column names
3555
3556
>>> from pandas.io.stata import StataWriterUTF8
3557
>>> data = pd.DataFrame([[1.0, 1, 'ᴬ']], columns=['a', 'β', 'ĉ'])
3558
>>> writer = StataWriterUTF8('./data_file.dta', data)
3559
>>> writer.write_file()
3560
3561
Directly write a zip file
3562
>>> compression = {"method": "zip", "archive_name": "data_file.dta"}
3563
>>> writer = StataWriterUTF8('./data_file.zip', data, compression=compression)
3564
>>> writer.write_file()
3565
3566
Or with long strings stored in strl format
3567
3568
>>> data = pd.DataFrame([['ᴀ relatively long ŝtring'], [''], ['']],
3569
... columns=['strls'])
3570
>>> writer = StataWriterUTF8('./data_file_with_long_strings.dta', data,
3571
... convert_strl=['strls'])
3572
>>> writer.write_file()
3573
"""
3574
3575
_encoding = "utf-8"
3576
3577
def __init__(
3578
self,
3579
fname: FilePath | WriteBuffer[bytes],
3580
data: DataFrame,
3581
convert_dates: dict[Hashable, str] | None = None,
3582
write_index: bool = True,
3583
byteorder: str | None = None,
3584
time_stamp: datetime.datetime | None = None,
3585
data_label: str | None = None,
3586
variable_labels: dict[Hashable, str] | None = None,
3587
convert_strl: Sequence[Hashable] | None = None,
3588
version: int | None = None,
3589
compression: CompressionOptions = "infer",
3590
storage_options: StorageOptions = None,
3591
*,
3592
value_labels: dict[Hashable, dict[float | int, str]] | None = None,
3593
):
3594
if version is None:
3595
version = 118 if data.shape[1] <= 32767 else 119
3596
elif version not in (118, 119):
3597
raise ValueError("version must be either 118 or 119.")
3598
elif version == 118 and data.shape[1] > 32767:
3599
raise ValueError(
3600
"You must use version 119 for data sets containing more than"
3601
"32,767 variables"
3602
)
3603
3604
super().__init__(
3605
fname,
3606
data,
3607
convert_dates=convert_dates,
3608
write_index=write_index,
3609
byteorder=byteorder,
3610
time_stamp=time_stamp,
3611
data_label=data_label,
3612
variable_labels=variable_labels,
3613
value_labels=value_labels,
3614
convert_strl=convert_strl,
3615
compression=compression,
3616
storage_options=storage_options,
3617
)
3618
# Override version set in StataWriter117 init
3619
self._dta_version = version
3620
3621
def _validate_variable_name(self, name: str) -> str:
3622
"""
3623
Validate variable names for Stata export.
3624
3625
Parameters
3626
----------
3627
name : str
3628
Variable name
3629
3630
Returns
3631
-------
3632
str
3633
The validated name with invalid characters replaced with
3634
underscores.
3635
3636
Notes
3637
-----
3638
Stata 118+ support most unicode characters. The only limitation is in
3639
the ascii range where the characters supported are a-z, A-Z, 0-9 and _.
3640
"""
3641
# High code points appear to be acceptable
3642
for c in name:
3643
if (
3644
ord(c) < 128
3645
and (c < "A" or c > "Z")
3646
and (c < "a" or c > "z")
3647
and (c < "0" or c > "9")
3648
and c != "_"
3649
) or 128 <= ord(c) < 256:
3650
name = name.replace(c, "_")
3651
3652
return name
3653
3654