#970 initial support of mercurials httppostargs protocol

This commit is contained in:
Sebastian Sdorra
2018-03-30 11:20:22 +02:00
parent a34acd8ed4
commit b43e406b76
7 changed files with 279 additions and 21 deletions

View File

@@ -38,16 +38,17 @@ package sonia.scm.web;
import com.google.common.collect.ImmutableSet; import com.google.common.collect.ImmutableSet;
import com.google.inject.Inject; import com.google.inject.Inject;
import com.google.inject.Singleton; import com.google.inject.Singleton;
import sonia.scm.config.ScmConfiguration; import sonia.scm.config.ScmConfiguration;
import sonia.scm.repository.HgRepositoryHandler;
import sonia.scm.repository.RepositoryProvider; import sonia.scm.repository.RepositoryProvider;
import sonia.scm.web.filter.ProviderPermissionFilter; import sonia.scm.web.filter.ProviderPermissionFilter;
//~--- JDK imports ------------------------------------------------------------ import javax.servlet.FilterChain;
import javax.servlet.ServletException;
import java.util.Set;
import javax.servlet.http.HttpServletRequest; import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import java.io.IOException;
import java.util.Set;
/** /**
* Permission filter for mercurial repositories. * Permission filter for mercurial repositories.
@@ -60,6 +61,8 @@ public class HgPermissionFilter extends ProviderPermissionFilter
private static final Set<String> READ_METHODS = ImmutableSet.of("GET", "HEAD", "OPTIONS", "TRACE"); private static final Set<String> READ_METHODS = ImmutableSet.of("GET", "HEAD", "OPTIONS", "TRACE");
private final HgRepositoryHandler repositoryHandler;
/** /**
* Constructs a new instance. * Constructs a new instance.
* *
@@ -67,17 +70,36 @@ public class HgPermissionFilter extends ProviderPermissionFilter
* @param repositoryProvider repository provider * @param repositoryProvider repository provider
*/ */
@Inject @Inject
public HgPermissionFilter(ScmConfiguration configuration, public HgPermissionFilter(ScmConfiguration configuration, RepositoryProvider repositoryProvider, HgRepositoryHandler repositoryHandler)
RepositoryProvider repositoryProvider)
{ {
super(configuration, repositoryProvider); super(configuration, repositoryProvider);
this.repositoryHandler = repositoryHandler;
} }
//~--- get methods ---------------------------------------------------------- //~--- get methods ----------------------------------------------------------
@Override
protected void doFilter(HttpServletRequest request, HttpServletResponse response, FilterChain chain) throws IOException, ServletException {
HgServletRequest hgRequest = new HgServletRequest(request);
super.doFilter(hgRequest, response, chain);
// TODO closing stream in case of fire?
}
@Override @Override
protected boolean isWriteRequest(HttpServletRequest request) protected boolean isWriteRequest(HttpServletRequest request)
{ {
if (repositoryHandler.getConfig().isEnableHttpPostArgs()) {
return isHttpPostArgsWriteRequest(request);
}
return isDefaultWriteRequest(request);
}
private boolean isHttpPostArgsWriteRequest(HttpServletRequest request) {
return WireProtocol.isWriteRequest(request);
}
private boolean isDefaultWriteRequest(HttpServletRequest request) {
if (READ_METHODS.contains(request.getMethod())) { if (READ_METHODS.contains(request.getMethod())) {
return WireProtocol.isWriteRequest(request); return WireProtocol.isWriteRequest(request);
} }

View File

@@ -0,0 +1,55 @@
package sonia.scm.web;
import com.google.common.base.Preconditions;
import javax.servlet.ServletInputStream;
import java.io.ByteArrayInputStream;
import java.io.IOException;
/**
* HgServletInputStream is a wrapper around the original {@link ServletInputStream} and provides some extra
* functionality to support the mercurial client.
*/
public class HgServletInputStream extends ServletInputStream {
private final ServletInputStream original;
private ByteArrayInputStream captured;
HgServletInputStream(ServletInputStream original) {
this.original = original;
}
/**
* Reads the given amount of bytes from the stream and captures them, if the {@link #read()} methods is called the
* captured bytes are returned before the rest of the stream.
*
* @param size amount of bytes to read
*
* @return byte array
*
* @throws IOException if the method is called twice
*/
public byte[] readAndCapture(int size) throws IOException {
Preconditions.checkState(captured == null, "readAndCapture can only be called once per request");
// TODO should we enforce a limit? to prevent OOM?
byte[] bytes = new byte[size];
original.read(bytes);
captured = new ByteArrayInputStream(bytes);
return bytes;
}
@Override
public int read() throws IOException {
if (captured != null && captured.available() > 0) {
return captured.read();
}
return original.read();
}
@Override
public void close() throws IOException {
original.close();
}
}

View File

@@ -0,0 +1,31 @@
package sonia.scm.web;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletRequestWrapper;
import java.io.IOException;
/**
* {@link HttpServletRequestWrapper} which adds some functionality in order to support the mercurial client.
*/
public final class HgServletRequest extends HttpServletRequestWrapper {
private HgServletInputStream hgServletInputStream;
/**
* Constructs a request object wrapping the given request.
*
* @param request
* @throws IllegalArgumentException if the request is null
*/
public HgServletRequest(HttpServletRequest request) {
super(request);
}
@Override
public HgServletInputStream getInputStream() throws IOException {
if (hgServletInputStream == null) {
hgServletInputStream = new HgServletInputStream(super.getInputStream());
}
return hgServletInputStream;
}
}

View File

@@ -33,14 +33,17 @@
package sonia.scm.web; package sonia.scm.web;
import com.google.common.annotations.VisibleForTesting; import com.google.common.annotations.VisibleForTesting;
import com.google.common.base.Charsets;
import com.google.common.base.Preconditions; import com.google.common.base.Preconditions;
import com.google.common.base.Strings; import com.google.common.base.Strings;
import com.google.common.base.Throwables;
import com.google.common.collect.*; import com.google.common.collect.*;
import org.slf4j.Logger; import org.slf4j.Logger;
import org.slf4j.LoggerFactory; import org.slf4j.LoggerFactory;
import sonia.scm.util.HttpUtil; import sonia.scm.util.HttpUtil;
import javax.servlet.http.HttpServletRequest; import javax.servlet.http.HttpServletRequest;
import java.io.IOException;
import java.util.*; import java.util.*;
/** /**
@@ -73,7 +76,10 @@ public final class WireProtocol {
* - no command was specified with the request (is required for the hgweb ui) * - no command was specified with the request (is required for the hgweb ui)
* - the command in the query string was found in the list of read request * - the command in the query string was found in the list of read request
* - if query string contains the batch command, then all commands specified in X-HgArg headers must be * - if query string contains the batch command, then all commands specified in X-HgArg headers must be
* in the list of read request * in the list of read requests
* - in case of enabled HttpPostArgs protocol and query string container the batch command, the header X-HgArgs-Post
* is read and the commands which are specified in the body from 0 to the value of X-HgArgs-Post must be in the list
* of read requests
* *
* @param request http request * @param request http request
* *
@@ -94,16 +100,40 @@ public final class WireProtocol {
@VisibleForTesting @VisibleForTesting
static List<String> commandsOf(HttpServletRequest request) { static List<String> commandsOf(HttpServletRequest request) {
List<String> listOfCmds = Lists.newArrayList(); List<String> listOfCmds = Lists.newArrayList();
String cmd = getCommandFromQueryString(request); String cmd = getCommandFromQueryString(request);
if (cmd != null) { if (cmd != null) {
listOfCmds.add(cmd); listOfCmds.add(cmd);
if (isBatchCommand(cmd)) { if (isBatchCommand(cmd)) {
parseHgArgHeaders(request, listOfCmds); parseHgArgHeaders(request, listOfCmds);
handleHttpPostArgs(request, listOfCmds);
} }
} }
return Collections.unmodifiableList(listOfCmds); return Collections.unmodifiableList(listOfCmds);
} }
private static void handleHttpPostArgs(HttpServletRequest request, List<String> listOfCmds) {
int hgArgsPostSize = request.getIntHeader("X-HgArgs-Post");
if (hgArgsPostSize > 0) {
if (request instanceof HgServletRequest) {
HgServletRequest hgRequest = (HgServletRequest) request;
try {
byte[] bytes = hgRequest.getInputStream().readAndCapture(hgArgsPostSize);
String hgArgs = new String(bytes, Charsets.US_ASCII);
String decoded = decodeValue(hgArgs);
parseHgCommandHeader(listOfCmds, decoded);
} catch (IOException ex) {
throw Throwables.propagate(ex);
}
} else {
throw new IllegalArgumentException("could not process the httppostargs protocol without HgServletRequest");
}
}
}
private static void parseHgArgHeaders(HttpServletRequest request, List<String> listOfCmds) { private static void parseHgArgHeaders(HttpServletRequest request, List<String> listOfCmds) {
Enumeration headerNames = request.getHeaderNames(); Enumeration headerNames = request.getHeaderNames();
while (headerNames.hasMoreElements()) { while (headerNames.hasMoreElements()) {
@@ -115,11 +145,15 @@ public final class WireProtocol {
private static void parseHgArgHeader(HttpServletRequest request, List<String> listOfCmds, String header) { private static void parseHgArgHeader(HttpServletRequest request, List<String> listOfCmds, String header) {
if (isHgArgHeader(header)) { if (isHgArgHeader(header)) {
String value = getHeaderDecoded(request, header); String value = getHeaderDecoded(request, header);
parseHgArgValue(listOfCmds, value);
}
}
private static void parseHgArgValue(List<String> listOfCmds, String value) {
if (isHgArgCommandHeader(value)) { if (isHgArgCommandHeader(value)) {
parseHgCommandHeader(listOfCmds, value); parseHgCommandHeader(listOfCmds, value);
} }
} }
}
private static void parseHgCommandHeader(List<String> listOfCmds, String value) { private static void parseHgCommandHeader(List<String> listOfCmds, String value) {
String[] cmds = value.substring(5).split(";"); String[] cmds = value.substring(5).split(";");
@@ -143,7 +177,11 @@ public final class WireProtocol {
} }
private static String getHeaderDecoded(HttpServletRequest request, String header) { private static String getHeaderDecoded(HttpServletRequest request, String header) {
return HttpUtil.decode(Strings.nullToEmpty(request.getHeader(header))); return decodeValue(request.getHeader(header));
}
private static String decodeValue(String value) {
return HttpUtil.decode(Strings.nullToEmpty(value));
} }
private static boolean isHgArgHeader(String header) { private static boolean isHgArgHeader(String header) {

View File

@@ -31,21 +31,27 @@
package sonia.scm.web; package sonia.scm.web;
import javax.servlet.http.HttpServletRequest; import org.junit.Before;
import org.junit.Test; import org.junit.Test;
import static org.junit.Assert.*;
import org.junit.runner.RunWith; import org.junit.runner.RunWith;
import org.mockito.InjectMocks; import org.mockito.InjectMocks;
import org.mockito.Mock; import org.mockito.Mock;
import static org.mockito.Mockito.*;
import static sonia.scm.web.WireProtocolRequestMockFactory.CMDS_HEADS_KNOWN_NODES;
import static sonia.scm.web.WireProtocolRequestMockFactory.Namespace.*;
import org.mockito.runners.MockitoJUnitRunner; import org.mockito.runners.MockitoJUnitRunner;
import sonia.scm.config.ScmConfiguration; import sonia.scm.config.ScmConfiguration;
import sonia.scm.repository.HgConfig;
import sonia.scm.repository.HgRepositoryHandler;
import sonia.scm.repository.RepositoryProvider; import sonia.scm.repository.RepositoryProvider;
import javax.servlet.http.HttpServletRequest;
import static org.junit.Assert.assertFalse;
import static org.junit.Assert.assertTrue;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;
import static sonia.scm.web.WireProtocolRequestMockFactory.CMDS_HEADS_KNOWN_NODES;
import static sonia.scm.web.WireProtocolRequestMockFactory.Namespace.BOOKMARKS;
import static sonia.scm.web.WireProtocolRequestMockFactory.Namespace.PHASES;
/** /**
* Unit tests for {@link HgPermissionFilter}. * Unit tests for {@link HgPermissionFilter}.
* *
@@ -60,11 +66,19 @@ public class HgPermissionFilterTest {
@Mock @Mock
private RepositoryProvider repositoryProvider; private RepositoryProvider repositoryProvider;
@Mock
private HgRepositoryHandler hgRepositoryHandler;
private WireProtocolRequestMockFactory wireProtocol = new WireProtocolRequestMockFactory("/scm/hg/repo"); private WireProtocolRequestMockFactory wireProtocol = new WireProtocolRequestMockFactory("/scm/hg/repo");
@InjectMocks @InjectMocks
private HgPermissionFilter filter; private HgPermissionFilter filter;
@Before
public void setUp() {
when(hgRepositoryHandler.getConfig()).thenReturn(new HgConfig());
}
/** /**
* Tests {@link HgPermissionFilter#isWriteRequest(HttpServletRequest)}. * Tests {@link HgPermissionFilter#isWriteRequest(HttpServletRequest)}.
*/ */
@@ -83,9 +97,27 @@ public class HgPermissionFilterTest {
assertTrue(isWriteRequest("KA")); assertTrue(isWriteRequest("KA"));
} }
/**
* Tests {@link HgPermissionFilter#isWriteRequest(HttpServletRequest)} with enabled httppostargs option.
*/
@Test
public void testIsWriteRequestWithEnabledHttpPostArgs() {
HgConfig config = new HgConfig();
config.setEnableHttpPostArgs(true);
when(hgRepositoryHandler.getConfig()).thenReturn(config);
assertFalse(isWriteRequest("POST"));
assertFalse(isWriteRequest("POST", "heads"));
assertTrue(isWriteRequest("POST", "unbundle"));
}
private boolean isWriteRequest(String method) { private boolean isWriteRequest(String method) {
return isWriteRequest(method, "capabilities");
}
private boolean isWriteRequest(String method, String command) {
HttpServletRequest request = mock(HttpServletRequest.class); HttpServletRequest request = mock(HttpServletRequest.class);
when(request.getQueryString()).thenReturn("cmd=capabilities"); when(request.getQueryString()).thenReturn("cmd=" + command);
when(request.getMethod()).thenReturn(method); when(request.getMethod()).thenReturn(method);
return filter.isWriteRequest(request); return filter.isWriteRequest(request);
} }

View File

@@ -0,0 +1,50 @@
package sonia.scm.web;
import com.google.common.base.Charsets;
import com.google.common.io.ByteStreams;
import org.junit.Test;
import javax.servlet.ServletInputStream;
import java.io.ByteArrayInputStream;
import java.io.IOException;
import static org.junit.Assert.assertEquals;
public class HgServletInputStreamTest {
@Test
public void testReadAndCapture() throws IOException {
SampleServletInputStream original = new SampleServletInputStream("trillian.mcmillian@hitchhiker.com");
HgServletInputStream hgServletInputStream = new HgServletInputStream(original);
byte[] prefix = hgServletInputStream.readAndCapture(8);
assertEquals("trillian", new String(prefix, Charsets.US_ASCII));
byte[] wholeBytes = ByteStreams.toByteArray(hgServletInputStream);
assertEquals("trillian.mcmillian@hitchhiker.com", new String(wholeBytes, Charsets.US_ASCII));
}
@Test(expected = IllegalStateException.class)
public void testReadAndCaptureCalledTwice() throws IOException {
SampleServletInputStream original = new SampleServletInputStream("trillian.mcmillian@hitchhiker.com");
HgServletInputStream hgServletInputStream = new HgServletInputStream(original);
hgServletInputStream.readAndCapture(1);
hgServletInputStream.readAndCapture(1);
}
private static class SampleServletInputStream extends ServletInputStream {
private ByteArrayInputStream input;
private SampleServletInputStream(String data) {
input = new ByteArrayInputStream(data.getBytes());
}
@Override
public int read() {
return input.read();
}
}
}

View File

@@ -32,14 +32,17 @@
package sonia.scm.web; package sonia.scm.web;
import com.google.common.base.Charsets;
import com.google.common.collect.Lists; import com.google.common.collect.Lists;
import org.junit.Test; import org.junit.Test;
import org.junit.runner.RunWith; import org.junit.runner.RunWith;
import org.mockito.Mock; import org.mockito.Mock;
import org.mockito.runners.MockitoJUnitRunner; import org.mockito.runners.MockitoJUnitRunner;
import javax.servlet.ServletInputStream;
import javax.servlet.http.HttpServletRequest; import javax.servlet.http.HttpServletRequest;
import java.io.ByteArrayInputStream;
import java.io.IOException;
import java.util.Collections; import java.util.Collections;
import java.util.List; import java.util.List;
@@ -93,6 +96,18 @@ public class WireProtocolTest {
expectQueryCommand("unbundle", "prefix=stu==ff&cmd=unbundle"); expectQueryCommand("unbundle", "prefix=stu==ff&cmd=unbundle");
} }
@Test
public void testGetCommandsOfWithHgArgsPost() throws IOException {
when(request.getMethod()).thenReturn("POST");
when(request.getQueryString()).thenReturn("cmd=batch");
when(request.getIntHeader("X-HgArgs-Post")).thenReturn(29);
when(request.getHeaderNames()).thenReturn(Collections.enumeration(Lists.newArrayList("X-HgArgs-Post")));
when(request.getInputStream()).thenReturn(new BufferedServletInputStream("cmds=lheads+%3Bknown+nodes%3D"));
List<String> commands = WireProtocol.commandsOf(new HgServletRequest(request));
assertThat(commands, contains("batch", "lheads", "known"));
}
@Test @Test
public void testGetCommandsOfWithBatch() { public void testGetCommandsOfWithBatch() {
prepareBatch("cmds=heads ;known nodes,ef5993bb4abb32a0565c347844c6d939fc4f4b98"); prepareBatch("cmds=heads ;known nodes,ef5993bb4abb32a0565c347844c6d939fc4f4b98");
@@ -159,4 +174,19 @@ public class WireProtocolTest {
assertTrue(commands.contains(expected)); assertTrue(commands.contains(expected));
} }
private static class BufferedServletInputStream extends ServletInputStream {
private ByteArrayInputStream input;
BufferedServletInputStream(String content) {
this.input = new ByteArrayInputStream(content.getBytes(Charsets.US_ASCII));
}
@Override
public int read() {
return input.read();
}
}
} }