Skip to content

Commit

Permalink
Merge pull request #3938 from aws-amplify/3937
Browse files Browse the repository at this point in the history
* fix(auth): fix credential decoding

* update comments

* update comment

* add unit tests

* test fix

* failed build
  • Loading branch information
harsh62 authored Dec 11, 2024
2 parents 06207f4 + fb0eb55 commit 833c7bb
Show file tree
Hide file tree
Showing 3 changed files with 205 additions and 18 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -133,8 +133,8 @@ extension AWSCognitoAuthCredentialStore: AmplifyAuthCredentialStoreBehavior {
func retrieveCredential() throws -> AmplifyCredentials {
let authCredentialStoreKey = generateSessionKey(for: authConfiguration)
let authCredentialData = try keychain._getData(authCredentialStoreKey)
let awsCredential: AmplifyCredentials = try decode(data: authCredentialData)
return awsCredential
let amplifyCredential: AmplifyCredentials = try decode(data: authCredentialData)
return amplifyCredential
}

func deleteCredential() throws {
Expand Down Expand Up @@ -191,15 +191,15 @@ private extension AWSCognitoAuthCredentialStore {
do {
return try JSONEncoder().encode(object)
} catch {
throw KeychainStoreError.codingError("Error occurred while encoding AWSCredentials", error)
throw KeychainStoreError.codingError("Error occurred while encoding credentials", error)
}
}

func decode<T: Decodable>(data: Data) throws -> T {
do {
return try JSONDecoder().decode(T.self, from: data)
} catch {
throw KeychainStoreError.codingError("Error occurred while decoding AWSCredentials", error)
throw KeychainStoreError.codingError("Error occurred while decoding credentials", error)
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,10 @@ public enum AuthFlowType {

internal init?(rawValue: String) {
switch rawValue {
case "CUSTOM_AUTH":
case "CUSTOM_AUTH", "CUSTOM_AUTH_WITH_SRP":
self = .customWithSRP
case "CUSTOM_AUTH_WITHOUT_SRP":
self = .customWithoutSRP
case "USER_SRP_AUTH":
self = .userSRP
case "USER_PASSWORD_AUTH":
Expand All @@ -51,8 +53,10 @@ public enum AuthFlowType {

var rawValue: String {
switch self {
case .custom, .customWithSRP, .customWithoutSRP:
return "CUSTOM_AUTH"
case .custom, .customWithSRP:
return "CUSTOM_AUTH_WITH_SRP"
case .customWithoutSRP:
return "CUSTOM_AUTH_WITHOUT_SRP"
case .userSRP:
return "USER_SRP_AUTH"
case .userPassword:
Expand All @@ -62,6 +66,24 @@ public enum AuthFlowType {
}
}

// This initializer has been added to migrate credentials that were created in the pre-passwordless era
internal static func legacyInit(rawValue: String) -> Self? {
switch rawValue {
case "userSRP":
return .userSRP
case "userPassword":
return .userPassword
case "custom":
return .custom
case "customWithSRP":
return .customWithSRP
case "customWithoutSRP":
return .customWithoutSRP
default:
return nil
}
}

public static var userAuth: AuthFlowType {
return .userAuth(preferredFirstFactor: nil)
}
Expand Down Expand Up @@ -110,27 +132,49 @@ extension AuthFlowType: Codable {

// Decoding the enum
public init(from decoder: Decoder) throws {
let container = try decoder.container(keyedBy: CodingKeys.self)
let container: KeyedDecodingContainer<CodingKeys>
do {
container = try decoder.container(keyedBy: CodingKeys.self)
} catch DecodingError.typeMismatch {
// The type mismatch has been added to handle a scenario where the user is migrating passwordless flows.
// Passwordless flow added a new enum case with a associated type.
// The association resulted in encoding structure changes that is different from the non-passwordless flows.
// The structure change causes the type mismatch exception and this code block tries to retrieve the legacy structure and decode it.
let legacyContainer = try decoder.singleValueContainer()
let type = try legacyContainer.decode(String.self)
guard let authFlowType = AuthFlowType.legacyInit(rawValue: type) else {
throw DecodingError.dataCorruptedError(in: legacyContainer, debugDescription: "Invalid AuthFlowType value")
}
self = authFlowType
return
} catch {
throw error
}

// Decode the type (raw value)
let type = try container.decode(String.self, forKey: .type)

// Initialize based on the type
switch type {
case "USER_SRP_AUTH":
self = .userSRP
case "CUSTOM_AUTH":
// Depending on your needs, choose either `.custom`, `.customWithSRP`, or `.customWithoutSRP`
// In this case, we'll default to `.custom`
self = .custom
case "CUSTOM_AUTH", "CUSTOM_AUTH_WITH_SRP":
self = .customWithSRP
case "CUSTOM_AUTH_WITHOUT_SRP":
self = .customWithoutSRP
case "USER_PASSWORD_AUTH":
self = .userPassword
case "USER_AUTH":
let preferredFirstFactorString = try container.decode(String.self, forKey: .preferredFirstFactor)
if let preferredFirstFactor = AuthFactorType(rawValue: preferredFirstFactorString) {
self = .userAuth(preferredFirstFactor: preferredFirstFactor)
if let preferredFirstFactorString = try container.decodeIfPresent(String.self, forKey: .preferredFirstFactor) {
if let preferredFirstFactor = AuthFactorType(rawValue: preferredFirstFactorString) {
self = .userAuth(preferredFirstFactor: preferredFirstFactor)
} else {
throw DecodingError.dataCorruptedError(
forKey: .preferredFirstFactor,
in: container,
debugDescription: "Unable to decode preferredFirstFactor value")
}
} else {
throw DecodingError.dataCorruptedError(forKey: .type, in: container, debugDescription: "Unable to decode preferredFirstFactor value")
self = .userAuth(preferredFirstFactor: nil)
}
default:
throw DecodingError.dataCorruptedError(forKey: .type, in: container, debugDescription: "Invalid AuthFlowType value")
Expand All @@ -152,5 +196,4 @@ extension AuthFlowType {
return .userAuth
}
}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,144 @@
//
// Copyright Amazon.com Inc. or its affiliates.
// All Rights Reserved.
//
// SPDX-License-Identifier: Apache-2.0
//


import XCTest
@testable import AWSCognitoAuthPlugin

class AuthFlowTypeTests: XCTestCase {

func testRawValue() {
XCTAssertEqual(AuthFlowType.userSRP.rawValue, "USER_SRP_AUTH")
XCTAssertEqual(AuthFlowType.customWithSRP.rawValue, "CUSTOM_AUTH_WITH_SRP")
XCTAssertEqual(AuthFlowType.customWithoutSRP.rawValue, "CUSTOM_AUTH_WITHOUT_SRP")
XCTAssertEqual(AuthFlowType.userPassword.rawValue, "USER_PASSWORD_AUTH")
XCTAssertEqual(AuthFlowType.userAuth(preferredFirstFactor: nil).rawValue, "USER_AUTH")
}

func testInitWithRawValue() {
XCTAssertEqual(AuthFlowType(rawValue: "USER_SRP_AUTH"), .userSRP)
XCTAssertEqual(AuthFlowType(rawValue: "CUSTOM_AUTH"), .customWithSRP)
XCTAssertEqual(AuthFlowType(rawValue: "CUSTOM_AUTH_WITH_SRP"), .customWithSRP)
XCTAssertEqual(AuthFlowType(rawValue: "CUSTOM_AUTH_WITHOUT_SRP"), .customWithoutSRP)
XCTAssertEqual(AuthFlowType(rawValue: "USER_PASSWORD_AUTH"), .userPassword)
XCTAssertEqual(AuthFlowType(rawValue: "USER_AUTH"), .userAuth(preferredFirstFactor: nil))
XCTAssertNil(AuthFlowType(rawValue: "INVALID_AUTH"))
}

func testDeprecatedCustom() {
// This test is to ensure the deprecated case is still functional
XCTAssertEqual(AuthFlowType.custom.rawValue, "CUSTOM_AUTH_WITH_SRP")
}

func testEncoding() throws {
let encoder = JSONEncoder()
let userSRP = try encoder.encode(AuthFlowType.userSRP)
XCTAssertEqual(String(data: userSRP, encoding: .utf8), "{\"type\":\"USER_SRP_AUTH\"}")

let customWithSRP = try encoder.encode(AuthFlowType.customWithSRP)
XCTAssertEqual(String(data: customWithSRP, encoding: .utf8), "{\"type\":\"CUSTOM_AUTH_WITH_SRP\"}")

let customWithoutSRP = try encoder.encode(AuthFlowType.customWithoutSRP)
XCTAssertEqual(String(data: customWithoutSRP, encoding: .utf8), "{\"type\":\"CUSTOM_AUTH_WITHOUT_SRP\"}")

let userPassword = try encoder.encode(AuthFlowType.userPassword)
XCTAssertEqual(String(data: userPassword, encoding: .utf8), "{\"type\":\"USER_PASSWORD_AUTH\"}")

let userAuth = try encoder.encode(AuthFlowType.userAuth(preferredFirstFactor: nil))
XCTAssertTrue(String(data: userAuth, encoding: .utf8)?.contains("\"preferredFirstFactor\":null") == true)
XCTAssertTrue(String(data: userAuth, encoding: .utf8)?.contains("\"type\":\"USER_AUTH\"") == true)
}

func testDecoding() throws {
let decoder = JSONDecoder()
let userSRP = try decoder.decode(AuthFlowType.self, from: "{\"type\":\"USER_SRP_AUTH\"}".data(using: .utf8)!)
XCTAssertEqual(userSRP, .userSRP)

let customWithSRP = try decoder.decode(AuthFlowType.self, from: "{\"type\":\"CUSTOM_AUTH_WITH_SRP\"}".data(using: .utf8)!)
XCTAssertEqual(customWithSRP, .customWithSRP)

let customWithoutSRP = try decoder.decode(AuthFlowType.self, from: "{\"type\":\"CUSTOM_AUTH_WITHOUT_SRP\"}".data(using: .utf8)!)
XCTAssertEqual(customWithoutSRP, .customWithoutSRP)

let userPassword = try decoder.decode(AuthFlowType.self, from: "{\"type\":\"USER_PASSWORD_AUTH\"}".data(using: .utf8)!)
XCTAssertEqual(userPassword, .userPassword)

let userAuth = try decoder.decode(AuthFlowType.self, from: "{\"type\":\"USER_AUTH\"}".data(using: .utf8)!)
XCTAssertEqual(userAuth, .userAuth(preferredFirstFactor: nil))
}

func testDecodingWithPreferredFirstFactor() throws {
let decoder = JSONDecoder()
let json = """
{
"type": "USER_AUTH",
"preferredFirstFactor": "SMS_OTP"
}
""".data(using: .utf8)!
let authFlowType = try decoder.decode(AuthFlowType.self, from: json)
XCTAssertEqual(authFlowType, .userAuth(preferredFirstFactor: .smsOTP))
}

func testDecodingLegacyStructure() throws {
let decoder = JSONDecoder()
var legacyJson = "\"userSRP\"".data(using: .utf8)!
var authFlowType = try decoder.decode(AuthFlowType.self, from: legacyJson)
XCTAssertEqual(authFlowType, .userSRP)

legacyJson = "\"userPassword\"".data(using: .utf8)!
authFlowType = try decoder.decode(AuthFlowType.self, from: legacyJson)
XCTAssertEqual(authFlowType, .userPassword)

legacyJson = "\"customWithSRP\"".data(using: .utf8)!
authFlowType = try decoder.decode(AuthFlowType.self, from: legacyJson)
XCTAssertEqual(authFlowType, .customWithSRP)

legacyJson = "\"customWithoutSRP\"".data(using: .utf8)!
authFlowType = try decoder.decode(AuthFlowType.self, from: legacyJson)
XCTAssertEqual(authFlowType, .customWithoutSRP)

legacyJson = "\"custom\"".data(using: .utf8)!
authFlowType = try decoder.decode(AuthFlowType.self, from: legacyJson)
XCTAssertEqual(authFlowType, .custom)
}

func testDecodingInvalidType() {
let decoder = JSONDecoder()
let invalidJson = "{\"type\":\"INVALID_AUTH\"}".data(using: .utf8)!
XCTAssertThrowsError(try decoder.decode(AuthFlowType.self, from: invalidJson)) { error in
guard case DecodingError.dataCorrupted(let context) = error else {
return XCTFail("Expected dataCorrupted error")
}
XCTAssertEqual(context.debugDescription, "Invalid AuthFlowType value")
}
}

func testDecodingInvalidPreferredFirstFactor() {
let decoder = JSONDecoder()
let invalidJson = """
{
"type": "USER_AUTH",
"preferredFirstFactor": "INVALID_FACTOR"
}
""".data(using: .utf8)!
XCTAssertThrowsError(try decoder.decode(AuthFlowType.self, from: invalidJson)) { error in
guard case DecodingError.dataCorrupted(let context) = error else {
return XCTFail("Expected dataCorrupted error")
}
XCTAssertEqual(context.debugDescription, "Unable to decode preferredFirstFactor value")
}
}

func testGetClientFlowType() {
XCTAssertEqual(AuthFlowType.custom.getClientFlowType(), .customAuth)
XCTAssertEqual(AuthFlowType.customWithSRP.getClientFlowType(), .customAuth)
XCTAssertEqual(AuthFlowType.customWithoutSRP.getClientFlowType(), .customAuth)
XCTAssertEqual(AuthFlowType.userSRP.getClientFlowType(), .userSrpAuth)
XCTAssertEqual(AuthFlowType.userPassword.getClientFlowType(), .userPasswordAuth)
XCTAssertEqual(AuthFlowType.userAuth(preferredFirstFactor: nil).getClientFlowType(), .userAuth)
}
}

0 comments on commit 833c7bb

Please sign in to comment.