summaryrefslogtreecommitdiffstats
path: root/third_party/python/cbor2/cbor2/encoder.py
diff options
context:
space:
mode:
Diffstat (limited to 'third_party/python/cbor2/cbor2/encoder.py')
-rw-r--r--third_party/python/cbor2/cbor2/encoder.py362
1 files changed, 362 insertions, 0 deletions
diff --git a/third_party/python/cbor2/cbor2/encoder.py b/third_party/python/cbor2/cbor2/encoder.py
new file mode 100644
index 0000000000..adcb2722e5
--- /dev/null
+++ b/third_party/python/cbor2/cbor2/encoder.py
@@ -0,0 +1,362 @@
+import re
+import struct
+from collections import OrderedDict, defaultdict
+from contextlib import contextmanager
+from functools import wraps
+from datetime import datetime, date, time
+from io import BytesIO
+
+from cbor2.compat import iteritems, timezone, long, unicode, as_unicode, bytes_from_list
+from cbor2.types import CBORTag, undefined, CBORSimpleValue
+
+
+class CBOREncodeError(Exception):
+ """Raised when an error occurs while serializing an object into a CBOR datastream."""
+
+
+def shareable_encoder(func):
+ """
+ Wrap the given encoder function to gracefully handle cyclic data structures.
+
+ If value sharing is enabled, this marks the given value shared in the datastream on the
+ first call. If the value has already been passed to this method, a reference marker is
+ instead written to the data stream and the wrapped function is not called.
+
+ If value sharing is disabled, only infinite recursion protection is done.
+
+ """
+ @wraps(func)
+ def wrapper(encoder, value, *args, **kwargs):
+ value_id = id(value)
+ container, container_index = encoder._shared_containers.get(value_id, (None, None))
+ if encoder.value_sharing:
+ if container is value:
+ # Generate a reference to the previous index instead of encoding this again
+ encoder.write(encode_length(0xd8, 0x1d))
+ encode_int(encoder, container_index)
+ else:
+ # Mark the container as shareable
+ encoder._shared_containers[value_id] = (value, len(encoder._shared_containers))
+ encoder.write(encode_length(0xd8, 0x1c))
+ func(encoder, value, *args, **kwargs)
+ else:
+ if container is value:
+ raise CBOREncodeError('cyclic data structure detected but value sharing is '
+ 'disabled')
+ else:
+ encoder._shared_containers[value_id] = (value, None)
+ func(encoder, value, *args, **kwargs)
+ del encoder._shared_containers[value_id]
+
+ return wrapper
+
+
+def encode_length(major_tag, length):
+ if length < 24:
+ return struct.pack('>B', major_tag | length)
+ elif length < 256:
+ return struct.pack('>BB', major_tag | 24, length)
+ elif length < 65536:
+ return struct.pack('>BH', major_tag | 25, length)
+ elif length < 4294967296:
+ return struct.pack('>BL', major_tag | 26, length)
+ else:
+ return struct.pack('>BQ', major_tag | 27, length)
+
+
+def encode_int(encoder, value):
+ # Big integers (2 ** 64 and over)
+ if value >= 18446744073709551616 or value < -18446744073709551616:
+ if value >= 0:
+ major_type = 0x02
+ else:
+ major_type = 0x03
+ value = -value - 1
+
+ values = []
+ while value > 0:
+ value, remainder = divmod(value, 256)
+ values.insert(0, remainder)
+
+ payload = bytes_from_list(values)
+ encode_semantic(encoder, CBORTag(major_type, payload))
+ elif value >= 0:
+ encoder.write(encode_length(0, value))
+ else:
+ encoder.write(encode_length(0x20, abs(value) - 1))
+
+
+def encode_bytestring(encoder, value):
+ encoder.write(encode_length(0x40, len(value)) + value)
+
+
+def encode_bytearray(encoder, value):
+ encode_bytestring(encoder, bytes(value))
+
+
+def encode_string(encoder, value):
+ encoded = value.encode('utf-8')
+ encoder.write(encode_length(0x60, len(encoded)) + encoded)
+
+
+@shareable_encoder
+def encode_array(encoder, value):
+ encoder.write(encode_length(0x80, len(value)))
+ for item in value:
+ encoder.encode(item)
+
+
+@shareable_encoder
+def encode_map(encoder, value):
+ encoder.write(encode_length(0xa0, len(value)))
+ for key, val in iteritems(value):
+ encoder.encode(key)
+ encoder.encode(val)
+
+
+def encode_semantic(encoder, value):
+ encoder.write(encode_length(0xc0, value.tag))
+ encoder.encode(value.value)
+
+
+#
+# Semantic decoders (major tag 6)
+#
+
+def encode_datetime(encoder, value):
+ # Semantic tag 0
+ if not value.tzinfo:
+ if encoder.timezone:
+ value = value.replace(tzinfo=encoder.timezone)
+ else:
+ raise CBOREncodeError(
+ 'naive datetime encountered and no default timezone has been set')
+
+ if encoder.datetime_as_timestamp:
+ from calendar import timegm
+ timestamp = timegm(value.utctimetuple()) + value.microsecond // 1000000
+ encode_semantic(encoder, CBORTag(1, timestamp))
+ else:
+ datestring = as_unicode(value.isoformat().replace('+00:00', 'Z'))
+ encode_semantic(encoder, CBORTag(0, datestring))
+
+
+def encode_date(encoder, value):
+ value = datetime.combine(value, time()).replace(tzinfo=timezone.utc)
+ encode_datetime(encoder, value)
+
+
+def encode_decimal(encoder, value):
+ # Semantic tag 4
+ if value.is_nan():
+ encoder.write(b'\xf9\x7e\x00')
+ elif value.is_infinite():
+ encoder.write(b'\xf9\x7c\x00' if value > 0 else b'\xf9\xfc\x00')
+ else:
+ dt = value.as_tuple()
+ mantissa = sum(d * 10 ** i for i, d in enumerate(reversed(dt.digits)))
+ with encoder.disable_value_sharing():
+ encode_semantic(encoder, CBORTag(4, [dt.exponent, mantissa]))
+
+
+def encode_rational(encoder, value):
+ # Semantic tag 30
+ with encoder.disable_value_sharing():
+ encode_semantic(encoder, CBORTag(30, [value.numerator, value.denominator]))
+
+
+def encode_regexp(encoder, value):
+ # Semantic tag 35
+ encode_semantic(encoder, CBORTag(35, as_unicode(value.pattern)))
+
+
+def encode_mime(encoder, value):
+ # Semantic tag 36
+ encode_semantic(encoder, CBORTag(36, as_unicode(value.as_string())))
+
+
+def encode_uuid(encoder, value):
+ # Semantic tag 37
+ encode_semantic(encoder, CBORTag(37, value.bytes))
+
+
+#
+# Special encoders (major tag 7)
+#
+
+def encode_simple_value(encoder, value):
+ if value.value < 20:
+ encoder.write(struct.pack('>B', 0xe0 | value.value))
+ else:
+ encoder.write(struct.pack('>BB', 0xf8, value.value))
+
+
+def encode_float(encoder, value):
+ # Handle special values efficiently
+ import math
+ if math.isnan(value):
+ encoder.write(b'\xf9\x7e\x00')
+ elif math.isinf(value):
+ encoder.write(b'\xf9\x7c\x00' if value > 0 else b'\xf9\xfc\x00')
+ else:
+ encoder.write(struct.pack('>Bd', 0xfb, value))
+
+
+def encode_boolean(encoder, value):
+ encoder.write(b'\xf5' if value else b'\xf4')
+
+
+def encode_none(encoder, value):
+ encoder.write(b'\xf6')
+
+
+def encode_undefined(encoder, value):
+ encoder.write(b'\xf7')
+
+
+default_encoders = OrderedDict([
+ (bytes, encode_bytestring),
+ (bytearray, encode_bytearray),
+ (unicode, encode_string),
+ (int, encode_int),
+ (long, encode_int),
+ (float, encode_float),
+ (('decimal', 'Decimal'), encode_decimal),
+ (bool, encode_boolean),
+ (type(None), encode_none),
+ (tuple, encode_array),
+ (list, encode_array),
+ (dict, encode_map),
+ (defaultdict, encode_map),
+ (OrderedDict, encode_map),
+ (type(undefined), encode_undefined),
+ (datetime, encode_datetime),
+ (date, encode_date),
+ (type(re.compile('')), encode_regexp),
+ (('fractions', 'Fraction'), encode_rational),
+ (('email.message', 'Message'), encode_mime),
+ (('uuid', 'UUID'), encode_uuid),
+ (CBORSimpleValue, encode_simple_value),
+ (CBORTag, encode_semantic)
+])
+
+
+class CBOREncoder(object):
+ """
+ Serializes objects to a byte stream using Concise Binary Object Representation.
+
+ :param datetime_as_timestamp: set to ``True`` to serialize datetimes as UNIX timestamps
+ (this makes datetimes more concise on the wire but loses the time zone information)
+ :param datetime.tzinfo timezone: the default timezone to use for serializing naive datetimes
+ :param value_sharing: if ``True``, allows more efficient serializing of repeated values and,
+ more importantly, cyclic data structures, at the cost of extra line overhead
+ :param default: a callable that is called by the encoder with three arguments
+ (encoder, value, file object) when no suitable encoder has been found, and should use the
+ methods on the encoder to encode any objects it wants to add to the data stream
+ """
+
+ __slots__ = ('fp', 'datetime_as_timestamp', 'timezone', 'default', 'value_sharing',
+ 'json_compatible', '_shared_containers', '_encoders')
+
+ def __init__(self, fp, datetime_as_timestamp=False, timezone=None, value_sharing=False,
+ default=None):
+ self.fp = fp
+ self.datetime_as_timestamp = datetime_as_timestamp
+ self.timezone = timezone
+ self.value_sharing = value_sharing
+ self.default = default
+ self._shared_containers = {} # indexes used for value sharing
+ self._encoders = default_encoders.copy()
+
+ def _find_encoder(self, obj_type):
+ from sys import modules
+
+ for type_, enc in list(iteritems(self._encoders)):
+ if type(type_) is tuple:
+ modname, typename = type_
+ imported_type = getattr(modules.get(modname), typename, None)
+ if imported_type is not None:
+ del self._encoders[type_]
+ self._encoders[imported_type] = enc
+ type_ = imported_type
+ else: # pragma: nocover
+ continue
+
+ if issubclass(obj_type, type_):
+ self._encoders[obj_type] = enc
+ return enc
+
+ return None
+
+ @contextmanager
+ def disable_value_sharing(self):
+ """Disable value sharing in the encoder for the duration of the context block."""
+ old_value_sharing = self.value_sharing
+ self.value_sharing = False
+ yield
+ self.value_sharing = old_value_sharing
+
+ def write(self, data):
+ """
+ Write bytes to the data stream.
+
+ :param data: the bytes to write
+
+ """
+ self.fp.write(data)
+
+ def encode(self, obj):
+ """
+ Encode the given object using CBOR.
+
+ :param obj: the object to encode
+
+ """
+ obj_type = obj.__class__
+ encoder = self._encoders.get(obj_type) or self._find_encoder(obj_type) or self.default
+ if not encoder:
+ raise CBOREncodeError('cannot serialize type %s' % obj_type.__name__)
+
+ encoder(self, obj)
+
+ def encode_to_bytes(self, obj):
+ """
+ Encode the given object to a byte buffer and return its value as bytes.
+
+ This method was intended to be used from the ``default`` hook when an object needs to be
+ encoded separately from the rest but while still taking advantage of the shared value
+ registry.
+
+ """
+ old_fp = self.fp
+ self.fp = fp = BytesIO()
+ self.encode(obj)
+ self.fp = old_fp
+ return fp.getvalue()
+
+
+def dumps(obj, **kwargs):
+ """
+ Serialize an object to a bytestring.
+
+ :param obj: the object to serialize
+ :param kwargs: keyword arguments passed to :class:`~.CBOREncoder`
+ :return: the serialized output
+ :rtype: bytes
+
+ """
+ fp = BytesIO()
+ dump(obj, fp, **kwargs)
+ return fp.getvalue()
+
+
+def dump(obj, fp, **kwargs):
+ """
+ Serialize an object to a file.
+
+ :param obj: the object to serialize
+ :param fp: a file-like object
+ :param kwargs: keyword arguments passed to :class:`~.CBOREncoder`
+
+ """
+ CBOREncoder(fp, **kwargs).encode(obj)