Source code for asn1

# -*- coding: utf-8 -*-
#
# This file is part of Python-ASN1. Python-ASN1 is free software that is
# made available under the MIT license. Consult the file "LICENSE" that is
# distributed together with this file for the exact licensing terms.
#
# Python-ASN1 is copyright (c) 2007-2016 by the Python-ASN1 authors. See the
# file "AUTHORS" for a complete overview.

"""
This module provides ASN.1 encoder and decoder.
"""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals

import collections
import re
from builtins import bytes
from builtins import int
from builtins import range
from builtins import str
from contextlib import contextmanager
from enum import IntEnum
from numbers import Number

__version__ = "2.7.0"


[docs] class Numbers(IntEnum): Boolean = 0x01 Integer = 0x02 BitString = 0x03 OctetString = 0x04 Null = 0x05 ObjectIdentifier = 0x06 Enumerated = 0x0a UTF8String = 0x0c Sequence = 0x10 Set = 0x11 PrintableString = 0x13 IA5String = 0x16 UTCTime = 0x17 GeneralizedTime = 0x18 UnicodeString = 0x1e
[docs] class Types(IntEnum): Constructed = 0x20 Primitive = 0x00
[docs] class Classes(IntEnum): Universal = 0x00 Application = 0x40 Context = 0x80 Private = 0xc0
Tag = collections.namedtuple('Tag', 'nr typ cls') """A named tuple to represent ASN.1 tags as returned by `Decoder.peek()` and `Decoder.read()`."""
[docs] class Error(Exception): """ASN.11 encoding or decoding error."""
[docs] class Encoder(object): """ASN.1 encoder. Uses DER encoding. """ def __init__(self): # type: () -> None """Constructor.""" self.m_stack = None
[docs] def start(self): # type: () -> None """This method instructs the encoder to start encoding a new ASN.1 output. This method may be called at any time to reset the encoder, and resets the current output (if any). """ self.m_stack = [[]]
[docs] def enter(self, nr, cls=None): # type: (int, int) -> None """This method starts the construction of a constructed type. Args: nr (int): The desired ASN.1 type. Use ``Numbers`` enumeration. cls (int): This optional parameter specifies the class of the constructed type. The default class to use is the universal class. Use ``Classes`` enumeration. Returns: None Raises: `Error` """ if self.m_stack is None: raise Error('Encoder not initialized. Call start() first.') if cls is None: cls = Classes.Universal self._emit_tag(nr, Types.Constructed, cls) self.m_stack.append([])
[docs] def leave(self): # type: () -> None """This method completes the construction of a constructed type and writes the encoded representation to the output buffer. """ if self.m_stack is None: raise Error('Encoder not initialized. Call start() first.') if len(self.m_stack) == 1: raise Error('Tag stack is empty.') value = b''.join(self.m_stack[-1]) del self.m_stack[-1] self._emit_length(len(value)) self._emit(value)
[docs] @contextmanager def construct(self, nr, cls=None): # type: (int, int) -> None """This method - context manager calls enter and leave methods, for better code mapping. Usage: ``` with encoder.construct(asn1.Numbers.Sequence): encoder.write(1) with encoder.construct(asn1.Numbers.Sequence): encoder.write('foo') encoder.write('bar') encoder.write(2) ``` encoder.output() will result following structure: SEQUENCE: INTEGER: 1 SEQUENCE: STRING: foo STRING: bar INTEGER: 2 Args: nr (int): The desired ASN.1 type. Use ``Numbers`` enumeration. cls (int): This optional parameter specifies the class of the constructed type. The default class to use is the universal class. Use ``Classes`` enumeration. Returns: None Raises: `Error` """ self.enter(nr, cls) yield self.leave()
[docs] def write(self, value, nr=None, typ=None, cls=None): # type: (object, int, int, int) -> None """This method encodes one ASN.1 tag and writes it to the output buffer. Note: Normally, ``value`` will be the only parameter to this method. In this case Python-ASN1 will autodetect the correct ASN.1 type from the type of ``value``, and will output the encoded value based on this type. Args: value (any): The value of the ASN.1 tag to write. Python-ASN1 will try to autodetect the correct ASN.1 type from the type of ``value``. nr (int): If the desired ASN.1 type cannot be autodetected or is autodetected wrongly, the ``nr`` parameter can be provided to specify the ASN.1 type to be used. Use ``Numbers`` enumeration. typ (int): This optional parameter can be used to write constructed types to the output by setting it to indicate the constructed encoding type. In this case, ``value`` must already be valid ASN.1 encoded data as plain Python bytes. This is not normally how constructed types should be encoded though, see `Encoder.enter()` and `Encoder.leave()` for the recommended way of doing this. Use ``Types`` enumeration. cls (int): This parameter can be used to override the class of the ``value``. The default class is the universal class. Use ``Classes`` enumeration. Returns: None Raises: `Error` """ if self.m_stack is None: raise Error('Encoder not initialized. Call start() first.') if typ is None: typ = Types.Primitive if cls is None: cls = Classes.Universal if cls != Classes.Universal and nr is None: raise Error('Please specify a tag number (nr) when using classes Application, Context or Private') if nr is None: if isinstance(value, bool): nr = Numbers.Boolean elif isinstance(value, int): nr = Numbers.Integer elif isinstance(value, str): nr = Numbers.PrintableString elif isinstance(value, bytes): nr = Numbers.OctetString elif value is None: nr = Numbers.Null value = self._encode_value(cls, nr, value) self._emit_tag(nr, typ, cls) self._emit_length(len(value)) self._emit(value)
[docs] def output(self): # type: () -> bytes """This method returns the encoded ASN.1 data as plain Python ``bytes``. This method can be called multiple times, also during encoding. In the latter case the data that has been encoded so far is returned. Note: It is an error to call this method if the encoder is still constructing a constructed type, i.e. if `Encoder.enter()` has been called more times that `Encoder.leave()`. Returns: bytes: The DER encoded ASN.1 data. Raises: `Error` """ if self.m_stack is None: raise Error('Encoder not initialized. Call start() first.') if len(self.m_stack) != 1: raise Error('Stack is not empty.') output = b''.join(self.m_stack[0]) return output
def _emit_tag(self, nr, typ, cls): # type: (int, int, int) -> None """Emit a tag.""" if nr < 31: self._emit_tag_short(nr, typ, cls) else: self._emit_tag_long(nr, typ, cls) def _emit_tag_short(self, nr, typ, cls): # type: (int, int, int) -> None """Emit a short (< 31 bytes) tag.""" assert nr < 31 self._emit(bytes([nr | typ | cls])) def _emit_tag_long(self, nr, typ, cls): # type: (int, int, int) -> None """Emit a long (>= 31 bytes) tag.""" head = bytes([typ | cls | 0x1f]) self._emit(head) values = [(nr & 0x7f)] nr >>= 7 while nr: values.append((nr & 0x7f) | 0x80) nr >>= 7 values.reverse() for val in values: self._emit(bytes([val])) def _emit_length(self, length): # type: (int) -> None """Emit length octects.""" if length < 128: self._emit_length_short(length) else: self._emit_length_long(length) def _emit_length_short(self, length): # type: (int) -> None """Emit the short length form (< 128 octets).""" assert length < 128 self._emit(bytes([length])) def _emit_length_long(self, length): # type: (int) -> None """Emit the long length form (>= 128 octets).""" values = [] while length: values.append(length & 0xff) length >>= 8 values.reverse() # really for correctness as this should not happen anytime soon assert len(values) < 127 head = bytes([0x80 | len(values)]) self._emit(head) for val in values: self._emit(bytes([val])) def _emit(self, s): # type: (bytes) -> None """Emit raw bytes.""" assert isinstance(s, bytes) self.m_stack[-1].append(s) def _encode_value(self, cls, nr, value): # type: (int, int, any) -> bytes """Encode a value.""" if cls != Classes.Universal: return value if nr in (Numbers.Integer, Numbers.Enumerated): return self._encode_integer(value) if nr in (Numbers.OctetString, Numbers.PrintableString, Numbers.UTF8String, Numbers.IA5String, Numbers.UnicodeString, Numbers.UTCTime, Numbers.GeneralizedTime): return self._encode_octet_string(value) if nr == Numbers.BitString: return self._encode_bit_string(value) if nr == Numbers.Boolean: return self._encode_boolean(value) if nr == Numbers.Null: return self._encode_null() if nr == Numbers.ObjectIdentifier: return self._encode_object_identifier(value) return value @staticmethod def _encode_boolean(value): # type: (bool) -> bytes """Encode a boolean.""" return value and bytes(b'\xff') or bytes(b'\x00') @staticmethod def _encode_integer(value): # type: (int) -> bytes """Encode an integer.""" if value < 0: value = -value negative = True limit = 0x80 else: negative = False limit = 0x7f values = [] while value > limit: values.append(value & 0xff) value >>= 8 values.append(value & 0xff) if negative: # create two's complement for i in range(len(values)): # Invert bits values[i] = 0xff - values[i] for i in range(len(values)): # Add 1 values[i] += 1 if values[i] <= 0xff: break assert i != len(values) - 1 values[i] = 0x00 if negative and values[len(values) - 1] == 0x7f: # Two's complement corner case values.append(0xff) values.reverse() return bytes(values) @staticmethod def _encode_octet_string(value): # type: (object) -> bytes """Encode an octetstring.""" # Use the primitive encoding assert isinstance(value, str) or isinstance(value, bytes) if isinstance(value, str): return value.encode('utf-8') else: return value @staticmethod def _encode_bit_string(value): # type: (object) -> bytes """Encode a bitstring. Assumes no unused bytes.""" # Use the primitive encoding assert isinstance(value, bytes) return b'\x00' + value @staticmethod def _encode_null(): # type: () -> bytes """Encode a Null value.""" return bytes(b'') _re_oid = re.compile(r'^[0-9]+(\.[0-9]+)+$') def _encode_object_identifier(self, oid): # type: (str) -> bytes """Encode an object identifier.""" if not self._re_oid.match(oid): raise Error('Illegal object identifier') cmps = list(map(int, oid.split('.'))) if cmps[0] > 39 or cmps[1] > 39: raise Error('Illegal object identifier') cmps = [40 * cmps[0] + cmps[1]] + cmps[2:] cmps.reverse() result = [] for cmp_data in cmps: result.append(cmp_data & 0x7f) while cmp_data > 0x7f: cmp_data >>= 7 result.append(0x80 | (cmp_data & 0x7f)) result.reverse() return bytes(result)
[docs] class Decoder(object): """ASN.1 decoder. Understands BER (and DER which is a subset).""" def __init__(self): # type: () -> None """Constructor.""" self.m_stack = None self.m_tag = None
[docs] def start(self, data): # type: (bytes) -> None """This method instructs the decoder to start decoding the ASN.1 input ``data``, which must be a passed in as plain Python bytes. This method may be called at any time to start a new decoding job. If this method is called while currently decoding another input, that decoding context is discarded. Note: It is not necessary to specify the encoding because the decoder assumes the input is in BER or DER format. Args: data (bytes): ASN.1 input, in BER or DER format, to be decoded. Returns: None Raises: `Error` """ if not isinstance(data, bytes): raise Error('Expecting bytes instance.') self.m_stack = [[0, bytes(data)]] self.m_tag = None
[docs] def peek(self): # type: () -> Tag """This method returns the current ASN.1 tag (i.e. the tag that a subsequent `Decoder.read()` call would return) without updating the decoding offset. In case no more data is available from the input, this method returns ``None`` to signal end-of-file. This method is useful if you don't know whether the next tag will be a primitive or a constructed tag. Depending on the return value of `peek`, you would decide to either issue a `Decoder.read()` in case of a primitive type, or an `Decoder.enter()` in case of a constructed type. Note: Because this method does not advance the current offset in the input, calling it multiple times in a row will return the same value for all calls. Returns: `Tag`: The current ASN.1 tag. Raises: `Error` """ if self.m_stack is None: raise Error('No input selected. Call start() first.') if self._end_of_input(): return None if self.m_tag is None: self.m_tag = self._read_tag() return self.m_tag
[docs] def read(self, tagnr=None): # type: (Number) -> (Tag, any) """This method decodes one ASN.1 tag from the input and returns it as a ``(tag, value)`` tuple. ``tag`` is a 3-tuple ``(nr, typ, cls)``, while ``value`` is a Python object representing the ASN.1 value. The offset in the input is increased so that the next `Decoder.read()` call will return the next tag. In case no more data is available from the input, this method returns ``None`` to signal end-of-file. Returns: `Tag`, value: The current ASN.1 tag and its value. Raises: `Error` """ if self.m_stack is None: raise Error('No input selected. Call start() first.') if self._end_of_input(): return None tag = self.peek() length = self._read_length() if tagnr is None: tagnr = tag.nr value = self._read_value(tag.cls, tagnr, length) self.m_tag = None return tag, value
[docs] def eof(self): # type: () -> bool """Return True if we are at the end of input. Returns: bool: True if all input has been decoded, and False otherwise. """ return self._end_of_input()
[docs] def enter(self): # type: () -> None """This method enters the constructed type that is at the current decoding offset. Note: It is an error to call `Decoder.enter()` if the to be decoded ASN.1 tag is not of a constructed type. Returns: None """ if self.m_stack is None: raise Error('No input selected. Call start() first.') tag = self.peek() if tag.typ != Types.Constructed: raise Error('Cannot enter a non-constructed tag.') length = self._read_length() bytes_data = self._read_bytes(length) self.m_stack.append([0, bytes_data]) self.m_tag = None
[docs] def leave(self): # type: () -> None """This method leaves the last constructed type that was `Decoder.enter()`-ed. Note: It is an error to call `Decoder.leave()` if the current ASN.1 tag is not of a constructed type. Returns: None """ if self.m_stack is None: raise Error('No input selected. Call start() first.') if len(self.m_stack) == 1: raise Error('Tag stack is empty.') del self.m_stack[-1] self.m_tag = None
def _read_tag(self): # type: () -> Tag """Read a tag from the input.""" byte = self._read_byte() cls = byte & 0xc0 typ = byte & 0x20 nr = byte & 0x1f if nr == 0x1f: # Long form of tag encoding nr = 0 while True: byte = self._read_byte() nr = (nr << 7) | (byte & 0x7f) if not byte & 0x80: break return Tag(nr=nr, typ=typ, cls=cls) def _read_length(self): # type: () -> int """Read a length from the input.""" byte = self._read_byte() if byte & 0x80: count = byte & 0x7f if count == 0x7f: raise Error('ASN1 syntax error') bytes_data = self._read_bytes(count) length = 0 for byte in bytes_data: length = (length << 8) | int(byte) try: length = int(length) except OverflowError: pass else: length = byte return length def _read_value(self, cls, nr, length): # type: (int, int, int) -> any """Read a value from the input.""" bytes_data = self._read_bytes(length) if cls != Classes.Universal: value = bytes_data elif nr == Numbers.Boolean: value = self._decode_boolean(bytes_data) elif nr in (Numbers.Integer, Numbers.Enumerated): value = self._decode_integer(bytes_data) elif nr == Numbers.OctetString: value = self._decode_octet_string(bytes_data) elif nr == Numbers.Null: value = self._decode_null(bytes_data) elif nr == Numbers.ObjectIdentifier: value = self._decode_object_identifier(bytes_data) elif nr in (Numbers.PrintableString, Numbers.IA5String, Numbers.UTF8String, Numbers.UTCTime, Numbers.GeneralizedTime): value = self._decode_printable_string(bytes_data) elif nr == Numbers.BitString: value = self._decode_bitstring(bytes_data) else: value = bytes_data return value def _read_byte(self): # type: () -> int """Return the next input byte, or raise an error on end-of-input.""" index, input_data = self.m_stack[-1] try: byte = input_data[index] except IndexError: raise Error('Premature end of input.') self.m_stack[-1][0] += 1 return byte def _read_bytes(self, count): # type: (int) -> bytes """Return the next ``count`` bytes of input. Raise error on end-of-input.""" index, input_data = self.m_stack[-1] bytes_data = input_data[index:index + count] if len(bytes_data) != count: raise Error('Premature end of input.') self.m_stack[-1][0] += count return bytes_data def _end_of_input(self): # type: () -> bool """Return True if we are at the end of input.""" index, input_data = self.m_stack[-1] assert not index > len(input_data) return index == len(input_data) @staticmethod def _decode_boolean(bytes_data): # type: (bytes) -> bool """Decode a boolean value.""" if len(bytes_data) != 1: raise Error('ASN1 syntax error') if bytes_data[0] == 0: return False return True @staticmethod def _decode_integer(bytes_data): # type: (bytes) -> int """Decode an integer value.""" values = [int(b) for b in bytes_data] # check if the integer is normalized if len(values) > 1 and (values[0] == 0xff and values[1] & 0x80 or values[0] == 0x00 and not (values[1] & 0x80)): raise Error('ASN1 syntax error') negative = values[0] & 0x80 if negative: # make positive by taking two's complement for i in range(len(values)): values[i] = 0xff - values[i] for i in range(len(values) - 1, -1, -1): values[i] += 1 if values[i] <= 0xff: break assert i > 0 values[i] = 0x00 value = 0 for val in values: value = (value << 8) | val if negative: value = -value try: value = int(value) except OverflowError: pass return value @staticmethod def _decode_octet_string(bytes_data): # type: (bytes) -> bytes """Decode an octet string.""" return bytes_data @staticmethod def _decode_null(bytes_data): # type: (bytes) -> any """Decode a Null value.""" if len(bytes_data) != 0: raise Error('ASN1 syntax error') return None @staticmethod def _decode_object_identifier(bytes_data): # type: (bytes) -> str """Decode an object identifier.""" result = [] value = 0 for i in range(len(bytes_data)): byte = int(bytes_data[i]) if value == 0 and byte == 0x80: raise Error('ASN1 syntax error') value = (value << 7) | (byte & 0x7f) if not byte & 0x80: result.append(value) value = 0 if len(result) == 0 or result[0] > 1599: raise Error('ASN1 syntax error') result = [result[0] // 40, result[0] % 40] + result[1:] result = list(map(str, result)) return str('.'.join(result)) @staticmethod def _decode_printable_string(bytes_data): # type: (bytes) -> str """Decode a printable string.""" return bytes_data.decode('utf-8') @staticmethod def _decode_bitstring(bytes_data): # type: (bytes) -> str """Decode a bitstring.""" if len(bytes_data) == 0: raise Error('ASN1 syntax error') num_unused_bits = bytes_data[0] if not (0 <= num_unused_bits <= 7): raise Error('ASN1 syntax error') if num_unused_bits == 0: return bytes_data[1:] # Shift off unused bits remaining = bytearray(bytes_data[1:]) bitmask = (1 << num_unused_bits) - 1 removed_bits = 0 for i in range(len(remaining)): byte = int(remaining[i]) remaining[i] = (byte >> num_unused_bits) | (removed_bits << num_unused_bits) removed_bits = byte & bitmask return bytes(remaining)