Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[tuya] Improve Netty handlers #564

Merged
merged 2 commits into from
Jan 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@
import io.netty.channel.socket.SocketChannel;
import io.netty.channel.socket.nio.NioSocketChannel;
import io.netty.handler.timeout.IdleStateHandler;
import io.netty.util.AttributeKey;

/**
* The {@link TuyaDevice} handles the device connection
Expand All @@ -57,22 +58,27 @@
*/
@NonNullByDefault
public class TuyaDevice implements ChannelFutureListener {
public static final AttributeKey<String> DEVICE_ID_ATTR = AttributeKey.valueOf("deviceId");
public static final AttributeKey<ProtocolVersion> PROTOCOL_ATTR = AttributeKey.valueOf("protocol");
public static final AttributeKey<byte[]> SESSION_RANDOM_ATTR = AttributeKey.valueOf("sessionRandom");
public static final AttributeKey<byte[]> SESSION_KEY_ATTR = AttributeKey.valueOf("sessionKey");

private final Logger logger = LoggerFactory.getLogger(TuyaDevice.class);

private final Bootstrap bootstrap = new Bootstrap();
private final DeviceStatusListener deviceStatusListener;
private final String deviceId;
private final byte[] deviceKey;

private final String address;
private final ProtocolVersion protocolVersion;
private final KeyStore keyStore;
private @Nullable Channel channel;

public TuyaDevice(Gson gson, DeviceStatusListener deviceStatusListener, EventLoopGroup eventLoopGroup,
String deviceId, byte[] deviceKey, String address, String protocolVersion) {
this.address = address;
this.deviceId = deviceId;
this.keyStore = new KeyStore(deviceKey);
this.deviceKey = deviceKey;
this.deviceStatusListener = deviceStatusListener;
this.protocolVersion = ProtocolVersion.fromString(protocolVersion);
bootstrap.group(eventLoopGroup).channel(NioSocketChannel.class);
Expand All @@ -83,20 +89,17 @@ protected void initChannel(SocketChannel ch) throws Exception {
ChannelPipeline pipeline = ch.pipeline();
pipeline.addLast("idleStateHandler",
new IdleStateHandler(TCP_CONNECTION_TIMEOUT, TCP_CONNECTION_HEARTBEAT_INTERVAL, 0));
pipeline.addLast("messageEncoder",
new TuyaEncoder(gson, deviceId, keyStore, TuyaDevice.this.protocolVersion));
pipeline.addLast("messageDecoder",
new TuyaDecoder(gson, deviceId, keyStore, TuyaDevice.this.protocolVersion));
pipeline.addLast("heartbeatHandler", new HeartbeatHandler(deviceId));
pipeline.addLast("deviceHandler", new TuyaMessageHandler(deviceId, keyStore, deviceStatusListener));
pipeline.addLast("userEventHandler", new UserEventHandler(deviceId));
pipeline.addLast("messageEncoder", new TuyaEncoder(gson));
pipeline.addLast("messageDecoder", new TuyaDecoder(gson));
pipeline.addLast("heartbeatHandler", new HeartbeatHandler());
pipeline.addLast("deviceHandler", new TuyaMessageHandler(deviceStatusListener));
pipeline.addLast("userEventHandler", new UserEventHandler());
}
});
connect();
}

public void connect() {
keyStore.reset(); // reset session key
bootstrap.connect(address, 6668).addListener(this);
}

Expand Down Expand Up @@ -147,12 +150,22 @@ public void dispose() {
public void operationComplete(@NonNullByDefault({}) ChannelFuture channelFuture) throws Exception {
if (channelFuture.isSuccess()) {
Channel channel = channelFuture.channel();
this.channel = channel;
channel.attr(DEVICE_ID_ATTR).set(deviceId);
channel.attr(PROTOCOL_ATTR).set(protocolVersion);
// session key is device key before negotiation
channel.attr(SESSION_KEY_ATTR).set(deviceKey);

if (protocolVersion == V3_4) {
byte[] sessionRandom = CryptoUtil.generateRandom(16);
channel.attr(SESSION_RANDOM_ATTR).set(sessionRandom);
this.channel = channel;

// handshake for session key required
MessageWrapper<?> m = new MessageWrapper<>(SESS_KEY_NEG_START, keyStore.getRandom());
MessageWrapper<?> m = new MessageWrapper<>(SESS_KEY_NEG_START, sessionRandom);
channel.writeAndFlush(m);
} else {
this.channel = channel;

// no handshake for 3.1/3.3
requestStatus();
}
Expand All @@ -164,37 +177,4 @@ public void operationComplete(@NonNullByDefault({}) ChannelFuture channelFuture)
deviceStatusListener.connectionStatus(false);
}
}

public static class KeyStore {
private final byte[] deviceKey;
private byte[] sessionKey;
private byte[] random;

public KeyStore(byte[] deviceKey) {
this.deviceKey = deviceKey;
this.sessionKey = deviceKey;
this.random = CryptoUtil.generateRandom(16).clone();
}

public void reset() {
this.sessionKey = this.deviceKey;
this.random = CryptoUtil.generateRandom(16).clone();
}

public byte[] getDeviceKey() {
return sessionKey;
}

public byte[] getSessionKey() {
return sessionKey;
}

public void setSessionKey(byte[] sessionKey) {
this.sessionKey = sessionKey;
}

public byte[] getRandom() {
return random;
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -74,19 +74,25 @@ private void activate() {
protected void initChannel(DatagramChannel ch) throws Exception {
ChannelPipeline pipeline = ch.pipeline();
pipeline.addLast("udpDecoder", new DatagramToByteBufDecoder());
pipeline.addLast("messageDecoder", new TuyaDecoder(gson, "udpListener",
new TuyaDevice.KeyStore(TUYA_UDP_KEY), ProtocolVersion.V3_1));
pipeline.addLast("messageDecoder", new TuyaDecoder(gson));
pipeline.addLast("discoveryHandler",
new DiscoveryMessageHandler(deviceInfos, deviceListeners));
pipeline.addLast("userEventHandler", new UserEventHandler("udpListener"));
pipeline.addLast("userEventHandler", new UserEventHandler());
}
});

ChannelFuture futureEncrypted = b.bind(6667).addListener(this).sync();
encryptedChannel = futureEncrypted.channel();
encryptedChannel.attr(TuyaDevice.DEVICE_ID_ATTR).set("udpListener");
encryptedChannel.attr(TuyaDevice.PROTOCOL_ATTR).set(ProtocolVersion.V3_1);
encryptedChannel.attr(TuyaDevice.SESSION_KEY_ATTR).set(TUYA_UDP_KEY);

ChannelFuture futureRaw = b.bind(6666).addListener(this).sync();
rawChannel = futureRaw.channel();
rawChannel.attr(TuyaDevice.DEVICE_ID_ATTR).set("udpListener");
rawChannel.attr(TuyaDevice.PROTOCOL_ATTR).set(ProtocolVersion.V3_1);
rawChannel.attr(TuyaDevice.SESSION_KEY_ATTR).set(TUYA_UDP_KEY);

} catch (InterruptedException e) {
throw new IllegalStateException(e);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import org.slf4j.LoggerFactory;
import org.smarthomej.binding.tuya.internal.local.CommandType;
import org.smarthomej.binding.tuya.internal.local.MessageWrapper;
import org.smarthomej.binding.tuya.internal.local.TuyaDevice;

import io.netty.channel.ChannelDuplexHandler;
import io.netty.channel.ChannelHandlerContext;
Expand All @@ -37,18 +38,19 @@
@NonNullByDefault
public class HeartbeatHandler extends ChannelDuplexHandler {
private final Logger logger = LoggerFactory.getLogger(HeartbeatHandler.class);
private final String deviceId;
private int heartBeatMissed = 0;

public HeartbeatHandler(String deviceId) {
this.deviceId = deviceId;
}

@Override
public void userEventTriggered(@NonNullByDefault({}) ChannelHandlerContext ctx, @NonNullByDefault({}) Object evt)
throws Exception {
if (evt instanceof IdleStateEvent) {
IdleStateEvent e = (IdleStateEvent) evt;
if (!ctx.channel().hasAttr(TuyaDevice.DEVICE_ID_ATTR)) {
logger.warn("{}: Failed to retrieve deviceId from ChannelHandlerContext. This is a bug.",
Objects.requireNonNullElse(ctx.channel().remoteAddress(), ""));
return;
}
String deviceId = ctx.channel().attr(TuyaDevice.DEVICE_ID_ATTR).get();

if (evt instanceof IdleStateEvent e) {
if (IdleState.READER_IDLE.equals(e.state())) {
logger.warn("{}{}: Did not receive a message from for {} seconds. Connection seems to be dead.",
deviceId, Objects.requireNonNullElse(ctx.channel().remoteAddress(), ""),
Expand All @@ -75,8 +77,14 @@ public void userEventTriggered(@NonNullByDefault({}) ChannelHandlerContext ctx,
@Override
public void channelRead(@NonNullByDefault({}) ChannelHandlerContext ctx, @NonNullByDefault({}) Object msg)
throws Exception {
if (msg instanceof MessageWrapper<?>) {
MessageWrapper<?> m = (MessageWrapper<?>) msg;
if (!ctx.channel().hasAttr(TuyaDevice.DEVICE_ID_ATTR)) {
logger.warn("{}: Failed to retrieve deviceId from ChannelHandlerContext. This is a bug.",
Objects.requireNonNullElse(ctx.channel().remoteAddress(), ""));
return;
}
String deviceId = ctx.channel().attr(TuyaDevice.DEVICE_ID_ATTR).get();

if (msg instanceof MessageWrapper<?> m) {
if (CommandType.HEART_BEAT.equals(m.commandType)) {
logger.trace("{}{}: Received pong", deviceId,
Objects.requireNonNullElse(ctx.channel().remoteAddress(), ""));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import static org.smarthomej.binding.tuya.internal.local.CommandType.UDP_NEW;
import static org.smarthomej.binding.tuya.internal.local.ProtocolVersion.V3_3;
import static org.smarthomej.binding.tuya.internal.local.ProtocolVersion.V3_4;
import static org.smarthomej.binding.tuya.internal.local.TuyaDevice.*;

import java.nio.ByteBuffer;
import java.util.Arrays;
Expand All @@ -35,7 +36,6 @@
import org.smarthomej.binding.tuya.internal.local.CommandType;
import org.smarthomej.binding.tuya.internal.local.MessageWrapper;
import org.smarthomej.binding.tuya.internal.local.ProtocolVersion;
import org.smarthomej.binding.tuya.internal.local.TuyaDevice;
import org.smarthomej.binding.tuya.internal.local.dto.DiscoveryMessage;
import org.smarthomej.binding.tuya.internal.local.dto.TcpStatusPayload;
import org.smarthomej.binding.tuya.internal.util.CryptoUtil;
Expand All @@ -58,16 +58,10 @@
public class TuyaDecoder extends ByteToMessageDecoder {
private final Logger logger = LoggerFactory.getLogger(TuyaDecoder.class);

private final TuyaDevice.KeyStore keyStore;
private final ProtocolVersion version;
private final Gson gson;
private final String deviceId;

public TuyaDecoder(Gson gson, String deviceId, TuyaDevice.KeyStore keyStore, ProtocolVersion version) {
public TuyaDecoder(Gson gson) {
this.gson = gson;
this.keyStore = keyStore;
this.version = version;
this.deviceId = deviceId;
}

@Override
Expand All @@ -78,6 +72,17 @@ public void decode(@NonNullByDefault({}) ChannelHandlerContext ctx, @NonNullByDe
return;
}

if (!ctx.channel().hasAttr(DEVICE_ID_ATTR) || !ctx.channel().hasAttr(PROTOCOL_ATTR)
|| !ctx.channel().hasAttr(SESSION_KEY_ATTR)) {
logger.warn(
"{}: Failed to retrieve deviceId, protocol or sessionKey from ChannelHandlerContext. This is a bug.",
Objects.requireNonNullElse(ctx.channel().remoteAddress(), ""));
return;
}
String deviceId = ctx.channel().attr(DEVICE_ID_ATTR).get();
ProtocolVersion protocol = ctx.channel().attr(PROTOCOL_ATTR).get();
byte[] sessionKey = ctx.channel().attr(SESSION_KEY_ATTR).get();

// we need to take a copy first so the buffer stays intact if we exit early
ByteBuf inCopy = in.copy();
byte[] bytes = new byte[inCopy.readableBytes()];
Expand Down Expand Up @@ -111,20 +116,20 @@ public void decode(@NonNullByDefault({}) ChannelHandlerContext ctx, @NonNullByDe
if ((returnCode & 0xffffff00) != 0) {
// rewind if no return code is present
buffer.position(buffer.position() - 4);
payload = version == V3_4 ? new byte[payloadLength - 32] : new byte[payloadLength - 8];
payload = protocol == V3_4 ? new byte[payloadLength - 32] : new byte[payloadLength - 8];
} else {
payload = version == V3_4 ? new byte[payloadLength - 32 - 8] : new byte[payloadLength - 8 - 4];
payload = protocol == V3_4 ? new byte[payloadLength - 32 - 8] : new byte[payloadLength - 8 - 4];
}

buffer.get(payload);

if (version == V3_4 && commandType != UDP && commandType != UDP_NEW) {
if (protocol == V3_4 && commandType != UDP && commandType != UDP_NEW) {
byte[] fullMessage = new byte[buffer.position()];
buffer.position(0);
buffer.get(fullMessage);
byte[] expectedHmac = new byte[32];
buffer.get(expectedHmac);
byte[] calculatedHmac = CryptoUtil.hmac(fullMessage, keyStore.getSessionKey());
byte[] calculatedHmac = CryptoUtil.hmac(fullMessage, sessionKey);
if (!Arrays.equals(expectedHmac, calculatedHmac)) {
logger.warn("{}{}: Checksum failed for message: calculated {}, found {}", deviceId,
Objects.requireNonNullElse(ctx.channel().remoteAddress(), ""),
Expand All @@ -150,8 +155,8 @@ public void decode(@NonNullByDefault({}) ChannelHandlerContext ctx, @NonNullByDe
return;
}

if (Arrays.equals(Arrays.copyOfRange(payload, 0, version.getBytes().length), version.getBytes())) {
if (version == V3_3) {
if (Arrays.equals(Arrays.copyOfRange(payload, 0, protocol.getBytes().length), protocol.getBytes())) {
if (protocol == V3_3) {
// Remove 3.3 header
payload = Arrays.copyOfRange(payload, 15, payload.length);
} else {
Expand All @@ -165,13 +170,14 @@ public void decode(@NonNullByDefault({}) ChannelHandlerContext ctx, @NonNullByDe
m = new MessageWrapper<>(commandType,
Objects.requireNonNull(gson.fromJson(new String(payload), DiscoveryMessage.class)));
} else {
byte[] decodedMessage = version == V3_4 ? CryptoUtil.decryptAesEcb(payload, keyStore.getSessionKey(), true)
: CryptoUtil.decryptAesEcb(payload, keyStore.getDeviceKey(), false);
byte[] decodedMessage = protocol == V3_4 ? CryptoUtil.decryptAesEcb(payload, sessionKey, true)
: CryptoUtil.decryptAesEcb(payload, sessionKey, false);
if (decodedMessage == null) {
return;
}
if (Arrays.equals(Arrays.copyOfRange(decodedMessage, 0, version.getBytes().length), version.getBytes())) {
if (version == V3_4) {

if (Arrays.equals(Arrays.copyOfRange(decodedMessage, 0, protocol.getBytes().length), protocol.getBytes())) {
if (protocol == V3_4) {
// Remove 3.4 header
decodedMessage = Arrays.copyOfRange(decodedMessage, 15, decodedMessage.length);
}
Expand Down
Loading
Loading