Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve and make scripted parallel #440

Merged
merged 2 commits into from
Nov 16, 2017
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
package sbt.internal.inc

import org.scalatest.exceptions.TestFailedException
import sbt.internal.inc
import sbt.internal.scripted._
import sbt.internal.inc.BatchScriptRunner.States

/** Defines an alternative script runner that allows batch execution. */
private[sbt] class BatchScriptRunner extends ScriptRunner {

/** Defines a method to run batched execution.
*
* @param statements The list of handlers and statements.
* @param states The states of the runner. In case it's empty, inherited apply is called.
*/
def apply(statements: List[(StatementHandler, Statement)], states: States): Unit = {
if (states.isEmpty) super.apply(statements)
else statements.foreach(st => processStatement(st._1, st._2, states))
}

def initStates(states: States, handlers: Seq[StatementHandler]): Unit =
handlers.foreach(handler => states(handler) = handler.initialState)

def cleanUpHandlers(handlers: Seq[StatementHandler], states: States): Unit = {
for (handler <- handlers; state <- states.get(handler)) {
try handler.finish(state.asInstanceOf[handler.State])
catch { case _: Exception => () }
}
}

import BatchScriptRunner.PreciseScriptedError
def processStatement(handler: StatementHandler, statement: Statement, states: States): Unit = {
val state = states(handler).asInstanceOf[handler.State]
val nextState =
try { Right(handler(statement.command, statement.arguments, state)) } catch {
case e: Exception => Left(e)
}
nextState match {
case Left(err) =>
if (statement.successExpected) {
err match {
case t: TestFailed =>
val errorMessage = s"${t.getMessage} produced by"
throw new PreciseScriptedError(statement, errorMessage, null)
case _ => throw new PreciseScriptedError(statement, "Command failed", err)
}
} else ()
case Right(s) =>
if (statement.successExpected) states(handler) = s
else throw new PreciseScriptedError(statement, "Expecting error at", null)
}
}
}

private[sbt] object BatchScriptRunner {
import scala.collection.mutable
type States = mutable.HashMap[StatementHandler, Any]

// Should be used instead of sbt.internal.scripted.TestException that doesn't show failed command
final class PreciseScriptedError(st: Statement, msg: String, e: Throwable)
extends RuntimeException(s"$msg: '${st.command} ${st.arguments.mkString(" ")}'", e) {
override def fillInStackTrace = e
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -57,8 +57,18 @@ final class IncHandler(directory: File, cacheDir: File, scriptedLog: ManagedLogg
type IncCommand = (ProjectStructure, List[String], IncInstance) => Unit

val compiler = new IncrementalCompilerImpl
def initialState: Option[IncInstance] = None
def finish(state: Option[IncInstance]): Unit = ()

def initialState: Option[IncInstance] = {
initBuildStructure()
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is required so that batch execution reusing the same IncHandler works.

None
}

def finish(state: Option[IncInstance]): Unit = {
// Required so that next projects re-read the project structure
buildStructure.clear()
()
}

val buildStructure: mutable.Map[String, ProjectStructure] = mutable.Map.empty
def initBuildStructure(): Unit = {
val build = initBuild
Expand All @@ -70,8 +80,6 @@ final class IncHandler(directory: File, cacheDir: File, scriptedLog: ManagedLogg
}
}

initBuildStructure()

private final val RootIdentifier = "root"
def initBuild: Build = {
if ((directory / "build.json").exists) {
Expand Down Expand Up @@ -108,14 +116,14 @@ final class IncHandler(directory: File, cacheDir: File, scriptedLog: ManagedLogg
private final val noLogger = Logger.Null
private[this] def onNewIncInstance(p: ProjectStructure): IncInstance = {
val scalaVersion = p.scalaVersion
val (compilerBridge, si) = IncHandler.scriptedCompilerCache.get(scalaVersion) match {
val (compilerBridge, si) = IncHandler.getCompilerCacheFor(scalaVersion) match {
case Some(alreadyInstantiated) =>
alreadyInstantiated
case None =>
val compilerBridge = getCompilerBridge(cacheDir, noLogger, scalaVersion)
val si = scalaInstance(scalaVersion, cacheDir, noLogger)
val toCache = (compilerBridge, si)
IncHandler.scriptedCompilerCache.put(scalaVersion, toCache)
IncHandler.putCompilerCache(scalaVersion, toCache)
toCache
}
val analyzingCompiler = scalaCompiler(si, compilerBridge)
Expand Down Expand Up @@ -505,7 +513,12 @@ case class ProjectStructure(

object IncHandler {
type Cached = (File, xsbti.compile.ScalaInstance)
private[internal] final val scriptedCompilerCache = new mutable.WeakHashMap[String, Cached]()
private[this] final val scriptedCompilerCache = new mutable.WeakHashMap[String, Cached]()
def getCompilerCacheFor(scalaVersion: String): Option[Cached] =
synchronized(scriptedCompilerCache.get(scalaVersion))
def putCompilerCache(scalaVersion: String, cached: Cached): Option[Cached] =
synchronized(scriptedCompilerCache.put(scalaVersion, cached))

private[internal] final val classLoaderCache = Some(
new ClassLoaderCache(new URLClassLoader(Array())))
}
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,80 @@ package sbt.internal.inc

import java.io.File

import sbt.internal.scripted.ScriptedRunnerImpl
import sbt.internal.scripted.{ HandlersProvider, ListTests, ScriptedTest }
import sbt.io.IO
import sbt.util.Logger

import scala.collection.parallel.ParSeq

class IncScriptedRunner {
def run(resourceBaseDirectory: File, bufferLog: Boolean, tests: Array[String]): Unit = {
IO.withTemporaryDirectory { tempDir =>
// Create a global temporary directory to store the bridge et al
val handlers = new IncScriptedHandlers(tempDir)
ScriptedRunnerImpl.run(resourceBaseDirectory, bufferLog, tests, handlers);
ScriptedRunnerImpl.run(resourceBaseDirectory, bufferLog, tests, handlers, 4)
}
}
}

object ScriptedRunnerImpl {
type TestRunner = () => Seq[Option[String]]

def run(
resourceBaseDirectory: File,
bufferLog: Boolean,
tests: Array[String],
handlersProvider: HandlersProvider,
instances: Int
): Unit = {
val globalLogger = newLogger
val logsDir = newScriptedLogsDir
val runner = new ScriptedTests(resourceBaseDirectory, bufferLog, handlersProvider, logsDir)
val scriptedTests = get(tests, resourceBaseDirectory, globalLogger)
val scriptedRunners = runner.batchScriptedRunner(scriptedTests, instances)
val parallelRunners = scriptedRunners.toParArray
// Using this deprecated value for 2.11 support
val pool = new scala.concurrent.forkjoin.ForkJoinPool(instances)
parallelRunners.tasksupport = new scala.collection.parallel.ForkJoinTaskSupport(pool)
runAllInParallel(parallelRunners)
globalLogger.info(s"Log files can be found at ${logsDir.getAbsolutePath}")
}

private val nl = IO.Newline
private val nlt = nl + "\t"
class ScriptedFailure(tests: Seq[String]) extends RuntimeException(tests.mkString(nlt, nlt, nl)) {
// We are not interested in the stack trace here, only the failing tests
override def fillInStackTrace = this
}

private def reportErrors(errors: Seq[String]): Unit =
if (errors.nonEmpty) throw new ScriptedFailure(errors) else ()

def runAllInParallel(tests: ParSeq[TestRunner]): Unit = {
reportErrors(tests.flatMap(test => test.apply().flatten.toSeq).toList)
}

def get(tests: Seq[String], baseDirectory: File, log: Logger): Seq[ScriptedTest] =
if (tests.isEmpty) listTests(baseDirectory, log) else parseTests(tests)

def listTests(baseDirectory: File, log: Logger): Seq[ScriptedTest] =
(new ListTests(baseDirectory, _ => true, log)).listTests

def parseTests(in: Seq[String]): Seq[ScriptedTest] = for (testString <- in) yield {
testString.split("/").map(_.trim) match {
case Array(group, name) => ScriptedTest(group, name)
case elems =>
sys.error(s"Expected two arguments 'group/name', obtained ${elems.mkString("/")}")
}
}

private[sbt] def newLogger: Logger = sbt.internal.util.ConsoleLogger()

private[this] val random = new java.util.Random()
private[sbt] def newScriptedLogsDir: File = {
val randomName = "scripted-logs-" + java.lang.Integer.toHexString(random.nextInt)
val logsDir = new File(IO.temporaryDirectory, randomName)
IO.createDirectory(logsDir)
logsDir
}
}
Loading