Skip to content

Add Stream support to JdbcAggregateOperations #1963

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

import java.util.List;
import java.util.Optional;
import java.util.stream.Stream;

import org.springframework.dao.IncorrectUpdateSemanticsDataAccessException;
import org.springframework.data.domain.Example;
Expand All @@ -35,6 +36,7 @@
* @author Chirag Tailor
* @author Diego Krupitza
* @author Myeonghyeon Lee
* @author Sergey Korotaev
*/
public interface JdbcAggregateOperations {

Expand Down Expand Up @@ -165,6 +167,17 @@ public interface JdbcAggregateOperations {
*/
<T> List<T> findAllById(Iterable<?> ids, Class<T> domainType);

/**
* Loads all entities that match one of the ids passed as an argument to a {@link Stream}.
* It is not guaranteed that the number of ids passed in matches the number of entities returned.
*
* @param ids the Ids of the entities to load. Must not be {@code null}.
* @param domainType the type of entities to load. Must not be {@code null}.
* @param <T> type of entities to load.
* @return the loaded entities. Guaranteed to be not {@code null}.
*/
<T> Stream<T> streamAllByIds(Iterable<?> ids, Class<T> domainType);

/**
* Load all aggregates of a given type.
*
Expand All @@ -174,6 +187,15 @@ public interface JdbcAggregateOperations {
*/
<T> List<T> findAll(Class<T> domainType);

/**
* Load all aggregates of a given type to a {@link Stream}.
*
* @param domainType the type of the aggregate roots. Must not be {@code null}.
* @param <T> the type of the aggregate roots. Must not be {@code null}.
* @return Guaranteed to be not {@code null}.
*/
<T> Stream<T> streamAll(Class<T> domainType);

/**
* Load all aggregates of a given type, sorted.
*
Expand All @@ -185,6 +207,17 @@ public interface JdbcAggregateOperations {
*/
<T> List<T> findAll(Class<T> domainType, Sort sort);

/**
* Loads all entities of the given type to a {@link Stream}, sorted.
*
* @param domainType the type of entities to load. Must not be {@code null}.
* @param <T> the type of entities to load.
* @param sort the sorting information. Must not be {@code null}.
* @return Guaranteed to be not {@code null}.
* @since 2.0
*/
<T> Stream<T> streamAll(Class<T> domainType, Sort sort);

/**
* Load a page of (potentially sorted) aggregates of a given type.
*
Expand Down Expand Up @@ -218,6 +251,17 @@ public interface JdbcAggregateOperations {
*/
<T> List<T> findAll(Query query, Class<T> domainType);

/**
* Execute a {@code SELECT} query and convert the resulting items to a {@link Stream}.
*
* @param query must not be {@literal null}.
* @param domainType the type of entities. Must not be {@code null}.
* @return a non-null list with all the matching results.
* @throws org.springframework.dao.IncorrectResultSizeDataAccessException if more than one match found.
* @since 3.0
*/
<T> Stream<T> streamAll(Query query, Class<T> domainType);

/**
* Returns a {@link Page} of entities matching the given {@link Query}. In case no match could be found, an empty
* {@link Page} is returned.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
import java.util.Optional;
import java.util.function.Function;
import java.util.stream.Collectors;
import java.util.stream.Stream;
import java.util.stream.StreamSupport;

import org.springframework.context.ApplicationContext;
Expand Down Expand Up @@ -68,6 +69,7 @@
* @author Myeonghyeon Lee
* @author Chirag Tailor
* @author Diego Krupitza
* @author Sergey Korotaev
*/
public class JdbcAggregateTemplate implements JdbcAggregateOperations {

Expand Down Expand Up @@ -283,6 +285,16 @@ public <T> List<T> findAll(Class<T> domainType, Sort sort) {
return triggerAfterConvert(all);
}

@Override
public <T> Stream<T> streamAll(Class<T> domainType, Sort sort) {

Assert.notNull(domainType, "Domain type must not be null");

Stream<T> allStreamable = accessStrategy.streamAll(domainType, sort);

return allStreamable.map(this::triggerAfterConvert);
}

@Override
public <T> Page<T> findAll(Class<T> domainType, Pageable pageable) {

Expand All @@ -309,6 +321,11 @@ public <T> List<T> findAll(Query query, Class<T> domainType) {
return Streamable.of(all).toList();
}

@Override
public <T> Stream<T> streamAll(Query query, Class<T> domainType) {
return accessStrategy.streamAll(query, domainType).map(this::triggerAfterConvert);
}

@Override
public <T> Page<T> findAll(Query query, Class<T> domainType, Pageable pageable) {

Expand All @@ -327,6 +344,12 @@ public <T> List<T> findAll(Class<T> domainType) {
return triggerAfterConvert(all);
}

@Override
public <T> Stream<T> streamAll(Class<T> domainType) {
Iterable<T> items = triggerAfterConvert(accessStrategy.findAll(domainType));
return StreamSupport.stream(items.spliterator(), false).map(this::triggerAfterConvert);
}

@Override
public <T> List<T> findAllById(Iterable<?> ids, Class<T> domainType) {

Expand All @@ -337,6 +360,17 @@ public <T> List<T> findAllById(Iterable<?> ids, Class<T> domainType) {
return triggerAfterConvert(allById);
}

@Override
public <T> Stream<T> streamAllByIds(Iterable<?> ids, Class<T> domainType) {

Assert.notNull(ids, "Ids must not be null");
Assert.notNull(domainType, "Domain type must not be null");

Stream<T> allByIdStreamable = accessStrategy.streamAllByIds(ids, domainType);

return allByIdStreamable.map(this::triggerAfterConvert);
}

@Override
public <S> void delete(S aggregateRoot) {

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import java.util.Optional;
import java.util.function.Consumer;
import java.util.function.Function;
import java.util.stream.Stream;

import org.springframework.data.domain.Pageable;
import org.springframework.data.domain.Sort;
Expand All @@ -42,6 +43,7 @@
* @author Myeonghyeon Lee
* @author Chirag Tailor
* @author Diego Krupitza
* @author Sergey Korotaev
* @since 1.1
*/
public class CascadingDataAccessStrategy implements DataAccessStrategy {
Expand Down Expand Up @@ -132,11 +134,21 @@ public <T> Iterable<T> findAll(Class<T> domainType) {
return collect(das -> das.findAll(domainType));
}

@Override
public <T> Stream<T> streamAll(Class<T> domainType) {
return collect(das -> das.streamAll(domainType));
}

@Override
public <T> Iterable<T> findAllById(Iterable<?> ids, Class<T> domainType) {
return collect(das -> das.findAllById(ids, domainType));
}

@Override
public <T> Stream<T> streamAllByIds(Iterable<?> ids, Class<T> domainType) {
return collect(das -> das.streamAllByIds(ids, domainType));
}

@Override
public Iterable<Object> findAllByPath(Identifier identifier,
PersistentPropertyPath<? extends RelationalPersistentProperty> path) {
Expand All @@ -153,6 +165,11 @@ public <T> Iterable<T> findAll(Class<T> domainType, Sort sort) {
return collect(das -> das.findAll(domainType, sort));
}

@Override
public <T> Stream<T> streamAll(Class<T> domainType, Sort sort) {
return collect(das -> das.streamAll(domainType, sort));
}

@Override
public <T> Iterable<T> findAll(Class<T> domainType, Pageable pageable) {
return collect(das -> das.findAll(domainType, pageable));
Expand All @@ -168,6 +185,11 @@ public <T> Iterable<T> findAll(Query query, Class<T> domainType) {
return collect(das -> das.findAll(query, domainType));
}

@Override
public <T> Stream<T> streamAll(Query query, Class<T> domainType) {
return collect(das -> das.streamAll(query, domainType));
}

@Override
public <T> Iterable<T> findAll(Query query, Class<T> domainType, Pageable pageable) {
return collect(das -> das.findAll(query, domainType, pageable));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.stream.Stream;

import org.springframework.dao.OptimisticLockingFailureException;
import org.springframework.data.domain.Pageable;
Expand All @@ -41,6 +42,7 @@
* @author Myeonghyeon Lee
* @author Chirag Tailor
* @author Diego Krupitza
* @author Sergey Korotaev
*/
public interface DataAccessStrategy extends ReadingDataAccessStrategy, RelationResolver {

Expand Down Expand Up @@ -252,6 +254,16 @@ public interface DataAccessStrategy extends ReadingDataAccessStrategy, RelationR
@Override
<T> Iterable<T> findAll(Class<T> domainType);

/**
* Loads all entities of the given type to a {@link Stream}.
*
* @param domainType the type of entities to load. Must not be {@code null}.
* @param <T> the type of entities to load.
* @return Guaranteed to be not {@code null}.
*/
@Override
<T> Stream<T> streamAll(Class<T> domainType);

/**
* Loads all entities that match one of the ids passed as an argument. It is not guaranteed that the number of ids
* passed in matches the number of entities returned.
Expand All @@ -264,6 +276,18 @@ public interface DataAccessStrategy extends ReadingDataAccessStrategy, RelationR
@Override
<T> Iterable<T> findAllById(Iterable<?> ids, Class<T> domainType);

/**
* Loads all entities that match one of the ids passed as an argument to a {@link Stream}.
* It is not guaranteed that the number of ids passed in matches the number of entities returned.
*
* @param ids the Ids of the entities to load. Must not be {@code null}.
* @param domainType the type of entities to load. Must not be {@code null}.
* @param <T> type of entities to load.
* @return the loaded entities. Guaranteed to be not {@code null}.
*/
@Override
<T> Stream<T> streamAllByIds(Iterable<?> ids, Class<T> domainType);

@Override
Iterable<Object> findAllByPath(Identifier identifier,
PersistentPropertyPath<? extends RelationalPersistentProperty> path);
Expand All @@ -280,6 +304,18 @@ Iterable<Object> findAllByPath(Identifier identifier,
@Override
<T> Iterable<T> findAll(Class<T> domainType, Sort sort);

/**
* Loads all entities of the given type to a {@link Stream}, sorted.
*
* @param domainType the type of entities to load. Must not be {@code null}.
* @param <T> the type of entities to load.
* @param sort the sorting information. Must not be {@code null}.
* @return Guaranteed to be not {@code null}.
* @since 2.0
*/
@Override
<T> Stream<T> streamAll(Class<T> domainType, Sort sort);

/**
* Loads all entities of the given type, paged and sorted.
*
Expand Down Expand Up @@ -316,6 +352,18 @@ Iterable<Object> findAllByPath(Identifier identifier,
@Override
<T> Iterable<T> findAll(Query query, Class<T> domainType);

/**
* Execute a {@code SELECT} query and convert the resulting items to a {@link Stream}.
*
* @param query must not be {@literal null}.
* @param domainType the type of entities. Must not be {@code null}.
* @return a non-null list with all the matching results.
* @throws org.springframework.dao.IncorrectResultSizeDataAccessException if more than one match found.
* @since 3.0
*/
@Override
<T> Stream<T> streamAll(Query query, Class<T> domainType);

/**
* Execute a {@code SELECT} query and convert the resulting items to a {@link Iterable}. Applies the {@link Pageable}
* to the result.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import java.util.Collections;
import java.util.List;
import java.util.Optional;
import java.util.stream.Stream;

import org.springframework.dao.EmptyResultDataAccessException;
import org.springframework.dao.OptimisticLockingFailureException;
Expand Down Expand Up @@ -60,6 +61,7 @@
* @author Radim Tlusty
* @author Chirag Tailor
* @author Diego Krupitza
* @author Sergey Korotaev
* @since 1.1
*/
public class DefaultDataAccessStrategy implements DataAccessStrategy {
Expand Down Expand Up @@ -276,6 +278,11 @@ public <T> List<T> findAll(Class<T> domainType) {
return operations.query(sql(domainType).getFindAll(), getEntityRowMapper(domainType));
}

@Override
public <T> Stream<T> streamAll(Class<T> domainType) {
return operations.queryForStream(sql(domainType).getFindAll(), new MapSqlParameterSource(), getEntityRowMapper(domainType));
}

@Override
public <T> List<T> findAllById(Iterable<?> ids, Class<T> domainType) {

Expand All @@ -288,6 +295,19 @@ public <T> List<T> findAllById(Iterable<?> ids, Class<T> domainType) {
return operations.query(findAllInListSql, parameterSource, getEntityRowMapper(domainType));
}

@Override
public <T> Stream<T> streamAllByIds(Iterable<?> ids, Class<T> domainType) {

if (!ids.iterator().hasNext()) {
return Stream.empty();
}

SqlParameterSource parameterSource = sqlParametersFactory.forQueryByIds(ids, domainType);
String findAllInListSql = sql(domainType).getFindAllInList();

return operations.queryForStream(findAllInListSql, parameterSource, getEntityRowMapper(domainType));
}

@Override
@SuppressWarnings("unchecked")
public List<Object> findAllByPath(Identifier identifier,
Expand Down Expand Up @@ -342,6 +362,11 @@ public <T> List<T> findAll(Class<T> domainType, Sort sort) {
return operations.query(sql(domainType).getFindAll(sort), getEntityRowMapper(domainType));
}

@Override
public <T> Stream<T> streamAll(Class<T> domainType, Sort sort) {
return operations.queryForStream(sql(domainType).getFindAll(sort), new MapSqlParameterSource(), getEntityRowMapper(domainType));
}

@Override
public <T> List<T> findAll(Class<T> domainType, Pageable pageable) {
return operations.query(sql(domainType).getFindAll(pageable), getEntityRowMapper(domainType));
Expand Down Expand Up @@ -369,6 +394,15 @@ public <T> List<T> findAll(Query query, Class<T> domainType) {
return operations.query(sqlQuery, parameterSource, getEntityRowMapper(domainType));
}

@Override
public <T> Stream<T> streamAll(Query query, Class<T> domainType) {

MapSqlParameterSource parameterSource = new MapSqlParameterSource();
String sqlQuery = sql(domainType).selectByQuery(query, parameterSource);

return operations.queryForStream(sqlQuery, parameterSource, getEntityRowMapper(domainType));
}

@Override
public <T> List<T> findAll(Query query, Class<T> domainType, Pageable pageable) {

Expand Down
Loading