Netty私有協議棧設計
消息定義
消息頭
消息主體
圖示:
圖1.png
Header:
public class Header {
private int crcCode = 0xadaf0105; // 唯一的通信標志
private int length; // 總消息的長度 header + body
private long sessionID; // 會話ID
private byte type; // 消息的類型
private byte priority; // 消息的優先級 0~255
private Map<String, Object> attachment = new HashMap<String, Object>(); // 附件
// ...
}
NettyMessage
public class NettyMessage {
private Header header;
private Object body;
public final Header getHeader() {
return header;
}
public final void setHeader(Header header) {
this.header = header;
}
public final Object getBody() {
return body;
}
public final void setBody(Object body) {
this.body = body;
}
/*
* (non-Javadoc)
*
* @see java.lang.Object#toString()
*/
public String toString() {
return "NettyMessage [header=" + header + "]";
}
}
編解碼設計
選擇Marshaller作為Java對象序列化和反序列化的工具
MarshallingCodeCFactory工廠生成具體對象
public class MarshallingCodeCFactory {
public static Marshaller buildMarshalling() throws IOException {
//首先通過Marshalling工具類的精通方法獲取Marshalling實例對象 參數serial標識創建的是java序列化工廠對象。
final MarshallerFactory marshallerFactory = Marshalling.getProvidedMarshallerFactory("serial");
//創建了MarshallingConfiguration對象,配置了版本號為5
final MarshallingConfiguration configuration = new MarshallingConfiguration();
configuration.setVersion(5);
Marshaller marshaller = marshallerFactory.createMarshaller(configuration);
return marshaller;
}
public static Unmarshaller buildUnMarshalling() throws IOException {
final MarshallerFactory marshallerFactory = Marshalling.getProvidedMarshallerFactory("serial");
final MarshallingConfiguration configuration = new MarshallingConfiguration();
configuration.setVersion(5);
Unmarshaller unmarshaller = marshallerFactory.createUnmarshaller(configuration);
return unmarshaller;
}
}
輔助Marshaller工作的兩個類:
public class ChannelBufferByteOutput implements ByteOutput {
private final ByteBuf buffer;
/**
* Create a new instance which use the given {@link ByteBuf}
*/
public ChannelBufferByteOutput(ByteBuf buffer) {
this.buffer = buffer;
}
@Override
public void close() throws IOException {
// Nothing to do
}
@Override
public void flush() throws IOException {
// nothing to do
}
@Override
public void write(int b) throws IOException {
buffer.writeByte(b);
}
@Override
public void write(byte[] bytes) throws IOException {
buffer.writeBytes(bytes);
}
@Override
public void write(byte[] bytes, int srcIndex, int length) throws IOException {
buffer.writeBytes(bytes, srcIndex, length);
}
/**
* Return the {@link ByteBuf} which contains the written content
*
*/
ByteBuf getBuffer() {
return buffer;
}
}
public class ChannelBufferByteInput implements ByteInput {
private final ByteBuf byteBuf;
public ChannelBufferByteInput(ByteBuf byteBuf) {
this.byteBuf = byteBuf;
}
@Override
public int read() throws IOException {
if (byteBuf.isReadable()) {
return byteBuf.readByte() & 0xff;
}
return -1;
}
@Override
public int read(byte[] bytes) throws IOException {
return read(bytes, 0 , bytes.length);
}
@Override
public int read(byte[] dst, int dstIndex, int length) throws IOException {
int available = available();
if (available == 0) {
return -1;
}
length = Math.min(available, length);
byteBuf.readBytes(dst, dstIndex, length);
return length;
}
@Override
public int available() throws IOException {
return byteBuf.readableBytes();
}
@Override
public long skip(long bytes) throws IOException {
int readable = byteBuf.readableBytes();
if (readable < bytes) {
bytes = readable;
}
byteBuf.readerIndex((int) (byteBuf.readerIndex() + bytes));
return bytes;
}
@Override
public void close() throws IOException {
}
}
編碼器
處理流程:
圖2.png
MarshallingEncoder:
public class MarshallingEncoder {
//空白占位: 用于預留設置 body的數據包長度
private static final byte[] LENGTH_PLACEHOLDER = new byte[4];
private Marshaller marshaller;
public MarshallingEncoder() throws IOException {
this.marshaller = MarshallingCodeCFactory.buildMarshalling();
}
public void encode(Object body, ByteBuf out) throws IOException {
try {
//必須要知道當前的數據位置是哪: 起始數據位置
//長度屬性的位置索引
int lengthPos = out.writerIndex();
//占位寫操作:先寫一個4個字節的空的內容,記錄在起始數據位置,用于設置內容長度
out.writeBytes(LENGTH_PLACEHOLDER);
ChannelBufferByteOutput output = new ChannelBufferByteOutput(out);
marshaller.start(output);
marshaller.writeObject(body);
marshaller.finish();
//總長度(結束位置) - 初始化長度(起始位置) - 預留的長度 = body數據長度
int endPos = out.writerIndex();
out.setInt(lengthPos, endPos - lengthPos - 4);
} finally {
marshaller.close();
}
}
}
NettyMessageEncoder:
public class NettyMessageEncoder extends MessageToByteEncoder<NettyMessage> {
private MarshallingEncoder marshallingEncoder;
public NettyMessageEncoder() throws IOException {
this.marshallingEncoder = new MarshallingEncoder();
}
@Override
protected void encode(ChannelHandlerContext ctx, NettyMessage message, ByteBuf sendBuf) throws Exception {
if(message == null || message.getHeader() == null){
throw new Exception("編碼失敗,沒有數據信息!");
}
//Head:
Header header = message.getHeader();
sendBuf.writeInt(header.getCrcCode());//校驗碼
sendBuf.writeInt(header.getLength());//總長度
sendBuf.writeLong(header.getSessionID());//會話id
sendBuf.writeByte(header.getType());//消息類型
sendBuf.writeByte(header.getPriority());//優先級
//對附件信息進行編碼
//編碼規則為:如果attachment的長度為0,表示沒有可選附件,則將長度 編碼設置為0
//如果attachment長度大于0,則需要編碼,規則:
//首先對附件的個數進行編碼
sendBuf.writeInt((header.getAttachment().size())); //附件大小
String key = null;
byte[] keyArray = null;
Object value = null;
//然后對key進行編碼,先編碼長度,然后再將它轉化為byte數組之后編碼內容
for (Map.Entry<String, Object> param : header.getAttachment()
.entrySet()) {
key = param.getKey();
keyArray = key.getBytes("UTF-8");
sendBuf.writeInt(keyArray.length);//key的字符編碼長度
sendBuf.writeBytes(keyArray);
value = param.getValue();
marshallingEncoder.encode(value, sendBuf);
}
key = null;
keyArray = null;
value = null;
//Body:
Object body = message.getBody();
//如果不為空 說明: 有數據
if(body != null){
//使用MarshallingEncoder
this.marshallingEncoder.encode(body, sendBuf);
} else {
//如果沒有數據 則進行補位 為了方便后續的 decoder操作
sendBuf.writeInt(0);
}
//最后我們要獲取整個數據包的總長度 也就是 header + body 進行對 header length的設置
// TODO: 解釋: 在這里必須要-8個字節 ,是因為要把CRC和長度本身占的減掉了
//(官方中給出的是:LengthFieldBasedFrameDecoder中的lengthFieldOffset+lengthFieldLength)
//總長度是在header協議的第二個標記字段中
//第一個參數是長度屬性的索引位置
sendBuf.setInt(4, sendBuf.readableBytes() - 8);
}
}
解碼器
圖3.png
MarshallingDecoder
public class MarshallingDecoder {
private Unmarshaller unmarshaller;
public MarshallingDecoder() throws IOException {
this.unmarshaller = MarshallingCodeCFactory.buildUnMarshalling();
}
public Object decode(ByteBuf in) throws Exception {
try {
//1 首先讀取4個長度(實際body內容長度)
int bodySize = in.readInt();
//2 獲取實際body的緩沖內容
int readIndex = in.readerIndex();
ByteBuf buf = in.slice(readIndex, bodySize);
//3 轉換
ChannelBufferByteInput input = new ChannelBufferByteInput(buf);
//4 讀取操作:
this.unmarshaller.start(input);
Object ret = this.unmarshaller.readObject();
this.unmarshaller.finish();
//5 讀取完畢以后, 更新當前讀取起始位置:
//因為使用slice方法,原buf的位置還在readIndex上,故需要將位置重新設置一下
in.readerIndex(in.readerIndex() + bodySize);
return ret;
} finally {
this.unmarshaller.close();
}
}
}
NettyMessageDecoder
public class NettyMessageDecoder extends LengthFieldBasedFrameDecoder {
private MarshallingDecoder marshallingDecoder;
/**
* 那減8應該是因為要把CRC和長度本身占的減掉了。
* @param maxFrameLength 第一個參數代表最大的序列化長度 1024*1024*5
* @param lengthFieldOffset 代表長度屬性的偏移量 簡單來說就是message中 總長度的起始位置(Header中的length屬性的起始位置) 本例中為4
* @param lengthFieldLength 代表長度屬性的長度 整個屬性占多長(length屬性為int,占4個字節) 4
* @throws IOException
*/
public NettyMessageDecoder(int maxFrameLength, int lengthFieldOffset, int lengthFieldLength) throws IOException {
super(maxFrameLength, lengthFieldOffset, lengthFieldLength);
this.marshallingDecoder = new MarshallingDecoder();
}
@Override
protected Object decode(ChannelHandlerContext ctx, ByteBuf in) throws Exception {
//1 調用父類(LengthFieldBasedFrameDecoder)方法:
ByteBuf frame = (ByteBuf)super.decode(ctx, in);
if(frame == null){
return null;
}
NettyMessage message = new NettyMessage();
Header header = new Header();
header.setCrcCode(frame.readInt()); //crcCode ----> 添加通信標記認證邏輯
header.setLength(frame.readInt()); //length
header.setSessionID(frame.readLong()); //sessionID
header.setType(frame.readByte()); //type
header.setPriority(frame.readByte()); //priority
int size = frame.readInt();
//附件個數大于0,則需要解碼操作
if (size > 0) {
Map<String, Object> attch = new HashMap<String, Object>(size);
int keySize = 0;
byte[] keyArray = null;
String key = null;
for (int i = 0; i < size; i++) {
keySize = frame.readInt();
keyArray = new byte[keySize];
frame.readBytes(keyArray);
key = new String(keyArray, "UTF-8");
attch.put(key, marshallingDecoder.decode(frame));
}
keyArray = null;
key = null;
//解碼完成放入attachment
header.setAttachment(attch);
}
message.setHeader(header);
//對于ByteBuf來說,讀一個數據,就會少一個數據,所以讀完header,剩下的應該就是body了
if(frame.readableBytes() > 4) { //大于4個字節,肯定就有數據了(4個字節是內容長度的占位)
message.setBody(marshallingDecoder.decode(frame));
}
return message;
}
}
握手消息請求的發送以及處理
圖示:
圖4.png
握手請求:
LoginAuthReqHandler
public class LoginAuthReqHandler extends ChannelInboundHandlerAdapter {
private static final Logger LOGGER = LoggerFactory.getLogger(LoginAuthReqHandler.class);
@Override
public void channelActive(ChannelHandlerContext ctx) throws Exception {
LOGGER.info("通道激活,握手請求認證..................");
ctx.writeAndFlush(buildLoginReq());
}
@Override
public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception {
NettyMessage message = (NettyMessage) msg;
if (message.getHeader() != null && message.getHeader().getType() == MessageType.LOGIN_RESP.value()) {
byte loginResult = (byte) message.getBody();
if (loginResult != ResultType.SUCCESS.value()) {
ctx.close();
} else {
System.out.println("Login is OK : " + message);
ctx.fireChannelRead(msg);
}
} else {
ctx.fireChannelRead(msg);
}
}
private NettyMessage buildLoginReq() {
NettyMessage message = new NettyMessage();
Header header = new Header();
header.setType(MessageType.LOGIN_REQ.value());
message.setHeader(header);
return message;
}
@Override
public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws Exception {
ctx.fireExceptionCaught(cause);
}
}
服務端處理:
LoginAuthRespHandler
public class LoginAuthRespHandler extends ChannelInboundHandlerAdapter {
private static final Logger LOGGER = LoggerFactory.getLogger(LoginAuthRespHandler.class);
/**
* 考慮到安全,鏈路的建立需要通過基于IP地址或者號段的黑白名單安全認證機制,本例中,多個IP通過逗號隔開
*/
private Map<String, Boolean> nodeCheck = new ConcurrentHashMap<String, Boolean>();
private String[] whitekList = { "127.0.0.1", "192.168.56.1" };
@Override
public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception {
NettyMessage message = (NettyMessage) msg;
// 判斷消息是否為握手請求消息
if (message.getHeader() != null && message.getHeader().getType()
== MessageType.LOGIN_REQ.value()) {
String nodeIndex = ctx.channel().remoteAddress().toString();
NettyMessage loginResp = null;
if (nodeCheck.containsKey(nodeIndex)) {
LOGGER.error("重復登錄,拒絕請求!");
loginResp = buildResponse(ResultType.FAIL);
} else {
InetSocketAddress address = (InetSocketAddress) ctx.channel().remoteAddress();
String ip = address.getAddress().getHostAddress();
boolean isOK = false;
for (String WIP : whitekList) {
if (WIP.equals(ip)) {
isOK = true;
break;
}
}
loginResp = isOK ? buildResponse(ResultType.SUCCESS) : buildResponse(ResultType.FAIL);
if (isOK)
nodeCheck.put(nodeIndex, true);
}
LOGGER.info("The login response is : {} body [{}]",loginResp,loginResp.getBody());
ctx.writeAndFlush(loginResp);
} else {
ctx.fireChannelRead(msg);
}
}
/**
* 服務端接到客戶端的握手請求消息后,如果IP校驗通過,返回握手成功應答消息給客戶端,應用層成功建立鏈路,否則返回驗證失敗信息。消息格式如下:
* 1.消息頭的type為4
* 2.可選附件個數為0
* 3.消息體為byte類型的結果,0表示認證成功,1表示認證失敗
*/
private NettyMessage buildResponse(ResultType result) {
NettyMessage message = new NettyMessage();
Header header = new Header();
header.setType(MessageType.LOGIN_RESP.value());
message.setHeader(header);
message.setBody(result.value());
return message;
}
@Override
public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws Exception {
cause.printStackTrace();
nodeCheck.remove(ctx.channel().remoteAddress().toString());// 刪除緩存
ctx.close();
ctx.fireExceptionCaught(cause); }
}
心跳檢測
圖示:
圖5.png
HeartBeatReqHandler
客戶端發送:
public class HeartBeatReqHandler extends ChannelInboundHandlerAdapter {
private static final Logger LOGGER = LoggerFactory.getLogger(HeartBeatReqHandler.class);
private volatile ScheduledFuture<?> heartBeat;
@Override
public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception {
NettyMessage message = (NettyMessage) msg;
// 握手成功,主動發送心跳消息
if (message.getHeader() != null && message.getHeader().getType() == MessageType.LOGIN_RESP.value()) {
heartBeat = ctx.executor().scheduleAtFixedRate(new HeartBeatReqHandler.HeartBeatTask(ctx), 0, 5000,
TimeUnit.MILLISECONDS);
} else if (message.getHeader() != null && message.getHeader().getType() == MessageType.HEARTBEAT_RESP.value()) {
LOGGER.info("Client receive server heart beat message : ---> {}", message);
} else
ctx.fireChannelRead(msg);
}
private class HeartBeatTask implements Runnable {
private final ChannelHandlerContext ctx;
public HeartBeatTask(final ChannelHandlerContext ctx) {
this.ctx = ctx;
}
@Override
public void run() {
NettyMessage heatBeat = buildHeatBeat();
LOGGER.info("Client send heart beat messsage to server : ---> {}", heatBeat);
ctx.writeAndFlush(heatBeat);
}
private NettyMessage buildHeatBeat() {
NettyMessage message = new NettyMessage();
Header header = new Header();
header.setType(MessageType.HEARTBEAT_REQ.value());
message.setHeader(header);
return message;
}
}
@Override
public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws Exception {
cause.printStackTrace();
//斷連期間,心跳定時器停止工作,不再發送心跳請求信息
if (heartBeat != null) {
heartBeat.cancel(true);
heartBeat = null;
}
ctx.fireExceptionCaught(cause); }
}
服務端處理:
public class HeartBeatRespHandler extends ChannelInboundHandlerAdapter {
private static final Logger LOGGER = LoggerFactory.getLogger(HeartBeatRespHandler.class);
@Override
public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception {
NettyMessage message = (NettyMessage) msg;
// 判斷是否 是心跳檢測消息
if (message.getHeader() != null && message.getHeader().getType() ==
MessageType.HEARTBEAT_REQ.value()) {
LOGGER.info("Receive client heart beat message : ---> {} " ,message);
NettyMessage heartBeat = buildHeatBeat();
LOGGER.info("Send heart beat response message to client : ---> {}" ,heartBeat);
ctx.writeAndFlush(heartBeat);
} else {
ctx.fireChannelRead(msg);
}
}
// 生成心跳檢測消息
private NettyMessage buildHeatBeat() {
NettyMessage message = new NettyMessage();
Header header = new Header();
header.setType(MessageType.HEARTBEAT_RESP.value());
message.setHeader(header);
return message;
}
}