diff options
| author | s-ol <s-ol@users.noreply.github.com> | 2020-05-28 12:42:14 +0000 |
|---|---|---|
| committer | s-ol <s-ol@users.noreply.github.com> | 2020-05-28 12:42:14 +0000 |
| commit | 95d8b69c68a17cf3fbcfb6d9f4752c9a90e9da69 (patch) | |
| tree | 2083ed2af86124e9ddbf4811a5b7650d0a43f893 | |
| parent | disassembly helper (diff) | |
| download | subv-95d8b69c68a17cf3fbcfb6d9f4752c9a90e9da69.tar.gz subv-95d8b69c68a17cf3fbcfb6d9f4752c9a90e9da69.zip | |
tests, cleanup, verify.py
| -rw-r--r-- | .gitignore | 2 | ||||
| -rw-r--r-- | README.md | 8 | ||||
| -rw-r--r-- | notes.md | 2 | ||||
| -rw-r--r-- | riscv.py | 229 | ||||
| -rw-r--r-- | subx.py | 102 | ||||
| -rwxr-xr-x | test.py | 41 | ||||
| -rw-r--r-- | test_riscv.py | 3 | ||||
| -rw-r--r-- | verify.py | 186 |
8 files changed, 487 insertions, 86 deletions
@@ -1,2 +1,4 @@ *.elf +*.bin +*.pyc __pycache__ @@ -6,6 +6,14 @@ This is a wip clone of [SubX][mu] for the RISC-V RV31I base ISA. $ ./test.py | ./elf.py > out.elf $ ./qemu.sh out.elf +Pipeline +-------- + +back to front: + +- `elf`: takes `hex.subx`-style input, outputs an ELF file +- `survey`: + Debugging --------- @@ -117,7 +117,7 @@ opcode: 0b0010111 / 0x17 - imm20: top 20 bits of offset added to PC - result stored in rd -## U-Format: op(imm20/21) -> rd +## J-Format: op(imm20/21) -> rd `opcode[7] rd[5] imm31:12[20]` immediates on a 32bit scale (jal) @@ -1,44 +1,199 @@ +def u(num, bits): + if num < 0: + raise ValueError("negative value not allowed: {}".format(num)) + + if num.bit_length() > bits: + raise ValueError("value too large for u{} field: {} ({} bits)", + bits, num, num.bit_length()) + + return (num, bits) + +def i(num, bits): + if num < 0: + num = (1 << bits) + num + + return u(num, bits) + +def bit_concat(*parts): + val, size = 0, 0 + for (pval, psize) in parts: + val = val | pval << size + size += psize + return (val, size) + +def bit_slice(bits, top, bottom): + (val, size) = bits + + if top < bottom: + raise ValueError("cant slice reverse range") + elif bottom < 0: + raise ValueError("negative slice index") + elif top >= size: + raise ValueError("cant slice [{}:{}] from {} bit value".format(top, bottom, size)) + + width = top - bottom + 1 + val = (val >> bottom) & ((1 << width) - 1) + return (val, width) + +def byteify(word): + (val, size) = word + if size != 32: + raise ValueError("Expected 32-bit word") + + b0 = bit_slice(word, 7, 0) + b1 = bit_slice(word, 15, 8) + b2 = bit_slice(word, 23, 16) + b3 = bit_slice(word, 31, 24) + return [b0[:1], b1[:1], b2[:1], b3[:1]] + def format_r(op, rd, r1, r2, funct3, funct7): - # opcode[7] rd[5] funct3[3] rs1[5] rs2[5] funct7[7] - b0 = (rd & 0x1) << 7 | op - b1 = (r1 & 0x1) << 7 | funct3 << 4 | rd >> 1 - b2 = (r2 & 0xf) << 4 | r1 >> 1 - b3 = funct7 << 1 | r2 >> 4 - return [(b0,), (b1,), (b2,), (b3,)] + # HH funct7 rs2 rs1 funct3 rd opcode LL + return byteify(bit_concat( + u(op, 7), + u(rd, 5), + u(funct3, 3), + u(r1, 5), + u(r2, 5), + u(funct7, 7), + )) def format_i(op, rd, r1, imm12, funct3): - # opcode[7] rd[5] funct3[3] rs1[5] imm[12] - b0 = (rd & 0x1) << 7 | op - b1 = (r1 & 0x1) << 7 | funct3 << 4 | rd >> 1 - b2 = (imm12 & 0xf) << 4 | r1 >> 1 - b3 = imm12 >> 4 - return [(b0,), (b1,), (b2,), (b3,)] + # HH imm[11:0] rs1 funct3 rd opcode LL + return byteify(bit_concat( + u(op, 7), + u(rd, 5), + u(funct3, 3), + u(r1, 5), + i(imm12, 12), + )) def format_s(op, r1, r2, imm12, funct3): - # opcode[7] rd[5] funct3[3] rs1[5] rs2[5] funct7[7] - b0 = (imm12 & 0x1) << 7 | op - b1 = (r1 & 0x1) << 7 | funct3 << 4 | (imm12 & 0x1f) >> 1 - b2 = (r2 & 0xf) << 4 | r1 >> 1 - b3 = (imm12 & 0xf0) | r2 >> 4 - return [(b0,), (b1,), (b2,), (b3,)] + # HH imm[11:5] rs2 rs1 funct3 imm[4:0] opcode LL + imm = i(imm12, 12) + imm_lo = bit_slice(imm, 4, 0) + imm_hi = bit_slice(imm, 11, 5) + return byteify(bit_concat( + u(op, 7), + imm_lo, + u(funct3, 3), + u(r1, 5), + u(r2, 5), + imm_hi + )) + + # imm12 = i(imm12, 12) + # r1 = i(r1, 5, neg=False) + # r2 = i(r2, 5, neg=False) + # + # b0 = (imm12 & 0x1) << 7 | op + # b1 = (r1 & 0x1) << 7 | funct3 << 4 | (imm12 & 0x1f) >> 1 + # b2 = (r2 & 0xf) << 4 | r1 >> 1 + # b3 = (imm12 & 0xf0) | r2 >> 4 + # return [(b0,), (b1,), (b2,), (b3,)] def format_u(op, rd, imm20): - # opcode[7] rd[5] imm31:12[20] - b0 = (rd & 0x1) << 7 | op - b1 = (imm20 & 0xf) << 4 | rd >> 1 - b2 = (imm20 >> 4) & 0xff - b3 = imm20 >> 12 - return [(b0,), (b1,), (b2,), (b3,)] - -def format_j(op, rd, imm): - # opcode[7] rd[5] imm19:12[8] imm11[1] imm10:1[10] imm20[1] - imm12 = (imm >> 11) & 0xff - imm11 = (imm >> 12) & 0x1 - imm1 = imm & 0x3ff - imm20 = imm >> 19 - - b0 = (rd & 0x1) << 7 | op # rd[1] op[7:0] - b1 = (imm12 & 0xf) << 4 | rd >> 1 # imm[15:12] rd[5:2] - b2 = (imm1 << 5) & 0xe0 | imm11 << 4 | (imm12 >> 4) # imm[3:1] imm[11] imm[19:16] - b3 = imm20 << 7 | imm1 >> 3 - return [(b0,), (b1,), (b2,), (b3,)] + # HH imm[31:12] rd opcode LL + return byteify(bit_concat( + u(op, 7), + u(rd, 5), + i(imm20, 20) + )) + +def format_j(op, rd, imm20): + # HH imm[20] imm[10:1] imm[11] imm[19:12] rd opcode LL + imm = i(imm20, 20) + imm_lo = bit_slice(imm, 9, 0) + imm_10 = bit_slice(imm, 10, 10) + imm_hi = bit_slice(imm, 18, 11) + imm_19 = bit_slice(imm, 19, 19) + return byteify(bit_concat( + u(op, 7), + u(rd, 5), + imm_hi, + imm_10, + imm_lo, + imm_19 + )) + +import unittest +class TestHelpers(unittest.TestCase): + def test_bit_concat(self): + self.assertEqual( + bit_concat((0b10, 2), (0b00, 2)), + (0b0010, 4) + ) + self.assertEqual( + bit_concat((0b1, 1), (0b0110, 4), (0b110, 3)), + (0b11001101, 8) + ) + + def test_bit_slice(self): + self.assertEqual( + bit_slice((0x7f, 8), 7, 4), + (0x7, 4) + ) + self.assertEqual( + bit_slice((0b100, 3), 2, 2), + (0b1, 1) + ) + with self.assertRaises(ValueError): + bit_slice((0xf, 4), 4, 0) + with self.assertRaises(ValueError): + bit_slice((0xf, 4), 2, 3) + with self.assertRaises(ValueError): + bit_slice((0xf, 4), 3, -1) + self.assertEqual( + bit_slice((0x12345678, 32), 7, 0), + (0x78, 8) + ) + self.assertEqual( + bit_slice((0x12345678, 32), 15, 8), + (0x56, 8) + ) + self.assertEqual( + bit_slice((0x12345678, 32), 23, 16), + (0x34, 8) + ) + + def test_byteify(self): + self.assertEqual( + byteify((0x12345678, 32)), + [(0x78,), (0x56,), (0x34,), (0x12,)] + ) + +class TestFormats(unittest.TestCase): + def test_format_u(self): + self.assertEqual( + format_u(0x37, 0x5, 0x10010), + [(183,), (2,), (1,), (16,)] + ) + + def test_format_i(self): + self.assertEqual( + format_i(0x13, 5, 0, 72, 0x0), + [(147,), (2,), (128,), (4,)] + ) + self.assertEqual( + format_i(0x13, 9, 2, 72, 0x3), + [(147,), (52,), (129,), (4,)] + ) + + def test_format_s(self): + self.assertEqual( + format_s(0x23, 3, 4, 0, 0x2), + [(35,), (160,), (65,), (0,)] + ) + self.assertEqual( + format_s(0x23, 2, 0, -4, 0x2), + [(35,), (46,), (1,), (254,)] + ) + + def test_format_j(self): + self.assertEqual( + format_j(0x6f, 0, -26), + [(111,), (240,), (223,), (252,)] + ) + self.assertEqual( + format_j(0x6f, 9, 0), + [(239,), (4,), (0,), (0,)] + ) @@ -1,40 +1,94 @@ import re white = re.compile('[ \t\.\n]+') -hex = re.compile('^(0x)?[0-9a-f]+$') +hex = re.compile(r'^\-?(0x)?[0-9a-f]+$') def parse_part(part): - part = part.split('/') - if hex.match(part[0]): - part[0] = int(part[0], 16) - return tuple(part) + part = part.split('/') + if hex.match(part[0]): + part[0] = int(part[0], 16) + return tuple(part) def parse_instr(line): - parts = white.split(line) - parts = [parse_part(part) for part in parts if part != ''] - return parts + parts = white.split(line) + parts = [parse_part(part) for part in parts if part != ''] + return parts def parse_segment(line): - parts = white.split(line) - return (parts[1], int(parts[2], 16)) + parts = white.split(line) + return (parts[1], int(parts[2], 16)) + +def is_lref(part): + return isinstance(part[0], str) + +def unlabel(part, expect=None): + if expect and part[1] != expect: + raise ValueError("expected {} to be labelled {}", part, expect) + return part[0] def format_part(part): - if not isinstance(part[0], str): - part = ('{:02x}'.format(part[0]),) + part[1:] - return '/'.join(part) + if not is_lref(part): + first = '{:02x}'.format(part[0]) + part = (first,) + part[1:] + return '/'.join([str(p) for p in part]) def format_instr(inst, comment=None): - packed = ' '.join(format_part(part) for part in inst) - if comment: - packed = packed + ' # ' + comment - return packed + packed = ' '.join(format_part(part) for part in inst) + if comment: + packed = packed + ' # ' + comment + return packed def clean(line): - return line.strip().split('#')[0] + return line.strip().split('#')[0] def classify(line): - if line.startswith('=='): # segment - return 'segment' - elif line.endswith(':'): # label - return 'label' - else: - return 'instr' + if line.startswith('=='): # segment + return 'segment' + elif line.endswith(':'): # label + return 'label' + else: + return 'instr' + +import unittest +class TestParsing(unittest.TestCase): + def test_parse_part(self): + self.assertEqual(parse_part('0'), (0,)) + self.assertEqual(parse_part('00'), (0,)) + self.assertEqual(parse_part('0x00'), (0,)) + + self.assertEqual(parse_part('12'), (0x12,)) + self.assertEqual(parse_part('0x12'), (0x12,)) + + self.assertEqual(parse_part('-12'), (-0x12,)) + self.assertEqual(parse_part('-0x12'), (-0x12,)) + + self.assertEqual(parse_part('00/with/tag'), (0, 'with', 'tag')) + self.assertEqual(parse_part('-12/and<</tag*'), (-0x12, 'and<<', 'tag*')) + + self.assertEqual(parse_part('label/tag*'), ('label', 'tag*')) + self.assertEqual(parse_part('$label/tag*'), ('$label', 'tag*')) + self.assertEqual(parse_part('label:suff/tag*'), ('label:suff', 'tag*')) + self.assertEqual(parse_part('$label:suff/tag*'), ('$label:suff', 'tag*')) + +class TestChecks(unittest.TestCase): + def test_is_lref(self): + self.assertTrue(is_lref(('hello',))) + self.assertTrue(is_lref(('hello:world',))) + self.assertTrue(is_lref(('$label',))) + self.assertTrue(is_lref(('$label:extra',))) + self.assertTrue(is_lref(('$label:extra','disp20u'))) + self.assertTrue(is_lref(('plain','disp20u'))) + + self.assertFalse(is_lref((0,))) + self.assertFalse(is_lref((1,))) + self.assertFalse(is_lref((1,'disp20u'))) + self.assertFalse(is_lref((0x13f,'imm12'))) + +class TestFormatting(unittest.TestCase): + def test_format_part(self): + self.assertEqual(format_part((0,)), '00') + self.assertEqual(format_part((0x00,)), '00') + + self.assertEqual(format_part(('label', 'tag*')), 'label/tag*') + self.assertEqual(format_part(('$label', 'tag*')), '$label/tag*') + self.assertEqual(format_part(('label:suff', 'tag')), 'label:suff/tag') + self.assertEqual(format_part(('$label:suff', 'tag')), '$label:suff/tag') @@ -2,28 +2,21 @@ from riscv import format_u, format_i, format_s, format_j from subx import format_instr -def neg(val, bits): - """compute the 2's complement of int value val""" - # if (val & (1 << (bits - 1))) != 0: # if sign bit is set e.g., 8bit: 128-255 - val = val - (1 << bits) # compute negative value - return -val - if __name__ == "__main__": - t0 = 0x5 - t1 = 0x6 - print("== code 0x80000000") - print(format_instr(format_u(0x37, t0, 0x10010), "lui t0, 0x10010")) - print(format_instr(format_i(0x13, t1, 0, 72, 0x0), "addi t1, x0, 72")) - print(format_instr(format_s(0x23, t0, t1, 0, 0x2), "sw t1, 0(t0)")) - print(format_instr(format_i(0x13, t1, 0, 101, 0x0), "addi t1, x0, 101")) - print(format_instr(format_s(0x23, t0, t1, 0, 0x2), "sw t1, 0(t0)")) - print(format_instr(format_i(0x13, t1, 0, 108, 0x0), "addi t1, x0, 108")) - print(format_instr(format_s(0x23, t0, t1, 0, 0x2), "sw t1, 0(t0)")) - print(format_instr(format_i(0x13, t1, 0, 108, 0x0), "addi t1, x0, 108")) - print(format_instr(format_s(0x23, t0, t1, 0, 0x2), "sw t1, 0(t0)")) - print(format_instr(format_i(0x13, t1, 0, 111, 0x0), "addi t1, x0, 111")) - print(format_instr(format_s(0x23, t0, t1, 0, 0x2), "sw t1, 0(t0)")) - print(format_instr(format_i(0x13, t1, 0, 10, 0x0), "addi t1, x0, 10")) - print(format_instr(format_s(0x23, t0, t1, 0, 0x2), "sw t1, 0(t0)")) - print(format_instr(format_j(0x6f, 0, neg(26, 20)), "jal x0, -26")) - + t0 = 0x5 + t1 = 0x6 + print("== code 0x80000000") + print(format_instr(format_u(0x37, t0, 0x10010), "lui t0, 0x10010")) + print(format_instr(format_i(0x13, t1, 0, 72, 0x0), "addi t1, x0, 72")) + print(format_instr(format_s(0x23, t0, t1, 0, 0x2), "sw t1, 0(t0)")) + print(format_instr(format_i(0x13, t1, 0, 101, 0x0), "addi t1, x0, 101")) + print(format_instr(format_s(0x23, t0, t1, 0, 0x2), "sw t1, 0(t0)")) + print(format_instr(format_i(0x13, t1, 0, 108, 0x0), "addi t1, x0, 108")) + print(format_instr(format_s(0x23, t0, t1, 0, 0x2), "sw t1, 0(t0)")) + print(format_instr(format_i(0x13, t1, 0, 108, 0x0), "addi t1, x0, 108")) + print(format_instr(format_s(0x23, t0, t1, 0, 0x2), "sw t1, 0(t0)")) + print(format_instr(format_i(0x13, t1, 0, 111, 0x0), "addi t1, x0, 111")) + print(format_instr(format_s(0x23, t0, t1, 0, 0x2), "sw t1, 0(t0)")) + print(format_instr(format_i(0x13, t1, 0, 10, 0x0), "addi t1, x0, 10")) + print(format_instr(format_s(0x23, t0, t1, 0, 0x2), "sw t1, 0(t0)")) + print(format_instr(format_j(0x6f, 0, -26), "jal x0, -26")) diff --git a/test_riscv.py b/test_riscv.py new file mode 100644 index 0000000..b945b36 --- /dev/null +++ b/test_riscv.py @@ -0,0 +1,3 @@ +import unittest +from riscv import bit_concat, bit_slice, byteify +from riscv import format_u, format_i, format_s, format_j diff --git a/verify.py b/verify.py new file mode 100644 index 0000000..8499da5 --- /dev/null +++ b/verify.py @@ -0,0 +1,186 @@ +#!/usr/bin/env python3 +from subx import clean, classify, parse_segment, parse_instr, is_lref, unlabel, format_instr +from riscv import u, i, bit_concat + +""" +""" + +instr_map = { + 'opr': ('r', 0x33), + 'load': ('i', 0x03), + 'opi': ('i', 0x13), + 'jalr': ('i', 0x67), + 'store': ('s', 0x23), + 'branch': ('b', 0x63), + 'lui': ('u', 0x37), + 'auipc': ('u', 0x17), + 'jal': ('j', 0x6f), +} + +def pack_u(instr): + (op, rd, imm) = instr + op = u(unlabel(op), 7) + rd = u(unlabel(rd, 'rd'), 5) + if not is_lref(imm): + imm = i(unlabel(imm, 'imm20'), 20) + + return [op, rd, imm] + +def pack_i(instr): + (op, sub, rd, rs, imm) = instr + op = u(unlabel(op), 7) + sub = u(unlabel(sub, 'subop'), 3) + rd = u(unlabel(rd, 'rd'), 5) + rs = u(unlabel(rs, 'rs'), 5) + if not is_lref(imm): + imm = i(unlabel(imm, 'imm12'), 12) + + return [op, sub, rd, rs, imm] + +def pack_s(instr): + (op, sub, rs1, rs2, imm) = instr + op = u(unlabel(op), 7) + sub = u(unlabel(sub, 'subop'), 3) + rs1 = u(unlabel(rs1, 'rs'), 5) + rs2 = u(unlabel(rs2, 'rs'), 5) + if not is_lref(imm): + imm = i(unlabel(imm, 'disp12'), 12) + + return [op, sub, rs1, rs2, imm] + +def pack_j(instr): + (op, rd, imm) = instr + op = u(unlabel(op), 7) + rd = u(unlabel(rd, 'rd'), 5) + if not is_lref(imm): + imm = i(unlabel(imm, 'disp20u'), 20) + + return [op, rd, imm] + +def pack(iter): + for line in iter: + line = clean(line) + if line == '': + continue + + type = classify(line) + + if type == 'segment' or type == 'label': + yield line + else: + instr = parse_instr(line) + op = instr[0] + if len(op) != 2: + raise ValueError("instruction without op label") + + (op, label) = op + if label not in instr_map: + raise ValueError("unknown instruction label: {}".format(label)) + (format, expected) = instr_map[label] + if op != expected: + raise ValueError("opcode {} doesn't match label {} (expected {})" + .format(op, label, expected)) + + out = None + if format == 'u': + out = pack_u(instr) + elif format == 'i': + out = pack_i(instr) + elif format == 's': + out = pack_s(instr) + elif format == 'j': + out = pack_j(instr) + else: + raise NotImplementedError() + + yield format_instr(out) + +if __name__ == '__main__': + import sys + pack(sys.stdin) + +import unittest +class TestPackers(unittest.TestCase): + def test_pack_u(self): + final = pack_u([(0x37, 'lui'), (5, 'rd', 't0'), (0x10010, 'imm20')]) + self.assertEqual( + final, + [(0x37, 7), (0x5, 5), (0x10010, 20)] + ) + self.assertEqual(bit_concat(*final)[1], 32) + + label = pack_u([(0x37, 'lui'), (5, 'rd', 't0'), ('pos', 'imm20')]) + self.assertEqual( + label, + [(0x37, 7), (0x5, 5), ('pos', 'imm20')] + ) + + def test_pack_i(self): + final = pack_i([ + (0x13, 'opi'), + (0, 'subop', 'add'), + (6, 'rd', 't1'), + (0, 'rs', 'x0'), + (0x65, 'imm12'), + ]) + self.assertEqual( + final, + [(0x13, 7), (0, 3), (6, 5), (0, 5), (0x65, 12)] + ) + self.assertEqual(bit_concat(*final)[1], 32) + +from io import StringIO +from textwrap import dedent +class TestE2E(unittest.TestCase): + def test_e2e(self): + inv = dedent('''\ + == code 0x80000000 + main: + # load 0x10010000 (UART0) into t0 + 37/lui 5/rd/t0 0x10010/imm20 + # store 0x48 (H) in UART0+0 + 13/opi 0/subop/add 6/rd/t1 0/rs/x0 48/imm12 + 23/store 2/subop/word 5/rs/t0 6/rs/t1 0/disp12 + # store 0x65 (e) in UART0+0 + 13/opi 0/subop/add 6/rd/t1 0/rs/x0 65/imm12 + 23/store 2/subop/word 5/rs/t0 6/rs/t1 0/disp12 + # store 0x6c (l) in UART0+0 + 13/opi 0/subop/add 6/rd/t1 0/rs/x0 6c/imm12 + 23/store 2/subop/word 5/rs/t0 6/rs/t1 0/disp12 + # store 0x6c (l) in UART0+0 + 13/opi 0/subop/add 6/rd/t1 0/rs/x0 6c/imm12 + 23/store 2/subop/word 5/rs/t0 6/rs/t1 0/disp12 + # store 0x6f (o) in UART0+0 + 13/opi 0/subop/add 6/rd/t1 0/rs/x0 6f/imm12 + 23/store 2/subop/word 5/rs/t0 6/rs/t1 0/disp12 + # store 0x0a (\\n) in UART0+0 + 13/opi 0/subop/add 6/rd/t1 0/rs/x0 0a/imm12 + 23/store 2/subop/word 5/rs/t0 6/rs/t1 0/disp12 + # jump back up to the top + 6f/jal 0/rd/x0 main/disp20u + ''') + + out = dedent('''\ + == code 0x80000000 + main: + 37/7 05/5 10010/20 + 13/7 00/3 06/5 00/5 48/12 + 23/7 02/3 05/5 06/5 00/12 + 13/7 00/3 06/5 00/5 65/12 + 23/7 02/3 05/5 06/5 00/12 + 13/7 00/3 06/5 00/5 6c/12 + 23/7 02/3 05/5 06/5 00/12 + 13/7 00/3 06/5 00/5 6c/12 + 23/7 02/3 05/5 06/5 00/12 + 13/7 00/3 06/5 00/5 6f/12 + 23/7 02/3 05/5 06/5 00/12 + 13/7 00/3 06/5 00/5 0a/12 + 23/7 02/3 05/5 06/5 00/12 + 6f/7 00/5 main/disp20u + ''') + + got = '' + for line in pack(StringIO(inv)): + got += line + '\n' + + self.assertEqual(got, out) |
