# -*- coding: utf-8 -*-
#
# This file is part of Python-ASN1. Python-ASN1 is free software that is
# made available under the MIT license. Consult the file "LICENSE" that is
# distributed together with this file for the exact licensing terms.
#
# Python-ASN1 is copyright (c) 2007-2016 by the Python-ASN1 authors. See the
# file "AUTHORS" for a complete overview.
"""
This module provides ASN.1 encoder and decoder.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals
import collections
import re
from builtins import bytes
from builtins import int
from builtins import range
from builtins import str
from contextlib import contextmanager
from enum import IntEnum
from numbers import Number
__version__ = "2.7.0"
[docs]
class Numbers(IntEnum):
Boolean = 0x01
Integer = 0x02
BitString = 0x03
OctetString = 0x04
Null = 0x05
ObjectIdentifier = 0x06
Enumerated = 0x0a
UTF8String = 0x0c
Sequence = 0x10
Set = 0x11
PrintableString = 0x13
IA5String = 0x16
UTCTime = 0x17
GeneralizedTime = 0x18
UnicodeString = 0x1e
[docs]
class Types(IntEnum):
Constructed = 0x20
Primitive = 0x00
[docs]
class Classes(IntEnum):
Universal = 0x00
Application = 0x40
Context = 0x80
Private = 0xc0
Tag = collections.namedtuple('Tag', 'nr typ cls')
"""A named tuple to represent ASN.1 tags as returned by `Decoder.peek()` and
`Decoder.read()`."""
[docs]
class Error(Exception):
"""ASN.11 encoding or decoding error."""
[docs]
class Encoder(object):
"""ASN.1 encoder. Uses DER encoding.
"""
def __init__(self): # type: () -> None
"""Constructor."""
self.m_stack = None
[docs]
def start(self): # type: () -> None
"""This method instructs the encoder to start encoding a new ASN.1
output. This method may be called at any time to reset the encoder,
and resets the current output (if any).
"""
self.m_stack = [[]]
[docs]
def enter(self, nr, cls=None): # type: (int, int) -> None
"""This method starts the construction of a constructed type.
Args:
nr (int): The desired ASN.1 type. Use ``Numbers`` enumeration.
cls (int): This optional parameter specifies the class
of the constructed type. The default class to use is the
universal class. Use ``Classes`` enumeration.
Returns:
None
Raises:
`Error`
"""
if self.m_stack is None:
raise Error('Encoder not initialized. Call start() first.')
if cls is None:
cls = Classes.Universal
self._emit_tag(nr, Types.Constructed, cls)
self.m_stack.append([])
[docs]
def leave(self): # type: () -> None
"""This method completes the construction of a constructed type and
writes the encoded representation to the output buffer.
"""
if self.m_stack is None:
raise Error('Encoder not initialized. Call start() first.')
if len(self.m_stack) == 1:
raise Error('Tag stack is empty.')
value = b''.join(self.m_stack[-1])
del self.m_stack[-1]
self._emit_length(len(value))
self._emit(value)
[docs]
@contextmanager
def construct(self, nr, cls=None): # type: (int, int) -> None
"""This method - context manager calls enter and leave methods,
for better code mapping.
Usage:
```
with encoder.construct(asn1.Numbers.Sequence):
encoder.write(1)
with encoder.construct(asn1.Numbers.Sequence):
encoder.write('foo')
encoder.write('bar')
encoder.write(2)
```
encoder.output() will result following structure:
SEQUENCE:
INTEGER: 1
SEQUENCE:
STRING: foo
STRING: bar
INTEGER: 2
Args:
nr (int): The desired ASN.1 type. Use ``Numbers`` enumeration.
cls (int): This optional parameter specifies the class
of the constructed type. The default class to use is the
universal class. Use ``Classes`` enumeration.
Returns:
None
Raises:
`Error`
"""
self.enter(nr, cls)
yield
self.leave()
[docs]
def write(self, value, nr=None, typ=None, cls=None): # type: (object, int, int, int) -> None
"""This method encodes one ASN.1 tag and writes it to the output buffer.
Note:
Normally, ``value`` will be the only parameter to this method.
In this case Python-ASN1 will autodetect the correct ASN.1 type from
the type of ``value``, and will output the encoded value based on this
type.
Args:
value (any): The value of the ASN.1 tag to write. Python-ASN1 will
try to autodetect the correct ASN.1 type from the type of
``value``.
nr (int): If the desired ASN.1 type cannot be autodetected or is
autodetected wrongly, the ``nr`` parameter can be provided to
specify the ASN.1 type to be used. Use ``Numbers`` enumeration.
typ (int): This optional parameter can be used to write constructed
types to the output by setting it to indicate the constructed
encoding type. In this case, ``value`` must already be valid ASN.1
encoded data as plain Python bytes. This is not normally how
constructed types should be encoded though, see `Encoder.enter()`
and `Encoder.leave()` for the recommended way of doing this.
Use ``Types`` enumeration.
cls (int): This parameter can be used to override the class of the
``value``. The default class is the universal class.
Use ``Classes`` enumeration.
Returns:
None
Raises:
`Error`
"""
if self.m_stack is None:
raise Error('Encoder not initialized. Call start() first.')
if typ is None:
typ = Types.Primitive
if cls is None:
cls = Classes.Universal
if cls != Classes.Universal and nr is None:
raise Error('Please specify a tag number (nr) when using classes Application, Context or Private')
if nr is None:
if isinstance(value, bool):
nr = Numbers.Boolean
elif isinstance(value, int):
nr = Numbers.Integer
elif isinstance(value, str):
nr = Numbers.PrintableString
elif isinstance(value, bytes):
nr = Numbers.OctetString
elif value is None:
nr = Numbers.Null
value = self._encode_value(cls, nr, value)
self._emit_tag(nr, typ, cls)
self._emit_length(len(value))
self._emit(value)
[docs]
def output(self): # type: () -> bytes
"""This method returns the encoded ASN.1 data as plain Python ``bytes``.
This method can be called multiple times, also during encoding.
In the latter case the data that has been encoded so far is
returned.
Note:
It is an error to call this method if the encoder is still
constructing a constructed type, i.e. if `Encoder.enter()` has been
called more times that `Encoder.leave()`.
Returns:
bytes: The DER encoded ASN.1 data.
Raises:
`Error`
"""
if self.m_stack is None:
raise Error('Encoder not initialized. Call start() first.')
if len(self.m_stack) != 1:
raise Error('Stack is not empty.')
output = b''.join(self.m_stack[0])
return output
def _emit_tag(self, nr, typ, cls): # type: (int, int, int) -> None
"""Emit a tag."""
if nr < 31:
self._emit_tag_short(nr, typ, cls)
else:
self._emit_tag_long(nr, typ, cls)
def _emit_tag_short(self, nr, typ, cls): # type: (int, int, int) -> None
"""Emit a short (< 31 bytes) tag."""
assert nr < 31
self._emit(bytes([nr | typ | cls]))
def _emit_tag_long(self, nr, typ, cls): # type: (int, int, int) -> None
"""Emit a long (>= 31 bytes) tag."""
head = bytes([typ | cls | 0x1f])
self._emit(head)
values = [(nr & 0x7f)]
nr >>= 7
while nr:
values.append((nr & 0x7f) | 0x80)
nr >>= 7
values.reverse()
for val in values:
self._emit(bytes([val]))
def _emit_length(self, length): # type: (int) -> None
"""Emit length octects."""
if length < 128:
self._emit_length_short(length)
else:
self._emit_length_long(length)
def _emit_length_short(self, length): # type: (int) -> None
"""Emit the short length form (< 128 octets)."""
assert length < 128
self._emit(bytes([length]))
def _emit_length_long(self, length): # type: (int) -> None
"""Emit the long length form (>= 128 octets)."""
values = []
while length:
values.append(length & 0xff)
length >>= 8
values.reverse()
# really for correctness as this should not happen anytime soon
assert len(values) < 127
head = bytes([0x80 | len(values)])
self._emit(head)
for val in values:
self._emit(bytes([val]))
def _emit(self, s): # type: (bytes) -> None
"""Emit raw bytes."""
assert isinstance(s, bytes)
self.m_stack[-1].append(s)
def _encode_value(self, cls, nr, value): # type: (int, int, any) -> bytes
"""Encode a value."""
if cls != Classes.Universal:
return value
if nr in (Numbers.Integer, Numbers.Enumerated):
return self._encode_integer(value)
if nr in (Numbers.OctetString, Numbers.PrintableString,
Numbers.UTF8String, Numbers.IA5String,
Numbers.UnicodeString, Numbers.UTCTime,
Numbers.GeneralizedTime):
return self._encode_octet_string(value)
if nr == Numbers.BitString:
return self._encode_bit_string(value)
if nr == Numbers.Boolean:
return self._encode_boolean(value)
if nr == Numbers.Null:
return self._encode_null()
if nr == Numbers.ObjectIdentifier:
return self._encode_object_identifier(value)
return value
@staticmethod
def _encode_boolean(value): # type: (bool) -> bytes
"""Encode a boolean."""
return value and bytes(b'\xff') or bytes(b'\x00')
@staticmethod
def _encode_integer(value): # type: (int) -> bytes
"""Encode an integer."""
if value < 0:
value = -value
negative = True
limit = 0x80
else:
negative = False
limit = 0x7f
values = []
while value > limit:
values.append(value & 0xff)
value >>= 8
values.append(value & 0xff)
if negative:
# create two's complement
for i in range(len(values)): # Invert bits
values[i] = 0xff - values[i]
for i in range(len(values)): # Add 1
values[i] += 1
if values[i] <= 0xff:
break
assert i != len(values) - 1
values[i] = 0x00
if negative and values[len(values) - 1] == 0x7f: # Two's complement corner case
values.append(0xff)
values.reverse()
return bytes(values)
@staticmethod
def _encode_octet_string(value): # type: (object) -> bytes
"""Encode an octetstring."""
# Use the primitive encoding
assert isinstance(value, str) or isinstance(value, bytes)
if isinstance(value, str):
return value.encode('utf-8')
else:
return value
@staticmethod
def _encode_bit_string(value): # type: (object) -> bytes
"""Encode a bitstring. Assumes no unused bytes."""
# Use the primitive encoding
assert isinstance(value, bytes)
return b'\x00' + value
@staticmethod
def _encode_null(): # type: () -> bytes
"""Encode a Null value."""
return bytes(b'')
_re_oid = re.compile(r'^[0-9]+(\.[0-9]+)+$')
def _encode_object_identifier(self, oid): # type: (str) -> bytes
"""Encode an object identifier."""
if not self._re_oid.match(oid):
raise Error('Illegal object identifier')
cmps = list(map(int, oid.split('.')))
if cmps[0] > 39 or cmps[1] > 39:
raise Error('Illegal object identifier')
cmps = [40 * cmps[0] + cmps[1]] + cmps[2:]
cmps.reverse()
result = []
for cmp_data in cmps:
result.append(cmp_data & 0x7f)
while cmp_data > 0x7f:
cmp_data >>= 7
result.append(0x80 | (cmp_data & 0x7f))
result.reverse()
return bytes(result)
[docs]
class Decoder(object):
"""ASN.1 decoder. Understands BER (and DER which is a subset)."""
def __init__(self): # type: () -> None
"""Constructor."""
self.m_stack = None
self.m_tag = None
[docs]
def start(self, data): # type: (bytes) -> None
"""This method instructs the decoder to start decoding the ASN.1 input
``data``, which must be a passed in as plain Python bytes.
This method may be called at any time to start a new decoding job.
If this method is called while currently decoding another input, that
decoding context is discarded.
Note:
It is not necessary to specify the encoding because the decoder
assumes the input is in BER or DER format.
Args:
data (bytes): ASN.1 input, in BER or DER format, to be decoded.
Returns:
None
Raises:
`Error`
"""
if not isinstance(data, bytes):
raise Error('Expecting bytes instance.')
self.m_stack = [[0, bytes(data)]]
self.m_tag = None
[docs]
def peek(self): # type: () -> Tag
"""This method returns the current ASN.1 tag (i.e. the tag that a
subsequent `Decoder.read()` call would return) without updating the
decoding offset. In case no more data is available from the input,
this method returns ``None`` to signal end-of-file.
This method is useful if you don't know whether the next tag will be a
primitive or a constructed tag. Depending on the return value of `peek`,
you would decide to either issue a `Decoder.read()` in case of a primitive
type, or an `Decoder.enter()` in case of a constructed type.
Note:
Because this method does not advance the current offset in the input,
calling it multiple times in a row will return the same value for all
calls.
Returns:
`Tag`: The current ASN.1 tag.
Raises:
`Error`
"""
if self.m_stack is None:
raise Error('No input selected. Call start() first.')
if self._end_of_input():
return None
if self.m_tag is None:
self.m_tag = self._read_tag()
return self.m_tag
[docs]
def read(self, tagnr=None): # type: (Number) -> (Tag, any)
"""This method decodes one ASN.1 tag from the input and returns it as a
``(tag, value)`` tuple. ``tag`` is a 3-tuple ``(nr, typ, cls)``,
while ``value`` is a Python object representing the ASN.1 value.
The offset in the input is increased so that the next `Decoder.read()`
call will return the next tag. In case no more data is available from
the input, this method returns ``None`` to signal end-of-file.
Returns:
`Tag`, value: The current ASN.1 tag and its value.
Raises:
`Error`
"""
if self.m_stack is None:
raise Error('No input selected. Call start() first.')
if self._end_of_input():
return None
tag = self.peek()
length = self._read_length()
if tagnr is None:
tagnr = tag.nr
value = self._read_value(tag.cls, tagnr, length)
self.m_tag = None
return tag, value
[docs]
def eof(self): # type: () -> bool
"""Return True if we are at the end of input.
Returns:
bool: True if all input has been decoded, and False otherwise.
"""
return self._end_of_input()
[docs]
def enter(self): # type: () -> None
"""This method enters the constructed type that is at the current
decoding offset.
Note:
It is an error to call `Decoder.enter()` if the to be decoded ASN.1 tag
is not of a constructed type.
Returns:
None
"""
if self.m_stack is None:
raise Error('No input selected. Call start() first.')
tag = self.peek()
if tag.typ != Types.Constructed:
raise Error('Cannot enter a non-constructed tag.')
length = self._read_length()
bytes_data = self._read_bytes(length)
self.m_stack.append([0, bytes_data])
self.m_tag = None
[docs]
def leave(self): # type: () -> None
"""This method leaves the last constructed type that was
`Decoder.enter()`-ed.
Note:
It is an error to call `Decoder.leave()` if the current ASN.1 tag
is not of a constructed type.
Returns:
None
"""
if self.m_stack is None:
raise Error('No input selected. Call start() first.')
if len(self.m_stack) == 1:
raise Error('Tag stack is empty.')
del self.m_stack[-1]
self.m_tag = None
def _read_tag(self): # type: () -> Tag
"""Read a tag from the input."""
byte = self._read_byte()
cls = byte & 0xc0
typ = byte & 0x20
nr = byte & 0x1f
if nr == 0x1f: # Long form of tag encoding
nr = 0
while True:
byte = self._read_byte()
nr = (nr << 7) | (byte & 0x7f)
if not byte & 0x80:
break
return Tag(nr=nr, typ=typ, cls=cls)
def _read_length(self): # type: () -> int
"""Read a length from the input."""
byte = self._read_byte()
if byte & 0x80:
count = byte & 0x7f
if count == 0x7f:
raise Error('ASN1 syntax error')
bytes_data = self._read_bytes(count)
length = 0
for byte in bytes_data:
length = (length << 8) | int(byte)
try:
length = int(length)
except OverflowError:
pass
else:
length = byte
return length
def _read_value(self, cls, nr, length): # type: (int, int, int) -> any
"""Read a value from the input."""
bytes_data = self._read_bytes(length)
if cls != Classes.Universal:
value = bytes_data
elif nr == Numbers.Boolean:
value = self._decode_boolean(bytes_data)
elif nr in (Numbers.Integer, Numbers.Enumerated):
value = self._decode_integer(bytes_data)
elif nr == Numbers.OctetString:
value = self._decode_octet_string(bytes_data)
elif nr == Numbers.Null:
value = self._decode_null(bytes_data)
elif nr == Numbers.ObjectIdentifier:
value = self._decode_object_identifier(bytes_data)
elif nr in (Numbers.PrintableString, Numbers.IA5String,
Numbers.UTF8String, Numbers.UTCTime,
Numbers.GeneralizedTime):
value = self._decode_printable_string(bytes_data)
elif nr == Numbers.BitString:
value = self._decode_bitstring(bytes_data)
else:
value = bytes_data
return value
def _read_byte(self): # type: () -> int
"""Return the next input byte, or raise an error on end-of-input."""
index, input_data = self.m_stack[-1]
try:
byte = input_data[index]
except IndexError:
raise Error('Premature end of input.')
self.m_stack[-1][0] += 1
return byte
def _read_bytes(self, count): # type: (int) -> bytes
"""Return the next ``count`` bytes of input. Raise error on
end-of-input."""
index, input_data = self.m_stack[-1]
bytes_data = input_data[index:index + count]
if len(bytes_data) != count:
raise Error('Premature end of input.')
self.m_stack[-1][0] += count
return bytes_data
def _end_of_input(self): # type: () -> bool
"""Return True if we are at the end of input."""
index, input_data = self.m_stack[-1]
assert not index > len(input_data)
return index == len(input_data)
@staticmethod
def _decode_boolean(bytes_data): # type: (bytes) -> bool
"""Decode a boolean value."""
if len(bytes_data) != 1:
raise Error('ASN1 syntax error')
if bytes_data[0] == 0:
return False
return True
@staticmethod
def _decode_integer(bytes_data): # type: (bytes) -> int
"""Decode an integer value."""
values = [int(b) for b in bytes_data]
# check if the integer is normalized
if len(values) > 1 and (values[0] == 0xff and values[1] & 0x80 or values[0] == 0x00 and not (values[1] & 0x80)):
raise Error('ASN1 syntax error')
negative = values[0] & 0x80
if negative:
# make positive by taking two's complement
for i in range(len(values)):
values[i] = 0xff - values[i]
for i in range(len(values) - 1, -1, -1):
values[i] += 1
if values[i] <= 0xff:
break
assert i > 0
values[i] = 0x00
value = 0
for val in values:
value = (value << 8) | val
if negative:
value = -value
try:
value = int(value)
except OverflowError:
pass
return value
@staticmethod
def _decode_octet_string(bytes_data): # type: (bytes) -> bytes
"""Decode an octet string."""
return bytes_data
@staticmethod
def _decode_null(bytes_data): # type: (bytes) -> any
"""Decode a Null value."""
if len(bytes_data) != 0:
raise Error('ASN1 syntax error')
return None
@staticmethod
def _decode_object_identifier(bytes_data): # type: (bytes) -> str
"""Decode an object identifier."""
result = []
value = 0
for i in range(len(bytes_data)):
byte = int(bytes_data[i])
if value == 0 and byte == 0x80:
raise Error('ASN1 syntax error')
value = (value << 7) | (byte & 0x7f)
if not byte & 0x80:
result.append(value)
value = 0
if len(result) == 0 or result[0] > 1599:
raise Error('ASN1 syntax error')
result = [result[0] // 40, result[0] % 40] + result[1:]
result = list(map(str, result))
return str('.'.join(result))
@staticmethod
def _decode_printable_string(bytes_data): # type: (bytes) -> str
"""Decode a printable string."""
return bytes_data.decode('utf-8')
@staticmethod
def _decode_bitstring(bytes_data): # type: (bytes) -> str
"""Decode a bitstring."""
if len(bytes_data) == 0:
raise Error('ASN1 syntax error')
num_unused_bits = bytes_data[0]
if not (0 <= num_unused_bits <= 7):
raise Error('ASN1 syntax error')
if num_unused_bits == 0:
return bytes_data[1:]
# Shift off unused bits
remaining = bytearray(bytes_data[1:])
bitmask = (1 << num_unused_bits) - 1
removed_bits = 0
for i in range(len(remaining)):
byte = int(remaining[i])
remaining[i] = (byte >> num_unused_bits) | (removed_bits << num_unused_bits)
removed_bits = byte & bitmask
return bytes(remaining)