Skip to content
This repository has been archived by the owner on Aug 20, 2024. It is now read-only.

Commit

Permalink
Multi protobuf module emission and consumption (#2344)
Browse files Browse the repository at this point in the history
* Add compiler option (`-p`) to emit individual module protobufs
* Implement multi module combination when reading directory of protobufs

Co-authored-by: Jack Koenig <[email protected]>
(cherry picked from commit 0c1ca58)
  • Loading branch information
jared-barocsi authored and mergify-bot committed Sep 8, 2021
1 parent fe5a56e commit de96afe
Show file tree
Hide file tree
Showing 10 changed files with 443 additions and 36 deletions.
39 changes: 38 additions & 1 deletion src/main/scala/firrtl/Emitter.scala
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ object EmitCircuitAnnotation extends HasShellOptions {
case "low-opt" =>
Seq(
RunFirrtlTransformAnnotation(new ProtoEmitter.OptLow),
EmitCircuitAnnotation(classOf[ProtoEmitter.Low])
EmitCircuitAnnotation(classOf[ProtoEmitter.OptLow])
)
case _ => throw new PhaseException(s"Unknown emitter '$a'! (Did you misspell it?)")
},
Expand Down Expand Up @@ -147,6 +147,43 @@ object EmitAllModulesAnnotation extends HasShellOptions {
helpText = "Run the specified module emitter (one file per module)",
shortOption = Some("e"),
helpValueName = Some("<chirrtl|high|middle|low|verilog|mverilog|sverilog>")
),
new ShellOption[String](
longOption = "emit-modules-protobuf",
toAnnotationSeq = (a: String) =>
a match {
case "chirrtl" =>
Seq(
RunFirrtlTransformAnnotation(new ProtoEmitter.Chirrtl),
EmitAllModulesAnnotation(classOf[ProtoEmitter.Chirrtl])
)
case "mhigh" =>
Seq(
RunFirrtlTransformAnnotation(new ProtoEmitter.MHigh),
EmitAllModulesAnnotation(classOf[ProtoEmitter.MHigh])
)
case "high" =>
Seq(
RunFirrtlTransformAnnotation(new ProtoEmitter.High),
EmitAllModulesAnnotation(classOf[ProtoEmitter.High])
)
case "middle" =>
Seq(
RunFirrtlTransformAnnotation(new ProtoEmitter.Middle),
EmitAllModulesAnnotation(classOf[ProtoEmitter.Middle])
)
case "low" =>
Seq(RunFirrtlTransformAnnotation(new ProtoEmitter.Low), EmitAllModulesAnnotation(classOf[ProtoEmitter.Low]))
case "low-opt" =>
Seq(
RunFirrtlTransformAnnotation(new ProtoEmitter.OptLow),
EmitAllModulesAnnotation(classOf[ProtoEmitter.OptLow])
)
case _ => throw new PhaseException(s"Unknown emitter '$a'! (Did you misspell it?)")
},
helpText = "Run the specified module emitter (one protobuf per module)",
shortOption = Some("p"),
helpValueName = Some("<chirrtl|mhigh|high|middle|low|low-opt>")
)
)

Expand Down
59 changes: 59 additions & 0 deletions src/main/scala/firrtl/Utils.scala
Original file line number Diff line number Diff line change
Expand Up @@ -967,6 +967,65 @@ object Utils extends LazyLogging {
map.view.map({ case (k, vs) => k -> vs.toList }).toList
}

// For a given module, returns a Seq of all instantiated modules inside of it
private[firrtl] def collectInstantiatedModules(mod: Module, map: Map[String, DefModule]): Seq[DefModule] = {
// Use list instead of set to maintain order
val modules = mutable.ArrayBuffer.empty[DefModule]
def onStmt(stmt: Statement): Unit = stmt match {
case DefInstance(_, _, name, _) => modules += map(name)
case _: WDefInstanceConnector => throwInternalError(s"unrecognized statement: $stmt")
case other => other.foreach(onStmt)
}
onStmt(mod.body)
modules.distinct.toSeq
}

/** Checks if two circuits are equal regardless of their ordering of module definitions */
def orderAgnosticEquality(a: Circuit, b: Circuit): Boolean =
a.copy(modules = a.modules.sortBy(_.name)) == b.copy(modules = b.modules.sortBy(_.name))

/** Combines several separate circuit modules (typically emitted by -e or -p compiler options) into a single circuit */
def combine(circuits: Seq[Circuit]): Circuit = {
def dedup(modules: Seq[DefModule]): Seq[Either[Module, DefModule]] = {
// Left means module with no ExtModules, Right means child modules or lone ExtModules
val module: Option[Module] = {
val found: Seq[Module] = modules.collect { case m: Module => m }
assert(
found.size <= 1,
s"Module definitions should have unique names, found ${found.size} definitions named ${found.head.name}"
)
found.headOption
}
val extModules: Seq[ExtModule] = modules.collect { case e: ExtModule => e }.distinct

// If the module is a lone module (no extmodule references in any other file)
if (extModules.isEmpty && !module.isEmpty)
Seq(Left(module.get))
// If a module has extmodules, but no other file contains the implementation
else if (!extModules.isEmpty && module.isEmpty)
extModules.map(Right(_))
// Otherwise there is a module implementation with extmodule references
else
Seq(Right(module.get))
}

// 1. Combine modules
val grouped: Seq[(String, Seq[DefModule])] = groupByIntoSeq(circuits.flatMap(_.modules))({
case mod: Module => mod.name
case ext: ExtModule => ext.defname
})
val deduped: Iterable[Either[Module, DefModule]] = grouped.flatMap { case (_, insts) => dedup(insts) }

// 2. Determine top
val top = {
val found = deduped.collect { case Left(m) => m }
assert(found.size == 1, s"There should only be 1 top module, got: ${found.map(_.name).mkString(", ")}")
found.head
}
val res = deduped.collect { case Right(m) => m }
ir.Circuit(NoInfo, top +: res.toSeq, top.name)
}

object True {
private val _True = UIntLiteral(1, IntWidth(1))

Expand Down
12 changes: 0 additions & 12 deletions src/main/scala/firrtl/backends/firrtl/FirrtlEmitter.scala
Original file line number Diff line number Diff line change
Expand Up @@ -15,18 +15,6 @@ sealed abstract class FirrtlEmitter(form: CircuitForm) extends Transform with Em
val outputSuffix: String = form.outputSuffix

private def emitAllModules(circuit: Circuit): Seq[EmittedFirrtlModule] = {
// For a given module, returns a Seq of all modules instantited inside of it
def collectInstantiatedModules(mod: Module, map: Map[String, DefModule]): Seq[DefModule] = {
// Use list instead of set to maintain order
val modules = mutable.ArrayBuffer.empty[DefModule]
def onStmt(stmt: Statement): Unit = stmt match {
case DefInstance(_, _, name, _) => modules += map(name)
case _: WDefInstanceConnector => throwInternalError(s"unrecognized statement: $stmt")
case other => other.foreach(onStmt)
}
onStmt(mod.body)
modules.distinct.toSeq
}
val modMap = circuit.modules.map(m => m.name -> m).toMap
// Turn each module into it's own circuit with it as the top and all instantied modules as ExtModules
circuit.modules.collect {
Expand Down
38 changes: 33 additions & 5 deletions src/main/scala/firrtl/backends/proto/ProtoBufEmitter.scala
Original file line number Diff line number Diff line change
@@ -1,15 +1,18 @@
// SPDX-License-Identifier: Apache-2.0
package firrtl.backends.proto

import firrtl.{AnnotationSeq, CircuitState, DependencyAPIMigration, Transform}
import firrtl.ir
import firrtl._
import firrtl.ir._
import firrtl.annotations.NoTargetAnnotation
import firrtl.options.CustomFileEmission
import firrtl.options.Viewer.view
import firrtl.proto.ToProto
import firrtl.stage.{FirrtlOptions, Forms}
import firrtl.stage.TransformManager.TransformDependency
import firrtl.traversals.Foreachers._
import java.io.{ByteArrayOutputStream, Writer}
import scala.collection.mutable.ArrayBuffer
import Utils.{collectInstantiatedModules, throwInternalError}

/** This object defines Annotations that are used by Protocol Buffer emission.
*/
Expand Down Expand Up @@ -59,10 +62,35 @@ sealed abstract class ProtoBufEmitter(prereqs: Seq[TransformDependency])
override def optionalPrerequisiteOf = Seq.empty
override def invalidates(a: Transform) = false

override def execute(state: CircuitState) =
state.copy(annotations = state.annotations :+ Annotation.ProtoBufSerialization(state.circuit, Some(outputSuffix)))
private def emitAllModules(circuit: Circuit): Seq[Annotation.ProtoBufSerialization] = {
val modMap = circuit.modules.map(m => m.name -> m).toMap
// Turn each module into it's own circuit with it as the top and all instantied modules as ExtModules
circuit.modules.collect {
case m: Module =>
val instModules = collectInstantiatedModules(m, modMap)
val extModules = instModules.map {
case Module(info, name, ports, _) => ExtModule(info, name, ports, name, Seq.empty)
case ext: ExtModule => ext
}
val newCircuit = Circuit(m.info, extModules :+ m, m.name)
Annotation.ProtoBufSerialization(newCircuit, Some(outputSuffix))
}
}

override def execute(state: CircuitState) = {
val newAnnos = state.annotations.flatMap {
case EmitCircuitAnnotation(a) if this.getClass == a =>
Seq(
Annotation.ProtoBufSerialization(state.circuit, Some(outputSuffix))
)
case EmitAllModulesAnnotation(a) if this.getClass == a =>
emitAllModules(state.circuit)
case _ => Seq()
}
state.copy(annotations = newAnnos ++ state.annotations)
}

override def emit(state: CircuitState, writer: Writer): Unit = {
def emit(state: CircuitState, writer: Writer): Unit = {
val ostream = new java.io.ByteArrayOutputStream
ToProto.writeToStream(ostream, state.circuit)
writer.write(ostream.toString())
Expand Down
25 changes: 25 additions & 0 deletions src/main/scala/firrtl/proto/FromProto.scala
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,11 @@ import collection.JavaConverters._
import FirrtlProtos._
import com.google.protobuf.CodedInputStream
import Firrtl.Statement.{Formal, ReadUnderWrite}
import firrtl.ir.DefModule
import Utils.combine
import java.io.FileNotFoundException
import firrtl.options.OptionsException
import java.nio.file.NotDirectoryException

object FromProto {

Expand All @@ -33,6 +38,26 @@ object FromProto {
proto.FromProto.convert(pb)
}

/** Deserialize all the ProtoBuf representations of [[ir.Circuit]] in @dir
*
* @param dir directory containing ProtoBuf representation(s)
* @return Deserialized FIRRTL Circuit
* @throws java.io.FileNotFoundException if dir does not exist
* @throws java.nio.file.NotDirectoryException if dir exists but is not a directory
*/
def fromDirectory(dir: String): ir.Circuit = {
val d = new File(dir)
if (!d.exists) {
throw new FileNotFoundException(s"Specified directory '$d' does not exist!")
}
if (!d.isDirectory) {
throw new NotDirectoryException(s"'$d' is not a directory!")
}

val fileList = d.listFiles.filter(_.isFile).toList
combine(fileList.map(f => fromInputStream(new FileInputStream(f))))
}

// Convert from ProtoBuf message repeated Statements to FIRRRTL Block
private def compressStmts(stmts: scala.collection.Seq[ir.Statement]): ir.Statement = stmts match {
case scala.collection.Seq() => ir.EmptyStmt
Expand Down
41 changes: 39 additions & 2 deletions src/main/scala/firrtl/stage/FirrtlAnnotations.scala
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@ import firrtl._
import firrtl.ir.Circuit
import firrtl.annotations.{Annotation, NoTargetAnnotation}
import firrtl.options.{Dependency, HasShellOptions, OptionsException, ShellOption, Unserializable}
import java.io.FileNotFoundException
import java.nio.file.NoSuchFileException
import java.io.{File, FileNotFoundException}
import java.nio.file.{NoSuchFileException, NotDirectoryException}

import firrtl.stage.TransformManager.TransformDependency

Expand Down Expand Up @@ -64,6 +64,43 @@ object FirrtlFileAnnotation extends HasShellOptions {

}

/** Read a directory of ProtoBufs
* - set with `-I/--input-directory`
*
* TODO: Does not currently support FIRRTL files.
* @param dir input directory name
*/
case class FirrtlDirectoryAnnotation(dir: String) extends NoTargetAnnotation with CircuitOption {

def toCircuit(info: Parser.InfoMode): FirrtlCircuitAnnotation = {
val circuit =
try {
proto.FromProto.fromDirectory(dir)
} catch {
case a @ (_: FileNotFoundException | _: NoSuchFileException) =>
throw new OptionsException(s"Directory '$dir' not found! (Did you misspell it?)", a)
case _: NotDirectoryException =>
throw new OptionsException(s"Directory '$dir' is not a directory")
}
FirrtlCircuitAnnotation(circuit)
}

}

object FirrtlDirectoryAnnotation extends HasShellOptions {

val options = Seq(
new ShellOption[String](
longOption = "input-directory",
toAnnotationSeq = a => Seq(FirrtlDirectoryAnnotation(a)),
helpText = "A directory of FIRRTL files",
shortOption = Some("I"),
helpValueName = Some("<directory>")
)
)

}

/** An explicit output file the emitter will write to
* - set with `-o/--output-file`
* @param file output filename
Expand Down
1 change: 1 addition & 0 deletions src/main/scala/firrtl/stage/FirrtlCli.scala
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ trait FirrtlCli { this: Shell =>
parser.note("FIRRTL Compiler Options")
Seq(
FirrtlFileAnnotation,
FirrtlDirectoryAnnotation,
OutputFileAnnotation,
InfoModeAnnotation,
FirrtlSourceAnnotation,
Expand Down
37 changes: 21 additions & 16 deletions src/main/scala/firrtl/stage/phases/Checks.scala
Original file line number Diff line number Diff line change
Expand Up @@ -31,34 +31,39 @@ class Checks extends Phase {
* @throws firrtl.options.OptionsException if any checks fail
*/
def transform(annos: AnnotationSeq): AnnotationSeq = {
val inF, inS, eam, ec, outF, emitter, im, inC = collection.mutable.ListBuffer[Annotation]()
val inF, inS, inD, eam, ec, outF, emitter, im, inC = collection.mutable.ListBuffer[Annotation]()
annos.foreach(_ match {
case a: FirrtlFileAnnotation => a +=: inF
case a: FirrtlSourceAnnotation => a +=: inS
case a: EmitAllModulesAnnotation => a +=: eam
case a: EmitCircuitAnnotation => a +=: ec
case a: OutputFileAnnotation => a +=: outF
case a: InfoModeAnnotation => a +=: im
case a: FirrtlCircuitAnnotation => a +=: inC
case a: FirrtlFileAnnotation => a +=: inF
case a: FirrtlSourceAnnotation => a +=: inS
case a: FirrtlDirectoryAnnotation => a +=: inD
case a: EmitAllModulesAnnotation => a +=: eam
case a: EmitCircuitAnnotation => a +=: ec
case a: OutputFileAnnotation => a +=: outF
case a: InfoModeAnnotation => a +=: im
case a: FirrtlCircuitAnnotation => a +=: inC
case a @ RunFirrtlTransformAnnotation(_: firrtl.Emitter) => a +=: emitter
case _ =>
})

/* At this point, only a FIRRTL Circuit should exist */
if (inF.isEmpty && inS.isEmpty && inC.isEmpty) {
throw new OptionsException(s"""|Unable to determine FIRRTL source to read. None of the following were found:
| - an input file: -i, --input-file, FirrtlFileAnnotation
| - FIRRTL source: --firrtl-source, FirrtlSourceAnnotation
| - FIRRTL circuit: FirrtlCircuitAnnotation""".stripMargin)
if (inF.isEmpty && inS.isEmpty && inD.isEmpty && inC.isEmpty) {
throw new OptionsException(
s"""|Unable to determine FIRRTL source to read. None of the following were found:
| - an input file: -i, --input-file, FirrtlFileAnnotation
| - an input dir: -I, --input-directory, FirrtlDirectoryAnnotation
| - FIRRTL source: --firrtl-source, FirrtlSourceAnnotation
| - FIRRTL circuit: FirrtlCircuitAnnotation""".stripMargin
)
}

/* Only one FIRRTL input can exist */
if (inF.size + inS.size + inC.size > 1) {
throw new OptionsException(
s"""|Multiply defined input FIRRTL sources. More than one of the following was found:
| - an input file (${inF.size} times): -i, --input-file, FirrtlFileAnnotation
| - FIRRTL source (${inS.size} times): --firrtl-source, FirrtlSourceAnnotation
| - FIRRTL circuit (${inC.size} times): FirrtlCircuitAnnotation""".stripMargin
| - an input file (${inF.size} times): -i, --input-file, FirrtlFileAnnotation
| - an input dir (${inD.size} times): -I, --input-directory, FirrtlDirectoryAnnotation
| - FIRRTL source (${inS.size} times): --firrtl-source, FirrtlSourceAnnotation
| - FIRRTL circuit (${inC.size} times): FirrtlCircuitAnnotation""".stripMargin
)
}

Expand Down
Loading

0 comments on commit de96afe

Please sign in to comment.