#!/usr/bin/env python3

import sys
import struct

FLIP_Y = False

data = b''

fb_width = 160
fb_height = 160
hierarchy_mask = 0xffff

HEAP_OFS = 0x8000

base_ptr = 0
heap_ptr = 0
midgard = False
bifrost = True
valhall = False
size = None

bak_data = b''

cur_data = b''

# TODO: More robust looping..
for line in sys.stdin.read().split("\n"):
    print(line)
    split = line.split(" ")
    if not len(split) or split[0] == "":
        continue
    if split[0] == "width":
        fb_width = int(split[1])
        continue
    if split[0] == "height":
        fb_height = int(split[1])
        continue
    if split[0] == "mask":
        hierarchy_mask = int(split[1], 0)
        continue
    if split[0] == "vaheap":
        base_ptr = int(split[1], 16)
        bifrost = False
        valhall = True
        continue
    if split[0] == "addr":
        base_ptr = int(split[1], 16)
        bifrost = False
        midgard = True
        HEAP_OFS = 0x40
        continue
    if split[0] == "heap":
        heap_ptr = int(split[1], 16)
        data += cur_data
        cur_data = b''
        bak_data = data
        data = b''
        continue
    if split[0] == "size":
        size = int(split[1], 0)
        continue
    offset = int(split[0], 16)
    if offset > len(data):
        data += cur_data
        cur_data = b''
        data += b'\0' * (offset - len(data))
    for d in split[1:]:
        if d == "" or d == "*":
            continue
        cur_data += bytes([int(d, 16)])

data += cur_data

if heap_ptr:
    data, heap_data = bak_data, data

if size == None:
    size = len(data)

def int7(val, signed=True):
    val = val & 0x7f
    if signed and val >= 0x40:
        return val - 0x80
    else:
        return val

def int8(val, signed=True):
    val = val & 0xff
    if signed and val >= 0x80:
        return val - 0x100
    else:
        return val

def fetch(ptr, size):
    if midgard:
        if ptr >= base_ptr and ptr < base_ptr + len(data):
            base = ptr - base_ptr
            return data[base:base+size]
        elif ptr >= heap_ptr and ptr < heap_ptr + len(heap_data):
            base = ptr - heap_ptr
            return heap_data[base:base+size]
    else:
        if valhall:
            ptr -= base_ptr
        if ptr < 0:
            return b""
        return data[ptr:ptr+size]

def print_draw(ptr):
    draw = fetch(ptr, 128)
    if len(draw) < 128:
        print(" couldn't fetch draw struct")
        return
    decoded = struct.unpack("=16Q", draw)
    coverage = [0 for x in decoded]

    fields = (
        ("Allow forward pixel to kill", 1, "0:0", "bool"),
        ("Allow forward pixel to be killed", 1, "0:1", "bool"),
        ("Pixel kill operation", 2, "0:2", "Pixel Kill"),
        ("ZS update operation", 2, "0:4", "Pixel Kill"),
        ("Allow primitive reorder", 1, "0:6", "bool"),
        ("Overdraw alpha0", 1, "0:7", "bool"),
        ("Overdraw alpha1", 1, "0:8", "bool"),
        ("Clean Fragment Write", 1, "0:9", "bool"),
        ("Primitive Barrier", 1, "0:10", "bool"),
        ("Evaluate per-sample", 1, "0:11", "bool"),
        ("Single-sampled lines", 1, "0:13", "bool"),
        ("Occlusion query", 2, "0:14", "Occlusion Mode"),
        ("Front face CCW", 1, "0:16", "bool"),
        ("Cull front face", 1, "0:17", "bool"),
        ("Cull back face", 1, "0:18", "bool"),
        ("Multisample enable", 1, "0:19", "bool"),
        ("Shader modifies coverage", 1, "0:20", "bool"),
        ("Alpha-to-coverage Invert", 1, "0:21", "bool"),
        ("Alpha-to-coverage", 1, "0:22", "bool"),
        ("Scissor to bounding box", 1, "0:23", "bool"),
        ("Sample mask", 16, "1:0", "uint"),
        ("Render target mask", 8, "1:16", "hex"),

        ("Packet", 1, "2:0", "bool"),
        # TODO: shr modifier
        ("Vertex array", 64, "2:0", "address"),
        ("Vertex packet stride", 16, "4:0", "uint"),
        ("Vertex attribute stride", 16, "4:16", "uint"),
        ("Unk", 16, "5:0", "uint"),

        ("Minimum Z", 32, "6:0", "float"),
        ("Maximum Z", 32, "7:0", "float"),
        ("Depth/stencil", 64, "10:0", "address"),
        ("Blend count", 4, "12:0", "uint"),
        ("Blend", 60, "12:4", "address"),
        ("Occlusion", 64, "14:0", "address"),

        ("Attribute offset", 32, "16:0", "uint"),
        ("FAU count", 8, "17:0", "uint"),
        ("Resources", 48, "24:0", "address"),
        ("Shader", 48, "26:0", "address"),
        ("Thread storage", 48, "28:0", "address"),
        ("FAU", 64, "30:0", "address"),
    )

    for f in fields:
        name, size, start, type = f
        word, bit = [int(x) for x in start.split(":")]
        if word & 1:
            bit += 32
        word >>= 1

        mask = (1 << size) - 1
        data = (decoded[word] >> bit) & mask
        coverage[word] |= mask << bit
        if type == "float":
            data = struct.unpack("=f", struct.pack("=I", data))[0]
        else:
            data = hex(data)
        print(f"   {name}: {data}")

    for i, (d, c) in enumerate(zip(decoded, coverage)):
        ci = c ^ ((1 << 64) - 1)
        if d & ci:
            print(f"    unk at 64-bit word {i}: {hex(d)} (known mask {hex(c)})")

def print_vertex(ptr, positions):
    for p in positions:
        addr = ptr + p * 16
        data = fetch(addr, 16)
        if len(data) < 16:
            print(f"        <no data : {hex(addr)}>")
            continue
        x, y, z, w = struct.unpack("=4f", data)
        print(f"       <{x} {y} {z} {w}>")

DRAW_TYPES = [
    "unk",
    "points",
    "lines",
    "tris",
]

def heap_interpret(start, end):
    print(f"interpreting from {hex(start)} to {hex(end)}")

    struct_count = 0

    signed = True

    base = 0
    a = 0
    b = 0
    c = 0

    num_vert = 3

    draw_ptr = 0
    pos_ptr = 0

    while start != end:
        if midgard and start & 0x1ff == 0x1f8:
            jump = struct.unpack("=Q", fetch(start, 8))[0]
            print(f"jump mdg: {hex(jump)}")
            start = jump
            continue

        dat = fetch(start, 4)
        if dat[3] & 0xe0 == 0x80:
            struct_count += 1

        print(f"{struct_count}:", " ".join([f"{hex(x)[2:].upper():>02}" for x in dat]), end="  ")

        masked_op = dat[3] & ~3

        up = struct.unpack("=I", dat)[0]

        if valhall:
            tri0 = tri0_7 = int7(up >> 15, signed)
            tri1 = int7(up >> 8, signed)
            tri2 = int7(up >> 1, signed)
        else:
            tri0 = int8(up >> 14, signed)
            tri0_7 = int7(up >> 14, signed)
            tri1 = int7(up >> 7, signed)
            tri2 = int7(up, signed)

        signed = True

        if dat[3] & 0xe0 == 0x80:
            res = ""
            if valhall:
                address = (up & 0x7ffffff) * 32
                num_vert = (dat[3] >> 3) & 0x3
            else:
                address = (up & 0xffffff) * 64
                num_vert = (dat[3] >> 2) & 0x3
                if dat[3] & 0x10:
                    a = 0
                    res = " reset"
            draw_ptr = address
            if valhall:
                pos_ptr = address + 128
            print(f"draw {DRAW_TYPES[num_vert]}{res}: {hex(address)}")
        elif valhall and dat[3] >> 4 == 12:
            unk1 = up & 0x3f
            address = (up >> 6) & 0xffff
            unk2 = up >> 22
            draw_ptr += address << 32
            pos_ptr += address << 32
            print(f"draw offset: {hex(address)}, unk {hex(unk1)}, {hex(unk2)}")

            print_draw(draw_ptr)
        elif dat[3] >> 6 == 1:
            # TODO: handle two of these in a row
            res = ""
            if valhall:
                # TOOD: Is the mask correct?
                pf = (up >> 22) & 0x7f
                shift = 7
                if dat[3] & 0x20:
                    a = 0
                    res = " reset"
            else:
                pf = (up >> 21) & 0x7f
                shift = 8

            a += tri0_7 << shift
            b += tri1 << 7
            c += tri2 << 7
            print(f"primitive offset{res}: {hex(pf << 4)} | +{tri0_7 << shift} {tri1 << 7} {tri2 << 7}")
            signed = False
        # TODO: Jumps are located based on position, not opcode
        elif dat[3] == 0xff:
            up64 = struct.unpack("=Q", fetch(start, 8))[0]
            assert((up64 & 3) == 3)
            print(f"jump (from {hex(start+8)}-8): {hex(up64 - 3)}")
            start = up64 - 7
        elif dat[3] == 0x00:
            assert((up & 3) == 3)
            print(f"jump (from {hex(start+4)}-4): {hex(up - 3)}, {hex(HEAP_OFS + up - 3)}")
            start = HEAP_OFS + up - 7
        elif (masked_op & 0xc0) == 0:
            mode = hex(dat[3] >> 2)

            pre_offset = (up >> 22) & 0xf

            unk = ""
            if valhall and up & 1:
                unk = ", unk 1"

            a += base + tri0
            b += a + tri1
            c += a + tri2
            base = a

            print(f"{mode} draw: {hex(pre_offset)} | +{tri0} {tri1} {tri2}{unk}")

            print_vertex(pos_ptr, [a, b, c][:num_vert])

            a = b = c = 0

        else:
            print(f"Unknown opcode {hex(dat[3])}")

        start += 4

def level_list():
    levels = []
    size = 16
    anylevel = False

    # TODO: Does this miss the largest level?
    while anylevel == False or size // 2 < min(fb_width, fb_height):
        if (hierarchy_mask << 4) & size != 0:
            anylevel = True
            levels.append(size)

        size *= 2

    return levels

def div_round_up(x, y):
    return (x + y - 1) // y

def align(x, y):
    return div_round_up(x, y) * y

def tile_count(alignment=4):
    return sum(align(div_round_up(fb_width, size) * div_round_up(fb_height, size), 4)
               for size in level_list())

if midgard:
    unpacked_header = list(struct.unpack("=16i", data[0:64]))
    # Is this really big endian?
    unpacked_header[5:7] = struct.unpack(">2i", data[20:28])
    print(f"header: {' '.join([str(x) for x in unpacked_header])}")

    # Extra is because of HEAP_OFS
    header_size = align(tile_count() + 8, 64)
elif valhall:
    # TODO: Does this figure need alignment?
    HEAP_STRIDE = tile_count() * 8
    HEAP_OFS = size - HEAP_STRIDE * 2

pos = base_ptr + HEAP_OFS

for size in level_list():
    for y in range((fb_height + size - 1) // size):
        for x in range((fb_width + size - 1) // size):
            header = fetch(pos, 8)
            if len(header) == 0:
                break

            if midgard:
                end = struct.unpack("=Q", header)[0]
                use = bool(end)
                end += 4
                start = base_ptr + header_size * 8 + (pos - base_ptr - HEAP_OFS) * 64
            elif bifrost:
                end, start = struct.unpack("=II", header)
                use = bool(end)
                start += HEAP_OFS
                end += HEAP_OFS + 4
                end &= ~3
            else:
                footer = fetch(pos + HEAP_STRIDE, 8)
                if len(footer) == 0:
                    break
                start, end = struct.unpack("=QQ", header + footer)
                use = bool(end)
                # The upper bits are used for jump metadata
                end &= (1 << 48) - 1
                end += 4
            if use:
                if FLIP_Y:
                    print([x * size, fb_height - (y + 1) * size], ((x + 1) * size, fb_height - y * size))
                else:
                    print([x * size, y * size], ((x + 1) * size, (y + 1) * size))
                heap_interpret(start, end)

            pos += 8
