diff --git a/src/java/org/jivesoftware/multiplexer/net/ConnectionHandler.java b/src/java/org/jivesoftware/multiplexer/net/ConnectionHandler.java index 20f2ea6..9d2e927 100644 --- a/src/java/org/jivesoftware/multiplexer/net/ConnectionHandler.java +++ b/src/java/org/jivesoftware/multiplexer/net/ConnectionHandler.java @@ -14,6 +14,7 @@ import org.apache.mina.common.IdleStatus; import org.apache.mina.common.IoHandlerAdapter; import org.apache.mina.common.IoSession; +import org.apache.mina.filter.codec.ProtocolDecoderException; import org.jivesoftware.multiplexer.Connection; import org.jivesoftware.multiplexer.ConnectionManager; import org.jivesoftware.multiplexer.PacketRouter; @@ -98,6 +99,10 @@ // TODO Verify if there were packets pending to be sent and decide what to do with them Log.debug(cause); } + else if (cause instanceof ProtocolDecoderException) { + Log.warn("Closing session due to exception: " + session, cause); + session.close(); + } else { Log.error(cause); } diff --git a/src/java/org/jivesoftware/multiplexer/net/XMLLightweightParser.java b/src/java/org/jivesoftware/multiplexer/net/XMLLightweightParser.java index 022cb59..e07916f 100644 --- a/src/java/org/jivesoftware/multiplexer/net/XMLLightweightParser.java +++ b/src/java/org/jivesoftware/multiplexer/net/XMLLightweightParser.java @@ -73,6 +73,7 @@ protected StringBuilder head = new StringBuilder(5); // List with all finished messages found. protected List msgs = new ArrayList(); + private int depth = 0; protected boolean insideChildrenTag = false; @@ -141,8 +142,12 @@ int readByte = byteBuffer.remaining(); invalidateBuffer(); + // Check that the buffer is not bigger than 1 Megabyte. For security reasons + // we will abort parsing when 1 Mega of queued chars was found. + if (buffer.length() > 1048576) { + throw new Exception("Stopped parsing never ending stanza"); + } CharBuffer charBuffer = encoder.decode(byteBuffer.buf()); - //charBuffer.flip(); char[] buf = charBuffer.array(); buffer.append(buf); @@ -153,10 +158,10 @@ ch = buf[i]; if (status == XMLLightweightParser.TAIL) { // Looking for the close tag - if (ch == head.charAt(tailCount)) { + if (depth < 1 && ch == head.charAt(tailCount)) { tailCount++; if (tailCount == head.length()) { - // Close tag found! + // Close stanza found! // Calculate the correct start,end position of the message into the buffer int end = buffer.length() - readByte + (i + 1); String msg = buffer.substring(startLastMsg, end); @@ -182,9 +187,16 @@ } if (ch == '/') { status = XMLLightweightParser.TAIL; + depth--; + } + else { + depth++; } } else if (status == XMLLightweightParser.VERIFY_CLOSE_TAG) { if (ch == '>') { + depth--; + } + if (depth < 1) { // Found a tag in the form int end = buffer.length() - readByte + (i + 1); String msg = buffer.substring(startLastMsg, end); @@ -241,7 +253,7 @@ } else if (ch == '<') { status = XMLLightweightParser.PRETAIL; insideChildrenTag = true; - } else if (ch == '/' && insideRootTag && !insideChildrenTag) { + } else if (ch == '/') { status = XMLLightweightParser.VERIFY_CLOSE_TAG; } } else if (status == XMLLightweightParser.HEAD) { @@ -253,14 +265,16 @@ insideChildrenTag = false; continue; } - else if (ch == '/') { + else if (ch == '/' && head.length() > 0) { status = XMLLightweightParser.VERIFY_CLOSE_TAG; + depth--; } head.append(ch); } else if (status == XMLLightweightParser.INIT) { if (ch == '<') { status = XMLLightweightParser.HEAD; + depth = 1; } else { startLastMsg++; diff --git a/test/org/jivesoftware/multiplexer/net/XMLLightweightParserTest.java b/test/org/jivesoftware/multiplexer/net/XMLLightweightParserTest.java index 0204b11..c623640 100644 --- a/test/org/jivesoftware/multiplexer/net/XMLLightweightParserTest.java +++ b/test/org/jivesoftware/multiplexer/net/XMLLightweightParserTest.java @@ -177,6 +177,53 @@ assertEquals(stanza, doc.asXML()); } + public void testNestedElements() throws Exception { + String msg1 = "1"; + //String msg1 = "will update it...12jid1will update it..."; + in.putString(msg1, Charset.forName(CHARSET).newEncoder()); + in.flip(); + // Fill parser with byte buffer content and parse it + parser.read(in); + // Make verifications + assertTrue("Stream header is not being correctly parsed", parser.areThereMsgs()); + String[] values = parser.getMsgs(); + assertEquals("Wrong number of parsed stanzas", 1, values.length); + assertEquals("Wrong stanza was parsed", msg1, values[0]); + } + + public void testIncompleteStanza() throws Exception { + String msg1 = "12"; + in.putString(msg1, Charset.forName(CHARSET).newEncoder()); + in.flip(); + // Fill parser with byte buffer content and parse it + parser.read(in); + // Make verifications + assertFalse("Found messages in incomplete stanza", parser.areThereMsgs()); + } + + public void testCompletedStanza() throws Exception { + String msg1 = "12"; + in.putString(msg1, Charset.forName(CHARSET).newEncoder()); + in.flip(); + // Fill parser with byte buffer content and parse it + parser.read(in); + // Make verifications + assertFalse("Found messages in incomplete stanza", parser.areThereMsgs()); + + String msg2 = ""; + ByteBuffer in2 = ByteBuffer.allocate(4096); + in2.setAutoExpand(true); + in2.putString(msg2, Charset.forName(CHARSET).newEncoder()); + in2.flip(); + // Fill parser with byte buffer content and parse it + parser.read(in2); + in2.release(); + assertTrue("Stream header is not being correctly parsed", parser.areThereMsgs()); + String[] values = parser.getMsgs(); + assertEquals("Wrong number of parsed stanzas", 1, values.length); + assertEquals("Wrong stanza was parsed", msg1 + msg2, values[0]); + } + protected void setUp() throws Exception { super.setUp(); // Create parser