|
2 | 2 |
|
3 | 3 | using System; |
4 | 4 | using System.Collections; |
5 | | -using System.Collections.Generic; |
6 | | -using System.Diagnostics; |
7 | | -using System.Diagnostics.CodeAnalysis; |
8 | 5 | using System.Linq; |
9 | 6 | using System.Linq.Expressions; |
10 | 7 | using System.Text; |
|
13 | 10 |
|
14 | 11 | namespace Microsoft.SemanticKernel.Connectors.AzureAISearch; |
15 | 12 |
|
16 | | -internal class AzureAISearchFilterTranslator |
17 | | -{ |
18 | | - private CollectionModel _model = null!; |
19 | | - private ParameterExpression _recordParameter = null!; |
| 13 | +#pragma warning disable MEVD9001 // Experimental: filter translation base types |
20 | 14 |
|
| 15 | +internal class AzureAISearchFilterTranslator : FilterTranslatorBase |
| 16 | +{ |
21 | 17 | private readonly StringBuilder _filter = new(); |
22 | 18 |
|
23 | 19 | private static readonly char[] s_searchInDefaultDelimiter = [' ', ',']; |
24 | 20 |
|
25 | 21 | internal string Translate(LambdaExpression lambdaExpression, CollectionModel model) |
26 | 22 | { |
27 | | - Debug.Assert(this._filter.Length == 0); |
28 | | - |
29 | | - this._model = model; |
30 | | - |
31 | | - Debug.Assert(lambdaExpression.Parameters.Count == 1); |
32 | | - this._recordParameter = lambdaExpression.Parameters[0]; |
33 | | - |
34 | | - var preprocessor = new FilterTranslationPreprocessor { SupportsParameterization = false }; |
35 | | - var preprocessedExpression = preprocessor.Preprocess(lambdaExpression.Body); |
| 23 | + var preprocessedExpression = this.PreprocessFilter(lambdaExpression, model, new FilterPreprocessingOptions()); |
36 | 24 |
|
37 | 25 | this.Translate(preprocessedExpression); |
38 | 26 |
|
@@ -161,52 +149,25 @@ private void TranslateMember(MemberExpression memberExpression) |
161 | 149 |
|
162 | 150 | private void TranslateMethodCall(MethodCallExpression methodCall) |
163 | 151 | { |
164 | | - switch (methodCall) |
| 152 | + // Dictionary access for dynamic mapping (r => r["SomeString"] == "foo") |
| 153 | + if (this.TryBindProperty(methodCall, out var property)) |
165 | 154 | { |
166 | | - // Dictionary access for dynamic mapping (r => r["SomeString"] == "foo") |
167 | | - case MethodCallExpression when this.TryBindProperty(methodCall, out var property): |
168 | | - // OData identifiers cannot be escaped; storage names are validated during model building. |
169 | | - this._filter.Append(property.StorageName); |
170 | | - return; |
171 | | - |
172 | | - // Enumerable.Contains() |
173 | | - case { Method.Name: nameof(Enumerable.Contains), Arguments: [var source, var item] } contains |
174 | | - when contains.Method.DeclaringType == typeof(Enumerable): |
175 | | - this.TranslateContains(source, item); |
176 | | - return; |
177 | | - |
178 | | - // List.Contains() |
179 | | - case |
180 | | - { |
181 | | - Method: |
182 | | - { |
183 | | - Name: nameof(Enumerable.Contains), |
184 | | - DeclaringType: { IsGenericType: true } declaringType |
185 | | - }, |
186 | | - Object: Expression source, |
187 | | - Arguments: [var item] |
188 | | - } when declaringType.GetGenericTypeDefinition() == typeof(List<>): |
189 | | - this.TranslateContains(source, item); |
190 | | - return; |
| 155 | + // OData identifiers cannot be escaped; storage names are validated during model building. |
| 156 | + this._filter.Append(property.StorageName); |
| 157 | + return; |
| 158 | + } |
191 | 159 |
|
192 | | - // C# 14 made changes to overload resolution to prefer Span-based overloads when those exist ("first-class spans"); |
193 | | - // this makes MemoryExtensions.Contains() be resolved rather than Enumerable.Contains() (see above). |
194 | | - // MemoryExtensions.Contains() also accepts a Span argument for the source, adding an implicit cast we need to remove. |
195 | | - // See https://github.com/dotnet/runtime/issues/109757 for more context. |
196 | | - // Note that MemoryExtensions.Contains has an optional 3rd ComparisonType parameter; we only match when |
197 | | - // it's null. |
198 | | - case { Method.Name: nameof(MemoryExtensions.Contains), Arguments: [var spanArg, var item, ..] } contains |
199 | | - when contains.Method.DeclaringType == typeof(MemoryExtensions) |
200 | | - && (contains.Arguments.Count is 2 |
201 | | - || (contains.Arguments.Count is 3 && contains.Arguments[2] is ConstantExpression { Value: null })) |
202 | | - && TryUnwrapSpanImplicitCast(spanArg, out var source): |
| 160 | + switch (methodCall) |
| 161 | + { |
| 162 | + // Enumerable.Contains(), List.Contains(), MemoryExtensions.Contains() |
| 163 | + case var _ when TryMatchContains(methodCall, out var source, out var item): |
203 | 164 | this.TranslateContains(source, item); |
204 | 165 | return; |
205 | 166 |
|
206 | 167 | // Enumerable.Any() with a Contains predicate (r => r.Strings.Any(s => array.Contains(s))) |
207 | | - case { Method.Name: nameof(Enumerable.Any), Arguments: [var source, LambdaExpression lambda] } any |
| 168 | + case { Method.Name: nameof(Enumerable.Any), Arguments: [var anySource, LambdaExpression lambda] } any |
208 | 169 | when any.Method.DeclaringType == typeof(Enumerable): |
209 | | - this.TranslateAny(source, lambda); |
| 170 | + this.TranslateAny(anySource, lambda); |
210 | 171 | return; |
211 | 172 |
|
212 | 173 | default: |
@@ -254,35 +215,12 @@ private void TranslateAny(Expression source, LambdaExpression lambda) |
254 | 215 | // We only support the pattern: r.ArrayField.Any(x => values.Contains(x)) |
255 | 216 | // Translates to: Field/any(t: search.in(t, 'value1, value2, value3')) |
256 | 217 | if (!this.TryBindProperty(source, out var property) |
257 | | - || lambda.Body is not MethodCallExpression { Method.Name: "Contains" } containsCall) |
| 218 | + || lambda.Body is not MethodCallExpression { Method.Name: "Contains" } containsCall |
| 219 | + || !TryMatchContains(containsCall, out var valuesExpression, out var itemExpression)) |
258 | 220 | { |
259 | 221 | throw new NotSupportedException("Unsupported method call: Enumerable.Any"); |
260 | 222 | } |
261 | 223 |
|
262 | | - // Match Enumerable.Contains(source, item), List<T>.Contains(item), or MemoryExtensions.Contains |
263 | | - var (valuesExpression, itemExpression) = containsCall switch |
264 | | - { |
265 | | - // Enumerable.Contains(source, item) |
266 | | - { Method.Name: nameof(Enumerable.Contains), Arguments: [var src, var item] } |
267 | | - when containsCall.Method.DeclaringType == typeof(Enumerable) |
268 | | - => (src, item), |
269 | | - |
270 | | - // List<T>.Contains(item) |
271 | | - { Method: { Name: nameof(Enumerable.Contains), DeclaringType: { IsGenericType: true } declaringType }, Object: Expression src, Arguments: [var item] } |
272 | | - when declaringType.GetGenericTypeDefinition() == typeof(List<>) |
273 | | - => (src, item), |
274 | | - |
275 | | - // MemoryExtensions.Contains (C# 14 first-class spans) |
276 | | - { Method.Name: nameof(MemoryExtensions.Contains), Arguments: [var spanArg, var item, ..] } |
277 | | - when containsCall.Method.DeclaringType == typeof(MemoryExtensions) |
278 | | - && (containsCall.Arguments.Count is 2 |
279 | | - || (containsCall.Arguments.Count is 3 && containsCall.Arguments[2] is ConstantExpression { Value: null })) |
280 | | - && TryUnwrapSpanImplicitCast(spanArg, out var unwrappedSource) |
281 | | - => (unwrappedSource, item), |
282 | | - |
283 | | - _ => throw new NotSupportedException("Unsupported method call: Enumerable.Any"), |
284 | | - }; |
285 | | - |
286 | 224 | // Verify that the item is the lambda parameter |
287 | 225 | if (itemExpression != lambda.Parameters[0]) |
288 | 226 | { |
@@ -390,65 +328,6 @@ private void GenerateSearchInValues(IEnumerable values) |
390 | 328 | return result; |
391 | 329 | } |
392 | 330 |
|
393 | | - private static bool TryUnwrapSpanImplicitCast(Expression expression, [NotNullWhen(true)] out Expression? result) |
394 | | - { |
395 | | - // Different versions of the compiler seem to generate slightly different expression tree representations for this |
396 | | - // implicit cast: |
397 | | - var (unwrapped, castDeclaringType) = expression switch |
398 | | - { |
399 | | - UnaryExpression |
400 | | - { |
401 | | - NodeType: ExpressionType.Convert, |
402 | | - Method: { Name: "op_Implicit", DeclaringType: { IsGenericType: true } implicitCastDeclaringType }, |
403 | | - Operand: var operand |
404 | | - } => (operand, implicitCastDeclaringType), |
405 | | - |
406 | | - MethodCallExpression |
407 | | - { |
408 | | - Method: { Name: "op_Implicit", DeclaringType: { IsGenericType: true } implicitCastDeclaringType }, |
409 | | - Arguments: [var firstArgument] |
410 | | - } => (firstArgument, implicitCastDeclaringType), |
411 | | - |
412 | | - // After the preprocessor runs, the Convert node may have Method: null because the visitor |
413 | | - // recreates the UnaryExpression with a different operand type (QueryParameterExpression). |
414 | | - // Handle this case by checking if the target type is Span<T> or ReadOnlySpan<T>. |
415 | | - UnaryExpression |
416 | | - { |
417 | | - NodeType: ExpressionType.Convert, |
418 | | - Method: null, |
419 | | - Type: { IsGenericType: true } targetType, |
420 | | - Operand: var operand |
421 | | - } when targetType.GetGenericTypeDefinition() is var gtd |
422 | | - && (gtd == typeof(Span<>) || gtd == typeof(ReadOnlySpan<>)) |
423 | | - => (operand, targetType), |
424 | | - |
425 | | - _ => (null, null) |
426 | | - }; |
427 | | - |
428 | | - // For the dynamic case, there's a Convert node representing an up-cast to object[]; unwrap that too. |
429 | | - // Also handle cases where the preprocessor adds a Convert node back to the array type. |
430 | | - while (unwrapped is UnaryExpression |
431 | | - { |
432 | | - NodeType: ExpressionType.Convert, |
433 | | - Method: null, |
434 | | - Operand: var innerOperand |
435 | | - }) |
436 | | - { |
437 | | - unwrapped = innerOperand; |
438 | | - } |
439 | | - |
440 | | - if (unwrapped is not null |
441 | | - && castDeclaringType?.GetGenericTypeDefinition() is var genericTypeDefinition |
442 | | - && (genericTypeDefinition == typeof(Span<>) || genericTypeDefinition == typeof(ReadOnlySpan<>))) |
443 | | - { |
444 | | - result = unwrapped; |
445 | | - return true; |
446 | | - } |
447 | | - |
448 | | - result = null; |
449 | | - return false; |
450 | | - } |
451 | | - |
452 | 331 | private void TranslateUnary(UnaryExpression unary) |
453 | 332 | { |
454 | 333 | switch (unary.NodeType) |
@@ -485,57 +364,4 @@ private void TranslateUnary(UnaryExpression unary) |
485 | 364 | throw new NotSupportedException("Unsupported unary expression node type: " + unary.NodeType); |
486 | 365 | } |
487 | 366 | } |
488 | | - |
489 | | - private bool TryBindProperty(Expression expression, [NotNullWhen(true)] out PropertyModel? property) |
490 | | - { |
491 | | - var unwrappedExpression = expression; |
492 | | - while (unwrappedExpression is UnaryExpression { NodeType: ExpressionType.Convert } convert) |
493 | | - { |
494 | | - unwrappedExpression = convert.Operand; |
495 | | - } |
496 | | - |
497 | | - var modelName = unwrappedExpression switch |
498 | | - { |
499 | | - // Regular member access for strongly-typed POCO binding (e.g. r => r.SomeInt == 8) |
500 | | - MemberExpression memberExpression when memberExpression.Expression == this._recordParameter |
501 | | - => memberExpression.Member.Name, |
502 | | - |
503 | | - // Dictionary lookup for weakly-typed dynamic binding (e.g. r => r["SomeInt"] == 8) |
504 | | - MethodCallExpression |
505 | | - { |
506 | | - Method: { Name: "get_Item", DeclaringType: var declaringType }, |
507 | | - Arguments: [ConstantExpression { Value: string keyName }] |
508 | | - } methodCall when methodCall.Object == this._recordParameter && declaringType == typeof(Dictionary<string, object?>) |
509 | | - => keyName, |
510 | | - |
511 | | - _ => null |
512 | | - }; |
513 | | - |
514 | | - if (modelName is null) |
515 | | - { |
516 | | - property = null; |
517 | | - return false; |
518 | | - } |
519 | | - |
520 | | - if (!this._model.PropertyMap.TryGetValue(modelName, out property)) |
521 | | - { |
522 | | - throw new InvalidOperationException($"Property name '{modelName}' provided as part of the filter clause is not a valid property name."); |
523 | | - } |
524 | | - |
525 | | - // Now that we have the property, go over all wrapping Convert nodes again to ensure that they're compatible with the property type |
526 | | - var unwrappedPropertyType = Nullable.GetUnderlyingType(property.Type) ?? property.Type; |
527 | | - unwrappedExpression = expression; |
528 | | - while (unwrappedExpression is UnaryExpression { NodeType: ExpressionType.Convert } convert) |
529 | | - { |
530 | | - var convertType = Nullable.GetUnderlyingType(convert.Type) ?? convert.Type; |
531 | | - if (convertType != unwrappedPropertyType && convertType != typeof(object)) |
532 | | - { |
533 | | - throw new InvalidCastException($"Property '{property.ModelName}' is being cast to type '{convert.Type.Name}', but its configured type is '{property.Type.Name}'."); |
534 | | - } |
535 | | - |
536 | | - unwrappedExpression = convert.Operand; |
537 | | - } |
538 | | - |
539 | | - return true; |
540 | | - } |
541 | 367 | } |
0 commit comments