diff --git a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/AbstractReadContext.java b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/AbstractReadContext.java
index 92ebf006d2..900dc62247 100644
--- a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/AbstractReadContext.java
+++ b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/AbstractReadContext.java
@@ -69,6 +69,7 @@ abstract class AbstractReadContext
 
   abstract static class Builder<B extends Builder<?, T>, T extends AbstractReadContext> {
     private SessionImpl session;
+    private boolean cancelQueryWhenClientIsClosed;
     private SpannerRpc rpc;
     private ISpan span;
     private TraceWrapper tracer;
@@ -91,6 +92,11 @@ B setSession(SessionImpl session) {
       return self();
     }
 
+    B setCancelQueryWhenClientIsClosed(boolean cancelQueryWhenClientIsClosed) {
+      this.cancelQueryWhenClientIsClosed = cancelQueryWhenClientIsClosed;
+      return self();
+    }
+
     B setRpc(SpannerRpc rpc) {
       this.rpc = rpc;
       return self();
@@ -440,6 +446,7 @@ void initTransaction() {
 
   final Object lock = new Object();
   final SessionImpl session;
+  final boolean cancelQueryWhenClientIsClosed;
   final SpannerRpc rpc;
   final ExecutorProvider executorProvider;
   ISpan span;
@@ -469,6 +476,7 @@ void initTransaction() {
 
   AbstractReadContext(Builder<?, ?> builder) {
     this.session = builder.session;
+    this.cancelQueryWhenClientIsClosed = builder.cancelQueryWhenClientIsClosed;
     this.rpc = builder.rpc;
     this.defaultPrefetchChunks = builder.defaultPrefetchChunks;
     this.defaultQueryOptions = builder.defaultQueryOptions;
@@ -749,7 +757,8 @@ ResultSet executeQueryInternalWithOptions(
             rpc.getExecuteQueryRetryableCodes()) {
           @Override
           CloseableIterator<PartialResultSet> startStream(@Nullable ByteString resumeToken) {
-            GrpcStreamIterator stream = new GrpcStreamIterator(statement, prefetchChunks);
+            GrpcStreamIterator stream =
+                new GrpcStreamIterator(statement, prefetchChunks, cancelQueryWhenClientIsClosed);
             if (partitionToken != null) {
               request.setPartitionToken(partitionToken);
             }
@@ -922,7 +931,8 @@ ResultSet readInternalWithOptions(
             rpc.getReadRetryableCodes()) {
           @Override
           CloseableIterator<PartialResultSet> startStream(@Nullable ByteString resumeToken) {
-            GrpcStreamIterator stream = new GrpcStreamIterator(prefetchChunks);
+            GrpcStreamIterator stream =
+                new GrpcStreamIterator(prefetchChunks, cancelQueryWhenClientIsClosed);
             TransactionSelector selector = null;
             if (resumeToken != null) {
               builder.setResumeToken(resumeToken);
diff --git a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/BatchClientImpl.java b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/BatchClientImpl.java
index 22fb9f710c..3d886dd383 100644
--- a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/BatchClientImpl.java
+++ b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/BatchClientImpl.java
@@ -54,6 +54,7 @@ public BatchReadOnlyTransaction batchReadOnlyTransaction(TimestampBound bound) {
     return new BatchReadOnlyTransactionImpl(
         MultiUseReadOnlyTransaction.newBuilder()
             .setSession(session)
+            .setCancelQueryWhenClientIsClosed(true)
             .setRpc(sessionClient.getSpanner().getRpc())
             .setTimestampBound(bound)
             .setDefaultQueryOptions(
@@ -75,6 +76,7 @@ public BatchReadOnlyTransaction batchReadOnlyTransaction(BatchTransactionId batc
     return new BatchReadOnlyTransactionImpl(
         MultiUseReadOnlyTransaction.newBuilder()
             .setSession(session)
+            .setCancelQueryWhenClientIsClosed(true)
             .setRpc(sessionClient.getSpanner().getRpc())
             .setTransactionId(batchTransactionId.getTransactionId())
             .setTimestamp(batchTransactionId.getTimestamp())
diff --git a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/GrpcStreamIterator.java b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/GrpcStreamIterator.java
index dde6b69c46..af6b568350 100644
--- a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/GrpcStreamIterator.java
+++ b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/GrpcStreamIterator.java
@@ -38,7 +38,7 @@ class GrpcStreamIterator extends AbstractIterator<PartialResultSet>
   private static final Logger logger = Logger.getLogger(GrpcStreamIterator.class.getName());
   private static final PartialResultSet END_OF_STREAM = PartialResultSet.newBuilder().build();
 
-  private final ConsumerImpl consumer = new ConsumerImpl();
+  private final ConsumerImpl consumer;
   private final BlockingQueue<PartialResultSet> stream;
   private final Statement statement;
 
@@ -49,13 +49,15 @@ class GrpcStreamIterator extends AbstractIterator<PartialResultSet>
   private SpannerException error;
 
   @VisibleForTesting
-  GrpcStreamIterator(int prefetchChunks) {
-    this(null, prefetchChunks);
+  GrpcStreamIterator(int prefetchChunks, boolean cancelQueryWhenClientIsClosed) {
+    this(null, prefetchChunks, cancelQueryWhenClientIsClosed);
   }
 
   @VisibleForTesting
-  GrpcStreamIterator(Statement statement, int prefetchChunks) {
+  GrpcStreamIterator(
+      Statement statement, int prefetchChunks, boolean cancelQueryWhenClientIsClosed) {
     this.statement = statement;
+    this.consumer = new ConsumerImpl(cancelQueryWhenClientIsClosed);
     // One extra to allow for END_OF_STREAM message.
     this.stream = new LinkedBlockingQueue<>(prefetchChunks + 1);
   }
@@ -136,6 +138,12 @@ private void addToStream(PartialResultSet results) {
   }
 
   private class ConsumerImpl implements SpannerRpc.ResultStreamConsumer {
+    private final boolean cancelQueryWhenClientIsClosed;
+
+    ConsumerImpl(boolean cancelQueryWhenClientIsClosed) {
+      this.cancelQueryWhenClientIsClosed = cancelQueryWhenClientIsClosed;
+    }
+
     @Override
     public void onPartialResultSet(PartialResultSet results) {
       addToStream(results);
@@ -168,5 +176,10 @@ public void onError(SpannerException e) {
       error = e;
       addToStream(END_OF_STREAM);
     }
+
+    @Override
+    public boolean cancelQueryWhenClientIsClosed() {
+      return this.cancelQueryWhenClientIsClosed;
+    }
   }
 }
diff --git a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/IsChannelShutdownException.java b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/IsChannelShutdownException.java
new file mode 100644
index 0000000000..367d75a13c
--- /dev/null
+++ b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/IsChannelShutdownException.java
@@ -0,0 +1,50 @@
+/*
+ * Copyright 2024 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *       http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.google.cloud.spanner;
+
+import com.google.api.gax.rpc.UnavailableException;
+import com.google.common.base.Predicate;
+import io.grpc.Status.Code;
+import io.grpc.StatusRuntimeException;
+
+/**
+ * Predicate that checks whether an exception is a ChannelShutdownException. This exception is
+ * thrown by gRPC if the underlying gRPC stub has been shut down and uses the UNAVAILABLE error
+ * code. This means that it would normally be retried by the Spanner client, but this specific
+ * UNAVAILABLE error should not be retried, as it would otherwise directly return the same error.
+ */
+class IsChannelShutdownException implements Predicate<Throwable> {
+
+  @Override
+  public boolean apply(Throwable input) {
+    Throwable cause = input;
+    do {
+      if (isUnavailableError(cause)
+          && (cause.getMessage().contains("Channel shutdown invoked")
+              || cause.getMessage().contains("Channel shutdownNow invoked"))) {
+        return true;
+      }
+    } while ((cause = cause.getCause()) != null);
+    return false;
+  }
+
+  private boolean isUnavailableError(Throwable cause) {
+    return (cause instanceof UnavailableException)
+        || (cause instanceof StatusRuntimeException
+            && ((StatusRuntimeException) cause).getStatus().getCode() == Code.UNAVAILABLE);
+  }
+}
diff --git a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/SpannerExceptionFactory.java b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/SpannerExceptionFactory.java
index 2c52192d21..068e2e2492 100644
--- a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/SpannerExceptionFactory.java
+++ b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/SpannerExceptionFactory.java
@@ -322,7 +322,9 @@ private static boolean isRetryable(ErrorCode code, @Nullable Throwable cause) {
       case UNAVAILABLE:
         // SSLHandshakeException is (probably) not retryable, as it is an indication that the server
         // certificate was not accepted by the client.
-        return !hasCauseMatching(cause, Matchers.isSSLHandshakeException);
+        // Channel shutdown is also not a retryable exception.
+        return !(hasCauseMatching(cause, Matchers.isSSLHandshakeException)
+            || hasCauseMatching(cause, Matchers.IS_CHANNEL_SHUTDOWN_EXCEPTION));
       case RESOURCE_EXHAUSTED:
         return SpannerException.extractRetryDelay(cause) > 0;
       default:
@@ -345,5 +347,8 @@ private static class Matchers {
 
     static final Predicate<Throwable> isRetryableInternalError = new IsRetryableInternalError();
     static final Predicate<Throwable> isSSLHandshakeException = new IsSslHandshakeException();
+
+    static final Predicate<Throwable> IS_CHANNEL_SHUTDOWN_EXCEPTION =
+        new IsChannelShutdownException();
   }
 }
diff --git a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/spi/v1/GapicSpannerRpc.java b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/spi/v1/GapicSpannerRpc.java
index 00ae72f169..e1e15b851b 100644
--- a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/spi/v1/GapicSpannerRpc.java
+++ b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/spi/v1/GapicSpannerRpc.java
@@ -201,6 +201,7 @@
 import java.util.concurrent.Callable;
 import java.util.concurrent.CancellationException;
 import java.util.concurrent.ConcurrentHashMap;
+import java.util.concurrent.ConcurrentLinkedDeque;
 import java.util.concurrent.ConcurrentMap;
 import java.util.concurrent.ExecutionException;
 import java.util.concurrent.ExecutorService;
@@ -262,6 +263,9 @@ public class GapicSpannerRpc implements SpannerRpc {
 
   private final ScheduledExecutorService spannerWatchdog;
 
+  private final ConcurrentLinkedDeque<SpannerResponseObserver> responseObservers =
+      new ConcurrentLinkedDeque<>();
+
   private final boolean throttleAdministrativeRequests;
   private final RetrySettings retryAdministrativeRequestsSettings;
   private static final double ADMINISTRATIVE_REQUESTS_RATE_LIMIT = 1.0D;
@@ -2004,9 +2008,29 @@ <ReqT, RespT> GrpcCallContext newCallContext(
     return (GrpcCallContext) context.merge(apiCallContextFromContext);
   }
 
+  void registerResponseObserver(SpannerResponseObserver responseObserver) {
+    responseObservers.add(responseObserver);
+  }
+
+  void unregisterResponseObserver(SpannerResponseObserver responseObserver) {
+    responseObservers.remove(responseObserver);
+  }
+
+  void closeResponseObservers() {
+    responseObservers.forEach(SpannerResponseObserver::close);
+    responseObservers.clear();
+  }
+
+  @InternalApi
+  @VisibleForTesting
+  public int getNumActiveResponseObservers() {
+    return responseObservers.size();
+  }
+
   @Override
   public void shutdown() {
     this.rpcIsClosed = true;
+    closeResponseObservers();
     if (this.spannerStub != null) {
       this.spannerStub.close();
       this.partitionedDmlStub.close();
@@ -2028,6 +2052,7 @@ public void shutdown() {
 
   public void shutdownNow() {
     this.rpcIsClosed = true;
+    closeResponseObservers();
     this.spannerStub.close();
     this.partitionedDmlStub.close();
     this.instanceAdminStub.close();
@@ -2085,7 +2110,7 @@ public void cancel(@Nullable String message) {
    * A {@code ResponseObserver} that exposes the {@code StreamController} and delegates callbacks to
    * the {@link ResultStreamConsumer}.
    */
-  private static class SpannerResponseObserver implements ResponseObserver<PartialResultSet> {
+  private class SpannerResponseObserver implements ResponseObserver<PartialResultSet> {
 
     private StreamController controller;
     private final ResultStreamConsumer consumer;
@@ -2094,13 +2119,21 @@ public SpannerResponseObserver(ResultStreamConsumer consumer) {
       this.consumer = consumer;
     }
 
+    void close() {
+      if (this.controller != null) {
+        this.controller.cancel();
+      }
+    }
+
     @Override
     public void onStart(StreamController controller) {
-
       // Disable the auto flow control to allow client library
       // set the number of messages it prefers to request
       controller.disableAutoInboundFlowControl();
       this.controller = controller;
+      if (this.consumer.cancelQueryWhenClientIsClosed()) {
+        registerResponseObserver(this);
+      }
     }
 
     @Override
@@ -2110,11 +2143,19 @@ public void onResponse(PartialResultSet response) {
 
     @Override
     public void onError(Throwable t) {
+      // Unregister the response observer when the query has completed with an error.
+      if (this.consumer.cancelQueryWhenClientIsClosed()) {
+        unregisterResponseObserver(this);
+      }
       consumer.onError(newSpannerException(t));
     }
 
     @Override
     public void onComplete() {
+      // Unregister the response observer when the query has completed normally.
+      if (this.consumer.cancelQueryWhenClientIsClosed()) {
+        unregisterResponseObserver(this);
+      }
       consumer.onCompleted();
     }
 
diff --git a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/spi/v1/SpannerRpc.java b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/spi/v1/SpannerRpc.java
index f07a28fb91..083cc11d6f 100644
--- a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/spi/v1/SpannerRpc.java
+++ b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/spi/v1/SpannerRpc.java
@@ -152,6 +152,15 @@ interface ResultStreamConsumer {
     void onCompleted();
 
     void onError(SpannerException e);
+
+    /**
+     * Returns true if the stream should be cancelled when the Spanner client is closed. This
+     * returns true for {@link com.google.cloud.spanner.BatchReadOnlyTransaction}, as these use a
+     * non-pooled session. Pooled sessions are deleted when the Spanner client is closed, and this
+     * automatically also cancels any query that uses the session, which means that we don't need to
+     * explicitly cancel those queries when the Spanner client is closed.
+     */
+    boolean cancelQueryWhenClientIsClosed();
   }
 
   /** Handle for cancellation of a streaming read or query call. */
diff --git a/google-cloud-spanner/src/test/java/com/google/cloud/spanner/CloseSpannerWithOpenResultSetTest.java b/google-cloud-spanner/src/test/java/com/google/cloud/spanner/CloseSpannerWithOpenResultSetTest.java
new file mode 100644
index 0000000000..67b14f60a4
--- /dev/null
+++ b/google-cloud-spanner/src/test/java/com/google/cloud/spanner/CloseSpannerWithOpenResultSetTest.java
@@ -0,0 +1,164 @@
+/*
+ * Copyright 2024 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *       http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.google.cloud.spanner;
+
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertNotNull;
+import static org.junit.Assert.assertThrows;
+import static org.junit.Assert.assertTrue;
+import static org.junit.Assume.assumeFalse;
+
+import com.google.cloud.NoCredentials;
+import com.google.cloud.spanner.MockSpannerServiceImpl.StatementResult;
+import com.google.cloud.spanner.connection.AbstractMockServerTest;
+import com.google.cloud.spanner.spi.v1.GapicSpannerRpc;
+import com.google.spanner.v1.DeleteSessionRequest;
+import com.google.spanner.v1.ExecuteSqlRequest;
+import io.grpc.ManagedChannelBuilder;
+import io.grpc.Status;
+import java.util.ArrayList;
+import java.util.List;
+import java.util.concurrent.ExecutorService;
+import java.util.concurrent.Executors;
+import java.util.concurrent.Future;
+import java.util.concurrent.ThreadLocalRandom;
+import java.util.concurrent.TimeUnit;
+import java.util.stream.Collectors;
+import org.junit.After;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.junit.runners.JUnit4;
+import org.threeten.bp.Duration;
+
+@RunWith(JUnit4.class)
+public class CloseSpannerWithOpenResultSetTest extends AbstractMockServerTest {
+
+  Spanner createSpanner() {
+    return SpannerOptions.newBuilder()
+        .setProjectId("p")
+        .setHost(String.format("http://localhost:%d", getPort()))
+        .setChannelConfigurator(ManagedChannelBuilder::usePlaintext)
+        .setCredentials(NoCredentials.getInstance())
+        .setSessionPoolOption(
+            SessionPoolOptions.newBuilder().setWaitForMinSessions(Duration.ofSeconds(5L)).build())
+        .build()
+        .getService();
+  }
+
+  @After
+  public void cleanup() {
+    mockSpanner.unfreeze();
+    mockSpanner.clearRequests();
+  }
+
+  @Test
+  public void testBatchClient_closedSpannerWithOpenResultSet_streamsAreCancelled() {
+    Spanner spanner = createSpanner();
+    assumeFalse(spanner.getOptions().getSessionPoolOptions().getUseMultiplexedSession());
+
+    BatchClient client = spanner.getBatchClient(DatabaseId.of("p", "i", "d"));
+    try (BatchReadOnlyTransaction transaction =
+            client.batchReadOnlyTransaction(TimestampBound.strong());
+        ResultSet resultSet = transaction.executeQuery(SELECT_RANDOM_STATEMENT)) {
+      mockSpanner.freezeAfterReturningNumRows(1);
+      assertTrue(resultSet.next());
+      ((SpannerImpl) spanner).close(1, TimeUnit.MILLISECONDS);
+      // This should return an error as the stream is cancelled.
+      SpannerException exception = assertThrows(SpannerException.class, resultSet::next);
+      assertEquals(ErrorCode.CANCELLED, exception.getErrorCode());
+    }
+  }
+
+  @Test
+  public void testNormalDatabaseClient_closedSpannerWithOpenResultSet_sessionsAreDeleted()
+      throws Exception {
+    Spanner spanner = createSpanner();
+    assumeFalse(spanner.getOptions().getSessionPoolOptions().getUseMultiplexedSession());
+
+    DatabaseClient client = spanner.getDatabaseClient(DatabaseId.of("p", "i", "d"));
+    try (ReadOnlyTransaction transaction = client.readOnlyTransaction(TimestampBound.strong());
+        ResultSet resultSet = transaction.executeQuery(SELECT_RANDOM_STATEMENT)) {
+      mockSpanner.freezeAfterReturningNumRows(1);
+      assertTrue(resultSet.next());
+      List<ExecuteSqlRequest> executeSqlRequests =
+          mockSpanner.getRequestsOfType(ExecuteSqlRequest.class).stream()
+              .filter(request -> request.getSql().equals(SELECT_RANDOM_STATEMENT.getSql()))
+              .collect(Collectors.toList());
+      assertEquals(1, executeSqlRequests.size());
+      ExecutorService service = Executors.newSingleThreadExecutor();
+      service.submit(spanner::close);
+      // Verify that the session that is used by this transaction is deleted.
+      // That will automatically cancel the query.
+      mockSpanner.waitForRequestsToContain(
+          request ->
+              request instanceof DeleteSessionRequest
+                  && ((DeleteSessionRequest) request)
+                      .getName()
+                      .equals(executeSqlRequests.get(0).getSession()),
+          /*timeoutMillis=*/ 1000L);
+      service.shutdownNow();
+    }
+  }
+
+  @Test
+  public void testStreamsAreCleanedUp() throws Exception {
+    String invalidSql = "select * from foo";
+    Statement invalidStatement = Statement.of(invalidSql);
+    mockSpanner.putStatementResult(
+        StatementResult.exception(
+            invalidStatement,
+            Status.NOT_FOUND.withDescription("Table not found: foo").asRuntimeException()));
+    int numThreads = 16;
+    int numQueries = 32;
+    try (Spanner spanner = createSpanner()) {
+      BatchClient client = spanner.getBatchClient(DatabaseId.of("p", "i", "d"));
+      ExecutorService service = Executors.newFixedThreadPool(numThreads);
+      List<Future<?>> futures = new ArrayList<>(numQueries);
+      for (int n = 0; n < numQueries; n++) {
+        futures.add(
+            service.submit(
+                () -> {
+                  try (BatchReadOnlyTransaction transaction =
+                      client.batchReadOnlyTransaction(TimestampBound.strong())) {
+                    if (ThreadLocalRandom.current().nextInt(10) < 2) {
+                      try (ResultSet resultSet = transaction.executeQuery(invalidStatement)) {
+                        SpannerException exception =
+                            assertThrows(SpannerException.class, resultSet::next);
+                        assertEquals(ErrorCode.NOT_FOUND, exception.getErrorCode());
+                      }
+                    } else {
+                      try (ResultSet resultSet =
+                          transaction.executeQuery(SELECT_RANDOM_STATEMENT)) {
+                        while (resultSet.next()) {
+                          assertNotNull(resultSet.getCurrentRowAsStruct());
+                        }
+                      }
+                    }
+                  }
+                }));
+      }
+      service.shutdown();
+      for (Future<?> fut : futures) {
+        fut.get();
+      }
+      assertTrue(service.awaitTermination(1L, TimeUnit.MINUTES));
+      // Verify that all response observers have been unregistered.
+      assertEquals(
+          0, ((GapicSpannerRpc) ((SpannerImpl) spanner).getRpc()).getNumActiveResponseObservers());
+    }
+  }
+}
diff --git a/google-cloud-spanner/src/test/java/com/google/cloud/spanner/GrpcResultSetTest.java b/google-cloud-spanner/src/test/java/com/google/cloud/spanner/GrpcResultSetTest.java
index 2051e006d8..62336163ea 100644
--- a/google-cloud-spanner/src/test/java/com/google/cloud/spanner/GrpcResultSetTest.java
+++ b/google-cloud-spanner/src/test/java/com/google/cloud/spanner/GrpcResultSetTest.java
@@ -81,7 +81,7 @@ public void onDone(boolean withBeginTransaction) {}
 
   @Before
   public void setUp() {
-    stream = new GrpcStreamIterator(10);
+    stream = new GrpcStreamIterator(10, /*cancelQueryWhenClientIsClosed=*/ false);
     stream.setCall(
         new SpannerRpc.StreamingCall() {
           @Override
diff --git a/google-cloud-spanner/src/test/java/com/google/cloud/spanner/MockSpannerServiceImpl.java b/google-cloud-spanner/src/test/java/com/google/cloud/spanner/MockSpannerServiceImpl.java
index 5266ecad7c..9f0a2822d8 100644
--- a/google-cloud-spanner/src/test/java/com/google/cloud/spanner/MockSpannerServiceImpl.java
+++ b/google-cloud-spanner/src/test/java/com/google/cloud/spanner/MockSpannerServiceImpl.java
@@ -578,6 +578,7 @@ private static void checkStreamException(
   private final Object lock = new Object();
   private Deque<AbstractMessage> requests = new ConcurrentLinkedDeque<>();
   private volatile CountDownLatch freezeLock = new CountDownLatch(0);
+  private final AtomicInteger freezeAfterReturningNumRows = new AtomicInteger();
   private Queue<Exception> exceptions = new ConcurrentLinkedQueue<>();
   private boolean stickyGlobalExceptions = false;
   private ConcurrentMap<Statement, StatementResult> statementResults = new ConcurrentHashMap<>();
@@ -784,6 +785,10 @@ public void unfreeze() {
     freezeLock.countDown();
   }
 
+  public void freezeAfterReturningNumRows(int numRows) {
+    freezeAfterReturningNumRows.set(numRows);
+  }
+
   public void setMaxSessionsInOneBatch(int max) {
     this.maxNumSessionsInOneBatch = max;
   }
@@ -1678,7 +1683,8 @@ private void returnPartialResultSet(
       ByteString transactionId,
       TransactionSelector transactionSelector,
       StreamObserver<PartialResultSet> responseObserver,
-      SimulatedExecutionTime executionTime) {
+      SimulatedExecutionTime executionTime)
+      throws Exception {
     ResultSetMetadata metadata = resultSet.getMetadata();
     if (transactionId == null) {
       Transaction transaction = getTemporaryTransactionOrNull(transactionSelector);
@@ -1700,6 +1706,12 @@ private void returnPartialResultSet(
       SimulatedExecutionTime.checkStreamException(
           index, executionTime.exceptions, executionTime.streamIndices);
       responseObserver.onNext(iterator.next());
+      if (freezeAfterReturningNumRows.get() > 0) {
+        if (freezeAfterReturningNumRows.decrementAndGet() == 0) {
+          freeze();
+          freezeLock.await();
+        }
+      }
       index++;
     }
     responseObserver.onCompleted();
diff --git a/google-cloud-spanner/src/test/java/com/google/cloud/spanner/ReadFormatTestRunner.java b/google-cloud-spanner/src/test/java/com/google/cloud/spanner/ReadFormatTestRunner.java
index 8d97d9d894..c973b7e471 100644
--- a/google-cloud-spanner/src/test/java/com/google/cloud/spanner/ReadFormatTestRunner.java
+++ b/google-cloud-spanner/src/test/java/com/google/cloud/spanner/ReadFormatTestRunner.java
@@ -114,7 +114,7 @@ private static class TestCaseRunner {
     }
 
     private void run() throws Exception {
-      stream = new GrpcStreamIterator(10);
+      stream = new GrpcStreamIterator(10, /*cancelQueryWhenClientIsClosed=*/ false);
       stream.setCall(
           new SpannerRpc.StreamingCall() {
             @Override