Skip to content

Commit

Permalink
Fix #317 - Create enrollment group for device models
Browse files Browse the repository at this point in the history
  • Loading branch information
kbeaugrand committed Feb 26, 2022
1 parent 4226454 commit cfad9f1
Show file tree
Hide file tree
Showing 7 changed files with 105 additions and 37 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
using Microsoft.Azure.Devices.Shared;
using AzureIoTHub.Portal.Shared.Models.V10.DeviceModel;
using AzureIoTHub.Portal.Server.Controllers.V10;
using Microsoft.Azure.Devices.Provisioning.Service;

namespace AzureIoTHub.Portal.Server.Tests.Controllers.V10
{
Expand All @@ -35,6 +36,7 @@ public class DeviceModelsControllerTests
private Mock<IDeviceModelMapper<DeviceModel, DeviceModel>> mockDeviceModelMapper;
private Mock<IDeviceService> mockDeviceService;
private Mock<ITableClientFactory> mockTableClientFactory;
private Mock<IDeviceProvisioningServiceManager> mockDeviceProvisioningServiceManager;
private Mock<TableClient> mockDeviceTemplatesTableClient;
private Mock<TableClient> mockCommandsTableClient;

Expand All @@ -46,6 +48,7 @@ public void SetUp()
this.mockLogger = this.mockRepository.Create<ILogger<DeviceModelsController>>();
this.mockDeviceModelImageManager = this.mockRepository.Create<IDeviceModelImageManager>();
this.mockDeviceModelCommandMapper = this.mockRepository.Create<IDeviceModelCommandMapper>();
this.mockDeviceProvisioningServiceManager = this.mockRepository.Create<IDeviceProvisioningServiceManager>();
this.mockDeviceModelMapper = this.mockRepository.Create<IDeviceModelMapper<DeviceModel, DeviceModel>>();
this.mockDeviceService = this.mockRepository.Create<IDeviceService>();
this.mockTableClientFactory = this.mockRepository.Create<ITableClientFactory>();
Expand All @@ -60,7 +63,8 @@ private DeviceModelsController CreateDeviceModelsController()
this.mockDeviceModelImageManager.Object,
this.mockDeviceModelMapper.Object,
this.mockDeviceService.Object,
this.mockTableClientFactory.Object);
this.mockTableClientFactory.Object,
mockDeviceProvisioningServiceManager.Object);

return result;
}
Expand Down Expand Up @@ -318,8 +322,10 @@ public async Task Post_Should_Create_A_New_Entity()

var requestModel = new DeviceModel
{
Name = Guid.NewGuid().ToString(),
ModelId = Guid.NewGuid().ToString()
};
var mockEnrollmentGroup = this.mockRepository.Create<EnrollmentGroup>(string.Empty, new SymmetricKeyAttestation(string.Empty, string.Empty));

var mockResponse = this.mockRepository.Create<Response>();

Expand All @@ -333,6 +339,12 @@ public async Task Post_Should_Create_A_New_Entity()
It.Is<TableEntity>(x => x.RowKey == requestModel.ModelId && x.PartitionKey == LoRaWANDeviceModelsController.DefaultPartitionKey),
It.IsAny<DeviceModel>()));

this.mockDeviceProvisioningServiceManager.Setup(c => c.CreateEnrollmentGroupFormModelAsync(
It.IsAny<string>(),
It.Is<string>(x => x == requestModel.Name),
It.IsAny<TwinCollection>()))
.ReturnsAsync(mockEnrollmentGroup.Object);

// Act
var result = await deviceModelsController.Post(requestModel);

Expand All @@ -355,10 +367,12 @@ public async Task WhenEmptyModelId_Post_Should_Create_A_New_Entity()

var requestModel = new DeviceModel
{
ModelId = String.Empty
ModelId = String.Empty,
Name = Guid.NewGuid().ToString(),
};

var mockResponse = this.mockRepository.Create<Response>();
var mockEnrollmentGroup = this.mockRepository.Create<EnrollmentGroup>(string.Empty, new SymmetricKeyAttestation(string.Empty, string.Empty));

this.mockDeviceTemplatesTableClient.Setup(c => c.UpsertEntityAsync(
It.Is<TableEntity>(x => x.RowKey != requestModel.ModelId && x.PartitionKey == LoRaWANDeviceModelsController.DefaultPartitionKey),
Expand All @@ -373,6 +387,12 @@ public async Task WhenEmptyModelId_Post_Should_Create_A_New_Entity()
this.mockTableClientFactory.Setup(c => c.GetDeviceTemplates())
.Returns(mockDeviceTemplatesTableClient.Object);

this.mockDeviceProvisioningServiceManager.Setup(c => c.CreateEnrollmentGroupFormModelAsync(
It.IsAny<string>(),
It.Is<string>(x => x == requestModel.Name),
It.IsAny<TwinCollection>()))
.ReturnsAsync(mockEnrollmentGroup.Object);

// Act
var result = await deviceModelsController.Post(requestModel);

Expand Down Expand Up @@ -417,9 +437,12 @@ public async Task Put_Should_Update_The_Device_Model()

var requestModel = new DeviceModel
{
Name = Guid.NewGuid().ToString(),
ModelId = deviceModel.RowKey
};

var mockEnrollmentGroup = this.mockRepository.Create<EnrollmentGroup>(string.Empty, new SymmetricKeyAttestation(string.Empty, string.Empty));

this.mockDeviceTemplatesTableClient.Setup(c => c.UpsertEntityAsync(
It.Is<TableEntity>(x => x.RowKey == deviceModel.RowKey && x.PartitionKey == LoRaWANDeviceModelsController.DefaultPartitionKey),
It.IsAny<TableUpdateMode>(),
Expand All @@ -433,6 +456,12 @@ public async Task Put_Should_Update_The_Device_Model()
this.mockTableClientFactory.Setup(c => c.GetDeviceTemplates())
.Returns(mockDeviceTemplatesTableClient.Object);

this.mockDeviceProvisioningServiceManager.Setup(c => c.CreateEnrollmentGroupFormModelAsync(
It.IsAny<string>(),
It.Is<string>(x => x == requestModel.Name),
It.IsAny<TwinCollection>()))
.ReturnsAsync(mockEnrollmentGroup.Object);

// Act
var result = await deviceModelsController.Put(requestModel);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ namespace AzureIoTHub.Portal.Server.Controllers.V10
using AzureIoTHub.Portal.Shared.Models.V10.DeviceModel;
using Microsoft.AspNetCore.Http;
using Microsoft.AspNetCore.Mvc;
using Microsoft.Azure.Devices.Shared;
using Microsoft.Extensions.Logging;

public abstract class DeviceModelsControllerBase<TListItemModel, TModel> : ControllerBase
Expand Down Expand Up @@ -52,6 +53,11 @@ public abstract class DeviceModelsControllerBase<TListItemModel, TModel> : Contr
/// </summary>
private readonly IDeviceService devicesService;

/// <summary>
/// The device provisioning service manager.
/// </summary>
private readonly IDeviceProvisioningServiceManager deviceProvisioningServiceManager;

/// <summary>
/// The device template filter.
/// </summary>
Expand All @@ -66,12 +72,14 @@ public abstract class DeviceModelsControllerBase<TListItemModel, TModel> : Contr
/// <param name="devicesService">The devices service.</param>
/// <param name="tableClientFactory">The table client factory.</param>
/// <param name="filter">The device template filter query string.</param>
/// <param name="deviceProvisioningServiceManager">The device provisioning service manager.</param>
public DeviceModelsControllerBase(
ILogger log,
IDeviceModelImageManager deviceModelImageManager,
IDeviceModelMapper<TListItemModel, TModel> deviceModelMapper,
IDeviceService devicesService,
ITableClientFactory tableClientFactory,
IDeviceProvisioningServiceManager deviceProvisioningServiceManager,
string filter)
{
this.log = log;
Expand All @@ -80,6 +88,7 @@ public DeviceModelsControllerBase(
this.deviceModelImageManager = deviceModelImageManager;
this.devicesService = devicesService;
this.filter = filter;
this.deviceProvisioningServiceManager = deviceProvisioningServiceManager;
}

/// <summary>
Expand Down Expand Up @@ -357,6 +366,10 @@ private async Task SaveEntity(TableEntity entity, TModel deviceModelObject)
await this.tableClientFactory
.GetDeviceTemplates()
.UpsertEntityAsync(entity);

var deviceModelTwin = new TwinCollection();

await this.deviceProvisioningServiceManager.CreateEnrollmentGroupFormModelAsync(deviceModelObject.ModelId, deviceModelObject.Name, deviceModelTwin);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -27,12 +27,14 @@ public class DeviceModelsController : DeviceModelsControllerBase<DeviceModel, De
/// <param name="deviceModelMapper">The device model mapper.</param>
/// <param name="devicesService">The devices service.</param>
/// <param name="tableClientFactory">The table client factory.</param>
/// <param name="deviceProvisioningServiceManager">The device provisioning service manager.</param>
public DeviceModelsController(ILogger<DeviceModelsControllerBase<DeviceModel, DeviceModel>> log,
IDeviceModelImageManager deviceModelImageManager,
IDeviceModelMapper<DeviceModel, DeviceModel> deviceModelMapper,
IDeviceService devicesService,
ITableClientFactory tableClientFactory)
: base(log, deviceModelImageManager, deviceModelMapper, devicesService, tableClientFactory, $"")
ITableClientFactory tableClientFactory,
IDeviceProvisioningServiceManager deviceProvisioningServiceManager)
: base(log, deviceModelImageManager, deviceModelMapper, devicesService, tableClientFactory, deviceProvisioningServiceManager, $"")
{
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,13 +30,15 @@ public class LoRaWANDeviceModelsController : DeviceModelsControllerBase<DeviceMo
/// <param name="deviceModelMapper">The device model mapper.</param>
/// <param name="devicesService">The devices service.</param>
/// <param name="tableClientFactory">The table client factory.</param>
/// <param name="deviceProvisioningServiceManager">The device provisioning service manager.</param>
public LoRaWANDeviceModelsController(
ILogger<DeviceModelsControllerBase<DeviceModel, LoRaDeviceModel>> log,
IDeviceModelImageManager deviceModelImageManager,
IDeviceModelMapper<DeviceModel, LoRaDeviceModel> deviceModelMapper,
IDeviceService devicesService,
ITableClientFactory tableClientFactory)
: base(log, deviceModelImageManager, deviceModelMapper, devicesService, tableClientFactory,
ITableClientFactory tableClientFactory,
IDeviceProvisioningServiceManager deviceProvisioningServiceManager)
: base(log, deviceModelImageManager, deviceModelMapper, devicesService, tableClientFactory, deviceProvisioningServiceManager,
$"{nameof(DeviceModel.SupportLoRaFeatures)} eq true")
{
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,12 +27,6 @@ public async Task<string> GetSymmetricKey(string deviceId, string deviceType)
}
catch (ProvisioningServiceClientHttpException e)
{
if (e.StatusCode == System.Net.HttpStatusCode.NotFound)
{
_ = await this.deviceProvisioningServiceManager.CreateEnrollmentGroupAsync(deviceType);
attestation = await this.deviceProvisioningServiceManager.GetAttestation(deviceType);
}

throw new InvalidOperationException("Failed to get symmetricKey.", e);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,41 +23,44 @@ public DeviceProvisioningServiceManager(ProvisioningServiceClient dps, ConfigHan

public async Task<EnrollmentGroup> CreateEnrollmentGroupAsync(string deviceType)
{
string enrollmentGroupId;
TwinCollection tags;
TwinCollection desiredProperties;
var twinState = new TwinState(
tags: new TwinCollection($"{{ \"purpose\":\"{deviceType}\" }}"),
desiredProperties: new TwinCollection());

if (deviceType == "LoRa Network Server")
{
enrollmentGroupId = this.config.DPSLoRaEnrollmentGroup;
tags = new TwinCollection("{ \"purpose\":\"" + "LoRa Network Server" + "\" }");
desiredProperties = new TwinCollection("{ }");
}
else
{
enrollmentGroupId = this.config.DPSDefaultEnrollmentGroup;
tags = new TwinCollection("{ \"purpose\":\"" + "Unknown" + "\" }");
desiredProperties = new TwinCollection("{ }");
}
return await this.CreateNewEnrollmentGroup(deviceType, true, twinState);
}

public async Task<EnrollmentGroup> CreateEnrollmentGroupFormModelAsync(string modelId, string modelName, TwinCollection desiredProperties)
{
var twinState = new TwinState(
tags: new TwinCollection($"{{ \"modelId\":\"{modelId}\" }}"),
desiredProperties: new TwinCollection());

return await this.CreateNewEnrollmentGroup(modelName, false, twinState);
}

/// <summary>
/// Create
/// </summary>
/// <returns></returns>
private async Task<EnrollmentGroup> CreateNewEnrollmentGroup(string name, bool iotEdge, TwinState initialTwinState)
{
string enrollmentGroupPrimaryKey = GenerateKey();
string enrollmentGroupSecondaryKey = GenerateKey();

SymmetricKeyAttestation attestation = new SymmetricKeyAttestation(enrollmentGroupPrimaryKey, enrollmentGroupSecondaryKey);

EnrollmentGroup enrollmentGroup = new EnrollmentGroup(enrollmentGroupId, attestation)
EnrollmentGroup enrollmentGroup = new EnrollmentGroup(ComputeEnrollmentGroupName(name), attestation)
{
ProvisioningStatus = ProvisioningStatus.Enabled,
Capabilities = new DeviceCapabilities
{
IotEdge = true
IotEdge = iotEdge
},
InitialTwinState = new TwinState(tags, desiredProperties)
InitialTwinState = initialTwinState
};

var enrollmentResult = await this.dps.CreateOrUpdateEnrollmentGroupAsync(enrollmentGroup).ConfigureAwait(false);

return enrollmentResult;
return await this.dps.CreateOrUpdateEnrollmentGroupAsync(enrollmentGroup);
}

/// <summary>
Expand All @@ -66,11 +69,9 @@ public async Task<EnrollmentGroup> CreateEnrollmentGroupAsync(string deviceType)
/// <returns>AttestationMechanism.</returns>
public async Task<Attestation> GetAttestation(string deviceType)
{
var attestationMechanism = deviceType == "LoRa Network Server" ?
await this.dps.GetEnrollmentGroupAttestationAsync(this.config.DPSLoRaEnrollmentGroup) :
await this.dps.GetEnrollmentGroupAttestationAsync(this.config.DPSDefaultEnrollmentGroup);
var attetationMechanism = await this.dps.GetEnrollmentGroupAttestationAsync(ComputeEnrollmentGroupName(deviceType));

return attestationMechanism.GetAttestation();
return attetationMechanism.GetAttestation();
}

private static string GenerateKey()
Expand All @@ -80,5 +81,12 @@ private static string GenerateKey()

return Convert.ToBase64String(rnd);
}

private static string ComputeEnrollmentGroupName(string deviceType)
{
return deviceType.Trim()
.ToLowerInvariant()
.Replace(" ", "-");
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,31 @@ namespace AzureIoTHub.Portal.Server.Managers
{
using System.Threading.Tasks;
using Microsoft.Azure.Devices.Provisioning.Service;
using Microsoft.Azure.Devices.Shared;

public interface IDeviceProvisioningServiceManager
{
/// <summary>
/// Gets the device symmetric key attestation for the enrollment group.
/// </summary>
/// <param name="deviceType">The device type.</param>
/// <returns>The corresponding attestation.</returns>
Task<Attestation> GetAttestation(string deviceType);

/// <summary>
/// Creates the Enrollment group fot the specified device type.
/// </summary>
/// <param name="deviceType">The device type name.</param>
/// <returns>An object representing the corresponding enrollment group.</returns>
Task<EnrollmentGroup> CreateEnrollmentGroupAsync(string deviceType);

/// <summary>
/// Create Enrolllment group for the specified device model.
/// </summary>
/// <param name="modelId">The model identifier.</param>
/// <param name="modelName">The model name.</param>
/// <param name="desiredProperties">The desired properties</param>
/// <returns></returns>
Task<EnrollmentGroup> CreateEnrollmentGroupFormModelAsync(string modelId, string modelName, TwinCollection desiredProperties);
}
}

0 comments on commit cfad9f1

Please sign in to comment.