Skip to content

Evaluate parameter expression for Enumerable.Contains calls #2873

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 10 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
74 changes: 65 additions & 9 deletions src/NHibernate.Test/Async/Linq/ParameterTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,15 @@ public async Task UsingEntityParameterForCollectionAsync()
1));
}

[Test]
public async Task UsingParameterForCollectionWithWhereAsync()
{
var item = await (db.OrderLines.FirstAsync());
await (AssertTotalParametersAsync(
db.Orders.Where(o => o.OrderLines.Select(ol => ol.Id).Where(id => id == item.Id).Contains(item.Id)),
1));
}

[Test]
public async Task UsingProxyParameterForCollectionAsync()
{
Expand Down Expand Up @@ -907,29 +916,76 @@ public async Task UsingTwoParametersInDMLDeleteAsync()
2));
}

[Test(Description = "GH-2872")]
public async Task UsingListParameterWithWhereAsync()
{
var ids = await (db.Orders.OrderBy(x => x.OrderId).Take(2).Select(o => o.OrderId).ToListAsync());
await (AssertTotalParametersAsync(
db.Orders.Where(o => ids.Where(i => i == ids[0]).Contains(o.OrderId)),
1,
countResults: 1));
}

[Test(Description = "GH-2276")]
public async Task UsingArrayParameterWithWhereAndSelectAsync()
{
var ids = db.Orders.OrderBy(x => x.OrderId).Take(2).ToArray();
var orderLines = new[] {ids[0].OrderLines.First(), ids[1].OrderLines.First()};
await (AssertTotalParametersAsync(
db.Orders.Where(o => ids.Where(i => i == ids[0]).Contains(o) && orderLines.Select(ol => ol.Order).Where(i => i.OrderId == ids[0].OrderId).Contains(o)),
2,
countResults: 1));
}

[Test]
public void UsingArrayParameterWithNotEvaluatableWhereAsync()
{
var ids = db.Orders.OrderBy(x => x.OrderId).Take(2).Select(x => x.OrderId).ToArray();
//ids.Where(i => i == o.OrderId) is not supported part of query. So just check it throws some exception and not silently ignored
Assert.ThrowsAsync<HibernateException>(() => db.Orders.Where(o => ids.Where(i => i == o.OrderId).Contains(o.OrderId)).ToListAsync());
}

[Test]
public async Task UsingArrayMethodParameterWithTakeAsync()
{
using (var logSpy = new SqlLogSpy())
{
var results = await (db.Orders.Where(o => GetArrayParameters().Take(1).Contains(o)).ToListAsync());
Assert.That(results.Count, Is.EqualTo(1));
Assert.That(logSpy.Appender.GetEvents().Length, Is.EqualTo(2));
}
}

private Order[] GetArrayParameters()
{
return db.Orders.OrderBy(x => x.OrderId).Take(3).ToArray();
}

private Task AssertTotalParametersAsync<T>(IQueryable<T> query, int parameterNumber, Action<string> sqlAction, CancellationToken cancellationToken = default(CancellationToken))
{
return AssertTotalParametersAsync(query, parameterNumber, null, sqlAction, cancellationToken);
return AssertTotalParametersAsync(query, parameterNumber, null, sqlAction, cancellationToken: cancellationToken);
}

private async Task AssertTotalParametersAsync<T>(IQueryable<T> query, int parameterNumber, int? linqParameterNumber = null, Action<string> sqlAction = null, CancellationToken cancellationToken = default(CancellationToken))
private async Task AssertTotalParametersAsync<T>(IQueryable<T> query, int parameterNumber, int? linqParameterNumber = null, Action<string> sqlAction = null, int? countResults = null, CancellationToken cancellationToken = default(CancellationToken))
{
using (var sqlSpy = new SqlLogSpy())
{
// In case of arrays linqParameterNumber and parameterNumber will be different
Assert.That(
GetLinqExpression(query).ParameterValuesByName.Count,
Is.EqualTo(linqParameterNumber ?? parameterNumber),
"Linq expression has different number of parameters");

var queryPlanCacheType = typeof(QueryPlanCache);
var cache = (SoftLimitMRUCache)
queryPlanCacheType
.GetField("planCache", BindingFlags.Instance | BindingFlags.NonPublic)
.GetValue(Sfi.QueryPlanCache);
cache.Clear();

await (query.ToListAsync(cancellationToken));
var results = await (query.ToListAsync(cancellationToken));

// In case of arrays linqParameterNumber and parameterNumber will be different
Assert.That(
GetLinqExpression(query).ParameterValuesByName.Count,
Is.EqualTo(linqParameterNumber ?? parameterNumber),
"Linq expression has different number of parameters");
if(countResults != null)
Assert.That(results.Count, Is.EqualTo(countResults), "Unexpected results count");

sqlAction?.Invoke(sqlSpy.GetWholeLog());

Expand Down
72 changes: 64 additions & 8 deletions src/NHibernate.Test/Linq/ParameterTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,15 @@ public void UsingEntityParameterForCollection()
1);
}

[Test]
public void UsingParameterForCollectionWithWhere()
{
var item = db.OrderLines.First();
AssertTotalParameters(
db.Orders.Where(o => o.OrderLines.Select(ol => ol.Id).Where(id => id == item.Id).Contains(item.Id)),
1);
}

[Test]
public void UsingProxyParameterForCollection()
{
Expand Down Expand Up @@ -968,29 +977,76 @@ public void DMLDeleteShouldHaveSameCacheKeys()
Assert.That(expression1.Key, Is.EqualTo(expression2.Key));
}

[Test(Description = "GH-2872")]
public void UsingListParameterWithWhere()
{
var ids = db.Orders.OrderBy(x => x.OrderId).Take(2).Select(o => o.OrderId).ToList();
AssertTotalParameters(
db.Orders.Where(o => ids.Where(i => i == ids[0]).Contains(o.OrderId)),
1,
countResults: 1);
}

[Test(Description = "GH-2276")]
public void UsingArrayParameterWithWhereAndSelect()
{
var ids = db.Orders.OrderBy(x => x.OrderId).Take(2).ToArray();
var orderLines = new[] {ids[0].OrderLines.First(), ids[1].OrderLines.First()};
AssertTotalParameters(
db.Orders.Where(o => ids.Where(i => i == ids[0]).Contains(o) && orderLines.Select(ol => ol.Order).Where(i => i.OrderId == ids[0].OrderId).Contains(o)),
2,
countResults: 1);
}

[Test]
public void UsingArrayParameterWithNotEvaluatableWhere()
{
var ids = db.Orders.OrderBy(x => x.OrderId).Take(2).Select(x => x.OrderId).ToArray();
//ids.Where(i => i == o.OrderId) is not supported part of query. So just check it throws some exception and not silently ignored
Assert.Throws<HibernateException>(() => db.Orders.Where(o => ids.Where(i => i == o.OrderId).Contains(o.OrderId)).ToList());
}

[Test]
public void UsingArrayMethodParameterWithTake()
{
using (var logSpy = new SqlLogSpy())
{
var results = db.Orders.Where(o => GetArrayParameters().Take(1).Contains(o)).ToList();
Assert.That(results.Count, Is.EqualTo(1));
Assert.That(logSpy.Appender.GetEvents().Length, Is.EqualTo(2));
}
}

private Order[] GetArrayParameters()
{
return db.Orders.OrderBy(x => x.OrderId).Take(3).ToArray();
}

private void AssertTotalParameters<T>(IQueryable<T> query, int parameterNumber, Action<string> sqlAction)
{
AssertTotalParameters(query, parameterNumber, null, sqlAction);
}

private void AssertTotalParameters<T>(IQueryable<T> query, int parameterNumber, int? linqParameterNumber = null, Action<string> sqlAction = null)
private void AssertTotalParameters<T>(IQueryable<T> query, int parameterNumber, int? linqParameterNumber = null, Action<string> sqlAction = null, int? countResults = null)
{
using (var sqlSpy = new SqlLogSpy())
{
// In case of arrays linqParameterNumber and parameterNumber will be different
Assert.That(
GetLinqExpression(query).ParameterValuesByName.Count,
Is.EqualTo(linqParameterNumber ?? parameterNumber),
"Linq expression has different number of parameters");

var queryPlanCacheType = typeof(QueryPlanCache);
var cache = (SoftLimitMRUCache)
queryPlanCacheType
.GetField("planCache", BindingFlags.Instance | BindingFlags.NonPublic)
.GetValue(Sfi.QueryPlanCache);
cache.Clear();

query.ToList();
var results = query.ToList();

// In case of arrays linqParameterNumber and parameterNumber will be different
Assert.That(
GetLinqExpression(query).ParameterValuesByName.Count,
Is.EqualTo(linqParameterNumber ?? parameterNumber),
"Linq expression has different number of parameters");
if(countResults != null)
Assert.That(results.Count, Is.EqualTo(countResults), "Unexpected results count");

sqlAction?.Invoke(sqlSpy.GetWholeLog());

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,10 @@
//

using System;
using System.Collections.Generic;
using System.Linq;
using System.Linq.Expressions;
using System.Reflection;
using NHibernate.Collection;
using NHibernate.Engine;
using NHibernate.Linq.Functions;
Expand Down Expand Up @@ -61,6 +63,7 @@ public static Expression EvaluateIndependentSubtrees(
// _partialEvaluationInfo contains a list of the expressions that are safe to be evaluated.
private readonly PartialEvaluationInfo _partialEvaluationInfo;
private readonly PreTransformationParameters _preTransformationParameters;
private static readonly MethodInfo ContainsMethodInfo = ReflectHelper.FastGetMethodDefinition(Enumerable.Contains, default(IEnumerable<object>), default(object));

private NhPartialEvaluatingExpressionVisitor(
PartialEvaluationInfo partialEvaluationInfo,
Expand Down Expand Up @@ -155,6 +158,42 @@ private Expression EvaluateSubtree(Expression subtree)

#region NH additions

protected override Expression VisitMethodCall(MethodCallExpression node)
{
DetectEvaluatableExpressionOnCollectionContains(node);
return base.VisitMethodCall(node);
}

private void DetectEvaluatableExpressionOnCollectionContains(MethodCallExpression expression)
{
if (!expression.Method.IsGenericMethod || ContainsMethodInfo != expression.Method.GetGenericMethodDefinition())
return;
var argument = expression.Arguments[0];
if (argument.NodeType != ExpressionType.Call)
return;

if(TryGetCollectionParameter((MethodCallExpression)argument))
_partialEvaluationInfo.AddEvaluatableExpression(argument);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here you're modifying external dependency, which you should not.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't get it. Why I shouldn't touch it? It's just a wrapper around set of expressions that needs evaluation. I can store it in separate field (as I did initially) - but to minimize change noise I used already present class for storing evaluatable expressions.

Copy link
Member

@hazzik hazzik Jul 23, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't get it. Why I shouldn't touch it?

Because this is the information provided by another concerning party. We already let this logic spread across solution and so we're trying to catch bugs doing 10th bugfix release. Let's avoid doing ad-hoc patches of logic here and there.

Copy link
Member Author

@bahusoid bahusoid Jul 23, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Well.. Ok... It can be stored separately. But I still think it's more proper to store it in _partialEvaluationInfo - external or not, it stores list of expressions that needs evaluation. So it's better to fill it also with our expressions - so all other external parties can handle it appropriately as evaluatable expression.

But before making any further changes - do you agree with this fix on principle?

}

private bool TryGetCollectionParameter(MethodCallExpression expression)
{
if (expression.Method.DeclaringType != typeof(Enumerable))
return IsCollectionParameter(expression);

var arg = expression.Arguments[0];

if (IsCollectionParameter(arg))
return true;

return arg.NodeType == ExpressionType.Call && TryGetCollectionParameter((MethodCallExpression) arg);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we need to do that recursively?

Copy link
Member Author

@bahusoid bahusoid Jul 23, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To support chain of LINQ commands. And to make sure it doesn't start with mapped association (somethin like
db.Orders.Where( x => x.OrderLines.Select(...).Where(...).Contains(...)) should not be evaluated)

}

private bool IsCollectionParameter(Expression expression)
{
return _partialEvaluationInfo.IsEvaluatableExpression(expression);
}

private bool ContainsVariable(Expression expression)
{
if (!(expression is UnaryExpression unaryExpression) ||
Expand Down