1236 lines
49 KiB
Swift

//===----------------------------------------------------------------------===//
//
// This source file is part of the Swift.org open source project
//
// Copyright (c) 2022-2024 Apple Inc. and the Swift project authors
// Licensed under Apache License v2.0 with Runtime Library Exception
//
// See https://swift.org/LICENSE.txt for license information
// See https://swift.org/CONTRIBUTORS.txt for the list of Swift project authors
//
//===----------------------------------------------------------------------===//
import SwiftSyntax
import SwiftSyntaxMacros
internal import SwiftDiagnostics
internal import SwiftSyntaxBuilder
// A list of all functions supported by Predicate/Expression itself, any other functions called will be diagnosed as an error
// This allows for checking the function name, the number of arguments, and the argument labels, but the types of the arguments will need to be validated by the post-expansion type checking pass
// The closure specification is used to determine whether keypaths should be transformed/expanded into closures and whether dropping the final argument in favor of a trailing closure is allowed
private let _knownSupportedFunctions: Set<FunctionStructure> = [
FunctionStructure("contains", arguments: [.unlabeled]),
FunctionStructure("contains", arguments: [.closure(labeled: "where")]),
FunctionStructure("allSatisfy", arguments: [.closure(labeled: nil)]),
FunctionStructure("flatMap", arguments: [.closure(labeled: nil)]),
FunctionStructure("filter", arguments: [.closure(labeled: nil)]),
FunctionStructure("subscript", arguments: [.unlabeled]),
FunctionStructure("subscript", arguments: [.unlabeled, "default"]),
FunctionStructure("starts", arguments: ["with"]),
FunctionStructure("min", arguments: []),
FunctionStructure("max", arguments: []),
FunctionStructure("localizedStandardContains", arguments: [.unlabeled]),
FunctionStructure("localizedCompare", arguments: [.unlabeled]),
FunctionStructure("caseInsensitiveCompare", arguments: [.unlabeled])
]
private var knownSupportedFunctions: Set<FunctionStructure> {
#if FOUNDATION_FRAMEWORK
var result = _knownSupportedFunctions
result.insert(FunctionStructure("evaluate", arguments: [.pack(labeled: nil)]))
return result
#else
_knownSupportedFunctions
#endif
}
private let supportedFunctionSuggestions: [FunctionStructure : FunctionStructure] = [
FunctionStructure("hasPrefix", arguments: [.unlabeled]) : FunctionStructure("starts", arguments: ["with"]),
FunctionStructure("localizedCaseInsensitiveContains", arguments: [.unlabeled]) : FunctionStructure("localizedStandardContains", arguments: [.unlabeled]),
FunctionStructure("localizedCaseInsensitiveCompare", arguments: [.unlabeled]) : FunctionStructure("localizedCompare", arguments: [.unlabeled]),
FunctionStructure("localizedStandardCompare", arguments: [.unlabeled]) : FunctionStructure("localizedCompare", arguments: [.unlabeled])
]
extension Array where Element == FunctionStructure.Argument {
fileprivate func argumentsEqual(_ other: Self) -> Bool {
let currentPackIndex = self.firstIndex { $0.kind == .pack }
let otherPackIndex = other.firstIndex { $0.kind == .pack }
var full: [FunctionStructure.Argument]
var prefix: ArraySlice<FunctionStructure.Argument>
var suffix: ArraySlice<FunctionStructure.Argument>
switch (currentPackIndex, otherPackIndex) {
// If neither contains a pack or both contain a pack, just compare arguments as-is
case (nil, nil), (.some(_), .some(_)):
return self == other
// If one of them contains a pack, compare the prefix and suffix to allow the pack to lazily consume multiple arguments
case (let .some(idx), nil):
full = other
prefix = self[..<idx]
suffix = self[self.index(after: idx)...]
case (nil, let .some(idx)):
full = self
prefix = other[..<idx]
suffix = other[other.index(after: idx)...]
}
return full.starts(with: prefix) && full.reversed().starts(with: suffix.reversed())
}
fileprivate func expandingPackToMatchCount(_ otherCount: Int) -> Self {
let countDifference = otherCount - self.count
guard countDifference >= 0, let packIdx = self.firstIndex(where: { $0.kind == .pack }) else {
return self
}
var copy = self
copy[packIdx] = .init(label: copy[packIdx].label, kind: .standard)
if countDifference > 0 {
copy.insert(contentsOf: Array(repeating: .unlabeled, count: countDifference), at: packIdx + 1)
}
return copy
}
}
private struct FunctionStructure: Hashable {
struct Argument : Hashable, ExpressibleByStringLiteral {
enum Kind : Hashable {
case standard
case closure
case pack
}
let label: String?
let kind: Kind
init(stringLiteral: String) {
label = stringLiteral
kind = .standard
}
init(label: String?, kind: Kind) {
self.label = label
self.kind = kind
}
static func closure(labeled label: String?) -> Self {
Self(label: label, kind: .closure)
}
static var unlabeled: Self {
Self(label: nil, kind: .standard)
}
static func pack(labeled label: String?) -> Self {
Self(label: label, kind: .pack)
}
static func ==(lhs: Self, rhs: Self) -> Bool {
lhs.label == rhs.label
}
}
let name: String
let arguments: [Argument]
let hasTrailingClosure: Bool
var supportsTrailingClosure: Bool {
hasTrailingClosure || arguments.last?.kind == .closure
}
var signature: String {
let args = arguments.map { ($0.label ?? "_") + ":" }.joined()
return "\(name)(\(args))"
}
init(_ name: String, arguments: [Argument], trailingClosure: Bool = false) {
self.name = name
self.arguments = arguments
self.hasTrailingClosure = trailingClosure
}
func matches(_ other: FunctionStructure) -> Bool {
guard self.name == other.name else { return false }
switch (self.hasTrailingClosure, other.hasTrailingClosure) {
case (true, true), (false, false):
return self.arguments.argumentsEqual(other.arguments)
case (true, false):
guard let otherLast = other.arguments.last else { return false }
return self.arguments.argumentsEqual(other.arguments.dropLast()) && otherLast.kind == .closure
case (false, true):
guard let last = self.arguments.last else { return false }
return self.arguments.dropLast().argumentsEqual(other.arguments) && last.kind == .closure
}
}
func fixItChanges(transformingFrom source: FunctionCallExprSyntax) -> [FixIt.Change]? {
let sourceHasTrailingClosure = source.trailingClosure != nil
if sourceHasTrailingClosure {
guard supportsTrailingClosure else { return nil }
}
let sourceArgumentTotalCount = source.arguments.count + (sourceHasTrailingClosure ? 1 : 0)
let argumentTotalCount = self.arguments.count + (hasTrailingClosure ? 1 : 0)
guard argumentTotalCount == sourceArgumentTotalCount,
let calledExpr = source.calledExpression.as(MemberAccessExprSyntax.self) else {
return nil
}
var newFunctionCall = source
newFunctionCall.calledExpression = ExprSyntax(calledExpr.with(\.declName, DeclReferenceExprSyntax(baseName: .identifier(name))))
newFunctionCall.arguments = LabeledExprListSyntax(zip(source.arguments, arguments).map {
if let newLabel = $1.label {
return $0.with(\.label, .identifier(newLabel)).with(\.colon, .colonToken()).with(\.expression, $0.expression.with(\.leadingTrivia, [.spaces(1)]))
} else {
return $0.with(\.label, nil).with(\.colon, nil).with(\.trailingTrivia, []).with(\.expression, $0.expression.with(\.leadingTrivia, []))
}
})
newFunctionCall.leadingTrivia = []
newFunctionCall.trailingTrivia = []
if self.hasTrailingClosure && source.trailingClosure == nil, let newTrailingClosure = source.arguments.last?.expression.as(ClosureExprSyntax.self) {
newFunctionCall.trailingClosure = newTrailingClosure
}
return [.replace(oldNode: Syntax(source), newNode: Syntax(newFunctionCall))]
}
}
private func _knownMatchingFunction(_ structure: FunctionStructure) -> FunctionStructure? {
knownSupportedFunctions.first {
$0.matches(structure)
}
}
private func _suggestionForUnknownFunction(_ structure: FunctionStructure) -> FunctionStructure? {
guard let key = supportedFunctionSuggestions.keys.first(where: { $0.matches(structure) }) else {
return nil
}
return supportedFunctionSuggestions[key]
}
private class ShorthandArgumentIdentifierDetector: SyntaxVisitor {
var found = false
override func visit(_ node: DeclReferenceExprSyntax) -> SyntaxVisitorContinueKind {
// Look for identifiers such as $0, $1, etc.
if case let .dollarIdentifier(identifier) = node.baseName.tokenKind, identifier.dropFirst().allSatisfy(\.isNumber) {
found = true
return .skipChildren
} else {
return .visitChildren
}
}
}
extension SyntaxProtocol {
var containsShorthandArgumentIdentifiers: Bool {
let visitor = ShorthandArgumentIdentifierDetector(viewMode: .all)
visitor.walk(self)
return visitor.found
}
}
private protocol PredicateSyntaxRewriter : SyntaxRewriter {
var success: Bool { get }
var ignorable: Bool { get }
var diagnostics: [Diagnostic] { get }
}
extension PredicateSyntaxRewriter {
var success: Bool { true }
var ignorable: Bool { false }
var diagnostics: [Diagnostic] { [] }
}
extension SyntaxProtocol {
fileprivate func rewrite(with rewriter: some PredicateSyntaxRewriter) throws -> Syntax {
let translated = rewriter.rewrite(self)
guard rewriter.success else {
throw DiagnosticsError(diagnostics: rewriter.diagnostics)
}
guard !rewriter.ignorable else {
return Syntax(self)
}
return translated
}
}
private class OptionalChainRewriter: SyntaxRewriter, PredicateSyntaxRewriter {
var withinValidChainingTreeStart = true
var withinChainingTree = false
var optionalInput: ExprSyntax? = nil
var ignorable = true
private func _prePossibleTopOfTree() -> Bool {
if !withinChainingTree && withinValidChainingTreeStart {
withinChainingTree = true
return true
}
return false
}
private func _postTopOfTree(_ node: ExprSyntax) -> ExprSyntax {
assert(withinChainingTree)
withinChainingTree = false
if let input = optionalInput {
optionalInput = nil
ignorable = false
let visited = self.visit(input)
let closure = ClosureExprSyntax(statements: [CodeBlockItemSyntax(item: CodeBlockItemSyntax.Item(node))])
let functionMember = MemberAccessExprSyntax(base: visited, name: "flatMap")
let functionCall = FunctionCallExprSyntax(calledExpression: functionMember, arguments: [], trailingClosure: closure)
return ExprSyntax(functionCall)
}
return node
}
override func visit(_ node: ClosureExprSyntax) -> ExprSyntax {
guard withinChainingTree else {
// If we're not already in a chaining tree, just keep progressing with our current rewriter
return super.visit(node)
}
// We're in the middle of a potential tree, so rewrite the closure with a fresh state
// This ensures potential chaining in the closure isn't rewritten outside of the closure
let nestedRewriter = OptionalChainRewriter()
guard let rewritten = (try? node.rewrite(with: nestedRewriter))?.as(ExprSyntax.self) else {
// If rewriting the closure failed, just leave the closure as-is
return ExprSyntax(node)
}
if ignorable {
ignorable = nestedRewriter.ignorable
}
return rewritten
}
override func visit(_ node: FunctionCallExprSyntax) -> ExprSyntax {
let priorValidTreeStart = withinValidChainingTreeStart
defer { withinValidChainingTreeStart = priorValidTreeStart }
if node.arguments.containsShorthandArgumentIdentifiers {
withinValidChainingTreeStart = false
}
let topOfTree = _prePossibleTopOfTree()
let visited = super.visit(node)
if topOfTree {
return _postTopOfTree(visited)
} else {
return visited
}
}
override func visit(_ node: MemberAccessExprSyntax) -> ExprSyntax {
let topOfTree = _prePossibleTopOfTree()
let visited = super.visit(node)
if topOfTree {
return _postTopOfTree(visited)
} else {
return visited
}
}
override func visit(_ node: OptionalChainingExprSyntax) -> ExprSyntax {
guard withinChainingTree else {
return super.visit(node)
}
// Capture the optional input, and replace it in the output expression with a "$0"
optionalInput = node.expression
return .init(DeclReferenceExprSyntax(baseName: .dollarIdentifier("$0")))
}
}
extension CodeBlockItemListSyntax.Element.Item {
fileprivate var _expression: ExprSyntax? {
switch self {
case .expr(let expr): return expr
case .stmt(let stmt): return stmt.as(ExpressionStmtSyntax.self)?.expression
default: return nil
}
}
}
extension ConditionElementListSyntax {
fileprivate var optionalBindings: [OptionalBindingConditionSyntax]? {
var result = [OptionalBindingConditionSyntax]()
for element in self {
switch element.condition {
case let .optionalBinding(binding):
result.append(binding)
default:
return nil
}
}
return result
}
}
extension ClosureParameterListSyntax {
fileprivate var withVariableWrappedTypes: Self {
return Self(self.map {
if let type = $0.type {
$0.with(\.type, "PredicateExpressions.Variable<\(type)>")
} else {
$0
}
})
}
}
extension KeyPathExprSyntax {
private enum KeyPathDirectExpressionRewritingError : Error {
case unknownKeypathComponentType
}
fileprivate func asDirectExpression(on base: some ExprSyntaxProtocol) -> ExprSyntax? {
var result = ExprSyntax(base)
for item in components {
switch item.component {
case .property(let prop):
result = ExprSyntax(MemberAccessExprSyntax(base: result, declName: prop.declName))
case .optional(let opt):
if opt.questionOrExclamationMark.tokenKind == .exclamationMark {
result = ExprSyntax(ForceUnwrapExprSyntax(expression: result))
} else {
result = ExprSyntax(OptionalChainingExprSyntax(expression: result))
}
case .subscript(let sub):
result = ExprSyntax(SubscriptCallExprSyntax(calledExpression: result, arguments: sub.arguments))
#if FOUNDATION_FRAMEWORK
default:
return nil
#endif
}
}
return result
}
}
private class PredicateQueryRewriter: SyntaxRewriter, PredicateSyntaxRewriter {
private let indentWidth: Trivia = .spaces(4)
private var indentLevel = 0
private var indent: Trivia {
Trivia(pieces: Array(repeating: .spaces(4), count: indentLevel))
}
var validOptionalChainingTree = true
var diagnostics: [Diagnostic] = []
var success: Bool { diagnostics.isEmpty }
let kind: ExpansionKind
init(kind: ExpansionKind) {
self.kind = kind
}
private func diagnose(node: SyntaxProtocol, message: PredicateExpansionDiagnostic, fixIts: [FixIt] = []) {
diagnostics.append(.init(node: Syntax(node), message: message, fixIts: fixIts))
}
private func makeArgument(label: String?, _ expression: ExprSyntax, shouldVisit: Bool = true, shouldIndent: Bool = true) -> LabeledExprSyntax {
if shouldIndent {
indentLevel += 1
}
defer {
if shouldIndent {
indentLevel -= 1
}
}
let labelSyntax = label.map {
TokenSyntax(.identifier($0), presence: .present)
}?.with(\.leadingTrivia, indent)
let colonSyntax = label.map { _ in
TokenSyntax(.colon, presence: .present)
}
var argument = shouldVisit ? visit(expression) : expression
if shouldVisit && argument == expression {
argument = "PredicateExpressions.build_Arg(\(expression.with(\.leadingTrivia, []).with(\.trailingTrivia, [])))"
}
argument = argument.with(\.leadingTrivia, label == nil ? indent : .space)
return .init(label: labelSyntax,
colon: colonSyntax,
expression: argument,
trailingComma: nil)
}
override func visit(_ node: PrefixOperatorExprSyntax) -> ExprSyntax {
switch node.operator.text {
case "!":
let syntax: ExprSyntax =
"""
\(raw: indent)PredicateExpressions.build_Negation(
\(makeArgument(label: nil, node.expression))
\(raw: indent))
"""
return syntax
case "-":
let syntax: ExprSyntax =
"""
\(raw: indent)PredicateExpressions.build_UnaryMinus(
\(makeArgument(label: nil, node.expression))
\(raw: indent))
"""
return syntax
default:
diagnose(node: node.operator, message: "The '\(node.operator.text)' operator is not supported in this \(kind.keyword)")
return ExprSyntax(node)
}
}
override func visit(_ node: InfixOperatorExprSyntax) -> ExprSyntax {
let lhsOp = node.leftOperand
let rhsOp = node.rightOperand
let opExpr = node.operator
guard let opSyntax = opExpr.as(BinaryOperatorExprSyntax.self) else {
diagnose(node: opExpr, message: "The '\(opExpr.description)' operator is not supported in this \(kind.keyword)")
return ExprSyntax(node)
}
let (lhsLabel, rhsLabel) = switch opSyntax.operator.text {
case "...", "..<": ("lower", "upper")
default: ("lhs", "rhs")
}
let lhsArgument = makeArgument(label: lhsLabel, lhsOp).with(\.trailingTrivia, [])
let rhsArgument = makeArgument(label: rhsLabel, rhsOp).with(\.trailingTrivia, [])
switch (opSyntax.operator.text) {
case "==":
let syntax: ExprSyntax =
"""
\(raw: indent)PredicateExpressions.build_Equal(
\(lhsArgument),
\(rhsArgument)
\(raw: indent))
"""
return syntax
case "!=":
let syntax: ExprSyntax =
"""
\(raw: indent)PredicateExpressions.build_NotEqual(
\(lhsArgument),
\(rhsArgument)
\(raw: indent))
"""
return syntax
case "<":
let syntax: ExprSyntax =
"""
\(raw: indent)PredicateExpressions.build_Comparison(
\(lhsArgument),
\(rhsArgument),
\(raw: indent + indentWidth)op: .lessThan
\(raw: indent))
"""
return syntax
case "<=":
let syntax: ExprSyntax =
"""
\(raw: indent)PredicateExpressions.build_Comparison(
\(lhsArgument),
\(rhsArgument),
\(raw: indent + indentWidth)op: .lessThanOrEqual
\(raw: indent))
"""
return syntax
case ">":
let syntax: ExprSyntax =
"""
\(raw: indent)PredicateExpressions.build_Comparison(
\(lhsArgument),
\(rhsArgument),
\(raw: indent + indentWidth)op: .greaterThan
\(raw: indent))
"""
return syntax
case ">=":
let syntax: ExprSyntax =
"""
\(raw: indent)PredicateExpressions.build_Comparison(
\(lhsArgument),
\(rhsArgument),
\(raw: indent + indentWidth)op: .greaterThanOrEqual
\(raw: indent))
"""
return syntax
case "||":
let syntax: ExprSyntax =
"""
\(raw: indent)PredicateExpressions.build_Disjunction(
\(lhsArgument),
\(rhsArgument)
\(raw: indent))
"""
return syntax
case "&&":
let syntax: ExprSyntax =
"""
\(raw: indent)PredicateExpressions.build_Conjunction(
\(lhsArgument),
\(rhsArgument)
\(raw: indent))
"""
return syntax
case "+":
let syntax: ExprSyntax =
"""
\(raw: indent)PredicateExpressions.build_Arithmetic(
\(lhsArgument),
\(rhsArgument),
\(raw: indent + indentWidth)op: .add
\(raw: indent))
"""
return syntax
case "-":
let syntax: ExprSyntax =
"""
\(raw: indent)PredicateExpressions.build_Arithmetic(
\(lhsArgument),
\(rhsArgument),
\(raw: indent + indentWidth)op: .subtract
\(raw: indent))
"""
return syntax
case "*":
let syntax: ExprSyntax =
"""
\(raw: indent)PredicateExpressions.build_Arithmetic(
\(lhsArgument),
\(rhsArgument),
\(raw: indent + indentWidth)op: .multiply
\(raw: indent))
"""
return syntax
case "/":
let syntax: ExprSyntax =
"""
\(raw: indent)PredicateExpressions.build_Division(
\(lhsArgument),
\(rhsArgument)
\(raw: indent))
"""
return syntax
case "%":
let syntax: ExprSyntax =
"""
\(raw: indent)PredicateExpressions.build_Remainder(
\(lhsArgument),
\(rhsArgument)
\(raw: indent))
"""
return syntax
case "??":
let syntax: ExprSyntax =
"""
\(raw: indent)PredicateExpressions.build_NilCoalesce(
\(lhsArgument),
\(rhsArgument)
\(raw: indent))
"""
return syntax
case "...":
let syntax: ExprSyntax =
"""
\(raw: indent)PredicateExpressions.build_ClosedRange(
\(lhsArgument),
\(rhsArgument)
\(raw: indent))
"""
return syntax
case "..<":
let syntax: ExprSyntax =
"""
\(raw: indent)PredicateExpressions.build_Range(
\(lhsArgument),
\(rhsArgument)
\(raw: indent))
"""
return syntax
default:
diagnose(node: opSyntax, message: "The '\(opSyntax.operator.text)' operator is not supported in this \(kind.keyword)")
return ExprSyntax(node)
}
}
// We only hit this if our OptionalChainingRewriter was unable to rewrite them out of the expression tree
override func visit(_ node: OptionalChainingExprSyntax) -> ExprSyntax {
diagnose(node: node.questionMark, message: "Optional chaining is not supported here in this \(kind.keyword). Use the flatMap(_:) function explicitly instead.")
return .init(node)
}
override func visit(_ node: ForceUnwrapExprSyntax) -> ExprSyntax {
return """
\(raw: indent)PredicateExpressions.build_ForcedUnwrap(
\(makeArgument(label: nil, node.expression))
\(raw: indent))
"""
}
override func visit(_ node: NilLiteralExprSyntax) -> ExprSyntax {
"PredicateExpressions.build_NilLiteral()"
}
override func visit(_ node: MemberAccessExprSyntax) -> ExprSyntax {
guard let base = node.base else {
diagnose(node: node, message: "Member access without an explicit base is not supported in this \(kind.keyword)")
return .init(node)
}
let newPropertyComponent = KeyPathPropertyComponentSyntax(declName: node.declName)
let keyPath = KeyPathExprSyntax(components: [.init(period: TokenSyntax.periodToken(), component: .property(newPropertyComponent))])
return """
\(raw: indent)PredicateExpressions.build_KeyPath(
\(makeArgument(label: "root", base)),
\(makeArgument(label: "keyPath", .init(keyPath), shouldVisit: false).with(\.trailingTrivia, []))
\(raw: indent))
"""
}
override func visit(_ node: FunctionCallExprSyntax) -> ExprSyntax {
let memberAccess = node.calledExpression.as(MemberAccessExprSyntax.self)
let base = memberAccess?.base
let funcName = memberAccess?.declName.baseName.with(\.leadingTrivia, []).with(\.trailingTrivia, []).text ?? node.calledExpression.as(DeclReferenceExprSyntax.self)!.baseName.text
return _processFunction(
base: base,
functionName: funcName,
argumentList: node.arguments,
trailingClosure: node.trailingClosure,
diagnosticPoint: .init(memberAccess?.declName) ?? .init(node),
functionCallExpr: node)
?? .init(node)
}
override func visit(_ node: SubscriptCallExprSyntax) -> ExprSyntax {
return _processFunction(
base: node.calledExpression,
functionName: "subscript",
argumentList: node.arguments,
trailingClosure: node.trailingClosure,
diagnosticPoint: .init(node.leftSquare))
?? .init(node)
}
private func _processFunction(base: ExprSyntax?, functionName: String, argumentList: LabeledExprListSyntax, trailingClosure: ClosureExprSyntax?, diagnosticPoint: Syntax, functionCallExpr: FunctionCallExprSyntax? = nil) -> ExprSyntax? {
// The provided base is nil when calling global functions functions
guard let base else {
diagnose(node: diagnosticPoint, message: "Global functions are not supported in this \(kind.keyword)")
return nil
}
// Check this function against our known list to provide rich diagnostics for functions we know we don't support
let name = TokenSyntax(.identifier(functionName), presence: .present).with(\.leadingTrivia, []).with(\.trailingTrivia, [])
let args = argumentList.map {
let isClosure = $0.expression.is(ClosureExprSyntax.self) || $0.expression.is(KeyPathExprSyntax.self)
return FunctionStructure.Argument(label: $0.label?.text, kind: isClosure ? .closure : .standard)
}
let structure = FunctionStructure(name.text, arguments: args, trailingClosure: trailingClosure != nil)
guard let knownFunc = _knownMatchingFunction(structure) else {
let diagnostic = PredicateExpansionDiagnostic("The \(structure.signature) function is not supported in this \(kind.keyword)")
var fixIts = [FixIt]()
if let functionCallExpr,
let suggestion = _suggestionForUnknownFunction(structure),
let changes = suggestion.fixItChanges(transformingFrom: functionCallExpr) {
fixIts.append(FixIt(message: PredicateExpansionDiagnostic("Use \(suggestion.signature)", severity: .note), changes: changes))
}
diagnose(node: diagnosticPoint, message: diagnostic, fixIts: fixIts)
return nil
}
var arguments: [LabeledExprSyntax] = []
func addArgument(_ argument: ExprSyntax, label: String?, withComma: Bool) {
arguments.append(
makeArgument(label: label, argument)
.with(\.trailingComma, withComma ? TokenSyntax(.comma, presence: .present) : nil)
.with(\.trailingTrivia, withComma ? .newline : [])
)
}
// Function arguments can contain dollar sign identifiers that can't be nested inside of a new closure
// Prevent this function call from being placed inside of a flatMap due to optionalChaining
let oldValidOptionalChainingTree = validOptionalChainingTree
validOptionalChainingTree = false
addArgument(base, label: nil, withComma: !argumentList.isEmpty)
validOptionalChainingTree = oldValidOptionalChainingTree
for (sourceArg, knownArgStructure) in zip(argumentList, knownFunc.arguments.expandingPackToMatchCount(argumentList.count)) {
var expression = sourceArg.expression
if knownArgStructure.kind == .closure, let kpExpr = sourceArg.expression.as(KeyPathExprSyntax.self) {
guard !kpExpr.containsShorthandArgumentIdentifiers,
let memberAccess = kpExpr.asDirectExpression(on: DeclReferenceExprSyntax(baseName: .dollarIdentifier("$0"))),
let preparedMemberAccess = try? memberAccess.rewrite(with: OptionalChainRewriter()) else {
diagnose(node: kpExpr, message: "This key path is not supported here in this \(kind.keyword). Use an explicit closure instead.")
return nil
}
expression = ExprSyntax(ClosureExprSyntax(statements: [CodeBlockItemSyntax(item: .expr(preparedMemberAccess.as(ExprSyntax.self)!))]))
}
addArgument(expression, label: sourceArg.label?.text, withComma: sourceArg.trailingComma != nil)
}
if let closure = trailingClosure {
// Don't indent, because closures already get indented
let closureArg = makeArgument(label: nil, ExprSyntax(closure), shouldIndent: false)
return """
\(raw: indent)PredicateExpressions.build_\(name.with(\.leadingTrivia, []).with(\.trailingTrivia, []))(
\(LabeledExprListSyntax(arguments))
\(raw: indent))\(raw: Trivia.space)\(closureArg.with(\.leadingTrivia, []).with(\.trailingTrivia, []))
"""
} else {
return """
\(raw: indent)PredicateExpressions.build_\(name.with(\.leadingTrivia, []).with(\.trailingTrivia, []))(
\(LabeledExprListSyntax(arguments))
\(raw: indent))
"""
}
}
override func visit(_ node: TupleExprSyntax) -> ExprSyntax {
guard node.elements.count == 1, let element = node.elements.first else {
diagnose(node: node, message: "Tuples are not supported in this \(kind.keyword)")
return ExprSyntax(node)
}
// Support expressions like "(input as? Bool) == true" where parantheses used for grouping are treated like a single element tuple expression
return visit(element.expression)
}
// Processes a code block and guarantees that the returned code block only contains one item
func _processCodeBlock(_ statements: CodeBlockItemListSyntax, in node: Syntax, removeReturn: Bool = false) -> CodeBlockItemListSyntax? {
guard statements.count == 1 else {
diagnose(node: statements.isEmpty ? node : statements[statements.index(after: statements.startIndex)], message: "\(kind.capitalizedKeyword) body may only contain one expression")
return nil
}
indentLevel += 1
var body = visit(statements)
if success && body == statements {
let wrapped: ExprSyntax =
"""
\(raw: indent)PredicateExpressions.build_Arg(
\(raw: indent + indentWidth)\(body.with(\.leadingTrivia, []).with(\.trailingTrivia, []))
\(raw: indent))
"""
body = [.init(item: .expr(wrapped))]
}
indentLevel -= 1
if removeReturn, let first = body.first, case .stmt(let statement) = first.item, let returnStmt = statement.as(ReturnStmtSyntax.self), let returnExpr = returnStmt.expression {
body = [.init(item: .expr(returnExpr.with(\.leadingTrivia, returnStmt.leadingTrivia)))]
}
return body
}
override func visit(_ node: CodeBlockSyntax) -> CodeBlockSyntax {
guard let body = _processCodeBlock(node.statements, in: .init(node)) else {
return node
}
return node.with(\.statements, body)
}
override func visit(_ node: ClosureExprSyntax) -> ExprSyntax {
guard let body = _processCodeBlock(node.statements, in: .init(node)) else {
return .init(node)
}
var resultingSignature = node.signature
if let signature = node.signature {
var visited = signature
visited.returnClause = nil
if case .parameterClause(let paramClause) = signature.parameterClause {
let newParamClause = paramClause.with(\.parameters, paramClause.parameters.withVariableWrappedTypes)
visited.parameterClause = .parameterClause(newParamClause)
}
resultingSignature = visited
}
return ExprSyntax(
node
.with(\.statements, body)
.with(\.leftBrace, node.leftBrace.with(\.trailingTrivia, node.signature == nil ? .newline : .space))
.with(\.signature, resultingSignature?.with(\.trailingTrivia, .newline))
.with(\.rightBrace, node.rightBrace.with(\.leadingTrivia, .newline + indent))
)
}
override func visit(_ node: TernaryExprSyntax) -> ExprSyntax {
let condition = node.condition
let firstChoice = node.thenExpression
let secondChoice = node.elseExpression
return """
\(raw: indent)PredicateExpressions.build_Conditional(
\(makeArgument(label: nil, condition).with(\.trailingTrivia, [])),
\(makeArgument(label: nil, firstChoice).with(\.trailingTrivia, [])),
\(makeArgument(label: nil, secondChoice).with(\.trailingTrivia, []))
\(raw: indent))
"""
}
override func visit(_ node: IsExprSyntax) -> ExprSyntax {
return """
\(raw: indent)PredicateExpressions.TypeCheck<_, \(node.type)>(
\(makeArgument(label: nil, node.expression).with(\.trailingTrivia, []))
\(raw: indent))
"""
}
override func visit(_ node: AsExprSyntax) -> ExprSyntax {
let castType: String
switch node.questionOrExclamationMark?.tokenKind {
case .none: fallthrough
case .some(.exclamationMark):
castType = "Force"
case .some(.postfixQuestionMark):
castType = "Conditional"
default:
fatalError("Unexpected question/exclamation mark token kind")
}
return """
\(raw: indent)PredicateExpressions.\(raw: castType)Cast<_, \(node.type)>(
\(makeArgument(label: nil, node.expression).with(\.trailingTrivia, []))
\(raw: indent))
"""
}
override func visit(_ node: ReturnStmtSyntax) -> StmtSyntax {
guard let expression = node.expression else {
// No expansion needed when returning Void
return StmtSyntax(node)
}
let visited = visit(expression)
guard visited == expression else {
// No expansion needed when returning transformed expression
return StmtSyntax(node.with(\.expression, visited.with(\.leadingTrivia, [])).with(\.leadingTrivia, indent))
}
// Wrap constant return expressions in a build_Arg call
let wrapped: ExprSyntax =
"""
PredicateExpressions.build_Arg(
\(visited.with(\.leadingTrivia, indent + indentWidth))
\(raw: indent))
"""
return StmtSyntax(node.with(\.expression, wrapped).with(\.leadingTrivia, indent))
}
override func visit(_ node: SwitchExprSyntax) -> ExprSyntax {
self.diagnose(node: node, message: "Switch expressions are not supported in this \(kind.keyword)")
return .init(node)
}
private func _rewriteConditionsAsExpression<C: BidirectionalCollection<ConditionElementListSyntax.Element>>(_ collection: C, in expr: IfExprSyntax) -> ExprSyntax? {
guard let last = collection.last else {
self.diagnose(node: expr, message: "This list of conditionals is unsupported in this \(kind.keyword)")
return nil
}
guard case .expression(let lastExpr) = last.condition else {
let type: String
switch last.condition {
case .availability(_):
type = "Availability conditions"
case .matchingPattern(_):
type = "Matching pattern conditions"
case .optionalBinding(_):
self.diagnose(node: last, message: "Mixing optional bindings with other conditions is not supported in this \(kind.keyword)")
return nil
default:
type = "These types of conditions"
}
self.diagnose(node: last, message: "\(type) are not supported in this \(kind.keyword)")
return nil
}
let rest = collection.dropLast()
if rest.isEmpty {
return lastExpr
} else {
guard let restRewritten = _rewriteConditionsAsExpression(rest, in: expr) else {
return nil
}
return .init(InfixOperatorExprSyntax(leftOperand: restRewritten, operator: BinaryOperatorExprSyntax(operator: .binaryOperator("&&")), rightOperand: lastExpr))
}
}
private func _rewriteIfAsFlatMap(bindings: [OptionalBindingConditionSyntax], body: ExprSyntax, else: ExprSyntax) -> ExprSyntax? {
indentLevel += bindings.count
var prior: ExprSyntax = body
for binding in bindings.reversed() {
guard let identifier = binding.pattern.as(IdentifierPatternSyntax.self)?.identifier else {
self.diagnose(node: binding.pattern, message: "This optional binding condition is not supported in this \(kind.keyword)")
return nil
}
let initializer = binding.initializer?.value ?? ExprSyntax(DeclReferenceExprSyntax(baseName: identifier))
prior = """
\(raw: indent)PredicateExpressions.build_flatMap(
\(makeArgument(label: nil, initializer).with(\.trailingTrivia, []))
\(raw: indent)) { \(identifier.with(\.trailingTrivia, []).with(\.leadingTrivia, [])) in
\(makeArgument(label: nil, prior, shouldVisit: false).with(\.trailingTrivia, []))
\(raw: indent)}
"""
indentLevel -= 1
}
return """
\(raw: indent)PredicateExpressions.build_NilCoalesce(
\(makeArgument(label: "lhs", prior, shouldVisit: false)),
\(makeArgument(label: "rhs", `else`, shouldVisit: false))
\(raw: indent))
"""
}
private func _processIfBody(_ node: IfExprSyntax) -> ExprSyntax? {
guard let visitedBody = _processCodeBlock(node.body.statements, in: .init(node.body), removeReturn: true) else {
return nil
}
guard let bodyExpression = visitedBody.first?.item._expression else {
self.diagnose(node: node.body, message: "This if expression body is not supported in this \(kind.keyword)")
return nil
}
return bodyExpression
}
private func _processElseBody(_ node: IfExprSyntax) -> ExprSyntax? {
guard let elseBody = node.elseBody else {
self.diagnose(node: node, message: "If expressions without an else expression are not supported in this \(kind.keyword)")
return nil
}
let elseExpression: ExprSyntax
switch elseBody {
case .codeBlock(let codeBlock):
guard let visitedElseBody = _processCodeBlock(codeBlock.statements, in: .init(codeBlock), removeReturn: true) else {
return nil
}
guard let expr = visitedElseBody.first?.item._expression else {
self.diagnose(node: node.body, message: "This if expression else body is not supported in this \(kind.keyword)")
return nil
}
elseExpression = expr
case .ifExpr(let ifExpr):
elseExpression = visit(ifExpr)
#if FOUNDATION_FRAMEWORK
@unknown default:
self.diagnose(node: elseBody, message: "This if expression else body is not supported in this \(kind.keyword)")
return nil
#endif
}
return elseExpression
}
override func visit(_ node: IfExprSyntax) -> ExprSyntax {
if let bindings = node.conditions.optionalBindings {
indentLevel += bindings.count
guard let bodyExpression = _processIfBody(node) else {
return .init(node)
}
indentLevel -= bindings.count
guard let elseExpression = _processElseBody(node) else {
return .init(node)
}
return _rewriteIfAsFlatMap(bindings: bindings, body: bodyExpression, else: elseExpression) ?? .init(node)
}
guard let ifExpression = _rewriteConditionsAsExpression(node.conditions, in: node),
let bodyExpression = _processIfBody(node),
let elseExpression = _processElseBody(node) else {
return .init(node)
}
return """
\(raw: indent)PredicateExpressions.build_Conditional(
\(makeArgument(label: nil, ifExpression).with(\.trailingTrivia, [])),
\(makeArgument(label: nil, bodyExpression, shouldVisit: false).with(\.trailingTrivia, [])),
\(makeArgument(label: nil, elseExpression, shouldVisit: false).with(\.trailingTrivia, []))
\(raw: indent))
"""
}
override func visit(_ node: WhileStmtSyntax) -> StmtSyntax {
self.diagnose(node: node, message: "While loops are not supported in this \(kind.keyword)")
return .init(node)
}
override func visit(_ node: ForStmtSyntax) -> StmtSyntax {
self.diagnose(node: node, message: "For-in loops are not supported in this \(kind.keyword)")
return .init(node)
}
override func visit(_ node: DoStmtSyntax) -> StmtSyntax {
self.diagnose(node: node, message: "Do statements are not supported in this \(kind.keyword)")
return .init(node)
}
override func visit(_ node: CatchClauseSyntax) -> CatchClauseSyntax {
self.diagnose(node: node, message: "Catch clauses are not supported in this \(kind.keyword)")
return node
}
override func visit(_ node: RepeatStmtSyntax) -> StmtSyntax {
self.diagnose(node: node, message: "Repeat-while loops are not supported in this \(kind.keyword)")
return .init(node)
}
override func visit(_ node: CodeBlockItemSyntax) -> CodeBlockItemSyntax {
// At this point, we know we're the only item in the code block because predicates only support single-expression code blocks
// Diagnose any declarations
if case .decl(_) = node.item {
diagnose(node: node.item, message: "Declarations are not supported in this \(kind.keyword)")
return node
}
if case let .stmt(statement) = node.item {
// Unwrap a do statement with valid expression bodies
if let doStatement = statement.as(DoStmtSyntax.self) {
if let catchClause = doStatement.catchClauses.first {
diagnose(node: catchClause, message: "Catch clauses are not supported in this \(kind.keyword)")
return node
}
indentLevel -= 1
let visitedBody = self.visit(doStatement.body)
indentLevel += 1
guard success else {
return node
}
guard let innerExpr = visitedBody.statements.first else {
diagnose(node: doStatement, message: "Do statement is not supported here in this \(kind.keyword)")
return node
}
return innerExpr
}
}
return super.visit(node)
}
}
private struct PredicateExpansionDiagnostic: DiagnosticMessage, FixItMessage, ExpressibleByStringLiteral, ExpressibleByStringInterpolation {
let message: String
let severity: DiagnosticSeverity
let diagnosticID: MessageID = .init(domain: "FoundationMacros", id: "PredicateDiagnostic")
var fixItID: MessageID { diagnosticID }
init(_ message: String, severity: DiagnosticSeverity = .error) {
self.message = message
self.severity = severity
}
init(stringLiteral value: String) {
self.init(value)
}
}
private enum ExpansionKind {
case predicate
case expression
var keyword: String {
switch self {
case .predicate:
"predicate"
case .expression:
"expression"
}
}
var capitalizedKeyword: String {
let keyword = self.keyword
let first = keyword.first!.uppercased()
return "\(first)\(keyword.dropFirst())"
}
var macroKeyword: String {
"#\(capitalizedKeyword)"
}
var qualifiedExpansionType: String {
#if FOUNDATION_FRAMEWORK
"Foundation.\(capitalizedKeyword)"
#else
"FoundationEssentials.\(capitalizedKeyword)"
#endif
}
}
private func predicateExpansion(of node: some FreestandingMacroExpansionSyntax, in context: some MacroExpansionContext, kind: ExpansionKind) throws -> ExprSyntax {
guard let closure = node.trailingClosure else {
let fixIts: [FixIt]
if let argument = node.arguments.first?.expression.as(ClosureExprSyntax.self) {
var newNode = node.with(\.leftParen, nil)
.with(\.rightParen, nil)
.with(\.trailingClosure, argument.with(\.leadingTrivia, [.spaces(1)]).with(\.trailingTrivia, []))
newNode.arguments = []
fixIts = [
FixIt(message: PredicateExpansionDiagnostic("Use a trailing closure instead of a function parameter", severity: .note), changes: [
.replace(oldNode: Syntax(node), newNode: Syntax(newNode))
])
]
} else {
fixIts = []
}
throw DiagnosticsError(diagnostics: [.init(
node: Syntax(node),
message: PredicateExpansionDiagnostic("\(kind.macroKeyword) macro expansion requires a trailing closure"),
fixIts: fixIts
)])
}
let translatedClosure = try closure
.rewrite(with: OptionalChainRewriter())
.rewrite(with: PredicateQueryRewriter(kind: kind))
.with(\.leadingTrivia, [])
.with(\.trailingTrivia, [])
if let genericArgs = node.genericArgumentClause {
let strippedGenericArgs = genericArgs
.with(\.leadingTrivia, [])
.with(\.trailingTrivia, [])
return "\(raw: kind.qualifiedExpansionType)\(strippedGenericArgs)(\(translatedClosure))"
} else {
// When the macro is specified without generic args (ex. "#Predicate { ... }") initialize a Predicate without generic args so they can be inferred from context
return "\(raw: kind.qualifiedExpansionType)(\(translatedClosure))"
}
}
public struct PredicateMacro: SwiftSyntaxMacros.ExpressionMacro, Sendable {
public static var formatMode: FormatMode { .disabled }
public static func expansion(of node: some FreestandingMacroExpansionSyntax, in context: some MacroExpansionContext) throws -> ExprSyntax {
try predicateExpansion(of: node, in: context, kind: .predicate)
}
}
public struct ExpressionMacro: SwiftSyntaxMacros.ExpressionMacro, Sendable {
public static var formatMode: FormatMode { .disabled }
public static func expansion(of node: some FreestandingMacroExpansionSyntax, in context: some MacroExpansionContext) throws -> ExprSyntax {
try predicateExpansion(of: node, in: context, kind: .expression)
}
}