Skip to content
This repository has been archived by the owner on Aug 20, 2024. It is now read-only.

smt: add support for write-first memories #1948

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,7 @@ private class ModuleToTransitionSystem extends LazyLogging {
onRegister(name, width, resetExpr, initExpr, nextExpr, presetRegs)
}
// turn memories into state
val memoryEncoding = new MemoryEncoding(makeRandom)
val memoryEncoding = new MemoryEncoding(makeRandom, scan.namespace)
val memoryStatesAndOutputs = scan.memories.map(m => memoryEncoding.onMemory(m, scan.connects, memInit.get(m.name)))
// replace pseudo assigns for memory outputs
val memOutputs = memoryStatesAndOutputs.flatMap(_._2).toMap
Expand Down Expand Up @@ -248,7 +248,7 @@ private class ModuleToTransitionSystem extends LazyLogging {
}
}

private class MemoryEncoding(makeRandom: (String, Int) => BVExpr) extends LazyLogging {
private class MemoryEncoding(makeRandom: (String, Int) => BVExpr, namespace: Namespace) extends LazyLogging {
type Connects = Iterable[(String, BVExpr)]
def onMemory(
defMem: ir.DefMemory,
Expand Down Expand Up @@ -303,37 +303,46 @@ private class MemoryEncoding(makeRandom: (String, Int) => BVExpr) extends LazyLo
readers.map { r =>
// combinatorial read
if (defMem.readUnderWrite != ir.ReadUnderWrite.New) {
//logger.warn(s"WARN: Memory ${m.name} with combinatorial read port will always return the most recently written entry." +
// s" The read-under-write => ${defMem.readUnderWrite} setting will be ignored.")
logger.warn(
s"WARN: Memory ${m.name} with combinatorial read port will always return the most recently written entry." +
s" The read-under-write => ${defMem.readUnderWrite} setting will be ignored."
)
}
// since we do a combinatorial read, the "old" data is the current data
val data = r.readOld()
val data = r.read()
r.data.name -> data
}
} else { Seq() }
val readPortStates = if (defMem.readLatency == 1) {
val readPortSignalsAndStates = if (defMem.readLatency == 1) {
readers.map { r =>
// we create a register for the read port data
val next = defMem.readUnderWrite match {
defMem.readUnderWrite match {
case ir.ReadUnderWrite.New =>
throw new UnsupportedFeatureException(
s"registered read ports that return the new value (${m.name}.${r.name})"
)
// the thing that makes this hard is to properly handle write conflicts
// create a state to save the address and the enable signal
val enPrev = BVSymbol(namespace.newName(r.en.name + "_prev"), r.en.width)
val addrPrev = BVSymbol(namespace.newName(r.addr.name + "_prev"), r.addr.width)
val signal = r.data.name -> r.read(addr = addrPrev, en = enPrev)
val states = Seq(State(enPrev, None, next = Some(r.en)), State(addrPrev, None, next = Some(r.addr)))
(Seq(signal), states)
case ir.ReadUnderWrite.Undefined =>
// check for potential read/write conflicts in which case we need to return an arbitrary value
val anyWriteToTheSameAddress = any(writers.map(_.doesConflict(r)))
if (anyWriteToTheSameAddress == False) { r.readOld() }
val next = if (anyWriteToTheSameAddress == False) { r.read() }
else {
val readUnderWriteData = r.makeRandomData("_read_under_write_undefined")
BVIte(anyWriteToTheSameAddress, readUnderWriteData, r.readOld())
BVIte(anyWriteToTheSameAddress, readUnderWriteData, r.read())
}
case ir.ReadUnderWrite.Old => r.readOld()
(Seq(), Seq(State(r.data, init = None, next = Some(next))))
case ir.ReadUnderWrite.Old =>
// we create a register for the read port data
(Seq(), Seq(State(r.data, init = None, next = Some(r.read()))))
}
State(r.data, init = None, next = Some(next))
}
} else { Seq() }

(state +: readPortStates, readPortSignals)
val allReadPortSignals = readPortSignals ++ readPortSignalsAndStates.flatMap(_._1)
val readPortStates = readPortSignalsAndStates.flatMap(_._2)

(state +: readPortStates, allReadPortSignals)
}

private def getInit(m: MemInfo, initValue: MemoryInitValue): ArrayExpr = initValue match {
Expand Down Expand Up @@ -385,7 +394,7 @@ private class MemoryEncoding(makeRandom: (String, Int) => BVExpr) extends LazyLo
val enIsTrue: Boolean = inputs(en.name) == True
def makeRandomData(suffix: String): BVExpr =
makeRandom(memory.name + "_" + name + suffix, memory.dataWidth)
def readOld(): BVExpr = {
def read(addr: BVSymbol = addr, en: BVSymbol = en): BVExpr = {
val canBeOutOfRange = !memory.fullAddressRange
val canBeDisabled = !enIsTrue
val data = ArrayRead(memory.sym, addr)
Expand Down Expand Up @@ -467,7 +476,7 @@ private class ModuleScanner(makeRandom: (String, Int) => BVExpr) extends LazyLog
// keeps track of unused memory (data) outputs so that we can see where they are first used
private val unusedMemOutputs = mutable.LinkedHashMap[String, Int]()
// ensure unique names for assert/assume signals
private val namespace = Namespace()
private[firrtl] val namespace = Namespace()

private[firrtl] def onPort(p: ir.Port): Unit = {
if (isAsyncReset(p.tpe)) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ class MemorySpec extends EndToEndSMTBaseSpec {
| ${cmds.mkString("\n ")}
|""".stripMargin

"Registered test memory" should "return written data after two cycles" taggedAs (RequiresZ3) in {
"Registered read-first memory" should "return written data after two cycles" taggedAs (RequiresZ3) in {
val cmds =
"""node past_past_valid = geq(cycle, UInt(2))
|reg past_in: UInt<8>, clock
Expand All @@ -61,6 +61,30 @@ class MemorySpec extends EndToEndSMTBaseSpec {
test(registeredTestMem("Mem00", cmds, "old"), MCSuccess, kmax = 3)
}

"Registered read-first memory" should "not return written data after one cycle" taggedAs (RequiresZ3) in {
val cmds =
"""
|reg past_in: UInt<8>, clock
|past_in <= in
|
|assume(clock, eq(read_addr, write_addr), UInt(1), "read_addr = write_addr")
|assert(clock, eq(out, past_in), past_valid, "out = past(in)")
|""".stripMargin
test(registeredTestMem("Mem00", cmds, "old"), MCFail(1), kmax = 3)
}

"Registered write-first memory" should "return written data after one cycle" taggedAs (RequiresZ3) in {
val cmds =
"""
|reg past_in: UInt<8>, clock
|past_in <= in
|
|assume(clock, eq(read_addr, write_addr), UInt(1), "read_addr = write_addr")
|assert(clock, eq(out, past_in), past_valid, "out = past(in)")
|""".stripMargin
test(registeredTestMem("Mem00", cmds, "new"), MCSuccess, kmax = 3)
}

Comment on lines +64 to +87
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Love formal tests 🙂

private def readOnlyMem(pred: String, num: Int) =
s"""circuit Mem0$num:
| module Mem0$num:
Expand Down