Skip to content
This repository was archived by the owner on Dec 19, 2023. It is now read-only.

Commit b1b2aba

Browse files
committed
feat: improve check origin functionality
1 parent 0ce4732 commit b1b2aba

File tree

5 files changed

+92
-27
lines changed

5 files changed

+92
-27
lines changed

.github/workflows/pull-request.yml

+17-17
Original file line numberDiff line numberDiff line change
@@ -44,16 +44,16 @@ jobs:
4444
with:
4545
distribution: 'temurin'
4646
java-version: ${{ matrix.java }}
47-
# - name: Cache Gradle
48-
# uses: actions/cache@v3
49-
# env:
50-
# java-version: ${{ matrix.java }}
51-
# with:
52-
# path: |
53-
# ~/.gradle/caches
54-
# ~/.gradle/wrapper
55-
# key: ${{ runner.os }}-${{ env.java-version }}-gradle-${{ hashFiles('**/*.gradle*') }}
56-
# restore-keys: ${{ runner.os }}-${{ env.java-version }}-gradle-
47+
- name: Cache Gradle
48+
uses: actions/cache@v3
49+
env:
50+
java-version: ${{ matrix.java }}
51+
with:
52+
path: |
53+
~/.gradle/caches
54+
~/.gradle/wrapper
55+
key: ${{ runner.os }}-${{ env.java-version }}-gradle-${{ hashFiles('**/*.gradle*') }}
56+
restore-keys: ${{ runner.os }}-${{ env.java-version }}-gradle-
5757
- name: Make gradlew executable (non-Windows only)
5858
if: matrix.os != 'windows-latest'
5959
run: chmod +x ./gradlew
@@ -94,13 +94,13 @@ jobs:
9494
path: ~/.sonar/cache
9595
key: ${{ runner.os }}-sonar
9696
restore-keys: ${{ runner.os }}-sonar
97-
# - name: Cache Gradle packages
98-
# if: env.SONAR_TOKEN != null
99-
# uses: actions/cache@v3
100-
# with:
101-
# path: ~/.gradle/caches
102-
# key: ${{ runner.os }}-gradle-${{ hashFiles('**/*.gradle') }}
103-
# restore-keys: ${{ runner.os }}-gradle
97+
- name: Cache Gradle packages
98+
if: env.SONAR_TOKEN != null
99+
uses: actions/cache@v3
100+
with:
101+
path: ~/.gradle/caches
102+
key: ${{ runner.os }}-gradle-${{ hashFiles('**/*.gradle') }}
103+
restore-keys: ${{ runner.os }}-gradle
104104
- name: Build and analyze
105105
if: env.SONAR_TOKEN != null
106106
env:

graphql-spring-boot-autoconfigure/src/main/java/graphql/kickstart/autoconfigure/web/servlet/GraphQLWebsocketAutoConfiguration.java

+1-2
Original file line numberDiff line numberDiff line change
@@ -66,8 +66,7 @@ public GraphQLWebsocketServlet graphQLWebsocketServlet(
6666
graphQLInvoker,
6767
invocationInputFactory,
6868
graphQLObjectMapper,
69-
listeners,
70-
websocketProperties.getAllowedOrigins());
69+
listeners);
7170
}
7271

7372
private Optional<SubscriptionConnectionListener> keepAliveListener() {

graphql-spring-boot-autoconfigure/src/main/java/graphql/kickstart/autoconfigure/web/servlet/GraphQLWsServerEndpointRegistration.java

+23-2
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
package graphql.kickstart.autoconfigure.web.servlet;
22

33
import graphql.kickstart.servlet.GraphQLWebsocketServlet;
4+
import java.util.ArrayList;
5+
import java.util.List;
46
import jakarta.websocket.HandshakeResponse;
57
import jakarta.websocket.server.HandshakeRequest;
68
import jakarta.websocket.server.ServerEndpointConfig;
@@ -13,19 +15,38 @@
1315
public class GraphQLWsServerEndpointRegistration extends ServerEndpointRegistration
1416
implements Lifecycle {
1517

18+
private static final String ALL = "*";
1619
private final GraphQLWebsocketServlet servlet;
1720
private final WsCsrfFilter csrfFilter;
1821

1922
public GraphQLWsServerEndpointRegistration(
20-
String path, GraphQLWebsocketServlet servlet, WsCsrfFilter csrfFilter) {
23+
String path, GraphQLWebsocketServlet servlet, WsCsrfFilter csrfFilter, List<String> allowedOrigins) {
2124
super(path, servlet);
2225
this.servlet = servlet;
26+
if (allowedOrigins == null || allowedOrigins.isEmpty()) {
27+
this.allowedOrigins = List.of(ALL);
28+
} else {
29+
this.allowedOrigins = new ArrayList<>(allowedOrigins);
30+
}
2331
this.csrfFilter = csrfFilter;
2432
}
2533

2634
@Override
2735
public boolean checkOrigin(String originHeaderValue) {
28-
return servlet.checkOrigin(originHeaderValue);
36+
if (originHeaderValue == null || originHeaderValue.isBlank()) {
37+
return allowedOrigins.contains(ALL);
38+
}
39+
if (allowedOrigins.contains(ALL)) {
40+
return true;
41+
}
42+
String originToCheck = trimTrailingSlash(originHeaderValue);
43+
return allowedOrigins.stream()
44+
.map(this::trimTrailingSlash)
45+
.anyMatch(originToCheck::equalsIgnoreCase);
46+
}
47+
48+
private String trimTrailingSlash(String origin) {
49+
return (origin.endsWith("/") ? origin.substring(0, origin.length() - 1) : origin);
2950
}
3051

3152
@Override
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
package graphql.kickstart.autoconfigure.web.servlet;
2+
3+
import static org.assertj.core.api.AssertionsForClassTypes.assertThat;
4+
5+
import graphql.kickstart.servlet.GraphQLWebsocketServlet;
6+
import java.util.List;
7+
import org.junit.jupiter.api.extension.ExtendWith;
8+
import org.junit.jupiter.params.ParameterizedTest;
9+
import org.junit.jupiter.params.provider.CsvSource;
10+
import org.mockito.Mock;
11+
import org.mockito.junit.jupiter.MockitoExtension;
12+
13+
@ExtendWith(MockitoExtension.class)
14+
class GraphQLWsServerEndpointRegistrationTest {
15+
16+
private static final String PATH = "/subscriptions";
17+
18+
@Mock private GraphQLWebsocketServlet servlet;
19+
20+
@ParameterizedTest
21+
@CsvSource(value = {"https://trusted.com", "NULL", "' '"}, nullValues = {"NULL"})
22+
void givenDefaultAllowedOrigins_whenCheckOrigin_thenReturnTrue(String origin) {
23+
var registration = createRegistration();
24+
var allowed = registration.checkOrigin("null".equals(origin) ? null : origin);
25+
assertThat(allowed).isTrue();
26+
}
27+
28+
private GraphQLWsServerEndpointRegistration createRegistration(String... allowedOrigins) {
29+
return new GraphQLWsServerEndpointRegistration(PATH, servlet, List.of(allowedOrigins));
30+
}
31+
32+
@ParameterizedTest(name = "{index} => allowedOrigin=''{0}'', originToCheck=''{1}''")
33+
@CsvSource(delimiterString = "|", textBlock = """
34+
* | https://trusted.com
35+
https://trusted.com | https://trusted.com
36+
https://trusted.com/ | https://trusted.com
37+
https://trusted.com/ | https://trusted.com/
38+
https://trusted.com | https://trusted.com/
39+
""")
40+
void givenAllowedOrigins_whenCheckOrigin_thenReturnTrue(String allowedOrigin, String originToCheck) {
41+
var registration = createRegistration(allowedOrigin);
42+
var allowed = registration.checkOrigin(originToCheck);
43+
assertThat(allowed).isTrue();
44+
}
45+
}

graphql-spring-boot-test/build.gradle

+6-6
Original file line numberDiff line numberDiff line change
@@ -17,13 +17,13 @@
1717
* FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
1818
*/
1919
dependencies {
20-
implementation("org.springframework:spring-web")
21-
implementation("org.springframework.boot:spring-boot-starter-test")
22-
implementation("com.fasterxml.jackson.core:jackson-databind")
23-
implementation("com.jayway.jsonpath:json-path")
20+
implementation "org.springframework:spring-web"
21+
implementation "org.springframework.boot:spring-boot-starter-test"
22+
implementation "com.fasterxml.jackson.core:jackson-databind"
23+
implementation "com.jayway.jsonpath:json-path"
2424
implementation "org.awaitility:awaitility:$LIB_AWAITILITY_VER"
25-
compileOnly("com.graphql-java-kickstart:graphql-java-servlet:$LIB_GRAPHQL_SERVLET_VER")
26-
testImplementation("org.springframework.boot:spring-boot-starter-web")
25+
compileOnly "com.graphql-java-kickstart:graphql-java-servlet:$LIB_GRAPHQL_SERVLET_VER"
26+
testImplementation "org.springframework.boot:spring-boot-starter-web"
2727
testImplementation "org.springframework.boot:spring-boot-starter-websocket"
2828
testImplementation project(":graphql-spring-boot-starter")
2929
testImplementation "io.reactivex.rxjava2:rxjava:2.2.21"

0 commit comments

Comments
 (0)