Skip to content

Commit 0fcf496

Browse files
committed
Fix id setting for partial updates of collections of immutables.
We gather immutable entities of which the id has changed, in order to set them as values in the parent entity. We now also gather unchanged entities. So they get set with the changed one in the parent. Closes #1907
1 parent 4ebf325 commit 0fcf496

File tree

3 files changed

+72
-75
lines changed

3 files changed

+72
-75
lines changed

spring-data-jdbc/src/main/java/org/springframework/data/jdbc/core/JdbcAggregateChangeExecutionContext.java

+43-20
Original file line numberDiff line numberDiff line change
@@ -15,16 +15,7 @@
1515
*/
1616
package org.springframework.data.jdbc.core;
1717

18-
import java.util.ArrayList;
19-
import java.util.Arrays;
20-
import java.util.Collections;
21-
import java.util.HashMap;
22-
import java.util.HashSet;
23-
import java.util.LinkedHashMap;
24-
import java.util.List;
25-
import java.util.Map;
26-
import java.util.Optional;
27-
import java.util.Set;
18+
import java.util.*;
2819
import java.util.function.BiConsumer;
2920
import java.util.stream.Collectors;
3021

@@ -241,7 +232,7 @@ private Object getIdFrom(DbAction.WithEntity<?> idOwningAction) {
241232
RelationalPersistentEntity<?> persistentEntity = getRequiredPersistentEntity(idOwningAction.getEntityType());
242233
Object identifier = persistentEntity.getIdentifierAccessor(idOwningAction.getEntity()).getIdentifier();
243234

244-
Assert.state(identifier != null,() -> "Couldn't obtain a required id value for " + persistentEntity);
235+
Assert.state(identifier != null, () -> "Couldn't obtain a required id value for " + persistentEntity);
245236

246237
return identifier;
247238
}
@@ -268,12 +259,22 @@ <T> List<T> populateIdsIfNecessary() {
268259
}
269260

270261
// the id property was immutable, so we have to propagate changes up the tree
271-
if (newEntity != action.getEntity() && action instanceof DbAction.Insert<?> insert) {
262+
if (action instanceof DbAction.Insert<?> insert) {
272263

273264
Pair<?, ?> qualifier = insert.getQualifier();
265+
Object qualifierValue = qualifier == null ? null : qualifier.getSecond();
274266

275-
cascadingValues.stage(insert.getDependingOn(), insert.getPropertyPath(),
276-
qualifier == null ? null : qualifier.getSecond(), newEntity);
267+
if (newEntity != action.getEntity()) {
268+
269+
cascadingValues.stage(insert.getDependingOn(), insert.getPropertyPath(),
270+
qualifierValue, newEntity);
271+
272+
} else if (insert.getPropertyPath().getLeafProperty().isCollectionLike()) {
273+
274+
cascadingValues.gather(insert.getDependingOn(), insert.getPropertyPath(),
275+
qualifierValue, newEntity);
276+
277+
}
277278
}
278279
}
279280

@@ -360,7 +361,7 @@ private static class StagedValues {
360361
static final List<MultiValueAggregator> aggregators = Arrays.asList(SetAggregator.INSTANCE, MapAggregator.INSTANCE,
361362
ListAggregator.INSTANCE, SingleElementAggregator.INSTANCE);
362363

363-
Map<DbAction, Map<PersistentPropertyPath, Object>> values = new HashMap<>();
364+
Map<DbAction, Map<PersistentPropertyPath, StagedValue>> values = new HashMap<>();
364365

365366
/**
366367
* Adds a value that needs to be set in an entity higher up in the tree of entities in the aggregate. If the
@@ -375,18 +376,26 @@ private static class StagedValues {
375376
*/
376377
@SuppressWarnings("unchecked")
377378
<T> void stage(DbAction<?> action, PersistentPropertyPath path, @Nullable Object qualifier, Object value) {
379+
gather(action, path, qualifier, value);
380+
values.get(action).get(path).isStaged = true;
381+
}
382+
383+
<T> void gather(DbAction<?> action, PersistentPropertyPath path, @Nullable Object qualifier, Object value) {
378384

379385
MultiValueAggregator<T> aggregator = getAggregatorFor(path);
380386

381-
Map<PersistentPropertyPath, Object> valuesForPath = this.values.computeIfAbsent(action,
387+
Map<PersistentPropertyPath, StagedValue> valuesForPath = this.values.computeIfAbsent(action,
382388
dbAction -> new HashMap<>());
383389

384-
T currentValue = (T) valuesForPath.computeIfAbsent(path,
385-
persistentPropertyPath -> aggregator.createEmptyInstance());
390+
StagedValue stagedValue = valuesForPath.computeIfAbsent(path,
391+
persistentPropertyPath -> new StagedValue(aggregator.createEmptyInstance()));
392+
T currentValue = (T) stagedValue.value;
386393

387394
Object newValue = aggregator.add(currentValue, qualifier, value);
388395

389-
valuesForPath.put(path, newValue);
396+
stagedValue.value = newValue;
397+
398+
valuesForPath.put(path, stagedValue);
390399
}
391400

392401
private MultiValueAggregator getAggregatorFor(PersistentPropertyPath path) {
@@ -408,7 +417,21 @@ private MultiValueAggregator getAggregatorFor(PersistentPropertyPath path) {
408417
* property.
409418
*/
410419
void forEachPath(DbAction<?> dbAction, BiConsumer<PersistentPropertyPath, Object> action) {
411-
values.getOrDefault(dbAction, Collections.emptyMap()).forEach(action);
420+
values.getOrDefault(dbAction, Collections.emptyMap()).forEach((persistentPropertyPath, stagedValue) -> {
421+
if (stagedValue.isStaged) {
422+
action.accept(persistentPropertyPath, stagedValue.value);
423+
}
424+
});
425+
}
426+
427+
}
428+
429+
private static class StagedValue {
430+
Object value;
431+
boolean isStaged;
432+
433+
public StagedValue(Object value) {
434+
this.value = value;
412435
}
413436
}
414437

spring-data-jdbc/src/test/java/org/springframework/data/jdbc/repository/JdbcRepositoryWithListsIntegrationTests.java

+22-49
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
import org.springframework.context.annotation.Configuration;
3333
import org.springframework.context.annotation.Import;
3434
import org.springframework.data.annotation.Id;
35+
import org.springframework.data.annotation.PersistenceCreator;
3536
import org.springframework.data.jdbc.repository.support.JdbcRepositoryFactory;
3637
import org.springframework.data.jdbc.testing.EnabledOnFeature;
3738
import org.springframework.data.jdbc.testing.IntegrationTest;
@@ -55,8 +56,7 @@ public class JdbcRepositoryWithListsIntegrationTests {
5556

5657
private static DummyEntity createDummyEntity() {
5758

58-
DummyEntity entity = new DummyEntity();
59-
entity.setName("Entity Name");
59+
DummyEntity entity = new DummyEntity(null, "Entity Name", new ArrayList<>());
6060
return entity;
6161
}
6262

@@ -94,7 +94,7 @@ public void saveAndLoadNonEmptyList() {
9494
assertThat(reloaded.content) //
9595
.isNotNull() //
9696
.extracting(e -> e.id) //
97-
.containsExactlyInAnyOrder(element1.id, element2.id);
97+
.containsExactlyInAnyOrder(entity.content.get(0).id, entity.content.get(1).id);
9898
}
9999

100100
@Test // GH-1159
@@ -147,24 +147,25 @@ public void findAllLoadsList() {
147147
@EnabledOnFeature(SUPPORTS_GENERATED_IDS_IN_REFERENCED_ENTITIES)
148148
public void updateList() {
149149

150-
Element element1 = createElement("one");
151-
Element element2 = createElement("two");
152-
Element element3 = createElement("three");
150+
Element element1 = new Element("one");
151+
Element element2 = new Element("two");
152+
Element element3 = new Element("three");
153153

154154
DummyEntity entity = createDummyEntity();
155155
entity.content.add(element1);
156156
entity.content.add(element2);
157157

158158
entity = repository.save(entity);
159159

160-
entity.content.remove(element1);
161-
element2.content = "two changed";
160+
entity.content.remove(0);
161+
entity.content.set(0, new Element(entity.content.get(0).id, "two changed"));
162162
entity.content.add(element3);
163163

164164
entity = repository.save(entity);
165165

166166
assertThat(entity.id).isNotNull();
167167
assertThat(entity.content).allMatch(v -> v.id != null);
168+
assertThat(entity.content).hasSize(2);
168169

169170
DummyEntity reloaded = repository.findById(entity.id).orElseThrow(AssertionFailedError::new);
170171

@@ -175,8 +176,8 @@ public void updateList() {
175176
assertThat(reloaded.content) //
176177
.extracting(e -> e.id, e -> e.content) //
177178
.containsExactly( //
178-
tuple(element2.id, "two changed"), //
179-
tuple(element3.id, "three") //
179+
tuple(entity.content.get(0).id, "two changed"), //
180+
tuple(entity.content.get(1).id, "three") //
180181
);
181182

182183
Long count = template.queryForObject("SELECT count(1) FROM Element", new HashMap<>(), Long.class);
@@ -186,8 +187,8 @@ public void updateList() {
186187
@Test // DATAJDBC-130
187188
public void deletingWithList() {
188189

189-
Element element1 = createElement("one");
190-
Element element2 = createElement("two");
190+
Element element1 = new Element("one");
191+
Element element2 = new Element("two");
191192

192193
DummyEntity entity = createDummyEntity();
193194
entity.content.add(element1);
@@ -203,13 +204,6 @@ public void deletingWithList() {
203204
assertThat(count).isEqualTo(0);
204205
}
205206

206-
private Element createElement(String content) {
207-
208-
Element element = new Element();
209-
element.content = content;
210-
return element;
211-
}
212-
213207
interface DummyEntityRepository extends CrudRepository<DummyEntity, Long> {}
214208

215209
interface RootRepository extends CrudRepository<Root, Long> {}
@@ -229,43 +223,22 @@ RootRepository rootRepository(JdbcRepositoryFactory factory) {
229223
}
230224
}
231225

232-
static class DummyEntity {
226+
record DummyEntity(@Id Long id, String name, List<Element> content) {
227+
}
233228

234-
String name;
235-
List<Element> content = new ArrayList<>();
236-
@Id private Long id;
229+
record Element(@Id Long id, String content) {
237230

238-
public String getName() {
239-
return this.name;
240-
}
231+
@PersistenceCreator
232+
Element {}
241233

242-
public List<Element> getContent() {
243-
return this.content;
234+
Element() {
235+
this(null, null);
244236
}
245237

246-
public Long getId() {
247-
return this.id;
238+
Element(String content) {
239+
this(null, content);
248240
}
249241

250-
public void setName(String name) {
251-
this.name = name;
252-
}
253-
254-
public void setContent(List<Element> content) {
255-
this.content = content;
256-
}
257-
258-
public void setId(Long id) {
259-
this.id = id;
260-
}
261-
}
262-
263-
static class Element {
264-
265-
String content;
266-
@Id private Long id;
267-
268-
public Element() {}
269242
}
270243

271244
static class Root {

spring-data-relational/src/main/java/org/springframework/data/relational/core/conversion/DbAction.java

+7-6
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,12 @@
1515
*/
1616
package org.springframework.data.relational.core.conversion;
1717

18+
import java.util.Comparator;
1819
import java.util.Iterator;
1920
import java.util.List;
2021
import java.util.Map;
22+
import java.util.Optional;
23+
import java.util.Set;
2124
import java.util.function.Function;
2225

2326
import org.springframework.data.mapping.PersistentPropertyPath;
@@ -479,15 +482,13 @@ interface WithDependingOn<T> extends WithPropertyPath<T>, WithEntity<T> {
479482
default Pair<PersistentPropertyPath<RelationalPersistentProperty>, Object> getQualifier() {
480483

481484
Map<PersistentPropertyPath<RelationalPersistentProperty>, Object> qualifiers = getQualifiers();
482-
if (qualifiers.isEmpty())
485+
if (qualifiers.isEmpty()) {
483486
return null;
484-
485-
if (qualifiers.size() > 1) {
486-
throw new IllegalStateException("Can't handle more then one qualifier");
487487
}
488488

489-
Map.Entry<PersistentPropertyPath<RelationalPersistentProperty>, Object> entry = qualifiers.entrySet().iterator()
490-
.next();
489+
Set<Map.Entry<PersistentPropertyPath<RelationalPersistentProperty>, Object>> entries = qualifiers.entrySet();
490+
Map.Entry<PersistentPropertyPath<RelationalPersistentProperty>, Object> entry = entries.stream().sorted(Comparator.comparing(e -> -e.getKey().getLength())).findFirst().get();
491+
491492
if (entry.getValue() == null) {
492493
return null;
493494
}

0 commit comments

Comments
 (0)