Skip to content

Commit

Permalink
First steps in transforming JIT code. While loops.
Browse files Browse the repository at this point in the history
  • Loading branch information
Baekalfen committed Oct 6, 2024
1 parent b29ec55 commit e12bf55
Showing 1 changed file with 180 additions and 56 deletions.
236 changes: 180 additions & 56 deletions pyboy/core/jit.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,48 @@ def patched_validate_type_visibility(self, type, pos, env):
"""

# NOTE: print('\n'.join([f"0x{i:02X}, # {x}" for i,x in enumerate(pyboy.core.opcodes.CPU_COMMANDS) if "JR" in x]))
jr_instruction = [
0x18, # JR r8
0x20, # JR NZ,r8
0x28, # JR Z,r8
0x30, # JR NC,r8
0x38, # JR C,r8
]
jp_instruction = [
0xC2, # JP NZ,a16
0xC3, # JP a16
0xCA, # JP Z,a16
0xD2, # JP NC,a16
0xDA, # JP C,a16
]
boundary_instruction = [
0xC7, # RST 00H
0xCF, # RST 08H
0xD7, # RST 10H
0xDF, # RST 18H
0xE7, # RST 20H
0xEF, # RST 28H
0xF7, # RST 30H
0xFF, # RST 38H
0xC4, # CALL NZ,a16
0xCC, # CALL Z,a16
0xCD, # CALL a16
0xD4, # CALL NC,a16
0xDC, # CALL C,a16
0xC0, # RET NZ
0xC8, # RET Z
0xC9, # RET
0xD0, # RET NC
0xD8, # RET C
0xD9, # RETI
0xE9, # JP (HL)
0x76, # HALT
0x10, # STOP
0xFB, # EI
0xDB, # Breakpoint/hook
]


def threaded_processor(jit):
while not jit.thread_stop:
Expand Down Expand Up @@ -166,39 +208,63 @@ def emit_code(self, code_block, func_name):
code_text = ""
if not cythonmode:
code_text += f"def {func_name}(cpu, cycles_target):\n\t"
code_text += "flag = 0\n\tt = 0\n\ttr = 0\n\tv = 0"
code_text += "flag = 0\n\tt = 0\n\tv = 0\n\t_cycles0 = cpu.cycles\n\t_target = _cycles0 + cycles_target"
else:
code_text += f"cdef public void {func_name}(_cpu.CPU cpu, int64_t cycles_target) noexcept nogil:"
code_text += "\n\tcdef uint8_t flag\n\tcdef int t\n\tcdef int v"
code_text += "\n\tcdef uint8_t flag\n\tcdef int t\n\tcdef int v\n\tcdef int64_t _cycles0 = cpu.cycles\n\tcdef int64_t _target = _cycles0 + cycles_target"
code_text += """
\tcdef uint16_t FLAGC = 4
\tcdef uint16_t FLAGH = 5
\tcdef uint16_t FLAGN = 6
\tcdef uint16_t FLAGZ = 7"""

for i, (opcode, length, pc, literal1, literal2) in enumerate(code_block):
def emit_opcode(indent, opcode, length, pc, literal1, literal2):
opcode_handler = opcodes_gen[opcode]
opcode_name = opcode_handler.name.split()[0]
code_text += "\n\t\n\t" + "# " + opcode_handler.name + f" (PC: 0x{pc:04x})\n\t"
preamble = f"\n\t\n\t" + "# " + opcode_handler.name + f" (PC: 0x{pc:04x})\n\t"
if length == 2:
v = literal1
code_text += f"v = 0x{v:02x} # {v}\n\t"
preamble += f"v = 0x{v:02x} # {v}\n\t"
elif length == 3:
v = (literal2 << 8) + literal1
code_text += f"v = 0x{v:04x} # {v}\n\t"
preamble += f"v = 0x{v:04x} # {v}\n\t"

tmp_code = opcode_handler.functionhandlers[opcode_name]()._code_body()
if "if" in tmp_code:
# Return early on jump
tmp_code = tmp_code.replace("else:", "\treturn\n\telse:")
tmp_code = tmp_code.replace("else:", f"\treturn\n\telse:")
elif "cpu.mb.setitem" in tmp_code:
# Return early on state-altering writes
tmp_code += "\n\tif cpu.bail: return"
code_text += tmp_code
tmp_code += f"\n\tif cpu.bail: return"
return (preamble + tmp_code).replace("\t", indent)

for i, (opcode, length, pc, literal1, literal2) in enumerate(code_block):
if opcode < 0x200: # Regular opcode
code_text += emit_opcode("\t", opcode, length, pc, literal1, literal2)
elif opcode == 0x200: # Loop body
loop_body_cycles, jump_to, jump_from, _block = length, pc, literal1, literal2
# breakpoint()
code_text += f"\n\n\twhile True: # Loop body (PC: 0x{jump_to:04X} to 0x{jump_from:04X})"
for i, (opcode, length, pc, literal1, literal2) in enumerate(_block[:-1]):
code_text += emit_opcode("\t\t", opcode, length, pc, literal1, literal2)

# Loop condition
opcode, length, pc, literal1, literal2 = _block[-1]
loop_condition = emit_opcode("\t\t", opcode, length, pc, literal1, literal2)
loop_condition = loop_condition.replace(
"return",
f'if cpu.cycles + {loop_body_cycles} < _target:\n\t\t\t\t\tcpu.jit_jump=False;continue\n\t\t\t\telse:\n\t\t\t\t\tcpu.jit_jump=False;return'
)
loop_condition += "\n\t\tbreak"
code_text += loop_condition
elif opcode == 0x201: # Remainder of block
remainder_cycles = length
code_text += f'\n\tif cpu.cycles + {remainder_cycles} < _target: return'

code_text += "\n\treturn\n\n"
# opcodes[7].functionhandlers[opcodes[7].name.split()[0]]().branch_op
# if .getitem in code, commit timer.tick(cycles); cycles = 0

return code_text

def getitem_bank(self, bank, i):
Expand All @@ -208,70 +274,127 @@ def getitem_bank(self, bank, i):
return self.cartridge.rombanks[bank, i - 0x4000]

def collect_block(self, block_id, cycles_target):
boundary_instruction = [
0xC7, # RST 00H
0xCF, # RST 08H
0xD7, # RST 10H
0xDF, # RST 18H
0xE7, # RST 20H
0xEF, # RST 28H
0xF7, # RST 30H
0xFF, # RST 38H
0xC4, # CALL NZ,a16
0xCC, # CALL Z,a16
0xCD, # CALL a16
0xD4, # CALL NC,a16
0xDC, # CALL C,a16
0xC0, # RET NZ
0xC8, # RET Z
0xC9, # RET
0xD0, # RET NC
0xD8, # RET C
0xD9, # RETI
0x18, # JR r8
0x20, # JR NZ,r8
0x28, # JR Z,r8
0x30, # JR NC,r8
0x38, # JR C,r8
0xC2, # JP NZ,a16
0xC3, # JP a16
0xCA, # JP Z,a16
0xD2, # JP NC,a16
0xDA, # JP C,a16
0xE9, # JP (HL)
0x76, # HALT
0x10, # STOP
0xFB, # EI
0xDB, # Breakpoint/hook
]
code_block = []
pc = block_id >> 8
assert pc < 0x8000
PC = block_id >> 8
_PC = PC
assert PC < 0x8000
rom_bank = block_id & 0xFF

has_internal_jump = False
block_max_cycles = 0
while True:
# for _ in range(200):
# while block_max_cycles < 200:
opcode = self.getitem_bank(rom_bank, pc)
opcode = self.getitem_bank(rom_bank, PC)
if opcode == 0xCB: # Extension code
pc += 1
opcode = self.getitem_bank(rom_bank, pc)
PC += 1
opcode = self.getitem_bank(rom_bank, PC)
opcode += 0x100 # Internally shifting look-up table
opcode_length = opcodes.OPCODE_LENGTHS[opcode]
opcode_max_cycles = opcodes.OPCODE_MAX_CYCLES[opcode]
# if (not interrupt_master_enable) and (block_max_cycles + opcode_max_cycles > cycles_target):
if (block_max_cycles + opcode_max_cycles > cycles_target):
break
block_max_cycles += opcode_max_cycles
code_block.append(
(opcode, opcode_length, pc, self.getitem_bank(rom_bank, pc + 1), self.getitem_bank(rom_bank, pc + 2))
)
pc += opcode_length
l1, l2 = self.getitem_bank(rom_bank, PC + 1), self.getitem_bank(rom_bank, PC + 2)
code_block.append((opcode, opcode_length, PC, l1, l2))
PC += opcode_length

is_jr = opcode in jr_instruction
is_jp = opcode in jp_instruction

if opcode in boundary_instruction:
break
elif is_jr or is_jp:
# We assume it's the ending instruction? Or is the validation at the top?
if not has_internal_jump:
if is_jr:
jump_to = PC + ((l1 ^ 0x80) - 0x80)
else:
jump_to = ((l2 << 8) | l1)

if _PC <= jump_to < PC: # Detect internal jump
has_internal_jump = True
else:
# The jump is to somewhere else
break
else:
# Expected jump away
# TODO: Just one loop?
break

return code_block, block_max_cycles, has_internal_jump

def print_block(self, code_block):
def opcode_translate(opcode):
if opcode == 0x200:
return "loop block"
else:
return pyboy.core.opcodes.CPU_COMMANDS[opcode]

print(
"\n".join(
f"0x{opcode:02X} {opcode_translate(opcode)}\tlen: {opcode_length}\tPC: {pc:04X}\tlit1: {l1:02X}\tlit2: {l2:02X}\tlit: {(l2<<8) | l1:04X}\t r8: {pc + ((l1 ^ 0x80) - 0x80):04X}"
for opcode, opcode_length, pc, l1, l2 in code_block
)
)

return code_block, block_max_cycles
def check_no_overlap(self, ranges):
if len(ranges) == 1:
return True

# Sort the ranges by the starting points
ranges.sort(key=lambda x: x[0])

# Traverse through the ranges to check for overlap
for i in range(1, len(ranges)):
# If the start of the current range is less than the end of the previous range, there's an overlap
if ranges[i][0] < ranges[i - 1][1]:
return False

return True

def optimize_block(self, raw_code_block, raw_block_max_cycles, has_internal_jump):
if not has_internal_jump:
return raw_code_block

# _, _, PC, _, _ = raw_code_block[0]
jumps = []
for opcode, opcode_length, PC, l1, l2 in raw_code_block:
is_jr = opcode in jr_instruction
is_jp = opcode in jp_instruction
if is_jp or is_jr:
if is_jr:
jump_to = PC + ((l1 ^ 0x80) - 0x80)
elif is_jp:
jump_to = ((l2 << 8) | l1)

jumps.append((jump_to, PC)) # Sorted as (start, end)

if not self.check_no_overlap(jumps):
return raw_code_block

new_block = []
_block = []
current_jump = jumps.pop()
for i, (opcode, opcode_length, pc, l1, l2) in enumerate(raw_code_block):
if current_jump and current_jump[0] <= pc < current_jump[1]:
# Collect body
_block.append((opcode, opcode_length, pc, l1, l2))
elif current_jump and pc == current_jump[1]:
# Add loop block
_block.append((opcode, opcode_length, pc, l1, l2))
loop_body_cycles = sum(opcodes.OPCODE_MAX_CYCLES[opcode] for opcode, _, _, _, _, in _block)
new_block.append((0x200, loop_body_cycles, current_jump[0], current_jump[1], _block))
_block = []
current_jump = jumps.pop() if jumps else None

remainder_cycles = sum(opcodes.OPCODE_MAX_CYCLES[opcode] for opcode, _, _, _, _, in raw_code_block[i:])
new_block.append((0x201, remainder_cycles, None, None, None))
else:
# Add regular opcode
new_block.append((opcode, opcode_length, pc, l1, l2))
return new_block

def invalidate(self, bank, address):
# Invalidate any JIT block that crosses this bank and adress.
Expand Down Expand Up @@ -328,7 +451,8 @@ def process(self):

# logger.critical("analyze: %x, %d, %d", block_id, cycles_target, interrupt_master_enable)

code_block, block_max_cycles = self.collect_block(block_id, cycles_target)
raw_code_block, block_max_cycles, has_internal_jump = self.collect_block(block_id, cycles_target)
code_block = self.optimize_block(raw_code_block, block_max_cycles, has_internal_jump)

if block_max_cycles < 100:
self.cycles[block_id] = -1 # Don't retry
Expand Down

0 comments on commit e12bf55

Please sign in to comment.