diff --git a/pyboy/core/jit.py b/pyboy/core/jit.py index 9b0816216..fd5a60602 100644 --- a/pyboy/core/jit.py +++ b/pyboy/core/jit.py @@ -51,6 +51,48 @@ def patched_validate_type_visibility(self, type, pos, env): """ +# NOTE: print('\n'.join([f"0x{i:02X}, # {x}" for i,x in enumerate(pyboy.core.opcodes.CPU_COMMANDS) if "JR" in x])) +jr_instruction = [ + 0x18, # JR r8 + 0x20, # JR NZ,r8 + 0x28, # JR Z,r8 + 0x30, # JR NC,r8 + 0x38, # JR C,r8 +] +jp_instruction = [ + 0xC2, # JP NZ,a16 + 0xC3, # JP a16 + 0xCA, # JP Z,a16 + 0xD2, # JP NC,a16 + 0xDA, # JP C,a16 +] +boundary_instruction = [ + 0xC7, # RST 00H + 0xCF, # RST 08H + 0xD7, # RST 10H + 0xDF, # RST 18H + 0xE7, # RST 20H + 0xEF, # RST 28H + 0xF7, # RST 30H + 0xFF, # RST 38H + 0xC4, # CALL NZ,a16 + 0xCC, # CALL Z,a16 + 0xCD, # CALL a16 + 0xD4, # CALL NC,a16 + 0xDC, # CALL C,a16 + 0xC0, # RET NZ + 0xC8, # RET Z + 0xC9, # RET + 0xD0, # RET NC + 0xD8, # RET C + 0xD9, # RETI + 0xE9, # JP (HL) + 0x76, # HALT + 0x10, # STOP + 0xFB, # EI + 0xDB, # Breakpoint/hook +] + def threaded_processor(jit): while not jit.thread_stop: @@ -166,39 +208,63 @@ def emit_code(self, code_block, func_name): code_text = "" if not cythonmode: code_text += f"def {func_name}(cpu, cycles_target):\n\t" - code_text += "flag = 0\n\tt = 0\n\ttr = 0\n\tv = 0" + code_text += "flag = 0\n\tt = 0\n\tv = 0\n\t_cycles0 = cpu.cycles\n\t_target = _cycles0 + cycles_target" else: code_text += f"cdef public void {func_name}(_cpu.CPU cpu, int64_t cycles_target) noexcept nogil:" - code_text += "\n\tcdef uint8_t flag\n\tcdef int t\n\tcdef int v" + code_text += "\n\tcdef uint8_t flag\n\tcdef int t\n\tcdef int v\n\tcdef int64_t _cycles0 = cpu.cycles\n\tcdef int64_t _target = _cycles0 + cycles_target" code_text += """ \tcdef uint16_t FLAGC = 4 \tcdef uint16_t FLAGH = 5 \tcdef uint16_t FLAGN = 6 \tcdef uint16_t FLAGZ = 7""" - for i, (opcode, length, pc, literal1, literal2) in enumerate(code_block): + def emit_opcode(indent, opcode, length, pc, literal1, literal2): opcode_handler = opcodes_gen[opcode] opcode_name = opcode_handler.name.split()[0] - code_text += "\n\t\n\t" + "# " + opcode_handler.name + f" (PC: 0x{pc:04x})\n\t" + preamble = f"\n\t\n\t" + "# " + opcode_handler.name + f" (PC: 0x{pc:04x})\n\t" if length == 2: v = literal1 - code_text += f"v = 0x{v:02x} # {v}\n\t" + preamble += f"v = 0x{v:02x} # {v}\n\t" elif length == 3: v = (literal2 << 8) + literal1 - code_text += f"v = 0x{v:04x} # {v}\n\t" + preamble += f"v = 0x{v:04x} # {v}\n\t" tmp_code = opcode_handler.functionhandlers[opcode_name]()._code_body() if "if" in tmp_code: # Return early on jump - tmp_code = tmp_code.replace("else:", "\treturn\n\telse:") + tmp_code = tmp_code.replace("else:", f"\treturn\n\telse:") elif "cpu.mb.setitem" in tmp_code: # Return early on state-altering writes - tmp_code += "\n\tif cpu.bail: return" - code_text += tmp_code + tmp_code += f"\n\tif cpu.bail: return" + return (preamble + tmp_code).replace("\t", indent) + + for i, (opcode, length, pc, literal1, literal2) in enumerate(code_block): + if opcode < 0x200: # Regular opcode + code_text += emit_opcode("\t", opcode, length, pc, literal1, literal2) + elif opcode == 0x200: # Loop body + loop_body_cycles, jump_to, jump_from, _block = length, pc, literal1, literal2 + # breakpoint() + code_text += f"\n\n\twhile True: # Loop body (PC: 0x{jump_to:04X} to 0x{jump_from:04X})" + for i, (opcode, length, pc, literal1, literal2) in enumerate(_block[:-1]): + code_text += emit_opcode("\t\t", opcode, length, pc, literal1, literal2) + + # Loop condition + opcode, length, pc, literal1, literal2 = _block[-1] + loop_condition = emit_opcode("\t\t", opcode, length, pc, literal1, literal2) + loop_condition = loop_condition.replace( + "return", + f'if cpu.cycles + {loop_body_cycles} < _target:\n\t\t\t\t\tcpu.jit_jump=False;continue\n\t\t\t\telse:\n\t\t\t\t\tcpu.jit_jump=False;return' + ) + loop_condition += "\n\t\tbreak" + code_text += loop_condition + elif opcode == 0x201: # Remainder of block + remainder_cycles = length + code_text += f'\n\tif cpu.cycles + {remainder_cycles} < _target: return' code_text += "\n\treturn\n\n" # opcodes[7].functionhandlers[opcodes[7].name.split()[0]]().branch_op # if .getitem in code, commit timer.tick(cycles); cycles = 0 + return code_text def getitem_bank(self, bank, i): @@ -208,55 +274,21 @@ def getitem_bank(self, bank, i): return self.cartridge.rombanks[bank, i - 0x4000] def collect_block(self, block_id, cycles_target): - boundary_instruction = [ - 0xC7, # RST 00H - 0xCF, # RST 08H - 0xD7, # RST 10H - 0xDF, # RST 18H - 0xE7, # RST 20H - 0xEF, # RST 28H - 0xF7, # RST 30H - 0xFF, # RST 38H - 0xC4, # CALL NZ,a16 - 0xCC, # CALL Z,a16 - 0xCD, # CALL a16 - 0xD4, # CALL NC,a16 - 0xDC, # CALL C,a16 - 0xC0, # RET NZ - 0xC8, # RET Z - 0xC9, # RET - 0xD0, # RET NC - 0xD8, # RET C - 0xD9, # RETI - 0x18, # JR r8 - 0x20, # JR NZ,r8 - 0x28, # JR Z,r8 - 0x30, # JR NC,r8 - 0x38, # JR C,r8 - 0xC2, # JP NZ,a16 - 0xC3, # JP a16 - 0xCA, # JP Z,a16 - 0xD2, # JP NC,a16 - 0xDA, # JP C,a16 - 0xE9, # JP (HL) - 0x76, # HALT - 0x10, # STOP - 0xFB, # EI - 0xDB, # Breakpoint/hook - ] code_block = [] - pc = block_id >> 8 - assert pc < 0x8000 + PC = block_id >> 8 + _PC = PC + assert PC < 0x8000 rom_bank = block_id & 0xFF + has_internal_jump = False block_max_cycles = 0 while True: # for _ in range(200): # while block_max_cycles < 200: - opcode = self.getitem_bank(rom_bank, pc) + opcode = self.getitem_bank(rom_bank, PC) if opcode == 0xCB: # Extension code - pc += 1 - opcode = self.getitem_bank(rom_bank, pc) + PC += 1 + opcode = self.getitem_bank(rom_bank, PC) opcode += 0x100 # Internally shifting look-up table opcode_length = opcodes.OPCODE_LENGTHS[opcode] opcode_max_cycles = opcodes.OPCODE_MAX_CYCLES[opcode] @@ -264,14 +296,105 @@ def collect_block(self, block_id, cycles_target): if (block_max_cycles + opcode_max_cycles > cycles_target): break block_max_cycles += opcode_max_cycles - code_block.append( - (opcode, opcode_length, pc, self.getitem_bank(rom_bank, pc + 1), self.getitem_bank(rom_bank, pc + 2)) - ) - pc += opcode_length + l1, l2 = self.getitem_bank(rom_bank, PC + 1), self.getitem_bank(rom_bank, PC + 2) + code_block.append((opcode, opcode_length, PC, l1, l2)) + PC += opcode_length + + is_jr = opcode in jr_instruction + is_jp = opcode in jp_instruction + if opcode in boundary_instruction: break + elif is_jr or is_jp: + # We assume it's the ending instruction? Or is the validation at the top? + if not has_internal_jump: + if is_jr: + jump_to = PC + ((l1 ^ 0x80) - 0x80) + else: + jump_to = ((l2 << 8) | l1) + + if _PC <= jump_to < PC: # Detect internal jump + has_internal_jump = True + else: + # The jump is to somewhere else + break + else: + # Expected jump away + # TODO: Just one loop? + break + + return code_block, block_max_cycles, has_internal_jump + + def print_block(self, code_block): + def opcode_translate(opcode): + if opcode == 0x200: + return "loop block" + else: + return pyboy.core.opcodes.CPU_COMMANDS[opcode] + + print( + "\n".join( + f"0x{opcode:02X} {opcode_translate(opcode)}\tlen: {opcode_length}\tPC: {pc:04X}\tlit1: {l1:02X}\tlit2: {l2:02X}\tlit: {(l2<<8) | l1:04X}\t r8: {pc + ((l1 ^ 0x80) - 0x80):04X}" + for opcode, opcode_length, pc, l1, l2 in code_block + ) + ) - return code_block, block_max_cycles + def check_no_overlap(self, ranges): + if len(ranges) == 1: + return True + + # Sort the ranges by the starting points + ranges.sort(key=lambda x: x[0]) + + # Traverse through the ranges to check for overlap + for i in range(1, len(ranges)): + # If the start of the current range is less than the end of the previous range, there's an overlap + if ranges[i][0] < ranges[i - 1][1]: + return False + + return True + + def optimize_block(self, raw_code_block, raw_block_max_cycles, has_internal_jump): + if not has_internal_jump: + return raw_code_block + + # _, _, PC, _, _ = raw_code_block[0] + jumps = [] + for opcode, opcode_length, PC, l1, l2 in raw_code_block: + is_jr = opcode in jr_instruction + is_jp = opcode in jp_instruction + if is_jp or is_jr: + if is_jr: + jump_to = PC + ((l1 ^ 0x80) - 0x80) + elif is_jp: + jump_to = ((l2 << 8) | l1) + + jumps.append((jump_to, PC)) # Sorted as (start, end) + + if not self.check_no_overlap(jumps): + return raw_code_block + + new_block = [] + _block = [] + current_jump = jumps.pop() + for i, (opcode, opcode_length, pc, l1, l2) in enumerate(raw_code_block): + if current_jump and current_jump[0] <= pc < current_jump[1]: + # Collect body + _block.append((opcode, opcode_length, pc, l1, l2)) + elif current_jump and pc == current_jump[1]: + # Add loop block + _block.append((opcode, opcode_length, pc, l1, l2)) + loop_body_cycles = sum(opcodes.OPCODE_MAX_CYCLES[opcode] for opcode, _, _, _, _, in _block) + new_block.append((0x200, loop_body_cycles, current_jump[0], current_jump[1], _block)) + _block = [] + current_jump = jumps.pop() if jumps else None + + remainder_cycles = sum(opcodes.OPCODE_MAX_CYCLES[opcode] for opcode, _, _, _, _, in raw_code_block[i:]) + new_block.append((0x201, remainder_cycles, None, None, None)) + else: + # Add regular opcode + new_block.append((opcode, opcode_length, pc, l1, l2)) + return new_block def invalidate(self, bank, address): # Invalidate any JIT block that crosses this bank and adress. @@ -328,7 +451,8 @@ def process(self): # logger.critical("analyze: %x, %d, %d", block_id, cycles_target, interrupt_master_enable) - code_block, block_max_cycles = self.collect_block(block_id, cycles_target) + raw_code_block, block_max_cycles, has_internal_jump = self.collect_block(block_id, cycles_target) + code_block = self.optimize_block(raw_code_block, block_max_cycles, has_internal_jump) if block_max_cycles < 100: self.cycles[block_id] = -1 # Don't retry