15
15
*/
16
16
package org .springframework .data .jdbc .core ;
17
17
18
- import java .util .ArrayList ;
19
- import java .util .Arrays ;
20
- import java .util .Collections ;
21
- import java .util .HashMap ;
22
- import java .util .HashSet ;
23
- import java .util .LinkedHashMap ;
24
- import java .util .List ;
25
- import java .util .Map ;
26
- import java .util .Optional ;
27
- import java .util .Set ;
18
+ import java .util .*;
28
19
import java .util .function .BiConsumer ;
29
20
import java .util .stream .Collectors ;
30
21
@@ -241,7 +232,7 @@ private Object getIdFrom(DbAction.WithEntity<?> idOwningAction) {
241
232
RelationalPersistentEntity <?> persistentEntity = getRequiredPersistentEntity (idOwningAction .getEntityType ());
242
233
Object identifier = persistentEntity .getIdentifierAccessor (idOwningAction .getEntity ()).getIdentifier ();
243
234
244
- Assert .state (identifier != null ,() -> "Couldn't obtain a required id value for " + persistentEntity );
235
+ Assert .state (identifier != null , () -> "Couldn't obtain a required id value for " + persistentEntity );
245
236
246
237
return identifier ;
247
238
}
@@ -268,12 +259,22 @@ <T> List<T> populateIdsIfNecessary() {
268
259
}
269
260
270
261
// the id property was immutable, so we have to propagate changes up the tree
271
- if (newEntity != action . getEntity () && action instanceof DbAction .Insert <?> insert ) {
262
+ if (action instanceof DbAction .Insert <?> insert ) {
272
263
273
264
Pair <?, ?> qualifier = insert .getQualifier ();
265
+ Object qualifierValue = qualifier == null ? null : qualifier .getSecond ();
274
266
275
- cascadingValues .stage (insert .getDependingOn (), insert .getPropertyPath (),
276
- qualifier == null ? null : qualifier .getSecond (), newEntity );
267
+ if (newEntity != action .getEntity ()) {
268
+
269
+ cascadingValues .stage (insert .getDependingOn (), insert .getPropertyPath (),
270
+ qualifierValue , newEntity );
271
+
272
+ } else if (insert .getPropertyPath ().getLeafProperty ().isCollectionLike ()) {
273
+
274
+ cascadingValues .gather (insert .getDependingOn (), insert .getPropertyPath (),
275
+ qualifierValue , newEntity );
276
+
277
+ }
277
278
}
278
279
}
279
280
@@ -360,7 +361,7 @@ private static class StagedValues {
360
361
static final List <MultiValueAggregator > aggregators = Arrays .asList (SetAggregator .INSTANCE , MapAggregator .INSTANCE ,
361
362
ListAggregator .INSTANCE , SingleElementAggregator .INSTANCE );
362
363
363
- Map <DbAction , Map <PersistentPropertyPath , Object >> values = new HashMap <>();
364
+ Map <DbAction , Map <PersistentPropertyPath , StagedValue >> values = new HashMap <>();
364
365
365
366
/**
366
367
* Adds a value that needs to be set in an entity higher up in the tree of entities in the aggregate. If the
@@ -375,18 +376,26 @@ private static class StagedValues {
375
376
*/
376
377
@ SuppressWarnings ("unchecked" )
377
378
<T > void stage (DbAction <?> action , PersistentPropertyPath path , @ Nullable Object qualifier , Object value ) {
379
+ gather (action , path , qualifier , value );
380
+ values .get (action ).get (path ).isStaged = true ;
381
+ }
382
+
383
+ <T > void gather (DbAction <?> action , PersistentPropertyPath path , @ Nullable Object qualifier , Object value ) {
378
384
379
385
MultiValueAggregator <T > aggregator = getAggregatorFor (path );
380
386
381
- Map <PersistentPropertyPath , Object > valuesForPath = this .values .computeIfAbsent (action ,
387
+ Map <PersistentPropertyPath , StagedValue > valuesForPath = this .values .computeIfAbsent (action ,
382
388
dbAction -> new HashMap <>());
383
389
384
- T currentValue = (T ) valuesForPath .computeIfAbsent (path ,
385
- persistentPropertyPath -> aggregator .createEmptyInstance ());
390
+ StagedValue stagedValue = valuesForPath .computeIfAbsent (path ,
391
+ persistentPropertyPath -> new StagedValue (aggregator .createEmptyInstance ()));
392
+ T currentValue = (T ) stagedValue .value ;
386
393
387
394
Object newValue = aggregator .add (currentValue , qualifier , value );
388
395
389
- valuesForPath .put (path , newValue );
396
+ stagedValue .value = newValue ;
397
+
398
+ valuesForPath .put (path , stagedValue );
390
399
}
391
400
392
401
private MultiValueAggregator getAggregatorFor (PersistentPropertyPath path ) {
@@ -408,7 +417,21 @@ private MultiValueAggregator getAggregatorFor(PersistentPropertyPath path) {
408
417
* property.
409
418
*/
410
419
void forEachPath (DbAction <?> dbAction , BiConsumer <PersistentPropertyPath , Object > action ) {
411
- values .getOrDefault (dbAction , Collections .emptyMap ()).forEach (action );
420
+ values .getOrDefault (dbAction , Collections .emptyMap ()).forEach ((persistentPropertyPath , stagedValue ) -> {
421
+ if (stagedValue .isStaged ) {
422
+ action .accept (persistentPropertyPath , stagedValue .value );
423
+ }
424
+ });
425
+ }
426
+
427
+ }
428
+
429
+ private static class StagedValue {
430
+ Object value ;
431
+ boolean isStaged ;
432
+
433
+ public StagedValue (Object value ) {
434
+ this .value = value ;
412
435
}
413
436
}
414
437
0 commit comments