Skip to content

Commit

Permalink
Merge pull request #440 from scalacenter/independent-scripted-parallel
Browse files Browse the repository at this point in the history
Improve and make scripted parallel
  • Loading branch information
jvican authored Nov 16, 2017
2 parents f3d2b73 + 2c9dbb7 commit 67cb2eb
Show file tree
Hide file tree
Showing 5 changed files with 377 additions and 10 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
package sbt.internal.inc

import org.scalatest.exceptions.TestFailedException
import sbt.internal.inc
import sbt.internal.scripted._
import sbt.internal.inc.BatchScriptRunner.States

/** Defines an alternative script runner that allows batch execution. */
private[sbt] class BatchScriptRunner extends ScriptRunner {

/** Defines a method to run batched execution.
*
* @param statements The list of handlers and statements.
* @param states The states of the runner. In case it's empty, inherited apply is called.
*/
def apply(statements: List[(StatementHandler, Statement)], states: States): Unit = {
if (states.isEmpty) super.apply(statements)
else statements.foreach(st => processStatement(st._1, st._2, states))
}

def initStates(states: States, handlers: Seq[StatementHandler]): Unit =
handlers.foreach(handler => states(handler) = handler.initialState)

def cleanUpHandlers(handlers: Seq[StatementHandler], states: States): Unit = {
for (handler <- handlers; state <- states.get(handler)) {
try handler.finish(state.asInstanceOf[handler.State])
catch { case _: Exception => () }
}
}

import BatchScriptRunner.PreciseScriptedError
def processStatement(handler: StatementHandler, statement: Statement, states: States): Unit = {
val state = states(handler).asInstanceOf[handler.State]
val nextState =
try { Right(handler(statement.command, statement.arguments, state)) } catch {
case e: Exception => Left(e)
}
nextState match {
case Left(err) =>
if (statement.successExpected) {
err match {
case t: TestFailed =>
val errorMessage = s"${t.getMessage} produced by"
throw new PreciseScriptedError(statement, errorMessage, null)
case _ => throw new PreciseScriptedError(statement, "Command failed", err)
}
} else ()
case Right(s) =>
if (statement.successExpected) states(handler) = s
else throw new PreciseScriptedError(statement, "Expecting error at", null)
}
}
}

private[sbt] object BatchScriptRunner {
import scala.collection.mutable
type States = mutable.HashMap[StatementHandler, Any]

// Should be used instead of sbt.internal.scripted.TestException that doesn't show failed command
final class PreciseScriptedError(st: Statement, msg: String, e: Throwable)
extends RuntimeException(s"$msg: '${st.command} ${st.arguments.mkString(" ")}'", e) {
override def fillInStackTrace = e
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -57,8 +57,18 @@ final class IncHandler(directory: File, cacheDir: File, scriptedLog: ManagedLogg
type IncCommand = (ProjectStructure, List[String], IncInstance) => Unit

val compiler = new IncrementalCompilerImpl
def initialState: Option[IncInstance] = None
def finish(state: Option[IncInstance]): Unit = ()

def initialState: Option[IncInstance] = {
initBuildStructure()
None
}

def finish(state: Option[IncInstance]): Unit = {
// Required so that next projects re-read the project structure
buildStructure.clear()
()
}

val buildStructure: mutable.Map[String, ProjectStructure] = mutable.Map.empty
def initBuildStructure(): Unit = {
val build = initBuild
Expand All @@ -70,8 +80,6 @@ final class IncHandler(directory: File, cacheDir: File, scriptedLog: ManagedLogg
}
}

initBuildStructure()

private final val RootIdentifier = "root"
def initBuild: Build = {
if ((directory / "build.json").exists) {
Expand Down Expand Up @@ -108,14 +116,14 @@ final class IncHandler(directory: File, cacheDir: File, scriptedLog: ManagedLogg
private final val noLogger = Logger.Null
private[this] def onNewIncInstance(p: ProjectStructure): IncInstance = {
val scalaVersion = p.scalaVersion
val (compilerBridge, si) = IncHandler.scriptedCompilerCache.get(scalaVersion) match {
val (compilerBridge, si) = IncHandler.getCompilerCacheFor(scalaVersion) match {
case Some(alreadyInstantiated) =>
alreadyInstantiated
case None =>
val compilerBridge = getCompilerBridge(cacheDir, noLogger, scalaVersion)
val si = scalaInstance(scalaVersion, cacheDir, noLogger)
val toCache = (compilerBridge, si)
IncHandler.scriptedCompilerCache.put(scalaVersion, toCache)
IncHandler.putCompilerCache(scalaVersion, toCache)
toCache
}
val analyzingCompiler = scalaCompiler(si, compilerBridge)
Expand Down Expand Up @@ -505,7 +513,12 @@ case class ProjectStructure(

object IncHandler {
type Cached = (File, xsbti.compile.ScalaInstance)
private[internal] final val scriptedCompilerCache = new mutable.WeakHashMap[String, Cached]()
private[this] final val scriptedCompilerCache = new mutable.WeakHashMap[String, Cached]()
def getCompilerCacheFor(scalaVersion: String): Option[Cached] =
synchronized(scriptedCompilerCache.get(scalaVersion))
def putCompilerCache(scalaVersion: String, cached: Cached): Option[Cached] =
synchronized(scriptedCompilerCache.put(scalaVersion, cached))

private[internal] final val classLoaderCache = Some(
new ClassLoaderCache(new URLClassLoader(Array())))
}
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,80 @@ package sbt.internal.inc

import java.io.File

import sbt.internal.scripted.ScriptedRunnerImpl
import sbt.internal.scripted.{ HandlersProvider, ListTests, ScriptedTest }
import sbt.io.IO
import sbt.util.Logger

import scala.collection.parallel.ParSeq

class IncScriptedRunner {
def run(resourceBaseDirectory: File, bufferLog: Boolean, tests: Array[String]): Unit = {
IO.withTemporaryDirectory { tempDir =>
// Create a global temporary directory to store the bridge et al
val handlers = new IncScriptedHandlers(tempDir)
ScriptedRunnerImpl.run(resourceBaseDirectory, bufferLog, tests, handlers);
ScriptedRunnerImpl.run(resourceBaseDirectory, bufferLog, tests, handlers, 4)
}
}
}

object ScriptedRunnerImpl {
type TestRunner = () => Seq[Option[String]]

def run(
resourceBaseDirectory: File,
bufferLog: Boolean,
tests: Array[String],
handlersProvider: HandlersProvider,
instances: Int
): Unit = {
val globalLogger = newLogger
val logsDir = newScriptedLogsDir
val runner = new ScriptedTests(resourceBaseDirectory, bufferLog, handlersProvider, logsDir)
val scriptedTests = get(tests, resourceBaseDirectory, globalLogger)
val scriptedRunners = runner.batchScriptedRunner(scriptedTests, instances)
val parallelRunners = scriptedRunners.toParArray
// Using this deprecated value for 2.11 support
val pool = new scala.concurrent.forkjoin.ForkJoinPool(instances)
parallelRunners.tasksupport = new scala.collection.parallel.ForkJoinTaskSupport(pool)
runAllInParallel(parallelRunners)
globalLogger.info(s"Log files can be found at ${logsDir.getAbsolutePath}")
}

private val nl = IO.Newline
private val nlt = nl + "\t"
class ScriptedFailure(tests: Seq[String]) extends RuntimeException(tests.mkString(nlt, nlt, nl)) {
// We are not interested in the stack trace here, only the failing tests
override def fillInStackTrace = this
}

private def reportErrors(errors: Seq[String]): Unit =
if (errors.nonEmpty) throw new ScriptedFailure(errors) else ()

def runAllInParallel(tests: ParSeq[TestRunner]): Unit = {
reportErrors(tests.flatMap(test => test.apply().flatten.toSeq).toList)
}

def get(tests: Seq[String], baseDirectory: File, log: Logger): Seq[ScriptedTest] =
if (tests.isEmpty) listTests(baseDirectory, log) else parseTests(tests)

def listTests(baseDirectory: File, log: Logger): Seq[ScriptedTest] =
(new ListTests(baseDirectory, _ => true, log)).listTests

def parseTests(in: Seq[String]): Seq[ScriptedTest] = for (testString <- in) yield {
testString.split("/").map(_.trim) match {
case Array(group, name) => ScriptedTest(group, name)
case elems =>
sys.error(s"Expected two arguments 'group/name', obtained ${elems.mkString("/")}")
}
}

private[sbt] def newLogger: Logger = sbt.internal.util.ConsoleLogger()

private[this] val random = new java.util.Random()
private[sbt] def newScriptedLogsDir: File = {
val randomName = "scripted-logs-" + java.lang.Integer.toHexString(random.nextInt)
val logsDir = new File(IO.temporaryDirectory, randomName)
IO.createDirectory(logsDir)
logsDir
}
}
Loading

0 comments on commit 67cb2eb

Please sign in to comment.