diff --git a/scm-plugins/scm-hg-plugin/src/main/java/sonia/scm/web/HgPermissionFilter.java b/scm-plugins/scm-hg-plugin/src/main/java/sonia/scm/web/HgPermissionFilter.java index 01fb21885e..dde048a746 100644 --- a/scm-plugins/scm-hg-plugin/src/main/java/sonia/scm/web/HgPermissionFilter.java +++ b/scm-plugins/scm-hg-plugin/src/main/java/sonia/scm/web/HgPermissionFilter.java @@ -38,16 +38,17 @@ package sonia.scm.web; import com.google.common.collect.ImmutableSet; import com.google.inject.Inject; import com.google.inject.Singleton; - import sonia.scm.config.ScmConfiguration; +import sonia.scm.repository.HgRepositoryHandler; import sonia.scm.repository.RepositoryProvider; import sonia.scm.web.filter.ProviderPermissionFilter; -//~--- JDK imports ------------------------------------------------------------ - -import java.util.Set; - +import javax.servlet.FilterChain; +import javax.servlet.ServletException; import javax.servlet.http.HttpServletRequest; +import javax.servlet.http.HttpServletResponse; +import java.io.IOException; +import java.util.Set; /** * Permission filter for mercurial repositories. @@ -60,6 +61,8 @@ public class HgPermissionFilter extends ProviderPermissionFilter private static final Set READ_METHODS = ImmutableSet.of("GET", "HEAD", "OPTIONS", "TRACE"); + private final HgRepositoryHandler repositoryHandler; + /** * Constructs a new instance. * @@ -67,17 +70,36 @@ public class HgPermissionFilter extends ProviderPermissionFilter * @param repositoryProvider repository provider */ @Inject - public HgPermissionFilter(ScmConfiguration configuration, - RepositoryProvider repositoryProvider) + public HgPermissionFilter(ScmConfiguration configuration, RepositoryProvider repositoryProvider, HgRepositoryHandler repositoryHandler) { super(configuration, repositoryProvider); + this.repositoryHandler = repositoryHandler; } //~--- get methods ---------------------------------------------------------- + + @Override + protected void doFilter(HttpServletRequest request, HttpServletResponse response, FilterChain chain) throws IOException, ServletException { + HgServletRequest hgRequest = new HgServletRequest(request); + super.doFilter(hgRequest, response, chain); + // TODO closing stream in case of fire? + } + @Override protected boolean isWriteRequest(HttpServletRequest request) { + if (repositoryHandler.getConfig().isEnableHttpPostArgs()) { + return isHttpPostArgsWriteRequest(request); + } + return isDefaultWriteRequest(request); + } + + private boolean isHttpPostArgsWriteRequest(HttpServletRequest request) { + return WireProtocol.isWriteRequest(request); + } + + private boolean isDefaultWriteRequest(HttpServletRequest request) { if (READ_METHODS.contains(request.getMethod())) { return WireProtocol.isWriteRequest(request); } diff --git a/scm-plugins/scm-hg-plugin/src/main/java/sonia/scm/web/HgServletInputStream.java b/scm-plugins/scm-hg-plugin/src/main/java/sonia/scm/web/HgServletInputStream.java new file mode 100644 index 0000000000..b0b2f8ef0d --- /dev/null +++ b/scm-plugins/scm-hg-plugin/src/main/java/sonia/scm/web/HgServletInputStream.java @@ -0,0 +1,55 @@ +package sonia.scm.web; + +import com.google.common.base.Preconditions; + +import javax.servlet.ServletInputStream; +import java.io.ByteArrayInputStream; +import java.io.IOException; + +/** + * HgServletInputStream is a wrapper around the original {@link ServletInputStream} and provides some extra + * functionality to support the mercurial client. + */ +public class HgServletInputStream extends ServletInputStream { + + private final ServletInputStream original; + private ByteArrayInputStream captured; + + HgServletInputStream(ServletInputStream original) { + this.original = original; + } + + /** + * Reads the given amount of bytes from the stream and captures them, if the {@link #read()} methods is called the + * captured bytes are returned before the rest of the stream. + * + * @param size amount of bytes to read + * + * @return byte array + * + * @throws IOException if the method is called twice + */ + public byte[] readAndCapture(int size) throws IOException { + Preconditions.checkState(captured == null, "readAndCapture can only be called once per request"); + + // TODO should we enforce a limit? to prevent OOM? + byte[] bytes = new byte[size]; + original.read(bytes); + captured = new ByteArrayInputStream(bytes); + + return bytes; + } + + @Override + public int read() throws IOException { + if (captured != null && captured.available() > 0) { + return captured.read(); + } + return original.read(); + } + + @Override + public void close() throws IOException { + original.close(); + } +} diff --git a/scm-plugins/scm-hg-plugin/src/main/java/sonia/scm/web/HgServletRequest.java b/scm-plugins/scm-hg-plugin/src/main/java/sonia/scm/web/HgServletRequest.java new file mode 100644 index 0000000000..80251c140a --- /dev/null +++ b/scm-plugins/scm-hg-plugin/src/main/java/sonia/scm/web/HgServletRequest.java @@ -0,0 +1,31 @@ +package sonia.scm.web; + +import javax.servlet.http.HttpServletRequest; +import javax.servlet.http.HttpServletRequestWrapper; +import java.io.IOException; + +/** + * {@link HttpServletRequestWrapper} which adds some functionality in order to support the mercurial client. + */ +public final class HgServletRequest extends HttpServletRequestWrapper { + + private HgServletInputStream hgServletInputStream; + + /** + * Constructs a request object wrapping the given request. + * + * @param request + * @throws IllegalArgumentException if the request is null + */ + public HgServletRequest(HttpServletRequest request) { + super(request); + } + + @Override + public HgServletInputStream getInputStream() throws IOException { + if (hgServletInputStream == null) { + hgServletInputStream = new HgServletInputStream(super.getInputStream()); + } + return hgServletInputStream; + } +} diff --git a/scm-plugins/scm-hg-plugin/src/main/java/sonia/scm/web/WireProtocol.java b/scm-plugins/scm-hg-plugin/src/main/java/sonia/scm/web/WireProtocol.java index bab3083445..8a411ead64 100644 --- a/scm-plugins/scm-hg-plugin/src/main/java/sonia/scm/web/WireProtocol.java +++ b/scm-plugins/scm-hg-plugin/src/main/java/sonia/scm/web/WireProtocol.java @@ -33,14 +33,17 @@ package sonia.scm.web; import com.google.common.annotations.VisibleForTesting; +import com.google.common.base.Charsets; import com.google.common.base.Preconditions; import com.google.common.base.Strings; +import com.google.common.base.Throwables; import com.google.common.collect.*; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import sonia.scm.util.HttpUtil; import javax.servlet.http.HttpServletRequest; +import java.io.IOException; import java.util.*; /** @@ -73,7 +76,10 @@ public final class WireProtocol { * - no command was specified with the request (is required for the hgweb ui) * - the command in the query string was found in the list of read request * - if query string contains the batch command, then all commands specified in X-HgArg headers must be - * in the list of read request + * in the list of read requests + * - in case of enabled HttpPostArgs protocol and query string container the batch command, the header X-HgArgs-Post + * is read and the commands which are specified in the body from 0 to the value of X-HgArgs-Post must be in the list + * of read requests * * @param request http request * @@ -94,16 +100,40 @@ public final class WireProtocol { @VisibleForTesting static List commandsOf(HttpServletRequest request) { List listOfCmds = Lists.newArrayList(); + String cmd = getCommandFromQueryString(request); if (cmd != null) { listOfCmds.add(cmd); if (isBatchCommand(cmd)) { parseHgArgHeaders(request, listOfCmds); + handleHttpPostArgs(request, listOfCmds); } } return Collections.unmodifiableList(listOfCmds); } + private static void handleHttpPostArgs(HttpServletRequest request, List listOfCmds) { + int hgArgsPostSize = request.getIntHeader("X-HgArgs-Post"); + if (hgArgsPostSize > 0) { + + if (request instanceof HgServletRequest) { + HgServletRequest hgRequest = (HgServletRequest) request; + + try { + byte[] bytes = hgRequest.getInputStream().readAndCapture(hgArgsPostSize); + String hgArgs = new String(bytes, Charsets.US_ASCII); + String decoded = decodeValue(hgArgs); + parseHgCommandHeader(listOfCmds, decoded); + } catch (IOException ex) { + throw Throwables.propagate(ex); + } + } else { + throw new IllegalArgumentException("could not process the httppostargs protocol without HgServletRequest"); + } + + } + } + private static void parseHgArgHeaders(HttpServletRequest request, List listOfCmds) { Enumeration headerNames = request.getHeaderNames(); while (headerNames.hasMoreElements()) { @@ -115,9 +145,13 @@ public final class WireProtocol { private static void parseHgArgHeader(HttpServletRequest request, List listOfCmds, String header) { if (isHgArgHeader(header)) { String value = getHeaderDecoded(request, header); - if (isHgArgCommandHeader(value)) { - parseHgCommandHeader(listOfCmds, value); - } + parseHgArgValue(listOfCmds, value); + } + } + + private static void parseHgArgValue(List listOfCmds, String value) { + if (isHgArgCommandHeader(value)) { + parseHgCommandHeader(listOfCmds, value); } } @@ -143,7 +177,11 @@ public final class WireProtocol { } private static String getHeaderDecoded(HttpServletRequest request, String header) { - return HttpUtil.decode(Strings.nullToEmpty(request.getHeader(header))); + return decodeValue(request.getHeader(header)); + } + + private static String decodeValue(String value) { + return HttpUtil.decode(Strings.nullToEmpty(value)); } private static boolean isHgArgHeader(String header) { diff --git a/scm-plugins/scm-hg-plugin/src/test/java/sonia/scm/web/HgPermissionFilterTest.java b/scm-plugins/scm-hg-plugin/src/test/java/sonia/scm/web/HgPermissionFilterTest.java index 8319134078..f8aefb95d1 100644 --- a/scm-plugins/scm-hg-plugin/src/test/java/sonia/scm/web/HgPermissionFilterTest.java +++ b/scm-plugins/scm-hg-plugin/src/test/java/sonia/scm/web/HgPermissionFilterTest.java @@ -31,21 +31,27 @@ package sonia.scm.web; -import javax.servlet.http.HttpServletRequest; +import org.junit.Before; import org.junit.Test; -import static org.junit.Assert.*; - import org.junit.runner.RunWith; import org.mockito.InjectMocks; import org.mockito.Mock; -import static org.mockito.Mockito.*; -import static sonia.scm.web.WireProtocolRequestMockFactory.CMDS_HEADS_KNOWN_NODES; -import static sonia.scm.web.WireProtocolRequestMockFactory.Namespace.*; - import org.mockito.runners.MockitoJUnitRunner; import sonia.scm.config.ScmConfiguration; +import sonia.scm.repository.HgConfig; +import sonia.scm.repository.HgRepositoryHandler; import sonia.scm.repository.RepositoryProvider; +import javax.servlet.http.HttpServletRequest; + +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertTrue; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; +import static sonia.scm.web.WireProtocolRequestMockFactory.CMDS_HEADS_KNOWN_NODES; +import static sonia.scm.web.WireProtocolRequestMockFactory.Namespace.BOOKMARKS; +import static sonia.scm.web.WireProtocolRequestMockFactory.Namespace.PHASES; + /** * Unit tests for {@link HgPermissionFilter}. * @@ -60,11 +66,19 @@ public class HgPermissionFilterTest { @Mock private RepositoryProvider repositoryProvider; + @Mock + private HgRepositoryHandler hgRepositoryHandler; + private WireProtocolRequestMockFactory wireProtocol = new WireProtocolRequestMockFactory("/scm/hg/repo"); @InjectMocks private HgPermissionFilter filter; + @Before + public void setUp() { + when(hgRepositoryHandler.getConfig()).thenReturn(new HgConfig()); + } + /** * Tests {@link HgPermissionFilter#isWriteRequest(HttpServletRequest)}. */ @@ -83,9 +97,27 @@ public class HgPermissionFilterTest { assertTrue(isWriteRequest("KA")); } + /** + * Tests {@link HgPermissionFilter#isWriteRequest(HttpServletRequest)} with enabled httppostargs option. + */ + @Test + public void testIsWriteRequestWithEnabledHttpPostArgs() { + HgConfig config = new HgConfig(); + config.setEnableHttpPostArgs(true); + when(hgRepositoryHandler.getConfig()).thenReturn(config); + + assertFalse(isWriteRequest("POST")); + assertFalse(isWriteRequest("POST", "heads")); + assertTrue(isWriteRequest("POST", "unbundle")); + } + private boolean isWriteRequest(String method) { + return isWriteRequest(method, "capabilities"); + } + + private boolean isWriteRequest(String method, String command) { HttpServletRequest request = mock(HttpServletRequest.class); - when(request.getQueryString()).thenReturn("cmd=capabilities"); + when(request.getQueryString()).thenReturn("cmd=" + command); when(request.getMethod()).thenReturn(method); return filter.isWriteRequest(request); } diff --git a/scm-plugins/scm-hg-plugin/src/test/java/sonia/scm/web/HgServletInputStreamTest.java b/scm-plugins/scm-hg-plugin/src/test/java/sonia/scm/web/HgServletInputStreamTest.java new file mode 100644 index 0000000000..51b0a050fc --- /dev/null +++ b/scm-plugins/scm-hg-plugin/src/test/java/sonia/scm/web/HgServletInputStreamTest.java @@ -0,0 +1,50 @@ +package sonia.scm.web; + +import com.google.common.base.Charsets; +import com.google.common.io.ByteStreams; +import org.junit.Test; + +import javax.servlet.ServletInputStream; +import java.io.ByteArrayInputStream; +import java.io.IOException; + +import static org.junit.Assert.assertEquals; + +public class HgServletInputStreamTest { + + @Test + public void testReadAndCapture() throws IOException { + SampleServletInputStream original = new SampleServletInputStream("trillian.mcmillian@hitchhiker.com"); + HgServletInputStream hgServletInputStream = new HgServletInputStream(original); + + byte[] prefix = hgServletInputStream.readAndCapture(8); + assertEquals("trillian", new String(prefix, Charsets.US_ASCII)); + + byte[] wholeBytes = ByteStreams.toByteArray(hgServletInputStream); + assertEquals("trillian.mcmillian@hitchhiker.com", new String(wholeBytes, Charsets.US_ASCII)); + } + + @Test(expected = IllegalStateException.class) + public void testReadAndCaptureCalledTwice() throws IOException { + SampleServletInputStream original = new SampleServletInputStream("trillian.mcmillian@hitchhiker.com"); + HgServletInputStream hgServletInputStream = new HgServletInputStream(original); + + hgServletInputStream.readAndCapture(1); + hgServletInputStream.readAndCapture(1); + } + + private static class SampleServletInputStream extends ServletInputStream { + + private ByteArrayInputStream input; + + private SampleServletInputStream(String data) { + input = new ByteArrayInputStream(data.getBytes()); + } + + @Override + public int read() { + return input.read(); + } + } + +} diff --git a/scm-plugins/scm-hg-plugin/src/test/java/sonia/scm/web/WireProtocolTest.java b/scm-plugins/scm-hg-plugin/src/test/java/sonia/scm/web/WireProtocolTest.java index 860b6a6392..519dadfd6c 100644 --- a/scm-plugins/scm-hg-plugin/src/test/java/sonia/scm/web/WireProtocolTest.java +++ b/scm-plugins/scm-hg-plugin/src/test/java/sonia/scm/web/WireProtocolTest.java @@ -32,14 +32,17 @@ package sonia.scm.web; +import com.google.common.base.Charsets; import com.google.common.collect.Lists; import org.junit.Test; import org.junit.runner.RunWith; import org.mockito.Mock; import org.mockito.runners.MockitoJUnitRunner; +import javax.servlet.ServletInputStream; import javax.servlet.http.HttpServletRequest; - +import java.io.ByteArrayInputStream; +import java.io.IOException; import java.util.Collections; import java.util.List; @@ -93,6 +96,18 @@ public class WireProtocolTest { expectQueryCommand("unbundle", "prefix=stu==ff&cmd=unbundle"); } + @Test + public void testGetCommandsOfWithHgArgsPost() throws IOException { + when(request.getMethod()).thenReturn("POST"); + when(request.getQueryString()).thenReturn("cmd=batch"); + when(request.getIntHeader("X-HgArgs-Post")).thenReturn(29); + when(request.getHeaderNames()).thenReturn(Collections.enumeration(Lists.newArrayList("X-HgArgs-Post"))); + when(request.getInputStream()).thenReturn(new BufferedServletInputStream("cmds=lheads+%3Bknown+nodes%3D")); + + List commands = WireProtocol.commandsOf(new HgServletRequest(request)); + assertThat(commands, contains("batch", "lheads", "known")); + } + @Test public void testGetCommandsOfWithBatch() { prepareBatch("cmds=heads ;known nodes,ef5993bb4abb32a0565c347844c6d939fc4f4b98"); @@ -159,4 +174,19 @@ public class WireProtocolTest { assertTrue(commands.contains(expected)); } + private static class BufferedServletInputStream extends ServletInputStream { + + private ByteArrayInputStream input; + + BufferedServletInputStream(String content) { + this.input = new ByteArrayInputStream(content.getBytes(Charsets.US_ASCII)); + } + + @Override + public int read() { + return input.read(); + } + + } + }