Skip to content

Commit

Permalink
Merge pull request #25 from vaslabs/feature/secureJoin
Browse files Browse the repository at this point in the history
Feature/secure join
  • Loading branch information
vaslabs authored Nov 7, 2020
2 parents cd32229 + cdc6460 commit a63e64a
Show file tree
Hide file tree
Showing 15 changed files with 109 additions and 68 deletions.
21 changes: 0 additions & 21 deletions endpoints/src/main/scala/cardgame/endpoints/Actions.scala

This file was deleted.

19 changes: 4 additions & 15 deletions endpoints/src/main/scala/cardgame/endpoints/JoiningGame.scala
Original file line number Diff line number Diff line change
@@ -1,31 +1,20 @@
package cardgame.endpoints

import java.util.Base64

import cardgame.model.{ClockedResponse, GameId, PlayerId}
import cardgame.json.circe._
import cardgame.model.{ClockedAction, ClockedResponse, GameId}
import sttp.model.StatusCode
import sttp.tapir._
import cardgame.json.circe._
import sttp.tapir.json.circe._
import cardgame.endpoints.codecs.rsa
object JoiningGame {

import codecs.ids._
import cardgame.endpoints.codecs.ids._
import schema.vector_clock._

val joinPlayer =
endpoint.in(
"game" / path[GameId] / "join"
).post
.in(query[PlayerId]("username"))
.in(
auth.bearer
.map(
rsa.fromString
)(rsaKey =>
Base64.getEncoder.encodeToString(rsaKey.getEncoded)
)
)
.in(jsonBody[ClockedAction])
.out(jsonBody[ClockedResponse])
.errorOut(statusCode(StatusCode.NotFound))
}
4 changes: 2 additions & 2 deletions endpoints/src/main/scala/cardgame/endpoints/admin.scala
Original file line number Diff line number Diff line change
Expand Up @@ -10,14 +10,14 @@ object admin {

val createGame =
endpoint.in("games-admin").post
.in(auth.bearer)
.in(auth.bearer[String])
.out(stringBody("utf-8").map(GameId compose UUID.fromString)(_.value.toString))
.errorOut(statusCode(StatusCode.Forbidden))

val startGame =
endpoint.in("games-admin")
.patch
.in(auth.bearer)
.in(auth.bearer[String])
.in(query[GameId]("game"))
.in(query[DeckId]("deck"))
.in(query[String]("server"))
Expand Down
11 changes: 5 additions & 6 deletions endpoints/src/main/scala/cardgame/endpoints/codecs/ids.scala
Original file line number Diff line number Diff line change
@@ -1,16 +1,15 @@
package cardgame.endpoints.codecs

import cardgame.model.{DeckId, GameId, PlayerId}
import sttp.tapir.Codec
import sttp.tapir.{Codec, CodecFormat}

object ids {
implicit val gameIdCodec = Codec.uuidPlainCodec
implicit val gameIdCodec: Codec[String, GameId, CodecFormat.TextPlain] = Codec.uuid
.map(GameId)(_.value)

implicit val deckIdCodec = Codec.uuidPlainCodec.map(DeckId)(_.value)
implicit val deckIdCodec: Codec[String, DeckId, CodecFormat.TextPlain] = Codec.uuid.map(DeckId)(_.value)

implicit val playerId = Codec.stringPlainCodecUtf8.map(
implicit val playerId: Codec[String, PlayerId, CodecFormat.TextPlain] = Codec.string.map(
PlayerId
)(_.value)
}

}
29 changes: 28 additions & 1 deletion endpoints/src/main/scala/cardgame/endpoints/codecs/rsa.scala
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
package cardgame.endpoints.codecs

import java.io.StringReader
import java.io.{StringReader, StringWriter}
import java.nio.charset.StandardCharsets
import java.security.interfaces.RSAPublicKey
import java.security.spec.X509EncodedKeySpec
import java.util.Base64

import org.bouncycastle.util.io.pem.{PemObject, PemWriter}


object rsa {
def fromString(base64: String): RSAPublicKey = {
import java.security.KeyFactory
Expand All @@ -25,4 +28,28 @@ object rsa {
val pubKeySpec = new X509EncodedKeySpec(keyBytes)
factory.generatePublic(pubKeySpec).asInstanceOf[RSAPublicKey]
}


def show(publicKey: RSAPublicKey): String = {

val pubKeySpec = new X509EncodedKeySpec(publicKey.getEncoded)

val pemObject = new PemObject("PUBLIC KEY", pubKeySpec.getEncoded)
val stringWriter = new StringWriter()
val pemWriter = new PemWriter(stringWriter)
pemWriter.writeObject(pemObject)
pemWriter.flush()
val bareKey = stringWriter.getBuffer.toString
.replaceAll("-----BEGIN PUBLIC KEY-----", "")
.replaceAll("-----END PUBLIC KEY-----", "")
.replaceAll("\n", "")
val content = {
Base64.getEncoder.encodeToString(
bareKey.getBytes(StandardCharsets.UTF_8)
)
}
stringWriter.close()
content

}
}
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
package cardgame.endpoints.schema

import cardgame.endpoints.schema
import cardgame.model.{ClockedAction, ClockedResponse, Event, PlayingGameAction}
import cardgame.model.{Action, ClockedAction, ClockedResponse, Event, PlayerId}
import sttp.tapir.SchemaType.SObjectInfo
import sttp.tapir.{Schema, SchemaType, Validator}

object vector_clock {
import schema.java_types._

implicit val actionSchema: Schema[PlayingGameAction] = Schema.derivedSchema[PlayingGameAction]
implicit val actionSchema: Schema[Action] = Schema.derivedSchema[Action]
implicit val eventSchema: Schema[Event] = Schema.derivedSchema[Event]
implicit val mapSchema: Schema[Map[String, Long]] = Schema.schemaForMap

Expand All @@ -22,4 +22,7 @@ object vector_clock {

implicit val clockedResponseValidation: Validator[ClockedResponse] = Validator.pass

implicit lazy val clockActionValidator: Validator[ClockedAction] = Validator.pass

implicit lazy val playerIdSchema: Schema[PlayerId] = Schema(SchemaType.SString)
}
11 changes: 7 additions & 4 deletions endpoints/src/main/scala/cardgame/json/circe.scala
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,14 @@ package cardgame.json

import java.net.URI
import java.security.interfaces.RSAPublicKey
import java.util.Base64

import cardgame.endpoints.codecs.rsa
import cardgame.model.{Action, CardId, ClockedAction, ClockedResponse, DeckId, Event, Game, HiddenCard, PlayerId}
import io.circe.{Codec, Decoder, Encoder, Json, KeyDecoder, KeyEncoder}
import io.circe.generic.auto._
import io.circe.generic.semiauto._
import io.circe.syntax._
import cats.implicits._
import sun.security.rsa.RSAPublicKeyImpl

import scala.util.Try
object circe {
Expand All @@ -34,12 +33,12 @@ object circe {

implicit val rsaPublicKeyEncoder: Encoder[RSAPublicKey] = Encoder.encodeString.contramap {
rsaPublicKey =>
Base64.getEncoder.encodeToString(rsaPublicKey.getEncoded)
rsa.show(rsaPublicKey)
}

implicit val rsaPublicKeyDecoder: Decoder[RSAPublicKey] = Decoder.decodeString.emapTry {
value =>
Try(RSAPublicKeyImpl.newKey(Base64.getDecoder.decode(value)))
Try(rsa.fromString(value))
}

implicit val playerIdKeyEncoder: KeyEncoder[PlayerId] = KeyEncoder.encodeKeyString.contramap(_.value)
Expand Down Expand Up @@ -90,5 +89,9 @@ object circe {
).mapN(
(clock, serverClock, signature, action) => ClockedAction(action, clock, serverClock, signature)
)
}.handleErrorWith {
df =>
println(df)
Decoder.failed(df)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@ package cardgame.endpoints.codecs

import cardgame.json.circe._
import cardgame.model.{Authorise, ClockedAction, PlayerId}
import org.scalatest.flatspec.AnyFlatSpec
import io.circe.parser._
import org.scalatest.flatspec.AnyFlatSpec
import org.scalatest.matchers.must.Matchers
class ActionJsonSpec extends AnyFlatSpec with Matchers {

Expand Down
5 changes: 4 additions & 1 deletion engine/src/main/scala/cardgame/engine/GameOps.scala
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,10 @@ object GameOps {
): (Game, Event) = {
gameAction match {
case jg: JoinGame =>
join(jg.player)
if (isIdempotent(jg.player.id)(oldClock, newClock))
join(jg.player)
else
game -> InvalidAction(jg.playerId)
case authorisationTicket: Authorise =>
authorise(authorisationTicket)
case EndGame =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ object GameProcessor {
ctx.system.eventStream ! EventStream.Publish(UserResponse(ClockedResponse(event, remoteClockCopy, updateLocalClock)))
behavior(gameAffected, randomizer, updateLocalClock, remoteClockCopy)(validSignature)

case c: Command if c.action.action.isInstanceOf[JoiningGameAction] || validSignature(game, c.action) =>
case c: Command if validSignature(game, c.action) =>

val remoteClock = RemoteClock.of(c.action.vectorClock)
val newRemoteClock = remoteClockCopy |+| remoteClock
Expand All @@ -46,6 +46,13 @@ object GameProcessor {
}

behavior(gameAffected, randomizer, updateLocalClock, newRemoteClock)(validSignature)
case c: Command =>
c match {
case rc: ReplyCommand =>
rc.replyTo ! Left(())
case _ =>
}
Behaviors.same
case Get(playerId, replyTo) =>
replyTo ! Right(personalise(playerId, game))
Behaviors.same
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ class GameProcessorAuthenticationSpec extends AnyWordSpec with Matchers with Bef
userProbe.expectMessageType[Right[Unit, ClockedResponse]].value.event mustBe InvalidAction(joiningPlayer)
}

"a player claiming the same public key is accepted" in {
"a replay of player successful join is rejected" in {
gameProcessor ! GameProcessor.RunCommand(
userProbe.ref,
ClockedAction(
Expand All @@ -72,6 +72,20 @@ class GameProcessorAuthenticationSpec extends AnyWordSpec with Matchers with Bef
)
)

userProbe.expectMessageType[Right[Unit, ClockedResponse]].value.event mustBe InvalidAction(joiningPlayer)
}

"a player can rejoin with an updated logical timestamp" in {
gameProcessor ! GameProcessor.RunCommand(
userProbe.ref,
ClockedAction(
JoinGame(JoiningPlayer(joiningPlayer, keyPair1.getPublic.asInstanceOf[RSAPublicKey])),
Map("player1" -> 2),
0L,
"",
)
)

userProbe.expectMessageType[Right[Unit, ClockedResponse]].value.event mustBe PlayerJoined(joiningPlayer)
}

Expand Down
2 changes: 1 addition & 1 deletion project/Dependencies.scala
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ object Dependencies {
val core = "3.1.1"
}
object tapir {
val core = "0.12.23"
val core = "0.16.16"
}

object circe {
Expand Down
12 changes: 6 additions & 6 deletions service/src/main/scala/cardgame/events/package.scala
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ import akka.actor.typed.ActorRef
import akka.stream.OverflowStrategy
import akka.stream.scaladsl.Source
import akka.stream.typed.scaladsl.ActorSource
import cardgame.model.{ClockedAction, ClockedResponse, Game, GameCompleted, JoiningGameAction, PlayerId, PlayingGameAction, StartedGame, StartingGame}
import cardgame.model._
import cardgame.processor.ActiveGames
import org.bouncycastle.util.encoders.Hex

Expand Down Expand Up @@ -41,8 +41,8 @@ package object events {
}
testSignature.getOrElse(false)
}
import io.circe.syntax._
import cardgame.json.circe._
import io.circe.syntax._

def validateSignature(game: Game, action: ClockedAction): Boolean = {
val plainText = {
Expand All @@ -51,10 +51,10 @@ package object events {
val verify = (game, action.action) match {
case (sg: StartedGame, a: PlayingGameAction) =>
sg.players.find(_.id == a.player).exists(p => verifySignature(plainText, action.signature, p.publicKey))
case (sg: StartingGame, a: JoiningGameAction) =>
sg.playersJoined.find(_.id == a.playerId).exists(
p => verifySignature(plainText, action.signature, p.publicKey)
)
case (StartingGame(players), Authorise(playerId)) =>
players.find(_.id == playerId).exists(p => verifySignature(plainText, action.signature, p.publicKey))
case (StartingGame(_), JoinGame(JoiningPlayer(_ , publicKey))) =>
verifySignature(plainText, action.signature, publicKey)
case _ =>
true
}
Expand Down
11 changes: 7 additions & 4 deletions service/src/main/scala/cardgame/routes/routes.scala
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,10 @@ import akka.http.scaladsl.server.Directives._
import akka.stream.scaladsl.{Flow, Sink}
import akka.stream.typed.scaladsl.ActorSink
import akka.util.Timeout
import cardgame.endpoints._
import cardgame.endpoints.{JoiningGame, _}
import cardgame.events._
import cardgame.json.circe._
import cardgame.model.{ClockedAction, GameId, PlayerId, RemoteClock}
import cardgame.model._
import cardgame.processor.ActiveGames
import cardgame.processor.ActiveGames.api._
import cardgame.processor.ActiveGames.{DoGameAction, Ignore}
Expand All @@ -21,14 +21,14 @@ import io.circe.parser._
import io.circe.syntax._
import sttp.tapir.server.akkahttp._

import scala.concurrent.Future
import scala.concurrent.duration._

class Routes(activeGames: ActorRef[ActiveGames.Protocol])(implicit scheduler: Scheduler) {

implicit val timeout = Timeout(5 seconds)
private val PlayerIdMatcher = RemainingPath.map(_.toString).map(PlayerId)


val websocketRoute = get {
path("live" / "actions" / PathGameId / PlayerIdMatcher) {
(gameId, playerId) =>
Expand All @@ -37,7 +37,10 @@ class Routes(activeGames: ActorRef[ActiveGames.Protocol])(implicit scheduler: Sc
}

val startingGame = JoiningGame.joinPlayer.toRoute {
case (gameId, playerId, publicKey) => activeGames.joinGame(gameId, playerId, "", RemoteClock.zero, publicKey)
case (gameId, ClockedAction(JoinGame(player), vectorClock, _, signature)) =>
activeGames.joinGame(gameId, player.id, signature, RemoteClock.of(vectorClock), player.publicKey)
case _ =>
Future.successful(Left(()))
} ~ View.gameStatus.toRoute {
case (gameId, playerId) => activeGames.getGame(gameId, playerId)
} ~ get {
Expand Down
Loading

0 comments on commit a63e64a

Please sign in to comment.