//===----------------------------------------------------------------------===// // // This source file is part of the Swift.org open source project // // Copyright (c) 2022-2024 Apple Inc. and the Swift project authors // Licensed under Apache License v2.0 with Runtime Library Exception // // See https://swift.org/LICENSE.txt for license information // See https://swift.org/CONTRIBUTORS.txt for the list of Swift project authors // //===----------------------------------------------------------------------===// import SwiftSyntax import SwiftSyntaxMacros internal import SwiftDiagnostics internal import SwiftSyntaxBuilder // A list of all functions supported by Predicate/Expression itself, any other functions called will be diagnosed as an error // This allows for checking the function name, the number of arguments, and the argument labels, but the types of the arguments will need to be validated by the post-expansion type checking pass // The closure specification is used to determine whether keypaths should be transformed/expanded into closures and whether dropping the final argument in favor of a trailing closure is allowed private let _knownSupportedFunctions: Set = [ FunctionStructure("contains", arguments: [.unlabeled]), FunctionStructure("contains", arguments: [.closure(labeled: "where")]), FunctionStructure("allSatisfy", arguments: [.closure(labeled: nil)]), FunctionStructure("flatMap", arguments: [.closure(labeled: nil)]), FunctionStructure("filter", arguments: [.closure(labeled: nil)]), FunctionStructure("subscript", arguments: [.unlabeled]), FunctionStructure("subscript", arguments: [.unlabeled, "default"]), FunctionStructure("starts", arguments: ["with"]), FunctionStructure("min", arguments: []), FunctionStructure("max", arguments: []), FunctionStructure("localizedStandardContains", arguments: [.unlabeled]), FunctionStructure("localizedCompare", arguments: [.unlabeled]), FunctionStructure("caseInsensitiveCompare", arguments: [.unlabeled]) ] private var knownSupportedFunctions: Set { #if FOUNDATION_FRAMEWORK var result = _knownSupportedFunctions result.insert(FunctionStructure("evaluate", arguments: [.pack(labeled: nil)])) return result #else _knownSupportedFunctions #endif } private let supportedFunctionSuggestions: [FunctionStructure : FunctionStructure] = [ FunctionStructure("hasPrefix", arguments: [.unlabeled]) : FunctionStructure("starts", arguments: ["with"]), FunctionStructure("localizedCaseInsensitiveContains", arguments: [.unlabeled]) : FunctionStructure("localizedStandardContains", arguments: [.unlabeled]), FunctionStructure("localizedCaseInsensitiveCompare", arguments: [.unlabeled]) : FunctionStructure("localizedCompare", arguments: [.unlabeled]), FunctionStructure("localizedStandardCompare", arguments: [.unlabeled]) : FunctionStructure("localizedCompare", arguments: [.unlabeled]) ] extension Array where Element == FunctionStructure.Argument { fileprivate func argumentsEqual(_ other: Self) -> Bool { let currentPackIndex = self.firstIndex { $0.kind == .pack } let otherPackIndex = other.firstIndex { $0.kind == .pack } var full: [FunctionStructure.Argument] var prefix: ArraySlice var suffix: ArraySlice switch (currentPackIndex, otherPackIndex) { // If neither contains a pack or both contain a pack, just compare arguments as-is case (nil, nil), (.some(_), .some(_)): return self == other // If one of them contains a pack, compare the prefix and suffix to allow the pack to lazily consume multiple arguments case (let .some(idx), nil): full = other prefix = self[.. Self { let countDifference = otherCount - self.count guard countDifference >= 0, let packIdx = self.firstIndex(where: { $0.kind == .pack }) else { return self } var copy = self copy[packIdx] = .init(label: copy[packIdx].label, kind: .standard) if countDifference > 0 { copy.insert(contentsOf: Array(repeating: .unlabeled, count: countDifference), at: packIdx + 1) } return copy } } private struct FunctionStructure: Hashable { struct Argument : Hashable, ExpressibleByStringLiteral { enum Kind : Hashable { case standard case closure case pack } let label: String? let kind: Kind init(stringLiteral: String) { label = stringLiteral kind = .standard } init(label: String?, kind: Kind) { self.label = label self.kind = kind } static func closure(labeled label: String?) -> Self { Self(label: label, kind: .closure) } static var unlabeled: Self { Self(label: nil, kind: .standard) } static func pack(labeled label: String?) -> Self { Self(label: label, kind: .pack) } static func ==(lhs: Self, rhs: Self) -> Bool { lhs.label == rhs.label } } let name: String let arguments: [Argument] let hasTrailingClosure: Bool var supportsTrailingClosure: Bool { hasTrailingClosure || arguments.last?.kind == .closure } var signature: String { let args = arguments.map { ($0.label ?? "_") + ":" }.joined() return "\(name)(\(args))" } init(_ name: String, arguments: [Argument], trailingClosure: Bool = false) { self.name = name self.arguments = arguments self.hasTrailingClosure = trailingClosure } func matches(_ other: FunctionStructure) -> Bool { guard self.name == other.name else { return false } switch (self.hasTrailingClosure, other.hasTrailingClosure) { case (true, true), (false, false): return self.arguments.argumentsEqual(other.arguments) case (true, false): guard let otherLast = other.arguments.last else { return false } return self.arguments.argumentsEqual(other.arguments.dropLast()) && otherLast.kind == .closure case (false, true): guard let last = self.arguments.last else { return false } return self.arguments.dropLast().argumentsEqual(other.arguments) && last.kind == .closure } } func fixItChanges(transformingFrom source: FunctionCallExprSyntax) -> [FixIt.Change]? { let sourceHasTrailingClosure = source.trailingClosure != nil if sourceHasTrailingClosure { guard supportsTrailingClosure else { return nil } } let sourceArgumentTotalCount = source.arguments.count + (sourceHasTrailingClosure ? 1 : 0) let argumentTotalCount = self.arguments.count + (hasTrailingClosure ? 1 : 0) guard argumentTotalCount == sourceArgumentTotalCount, let calledExpr = source.calledExpression.as(MemberAccessExprSyntax.self) else { return nil } var newFunctionCall = source newFunctionCall.calledExpression = ExprSyntax(calledExpr.with(\.declName, DeclReferenceExprSyntax(baseName: .identifier(name)))) newFunctionCall.arguments = LabeledExprListSyntax(zip(source.arguments, arguments).map { if let newLabel = $1.label { return $0.with(\.label, .identifier(newLabel)).with(\.colon, .colonToken()).with(\.expression, $0.expression.with(\.leadingTrivia, [.spaces(1)])) } else { return $0.with(\.label, nil).with(\.colon, nil).with(\.trailingTrivia, []).with(\.expression, $0.expression.with(\.leadingTrivia, [])) } }) newFunctionCall.leadingTrivia = [] newFunctionCall.trailingTrivia = [] if self.hasTrailingClosure && source.trailingClosure == nil, let newTrailingClosure = source.arguments.last?.expression.as(ClosureExprSyntax.self) { newFunctionCall.trailingClosure = newTrailingClosure } return [.replace(oldNode: Syntax(source), newNode: Syntax(newFunctionCall))] } } private func _knownMatchingFunction(_ structure: FunctionStructure) -> FunctionStructure? { knownSupportedFunctions.first { $0.matches(structure) } } private func _suggestionForUnknownFunction(_ structure: FunctionStructure) -> FunctionStructure? { guard let key = supportedFunctionSuggestions.keys.first(where: { $0.matches(structure) }) else { return nil } return supportedFunctionSuggestions[key] } private class ShorthandArgumentIdentifierDetector: SyntaxVisitor { var found = false override func visit(_ node: DeclReferenceExprSyntax) -> SyntaxVisitorContinueKind { // Look for identifiers such as $0, $1, etc. if case let .dollarIdentifier(identifier) = node.baseName.tokenKind, identifier.dropFirst().allSatisfy(\.isNumber) { found = true return .skipChildren } else { return .visitChildren } } } extension SyntaxProtocol { var containsShorthandArgumentIdentifiers: Bool { let visitor = ShorthandArgumentIdentifierDetector(viewMode: .all) visitor.walk(self) return visitor.found } } private protocol PredicateSyntaxRewriter : SyntaxRewriter { var success: Bool { get } var ignorable: Bool { get } var diagnostics: [Diagnostic] { get } } extension PredicateSyntaxRewriter { var success: Bool { true } var ignorable: Bool { false } var diagnostics: [Diagnostic] { [] } } extension SyntaxProtocol { fileprivate func rewrite(with rewriter: some PredicateSyntaxRewriter) throws -> Syntax { let translated = rewriter.rewrite(self) guard rewriter.success else { throw DiagnosticsError(diagnostics: rewriter.diagnostics) } guard !rewriter.ignorable else { return Syntax(self) } return translated } } private class OptionalChainRewriter: SyntaxRewriter, PredicateSyntaxRewriter { var withinValidChainingTreeStart = true var withinChainingTree = false var optionalInput: ExprSyntax? = nil var ignorable = true private func _prePossibleTopOfTree() -> Bool { if !withinChainingTree && withinValidChainingTreeStart { withinChainingTree = true return true } return false } private func _postTopOfTree(_ node: ExprSyntax) -> ExprSyntax { assert(withinChainingTree) withinChainingTree = false if let input = optionalInput { optionalInput = nil ignorable = false let visited = self.visit(input) let closure = ClosureExprSyntax(statements: [CodeBlockItemSyntax(item: CodeBlockItemSyntax.Item(node))]) let functionMember = MemberAccessExprSyntax(base: visited, name: "flatMap") let functionCall = FunctionCallExprSyntax(calledExpression: functionMember, arguments: [], trailingClosure: closure) return ExprSyntax(functionCall) } return node } override func visit(_ node: ClosureExprSyntax) -> ExprSyntax { guard withinChainingTree else { // If we're not already in a chaining tree, just keep progressing with our current rewriter return super.visit(node) } // We're in the middle of a potential tree, so rewrite the closure with a fresh state // This ensures potential chaining in the closure isn't rewritten outside of the closure let nestedRewriter = OptionalChainRewriter() guard let rewritten = (try? node.rewrite(with: nestedRewriter))?.as(ExprSyntax.self) else { // If rewriting the closure failed, just leave the closure as-is return ExprSyntax(node) } if ignorable { ignorable = nestedRewriter.ignorable } return rewritten } override func visit(_ node: FunctionCallExprSyntax) -> ExprSyntax { let priorValidTreeStart = withinValidChainingTreeStart defer { withinValidChainingTreeStart = priorValidTreeStart } if node.arguments.containsShorthandArgumentIdentifiers { withinValidChainingTreeStart = false } let topOfTree = _prePossibleTopOfTree() let visited = super.visit(node) if topOfTree { return _postTopOfTree(visited) } else { return visited } } override func visit(_ node: MemberAccessExprSyntax) -> ExprSyntax { let topOfTree = _prePossibleTopOfTree() let visited = super.visit(node) if topOfTree { return _postTopOfTree(visited) } else { return visited } } override func visit(_ node: OptionalChainingExprSyntax) -> ExprSyntax { guard withinChainingTree else { return super.visit(node) } // Capture the optional input, and replace it in the output expression with a "$0" optionalInput = node.expression return .init(DeclReferenceExprSyntax(baseName: .dollarIdentifier("$0"))) } } extension CodeBlockItemListSyntax.Element.Item { fileprivate var _expression: ExprSyntax? { switch self { case .expr(let expr): return expr case .stmt(let stmt): return stmt.as(ExpressionStmtSyntax.self)?.expression default: return nil } } } extension ConditionElementListSyntax { fileprivate var optionalBindings: [OptionalBindingConditionSyntax]? { var result = [OptionalBindingConditionSyntax]() for element in self { switch element.condition { case let .optionalBinding(binding): result.append(binding) default: return nil } } return result } } extension ClosureParameterListSyntax { fileprivate var withVariableWrappedTypes: Self { return Self(self.map { if let type = $0.type { $0.with(\.type, "PredicateExpressions.Variable<\(type)>") } else { $0 } }) } } extension KeyPathExprSyntax { private enum KeyPathDirectExpressionRewritingError : Error { case unknownKeypathComponentType } fileprivate func asDirectExpression(on base: some ExprSyntaxProtocol) -> ExprSyntax? { var result = ExprSyntax(base) for item in components { switch item.component { case .property(let prop): result = ExprSyntax(MemberAccessExprSyntax(base: result, declName: prop.declName)) case .optional(let opt): if opt.questionOrExclamationMark.tokenKind == .exclamationMark { result = ExprSyntax(ForceUnwrapExprSyntax(expression: result)) } else { result = ExprSyntax(OptionalChainingExprSyntax(expression: result)) } case .subscript(let sub): result = ExprSyntax(SubscriptCallExprSyntax(calledExpression: result, arguments: sub.arguments)) #if FOUNDATION_FRAMEWORK default: return nil #endif } } return result } } private class PredicateQueryRewriter: SyntaxRewriter, PredicateSyntaxRewriter { private let indentWidth: Trivia = .spaces(4) private var indentLevel = 0 private var indent: Trivia { Trivia(pieces: Array(repeating: .spaces(4), count: indentLevel)) } var validOptionalChainingTree = true var diagnostics: [Diagnostic] = [] var success: Bool { diagnostics.isEmpty } let kind: ExpansionKind init(kind: ExpansionKind) { self.kind = kind } private func diagnose(node: SyntaxProtocol, message: PredicateExpansionDiagnostic, fixIts: [FixIt] = []) { diagnostics.append(.init(node: Syntax(node), message: message, fixIts: fixIts)) } private func makeArgument(label: String?, _ expression: ExprSyntax, shouldVisit: Bool = true, shouldIndent: Bool = true) -> LabeledExprSyntax { if shouldIndent { indentLevel += 1 } defer { if shouldIndent { indentLevel -= 1 } } let labelSyntax = label.map { TokenSyntax(.identifier($0), presence: .present) }?.with(\.leadingTrivia, indent) let colonSyntax = label.map { _ in TokenSyntax(.colon, presence: .present) } var argument = shouldVisit ? visit(expression) : expression if shouldVisit && argument == expression { argument = "PredicateExpressions.build_Arg(\(expression.with(\.leadingTrivia, []).with(\.trailingTrivia, [])))" } argument = argument.with(\.leadingTrivia, label == nil ? indent : .space) return .init(label: labelSyntax, colon: colonSyntax, expression: argument, trailingComma: nil) } override func visit(_ node: PrefixOperatorExprSyntax) -> ExprSyntax { switch node.operator.text { case "!": let syntax: ExprSyntax = """ \(raw: indent)PredicateExpressions.build_Negation( \(makeArgument(label: nil, node.expression)) \(raw: indent)) """ return syntax case "-": let syntax: ExprSyntax = """ \(raw: indent)PredicateExpressions.build_UnaryMinus( \(makeArgument(label: nil, node.expression)) \(raw: indent)) """ return syntax default: diagnose(node: node.operator, message: "The '\(node.operator.text)' operator is not supported in this \(kind.keyword)") return ExprSyntax(node) } } override func visit(_ node: InfixOperatorExprSyntax) -> ExprSyntax { let lhsOp = node.leftOperand let rhsOp = node.rightOperand let opExpr = node.operator guard let opSyntax = opExpr.as(BinaryOperatorExprSyntax.self) else { diagnose(node: opExpr, message: "The '\(opExpr.description)' operator is not supported in this \(kind.keyword)") return ExprSyntax(node) } let (lhsLabel, rhsLabel) = switch opSyntax.operator.text { case "...", "..<": ("lower", "upper") default: ("lhs", "rhs") } let lhsArgument = makeArgument(label: lhsLabel, lhsOp).with(\.trailingTrivia, []) let rhsArgument = makeArgument(label: rhsLabel, rhsOp).with(\.trailingTrivia, []) switch (opSyntax.operator.text) { case "==": let syntax: ExprSyntax = """ \(raw: indent)PredicateExpressions.build_Equal( \(lhsArgument), \(rhsArgument) \(raw: indent)) """ return syntax case "!=": let syntax: ExprSyntax = """ \(raw: indent)PredicateExpressions.build_NotEqual( \(lhsArgument), \(rhsArgument) \(raw: indent)) """ return syntax case "<": let syntax: ExprSyntax = """ \(raw: indent)PredicateExpressions.build_Comparison( \(lhsArgument), \(rhsArgument), \(raw: indent + indentWidth)op: .lessThan \(raw: indent)) """ return syntax case "<=": let syntax: ExprSyntax = """ \(raw: indent)PredicateExpressions.build_Comparison( \(lhsArgument), \(rhsArgument), \(raw: indent + indentWidth)op: .lessThanOrEqual \(raw: indent)) """ return syntax case ">": let syntax: ExprSyntax = """ \(raw: indent)PredicateExpressions.build_Comparison( \(lhsArgument), \(rhsArgument), \(raw: indent + indentWidth)op: .greaterThan \(raw: indent)) """ return syntax case ">=": let syntax: ExprSyntax = """ \(raw: indent)PredicateExpressions.build_Comparison( \(lhsArgument), \(rhsArgument), \(raw: indent + indentWidth)op: .greaterThanOrEqual \(raw: indent)) """ return syntax case "||": let syntax: ExprSyntax = """ \(raw: indent)PredicateExpressions.build_Disjunction( \(lhsArgument), \(rhsArgument) \(raw: indent)) """ return syntax case "&&": let syntax: ExprSyntax = """ \(raw: indent)PredicateExpressions.build_Conjunction( \(lhsArgument), \(rhsArgument) \(raw: indent)) """ return syntax case "+": let syntax: ExprSyntax = """ \(raw: indent)PredicateExpressions.build_Arithmetic( \(lhsArgument), \(rhsArgument), \(raw: indent + indentWidth)op: .add \(raw: indent)) """ return syntax case "-": let syntax: ExprSyntax = """ \(raw: indent)PredicateExpressions.build_Arithmetic( \(lhsArgument), \(rhsArgument), \(raw: indent + indentWidth)op: .subtract \(raw: indent)) """ return syntax case "*": let syntax: ExprSyntax = """ \(raw: indent)PredicateExpressions.build_Arithmetic( \(lhsArgument), \(rhsArgument), \(raw: indent + indentWidth)op: .multiply \(raw: indent)) """ return syntax case "/": let syntax: ExprSyntax = """ \(raw: indent)PredicateExpressions.build_Division( \(lhsArgument), \(rhsArgument) \(raw: indent)) """ return syntax case "%": let syntax: ExprSyntax = """ \(raw: indent)PredicateExpressions.build_Remainder( \(lhsArgument), \(rhsArgument) \(raw: indent)) """ return syntax case "??": let syntax: ExprSyntax = """ \(raw: indent)PredicateExpressions.build_NilCoalesce( \(lhsArgument), \(rhsArgument) \(raw: indent)) """ return syntax case "...": let syntax: ExprSyntax = """ \(raw: indent)PredicateExpressions.build_ClosedRange( \(lhsArgument), \(rhsArgument) \(raw: indent)) """ return syntax case "..<": let syntax: ExprSyntax = """ \(raw: indent)PredicateExpressions.build_Range( \(lhsArgument), \(rhsArgument) \(raw: indent)) """ return syntax default: diagnose(node: opSyntax, message: "The '\(opSyntax.operator.text)' operator is not supported in this \(kind.keyword)") return ExprSyntax(node) } } // We only hit this if our OptionalChainingRewriter was unable to rewrite them out of the expression tree override func visit(_ node: OptionalChainingExprSyntax) -> ExprSyntax { diagnose(node: node.questionMark, message: "Optional chaining is not supported here in this \(kind.keyword). Use the flatMap(_:) function explicitly instead.") return .init(node) } override func visit(_ node: ForceUnwrapExprSyntax) -> ExprSyntax { return """ \(raw: indent)PredicateExpressions.build_ForcedUnwrap( \(makeArgument(label: nil, node.expression)) \(raw: indent)) """ } override func visit(_ node: NilLiteralExprSyntax) -> ExprSyntax { "PredicateExpressions.build_NilLiteral()" } override func visit(_ node: MemberAccessExprSyntax) -> ExprSyntax { guard let base = node.base else { diagnose(node: node, message: "Member access without an explicit base is not supported in this \(kind.keyword)") return .init(node) } let newPropertyComponent = KeyPathPropertyComponentSyntax(declName: node.declName) let keyPath = KeyPathExprSyntax(components: [.init(period: TokenSyntax.periodToken(), component: .property(newPropertyComponent))]) return """ \(raw: indent)PredicateExpressions.build_KeyPath( \(makeArgument(label: "root", base)), \(makeArgument(label: "keyPath", .init(keyPath), shouldVisit: false).with(\.trailingTrivia, [])) \(raw: indent)) """ } override func visit(_ node: FunctionCallExprSyntax) -> ExprSyntax { let memberAccess = node.calledExpression.as(MemberAccessExprSyntax.self) let base = memberAccess?.base let funcName = memberAccess?.declName.baseName.with(\.leadingTrivia, []).with(\.trailingTrivia, []).text ?? node.calledExpression.as(DeclReferenceExprSyntax.self)!.baseName.text return _processFunction( base: base, functionName: funcName, argumentList: node.arguments, trailingClosure: node.trailingClosure, diagnosticPoint: .init(memberAccess?.declName) ?? .init(node), functionCallExpr: node) ?? .init(node) } override func visit(_ node: SubscriptCallExprSyntax) -> ExprSyntax { return _processFunction( base: node.calledExpression, functionName: "subscript", argumentList: node.arguments, trailingClosure: node.trailingClosure, diagnosticPoint: .init(node.leftSquare)) ?? .init(node) } private func _processFunction(base: ExprSyntax?, functionName: String, argumentList: LabeledExprListSyntax, trailingClosure: ClosureExprSyntax?, diagnosticPoint: Syntax, functionCallExpr: FunctionCallExprSyntax? = nil) -> ExprSyntax? { // The provided base is nil when calling global functions functions guard let base else { diagnose(node: diagnosticPoint, message: "Global functions are not supported in this \(kind.keyword)") return nil } // Check this function against our known list to provide rich diagnostics for functions we know we don't support let name = TokenSyntax(.identifier(functionName), presence: .present).with(\.leadingTrivia, []).with(\.trailingTrivia, []) let args = argumentList.map { let isClosure = $0.expression.is(ClosureExprSyntax.self) || $0.expression.is(KeyPathExprSyntax.self) return FunctionStructure.Argument(label: $0.label?.text, kind: isClosure ? .closure : .standard) } let structure = FunctionStructure(name.text, arguments: args, trailingClosure: trailingClosure != nil) guard let knownFunc = _knownMatchingFunction(structure) else { let diagnostic = PredicateExpansionDiagnostic("The \(structure.signature) function is not supported in this \(kind.keyword)") var fixIts = [FixIt]() if let functionCallExpr, let suggestion = _suggestionForUnknownFunction(structure), let changes = suggestion.fixItChanges(transformingFrom: functionCallExpr) { fixIts.append(FixIt(message: PredicateExpansionDiagnostic("Use \(suggestion.signature)", severity: .note), changes: changes)) } diagnose(node: diagnosticPoint, message: diagnostic, fixIts: fixIts) return nil } var arguments: [LabeledExprSyntax] = [] func addArgument(_ argument: ExprSyntax, label: String?, withComma: Bool) { arguments.append( makeArgument(label: label, argument) .with(\.trailingComma, withComma ? TokenSyntax(.comma, presence: .present) : nil) .with(\.trailingTrivia, withComma ? .newline : []) ) } // Function arguments can contain dollar sign identifiers that can't be nested inside of a new closure // Prevent this function call from being placed inside of a flatMap due to optionalChaining let oldValidOptionalChainingTree = validOptionalChainingTree validOptionalChainingTree = false addArgument(base, label: nil, withComma: !argumentList.isEmpty) validOptionalChainingTree = oldValidOptionalChainingTree for (sourceArg, knownArgStructure) in zip(argumentList, knownFunc.arguments.expandingPackToMatchCount(argumentList.count)) { var expression = sourceArg.expression if knownArgStructure.kind == .closure, let kpExpr = sourceArg.expression.as(KeyPathExprSyntax.self) { guard !kpExpr.containsShorthandArgumentIdentifiers, let memberAccess = kpExpr.asDirectExpression(on: DeclReferenceExprSyntax(baseName: .dollarIdentifier("$0"))), let preparedMemberAccess = try? memberAccess.rewrite(with: OptionalChainRewriter()) else { diagnose(node: kpExpr, message: "This key path is not supported here in this \(kind.keyword). Use an explicit closure instead.") return nil } expression = ExprSyntax(ClosureExprSyntax(statements: [CodeBlockItemSyntax(item: .expr(preparedMemberAccess.as(ExprSyntax.self)!))])) } addArgument(expression, label: sourceArg.label?.text, withComma: sourceArg.trailingComma != nil) } if let closure = trailingClosure { // Don't indent, because closures already get indented let closureArg = makeArgument(label: nil, ExprSyntax(closure), shouldIndent: false) return """ \(raw: indent)PredicateExpressions.build_\(name.with(\.leadingTrivia, []).with(\.trailingTrivia, []))( \(LabeledExprListSyntax(arguments)) \(raw: indent))\(raw: Trivia.space)\(closureArg.with(\.leadingTrivia, []).with(\.trailingTrivia, [])) """ } else { return """ \(raw: indent)PredicateExpressions.build_\(name.with(\.leadingTrivia, []).with(\.trailingTrivia, []))( \(LabeledExprListSyntax(arguments)) \(raw: indent)) """ } } override func visit(_ node: TupleExprSyntax) -> ExprSyntax { guard node.elements.count == 1, let element = node.elements.first else { diagnose(node: node, message: "Tuples are not supported in this \(kind.keyword)") return ExprSyntax(node) } // Support expressions like "(input as? Bool) == true" where parantheses used for grouping are treated like a single element tuple expression return visit(element.expression) } // Processes a code block and guarantees that the returned code block only contains one item func _processCodeBlock(_ statements: CodeBlockItemListSyntax, in node: Syntax, removeReturn: Bool = false) -> CodeBlockItemListSyntax? { guard statements.count == 1 else { diagnose(node: statements.isEmpty ? node : statements[statements.index(after: statements.startIndex)], message: "\(kind.capitalizedKeyword) body may only contain one expression") return nil } indentLevel += 1 var body = visit(statements) if success && body == statements { let wrapped: ExprSyntax = """ \(raw: indent)PredicateExpressions.build_Arg( \(raw: indent + indentWidth)\(body.with(\.leadingTrivia, []).with(\.trailingTrivia, [])) \(raw: indent)) """ body = [.init(item: .expr(wrapped))] } indentLevel -= 1 if removeReturn, let first = body.first, case .stmt(let statement) = first.item, let returnStmt = statement.as(ReturnStmtSyntax.self), let returnExpr = returnStmt.expression { body = [.init(item: .expr(returnExpr.with(\.leadingTrivia, returnStmt.leadingTrivia)))] } return body } override func visit(_ node: CodeBlockSyntax) -> CodeBlockSyntax { guard let body = _processCodeBlock(node.statements, in: .init(node)) else { return node } return node.with(\.statements, body) } override func visit(_ node: ClosureExprSyntax) -> ExprSyntax { guard let body = _processCodeBlock(node.statements, in: .init(node)) else { return .init(node) } var resultingSignature = node.signature if let signature = node.signature { var visited = signature visited.returnClause = nil if case .parameterClause(let paramClause) = signature.parameterClause { let newParamClause = paramClause.with(\.parameters, paramClause.parameters.withVariableWrappedTypes) visited.parameterClause = .parameterClause(newParamClause) } resultingSignature = visited } return ExprSyntax( node .with(\.statements, body) .with(\.leftBrace, node.leftBrace.with(\.trailingTrivia, node.signature == nil ? .newline : .space)) .with(\.signature, resultingSignature?.with(\.trailingTrivia, .newline)) .with(\.rightBrace, node.rightBrace.with(\.leadingTrivia, .newline + indent)) ) } override func visit(_ node: TernaryExprSyntax) -> ExprSyntax { let condition = node.condition let firstChoice = node.thenExpression let secondChoice = node.elseExpression return """ \(raw: indent)PredicateExpressions.build_Conditional( \(makeArgument(label: nil, condition).with(\.trailingTrivia, [])), \(makeArgument(label: nil, firstChoice).with(\.trailingTrivia, [])), \(makeArgument(label: nil, secondChoice).with(\.trailingTrivia, [])) \(raw: indent)) """ } override func visit(_ node: IsExprSyntax) -> ExprSyntax { return """ \(raw: indent)PredicateExpressions.TypeCheck<_, \(node.type)>( \(makeArgument(label: nil, node.expression).with(\.trailingTrivia, [])) \(raw: indent)) """ } override func visit(_ node: AsExprSyntax) -> ExprSyntax { let castType: String switch node.questionOrExclamationMark?.tokenKind { case .none: fallthrough case .some(.exclamationMark): castType = "Force" case .some(.postfixQuestionMark): castType = "Conditional" default: fatalError("Unexpected question/exclamation mark token kind") } return """ \(raw: indent)PredicateExpressions.\(raw: castType)Cast<_, \(node.type)>( \(makeArgument(label: nil, node.expression).with(\.trailingTrivia, [])) \(raw: indent)) """ } override func visit(_ node: ReturnStmtSyntax) -> StmtSyntax { guard let expression = node.expression else { // No expansion needed when returning Void return StmtSyntax(node) } let visited = visit(expression) guard visited == expression else { // No expansion needed when returning transformed expression return StmtSyntax(node.with(\.expression, visited.with(\.leadingTrivia, [])).with(\.leadingTrivia, indent)) } // Wrap constant return expressions in a build_Arg call let wrapped: ExprSyntax = """ PredicateExpressions.build_Arg( \(visited.with(\.leadingTrivia, indent + indentWidth)) \(raw: indent)) """ return StmtSyntax(node.with(\.expression, wrapped).with(\.leadingTrivia, indent)) } override func visit(_ node: SwitchExprSyntax) -> ExprSyntax { self.diagnose(node: node, message: "Switch expressions are not supported in this \(kind.keyword)") return .init(node) } private func _rewriteConditionsAsExpression>(_ collection: C, in expr: IfExprSyntax) -> ExprSyntax? { guard let last = collection.last else { self.diagnose(node: expr, message: "This list of conditionals is unsupported in this \(kind.keyword)") return nil } guard case .expression(let lastExpr) = last.condition else { let type: String switch last.condition { case .availability(_): type = "Availability conditions" case .matchingPattern(_): type = "Matching pattern conditions" case .optionalBinding(_): self.diagnose(node: last, message: "Mixing optional bindings with other conditions is not supported in this \(kind.keyword)") return nil default: type = "These types of conditions" } self.diagnose(node: last, message: "\(type) are not supported in this \(kind.keyword)") return nil } let rest = collection.dropLast() if rest.isEmpty { return lastExpr } else { guard let restRewritten = _rewriteConditionsAsExpression(rest, in: expr) else { return nil } return .init(InfixOperatorExprSyntax(leftOperand: restRewritten, operator: BinaryOperatorExprSyntax(operator: .binaryOperator("&&")), rightOperand: lastExpr)) } } private func _rewriteIfAsFlatMap(bindings: [OptionalBindingConditionSyntax], body: ExprSyntax, else: ExprSyntax) -> ExprSyntax? { indentLevel += bindings.count var prior: ExprSyntax = body for binding in bindings.reversed() { guard let identifier = binding.pattern.as(IdentifierPatternSyntax.self)?.identifier else { self.diagnose(node: binding.pattern, message: "This optional binding condition is not supported in this \(kind.keyword)") return nil } let initializer = binding.initializer?.value ?? ExprSyntax(DeclReferenceExprSyntax(baseName: identifier)) prior = """ \(raw: indent)PredicateExpressions.build_flatMap( \(makeArgument(label: nil, initializer).with(\.trailingTrivia, [])) \(raw: indent)) { \(identifier.with(\.trailingTrivia, []).with(\.leadingTrivia, [])) in \(makeArgument(label: nil, prior, shouldVisit: false).with(\.trailingTrivia, [])) \(raw: indent)} """ indentLevel -= 1 } return """ \(raw: indent)PredicateExpressions.build_NilCoalesce( \(makeArgument(label: "lhs", prior, shouldVisit: false)), \(makeArgument(label: "rhs", `else`, shouldVisit: false)) \(raw: indent)) """ } private func _processIfBody(_ node: IfExprSyntax) -> ExprSyntax? { guard let visitedBody = _processCodeBlock(node.body.statements, in: .init(node.body), removeReturn: true) else { return nil } guard let bodyExpression = visitedBody.first?.item._expression else { self.diagnose(node: node.body, message: "This if expression body is not supported in this \(kind.keyword)") return nil } return bodyExpression } private func _processElseBody(_ node: IfExprSyntax) -> ExprSyntax? { guard let elseBody = node.elseBody else { self.diagnose(node: node, message: "If expressions without an else expression are not supported in this \(kind.keyword)") return nil } let elseExpression: ExprSyntax switch elseBody { case .codeBlock(let codeBlock): guard let visitedElseBody = _processCodeBlock(codeBlock.statements, in: .init(codeBlock), removeReturn: true) else { return nil } guard let expr = visitedElseBody.first?.item._expression else { self.diagnose(node: node.body, message: "This if expression else body is not supported in this \(kind.keyword)") return nil } elseExpression = expr case .ifExpr(let ifExpr): elseExpression = visit(ifExpr) #if FOUNDATION_FRAMEWORK @unknown default: self.diagnose(node: elseBody, message: "This if expression else body is not supported in this \(kind.keyword)") return nil #endif } return elseExpression } override func visit(_ node: IfExprSyntax) -> ExprSyntax { if let bindings = node.conditions.optionalBindings { indentLevel += bindings.count guard let bodyExpression = _processIfBody(node) else { return .init(node) } indentLevel -= bindings.count guard let elseExpression = _processElseBody(node) else { return .init(node) } return _rewriteIfAsFlatMap(bindings: bindings, body: bodyExpression, else: elseExpression) ?? .init(node) } guard let ifExpression = _rewriteConditionsAsExpression(node.conditions, in: node), let bodyExpression = _processIfBody(node), let elseExpression = _processElseBody(node) else { return .init(node) } return """ \(raw: indent)PredicateExpressions.build_Conditional( \(makeArgument(label: nil, ifExpression).with(\.trailingTrivia, [])), \(makeArgument(label: nil, bodyExpression, shouldVisit: false).with(\.trailingTrivia, [])), \(makeArgument(label: nil, elseExpression, shouldVisit: false).with(\.trailingTrivia, [])) \(raw: indent)) """ } override func visit(_ node: WhileStmtSyntax) -> StmtSyntax { self.diagnose(node: node, message: "While loops are not supported in this \(kind.keyword)") return .init(node) } override func visit(_ node: ForStmtSyntax) -> StmtSyntax { self.diagnose(node: node, message: "For-in loops are not supported in this \(kind.keyword)") return .init(node) } override func visit(_ node: DoStmtSyntax) -> StmtSyntax { self.diagnose(node: node, message: "Do statements are not supported in this \(kind.keyword)") return .init(node) } override func visit(_ node: CatchClauseSyntax) -> CatchClauseSyntax { self.diagnose(node: node, message: "Catch clauses are not supported in this \(kind.keyword)") return node } override func visit(_ node: RepeatStmtSyntax) -> StmtSyntax { self.diagnose(node: node, message: "Repeat-while loops are not supported in this \(kind.keyword)") return .init(node) } override func visit(_ node: CodeBlockItemSyntax) -> CodeBlockItemSyntax { // At this point, we know we're the only item in the code block because predicates only support single-expression code blocks // Diagnose any declarations if case .decl(_) = node.item { diagnose(node: node.item, message: "Declarations are not supported in this \(kind.keyword)") return node } if case let .stmt(statement) = node.item { // Unwrap a do statement with valid expression bodies if let doStatement = statement.as(DoStmtSyntax.self) { if let catchClause = doStatement.catchClauses.first { diagnose(node: catchClause, message: "Catch clauses are not supported in this \(kind.keyword)") return node } indentLevel -= 1 let visitedBody = self.visit(doStatement.body) indentLevel += 1 guard success else { return node } guard let innerExpr = visitedBody.statements.first else { diagnose(node: doStatement, message: "Do statement is not supported here in this \(kind.keyword)") return node } return innerExpr } } return super.visit(node) } } private struct PredicateExpansionDiagnostic: DiagnosticMessage, FixItMessage, ExpressibleByStringLiteral, ExpressibleByStringInterpolation { let message: String let severity: DiagnosticSeverity let diagnosticID: MessageID = .init(domain: "FoundationMacros", id: "PredicateDiagnostic") var fixItID: MessageID { diagnosticID } init(_ message: String, severity: DiagnosticSeverity = .error) { self.message = message self.severity = severity } init(stringLiteral value: String) { self.init(value) } } private enum ExpansionKind { case predicate case expression var keyword: String { switch self { case .predicate: "predicate" case .expression: "expression" } } var capitalizedKeyword: String { let keyword = self.keyword let first = keyword.first!.uppercased() return "\(first)\(keyword.dropFirst())" } var macroKeyword: String { "#\(capitalizedKeyword)" } var qualifiedExpansionType: String { #if FOUNDATION_FRAMEWORK "Foundation.\(capitalizedKeyword)" #else "FoundationEssentials.\(capitalizedKeyword)" #endif } } private func predicateExpansion(of node: some FreestandingMacroExpansionSyntax, in context: some MacroExpansionContext, kind: ExpansionKind) throws -> ExprSyntax { guard let closure = node.trailingClosure else { let fixIts: [FixIt] if let argument = node.arguments.first?.expression.as(ClosureExprSyntax.self) { var newNode = node.with(\.leftParen, nil) .with(\.rightParen, nil) .with(\.trailingClosure, argument.with(\.leadingTrivia, [.spaces(1)]).with(\.trailingTrivia, [])) newNode.arguments = [] fixIts = [ FixIt(message: PredicateExpansionDiagnostic("Use a trailing closure instead of a function parameter", severity: .note), changes: [ .replace(oldNode: Syntax(node), newNode: Syntax(newNode)) ]) ] } else { fixIts = [] } throw DiagnosticsError(diagnostics: [.init( node: Syntax(node), message: PredicateExpansionDiagnostic("\(kind.macroKeyword) macro expansion requires a trailing closure"), fixIts: fixIts )]) } let translatedClosure = try closure .rewrite(with: OptionalChainRewriter()) .rewrite(with: PredicateQueryRewriter(kind: kind)) .with(\.leadingTrivia, []) .with(\.trailingTrivia, []) if let genericArgs = node.genericArgumentClause { let strippedGenericArgs = genericArgs .with(\.leadingTrivia, []) .with(\.trailingTrivia, []) return "\(raw: kind.qualifiedExpansionType)\(strippedGenericArgs)(\(translatedClosure))" } else { // When the macro is specified without generic args (ex. "#Predicate { ... }") initialize a Predicate without generic args so they can be inferred from context return "\(raw: kind.qualifiedExpansionType)(\(translatedClosure))" } } public struct PredicateMacro: SwiftSyntaxMacros.ExpressionMacro, Sendable { public static var formatMode: FormatMode { .disabled } public static func expansion(of node: some FreestandingMacroExpansionSyntax, in context: some MacroExpansionContext) throws -> ExprSyntax { try predicateExpansion(of: node, in: context, kind: .predicate) } } public struct ExpressionMacro: SwiftSyntaxMacros.ExpressionMacro, Sendable { public static var formatMode: FormatMode { .disabled } public static func expansion(of node: some FreestandingMacroExpansionSyntax, in context: some MacroExpansionContext) throws -> ExprSyntax { try predicateExpansion(of: node, in: context, kind: .expression) } }