Skip to content

Commit

Permalink
fix a few minor bugs in mem_writer
Browse files Browse the repository at this point in the history
  • Loading branch information
Kyle-Kyle committed Jan 22, 2025
1 parent 01c8571 commit efb95b8
Showing 1 changed file with 42 additions and 55 deletions.
97 changes: 42 additions & 55 deletions angrop/chain_builder/mem_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,11 @@ class MemWriter(Builder):
def __init__(self, chain_builder):
super().__init__(chain_builder)
self._mem_write_gadgets = None
self._good_mem_write_gadgets = None

def update(self):
self._mem_write_gadgets = self._get_all_mem_write_gadgets(self.chain_builder.gadgets)
self._good_mem_write_gadgets = set()

def _set_regs(self, *args, **kwargs):
return self.chain_builder._reg_setter.run(*args, **kwargs)
Expand All @@ -45,62 +47,45 @@ def _gen_mem_write_gadgets(self, string_data):
# create a dict of bytes per write to gadgets
# assume we need intersection of addr_dependencies and data_dependencies to be 0
# TODO could allow mem_reads as long as we control the address?
possible_gadgets = self._mem_write_gadgets

while possible_gadgets:
# get the data from trying to set all the registers
registers = dict((reg, 0x41) for reg in self.arch.reg_set)
l.debug("getting reg data for mem writes")
reg_setter = self.chain_builder._reg_setter
_, _, reg_data = reg_setter.find_candidate_chains_graph_search(max_stack_change=0x50, **registers)
l.debug("trying mem_write gadgets")
# generate from the cache first
if self._good_mem_write_gadgets:
for g in self._good_mem_write_gadgets:
yield g

possible_gadgets = self._mem_write_gadgets.copy() - self._good_mem_write_gadgets

# use the graph-search to gain a rough idea about (stack_change, register setting)
registers = dict((reg, 0x41) for reg in self.arch.reg_set)
l.debug("getting reg data for mem writes")
reg_setter = self.chain_builder._reg_setter
_, _, reg_data = reg_setter.find_candidate_chains_graph_search(max_stack_change=0x50, **registers)
l.debug("trying mem_write gadgets")

# find a write gadget that induces the smallest stack_change
while possible_gadgets:
# limit the maximum size of the chain
best_stack_change = 0x400
best_gadget = None
use_partial_controllers = False
for t, vals in reg_data.items():
if vals[1] >= best_stack_change:
# regs: according to the graph search, what registers can be controlled
# vals[1]: stack_change to set those registers
for regs, vals in reg_data.items():
reg_set_stack_change = vals[1]
if reg_set_stack_change >= best_stack_change:
continue
for g in possible_gadgets:
mem_write = g.mem_writes[0]
if (set(mem_write.addr_dependencies) | set(mem_write.data_dependencies)).issubset(set(t)):
stack_change = g.stack_change + vals[1]
bytes_per_write = mem_write.data_size // 8
num_writes = (len(string_data) + bytes_per_write - 1)//bytes_per_write
stack_change *= num_writes
if stack_change < best_stack_change:
best_gadget = g
best_stack_change = stack_change

# try again using partial_controllers
best_stack_change = 0x400
if best_gadget is None:
use_partial_controllers = True
l.warning("Trying to use partial controllers for memory write")
l.debug("getting reg data for mem writes")
_, _, reg_data = self.chain_builder._reg_setter.find_candidate_chains_graph_search(max_stack_change=0x50,
use_partial_controllers=True,
**registers)
l.debug("trying mem_write gadgets")
for t, vals in reg_data.items():
if vals[1] >= best_stack_change:
if not (mem_write.addr_dependencies | mem_write.data_dependencies).issubset(regs):
continue
for g in possible_gadgets:
mem_write = g.mem_writes[0]
# we need the addr to not be partially controlled
if (set(mem_write.addr_dependencies) | set(mem_write.data_dependencies)).issubset(set(t)) and \
len(set(mem_write.addr_dependencies) & vals[3]) == 0:
stack_change = g.stack_change + vals[1]
# only one byte at a time
bytes_per_write = 1
num_writes = (len(string_data) + bytes_per_write - 1)//bytes_per_write
stack_change *= num_writes
if stack_change < best_stack_change:
best_gadget = g
best_stack_change = stack_change

yield best_gadget, use_partial_controllers
stack_change = g.stack_change + reg_set_stack_change
bytes_per_write = mem_write.data_size // 8
num_writes = (len(string_data) + bytes_per_write - 1)//bytes_per_write
stack_change *= num_writes
if stack_change < best_stack_change:
best_gadget = g
best_stack_change = stack_change

yield best_gadget
possible_gadgets.remove(best_gadget)

@rop_utils.timeout(5)
Expand Down Expand Up @@ -132,15 +117,13 @@ def _write_to_mem(self, addr, string_data, fill_byte=b"\xff"):# pylint:disable=i
:param fill_byte: a byte to use to fill up the string if necessary
:return: a rop chain
"""

gen = self._gen_mem_write_gadgets(string_data)
gadget, use_partial_controllers = next(gen, (None, None))
while gadget:
for gadget in self._gen_mem_write_gadgets(string_data):
try:
return self._try_write_to_mem(gadget, use_partial_controllers, addr, string_data, fill_byte)
chain = self._try_write_to_mem(gadget, False, addr, string_data, fill_byte)
self._good_mem_write_gadgets.add(gadget)
return chain
except (RopException, angr.errors.SimEngineError, angr.errors.SimUnsatError):
pass
gadget, use_partial_controllers = next(gen, (None, None))

raise RopException("Fail to write data to memory :(")

Expand All @@ -165,17 +148,20 @@ def write_to_mem(self, addr, data, fill_byte=b"\xff"):
if x not in self.badbytes:
e += bytes([x])
else:
elems.append(e)
if e:
elems.append(e)
elems.append(bytes([x]))
e = b''
if e:
elems.append(e)

# do the write
offset = 0
chain = RopChain(self.project, self, badbytes=self.badbytes)
for elem in elems:
ptr = addr + offset
if self._word_contain_badbyte(ptr):
raise RopException(f"{ptr:#x} contains bad byte!")
raise RopException(f"{ptr} contains bad byte!")
if len(elem) != 1 or ord(elem) not in self.badbytes:
chain += self._write_to_mem(ptr, elem, fill_byte=fill_byte)
offset += len(elem)
Expand Down Expand Up @@ -262,6 +248,7 @@ def _write_to_mem_with_gadget(self, gadget, addr_val, data, use_partial_controll
sim_data = state.memory.load(addr_val.data, len(data))
if not state.solver.eval(sim_data == data):
raise RopException("memory write fails")

# the next pc must come from the stack
if len(state.regs.pc.variables) != 1:
raise RopException("must have only one pc variable")
Expand Down

0 comments on commit efb95b8

Please sign in to comment.