import re
import bits
white = re.compile(r"[ \t\.\n]+")
hex = re.compile(r"^\-?[0-9a-f]+$")
num = re.compile(r"^\d+$")
ref_re = re.compile(r"^([^\[+-]+)(?:([+-]\d+))?$")
immediate_re = re.compile(r"^(imm|off|csr)(\d+)(hi|lo)?$")
# parsing
def parse_part(part):
    """parse a literal with attached metadata.
    >>> parse_part('0')
    (0,)
    >>> parse_part('00')
    (0,)
    >>> parse_part('12')
    (18,)
    >>> parse_part('-12')
    (-18,)
    >>> parse_part('00/with/tag')
    (0, 'with', 'tag')
    >>> parse_part('-12/and/tag')
    (-18, 'and', 'tag')
    >>> parse_part('cafe/and/tag')
    (51966, 'and', 'tag')
    >>> parse_part('label/tag*')
    ('label', 'tag*')
    >>> parse_part('$label/tag*')
    ('$label', 'tag*')
    >>> parse_part('label:suff/tag*')
    ('label:suff', 'tag*')
    >>> parse_part('$label:suff/tag*')
    ('$label:suff', 'tag*')
    >>> parse_part('label[3:1]/tag*')
    ('label[3:1]', 'tag*')
    >>> parse_part('$label:suff[11:0]/tag*')
    ('$label:suff[11:0]', 'tag*')
    """
    part = part.split("/")
    if hex.match(part[0]):
        part[0] = int(part[0], 16)
    return tuple(part)
def parse_instr(line):
    """parse an instruction-line.
    >>> parse_instr('ff/op 0/subop/add 1/rd/x1 label[11:0]/imm12')
    [(255, 'op'), (0, 'subop', 'add'), (1, 'rd', 'x1'), ('label[11:0]', 'imm12')]
    """
    parts = white.split(line)
    parts = [parse_part(part) for part in parts if part != ""]
    return parts
def parse_segment(line):
    """parse a segment-line.
    >>> parse_segment('== code 0x8000')
    ('code', 32768)
    >>> parse_segment('== text')
    ('text',)
    """
    parts = white.split(line)
    if len(parts) == 3:
        return (parts[1], int(parts[2], 16))
    elif len(parts) == 2:
        return (parts[1],)
    else:
        raise ValueError("invalid segment line")
def parse_special(line):
    """parse a special-line.
    >>> parse_special('@@/test "hello/123 asd"')
    [('@@', 'test'), '"hello/123 asd"']
    >>> parse_special('@@/test @@/123 @@/uuu')
    [('@@', 'test'), '@@/123 @@/uuu']
    """
    parts = white.split(line, maxsplit=1)
    parts[0] = parse_part(parts[0])
    return parts
def parse_label(line):
    """parse a label-line.
    >>> parse_label('some_label:')
    'some_label'
    >>> parse_label('$some:label:')
    '$some:label'
    """
    return line[:-1]
def parse_immediate(particle):
    match = immediate_re.match(particle)
    if not match:
        return None
    mode, size, split = match.groups()
    imm = {
        "mode": mode,
        "size": int(size),
    }
    if split is not None:
        imm["split"] = split
    return imm
def parse_reference(part):
    """parse a sliced-label part.
    >>> parse_reference(('lbl', 'imm32'))
    {'label': 'lbl', 'mode': 'imm', 'size': 32, 'offset': 0}
    >>> parse_reference(('lbl+0', 'off32'))
    {'label': 'lbl', 'mode': 'off', 'size': 32, 'offset': 0}
    >>> parse_reference(('lbl+2', 'imm32'))
    {'label': 'lbl', 'mode': 'imm', 'size': 32, 'offset': 2}
    >>> parse_reference(('lbl+2', 'off20hi'))
    {'label': 'lbl', 'mode': 'off', 'size': 20, 'offset': 2, 'split': 'hi'}
    >>> parse_reference(('lbl+2', 'off12lo'))
    {'label': 'lbl', 'mode': 'off', 'size': 12, 'offset': 2, 'split': 'lo'}
    """
    ref = part[0]
    field = part[1]
    label, off = ref_re.match(ref).groups()
    mode, size, split = immediate_re.match(field).groups()
    ref = {
        "label": label,
        "mode": mode,
        "size": int(size),
        "offset": int(off or 0),
    }
    if split is not None:
        ref["split"] = split
    return ref
def classify(line):
    """classify cleaned lines.
    >>> classify('')
    'empty'
    >>> classify('== code')
    'segment'
    >>> classify('== text 0x1000')
    'segment'
    >>> classify('some_label:')
    'label'
    >>> classify('$some:label:')
    'label'
    >>> classify('ff/op 0/subop/add 1/rd/x1 label[11:0]/imm12')
    'instr'
    >>> classify('ff/8 0/3 2/5')
    'instr'
    >>> classify('@@/8')
    'special'
    >>> classify('@@/test 123 "asdx"')
    'special'
    """
    if line == "":
        return "empty"
    elif line.startswith("=="):  # segment
        return "segment"
    elif line.endswith(":"):  # label
        return "label"
    elif line.startswith("@@/"):
        return "special"
    else:
        return "instr"
def parse(line):
    """clean, classify and parse lines."""
    raw = line.strip()
    split = raw.split("#", 1)
    if len(split) == 1:
        clean = raw
        comment = None
    else:
        clean, comment = split
    type = classify(clean)
    if type == "segment":
        parsed = parse_segment(clean)
    elif type == "label":
        parsed = parse_label(clean)
    elif type == "instr":
        parsed = parse_instr(clean)
    elif type == "special":
        parsed = parse_special(clean)
    else:
        parsed = None
    return {
        "type": type,
        "raw": raw,
        "line": clean,
        "comment": comment,
        type: parsed,
    }
def is_reference(part):
    """check whether a part is a label reference.
    >>> is_reference(('hello',))
    True
    >>> is_reference(('hello:world',))
    True
    >>> is_reference(('hello-4',))
    True
    >>> is_reference(('hello[11:0]',))
    True
    >>> is_reference(('$label',))
    True
    >>> is_reference(('$label:extra',))
    True
    >>> is_reference(('$label:extra', 'off20'))
    True
    >>> is_reference(('$label+3[31:12]', 'off20'))
    True
    >>> is_reference(('plain', 'off20'))
    True
    >>> is_reference((0,))
    False
    >>> is_reference((1,))
    False
    >>> is_reference((1, 'disp20'))
    False
    >>> is_reference((0x13f, 'imm12'))
    False
    """
    return isinstance(part[0], str)
def untag(part, expect=None):
    """returns the value of a part and optionally verifies the first tag.
    >>> untag((2, 'num'))
    2
    >>> untag(('$label', 'imm20'))
    '$label'
    >>> untag((2, 'hello', 'things'))
    2
    >>> untag((2, 'num'), expect='num')
    2
    >>> untag(('$label', 'imm20'), expect='imm20')
    '$label'
    >>> untag((2, 'num'), expect=['num', 'imm20'])
    2
    >>> untag(('$label', 'imm20'), expect=['num', 'imm20'])
    '$label'
    >>> untag((2, 'imm12'), expect='imm20')
    Traceback (most recent call last):
        ...
    ValueError: expected (2, 'imm12') to be labelled imm20
    >>> untag(('$label', 'imm20'), expect='off12')
    Traceback (most recent call last):
        ...
    ValueError: expected ('$label', 'imm20') to be labelled off12
    >>> untag((2, 'imm12'), expect=['num', 'off12'])
    Traceback (most recent call last):
        ...
    ValueError: expected (2, 'imm12') to be labelled one of ['num', 'off12']
    """
    if isinstance(expect, str) and part[1] != expect:
        raise ValueError("expected {} to be labelled {}".format(part, expect))
    elif isinstance(expect, list) and part[1] not in expect:
        raise ValueError("expected {} to be labelled one of {}".format(part, expect))
    return part[0]
def format_immediate(imm):
    tag = "{mode}{size}".format(**imm)
    if "split" in imm:
        tag += imm["split"]
    return tag
def format_part(part):
    """opposite of parse_part.
    >>> format_part((0,))
    '0'
    >>> format_part((0x00,))
    '0'
    >>> format_part((16,))
    '10'
    >>> format_part((0x10,))
    '10'
    >>> format_part(('label', 'tag*'))
    'label/tag*'
    >>> format_part(('$label', 'tag*'))
    '$label/tag*'
    >>> format_part(('label:suff', 'tag'))
    'label:suff/tag'
    >>> format_part(('$label:suff', 'tag'))
    '$label:suff/tag'
    """
    if isinstance(part, bits.Bitfield):
        return str(part)
    elif not is_reference(part):
        first = "{:x}".format(part[0])
        part = (first, *part[1:])
    return "/".join([str(p) for p in part])
def format(line):
    """opposite of parse.
    >>> format({
    ...     'type': 'instr',
    ...     'instr': [(255, 'op'), (0, 'subop', 'add'), (1, 'rd', 'x1'), ('label', 'imm12')],
    ...     'comment': "this does things."
    ... })
    'ff/op 0/subop/add 1/rd/x1 label/imm12 # this does things.'
    """
    type = line["type"]
    if type == "instr" or type == "data":
        packed = " ".join(format_part(part) for part in line[type])
        if line["comment"]:
            packed = packed + " # " + line["comment"]
        return packed
    elif type == "empty" or type == "special":
        return line["raw"]
    else:
        raise NotImplementedError("type {}".format(type))
def dump(line):
    """debug-friendly string representation of parsed lines.
    >>> dump({
    ...     'type': 'instr',
    ...     'instr': [(255, 'op'), (0, 'subop', 'add'), (1, 'rd', 'x1'), ('label', 'imm12')],
    ...     'comment': "this does things."
    ... })
    "instr[(255, 'op'), (0, 'subop', 'add'), (1, 'rd', 'x1'), ('label', 'imm12')]"
    """
    return "{}{}".format(line["type"], line[line["type"]])
def join_all(gen):
    res = "\n".join(gen)
    return res
class SubVException(Exception):
    pass
class LineIterator(object):
    def __init__(self, stream):
        self.stream = stream
        self.iter = enumerate(self.stream, start=1)
        self.i, self.raw_line = 0, None
        self.line, self.segment = None, None
    def __iter__(self):
        return self
    def __next__(self):
        self.i, self.raw_line = next(self.iter)
        self.line = None
        try:
            self.line = parse(self.raw_line)
            if self.line["type"] == "segment":
                self.segment = self.line["segment"][0]
        except Exception as e:
            raise self.exception("failed to parse line") from e
        return (self.segment, self.line)
    def exception(self, msg):
        stream_name = getattr(self.stream, "name", "(unnamed)")
        if self.line:
            msg = msg + "\n{}:{}:  {}".format(stream_name, self.i, format(self.line))
            msg = msg + "\nparsed as {}".format(dump(self.line))
        elif self.raw_line:
            msg = msg + "\n{}:{}:  {}".format(
                stream_name, self.i, self.raw_line.strip()
            )
        return SubVException(msg)
def with_parsed_lines(process_fn):
    def _wrapped(iter):
        iterator = LineIterator(iter)
        try:
            yield from process_fn(iterator)
        except SubVException:
            raise
        except Exception as e:
            raise iterator.exception(
                "failed to {} line".format(process_fn.__name__)
            ) from e
    return _wrapped